From f46b4a6fa263d7cf51bc8f3ceb2a69d2c1e83fdd Mon Sep 17 00:00:00 2001 From: pufferffish Date: Fri, 14 Jun 2024 19:56:35 +0100 Subject: [PATCH 001/172] implement the vulkan C backend --- gpu/gpu_vulkan.c | 126 +++++++++++++++++++++++++++++++++++++++++++++++ gpu/gpu_vulkan.h | 17 +++++++ 2 files changed, 143 insertions(+) create mode 100644 gpu/gpu_vulkan.c create mode 100644 gpu/gpu_vulkan.h diff --git a/gpu/gpu_vulkan.c b/gpu/gpu_vulkan.c new file mode 100644 index 000000000..39058cd75 --- /dev/null +++ b/gpu/gpu_vulkan.c @@ -0,0 +1,126 @@ +#include "gpu_vulkan.h" + +#include + +int check_perfmon() { +#ifdef __linux__ + cap_t caps; + const cap_value_t cap_list[2] = {CAP_PERFMON}; + + if (!CAP_IS_SUPPORTED(CAP_SETFCAP)) + return -1; + + caps = cap_get_proc(); + if (caps == NULL) + return -1; + + if (cap_set_flag(caps, CAP_EFFECTIVE, 2, cap_list, CAP_SET) == -1) + return -1; + + if (cap_set_proc(caps) == -1) + return -1; + + if (cap_free(caps) == -1) + return -1; + + return 0; +#else + return 0; +#endif +} + +void vk_init(vk_init_resp_t *resp) { + if (check_perfmon() != 0) { + resp->err = "Performance monitoring is not allowed. Please enable CAP_PERFMON or run as root to use Vulkan."; + return; + } + + VkInstance instance; + VkApplicationInfo appInfo = {}; + appInfo.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO; + appInfo.pNext = NULL; + appInfo.pApplicationName = "Ollama"; + appInfo.applicationVersion = VK_MAKE_VERSION(1, 0, 0); + appInfo.pEngineName = "No Engine"; + appInfo.engineVersion = VK_MAKE_VERSION(1, 0, 0); + appInfo.apiVersion = VK_API_VERSION_1_2; + VkInstanceCreateInfo createInfo = {}; + createInfo.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO; + createInfo.pNext = NULL; + createInfo.flags = 0; + createInfo.enabledExtensionCount = 1; + const char* extensions[] = { VK_KHR_GET_PHYSICAL_DEVICE_PROPERTIES_2_EXTENSION_NAME }; + createInfo.ppEnabledExtensionNames = extensions; + createInfo.pApplicationInfo = &appInfo; + VkResult result = vkCreateInstance(&createInfo, NULL, &instance); + if (result != VK_SUCCESS) { + resp.err = sprintf("Failed to create instance: %d", result); + return; + } + + uint32_t deviceCount; + result = vkEnumeratePhysicalDevices(instance, &deviceCount, NULL); + if (result != VK_SUCCESS) { + resp.err = sprintf("Failed to enumerate physical devices: %d", result); + return; + } + + resp.err = NULL; + resp.oh = instance; + resp.num_devices = deviceCount; +} + +void vk_check_vram(vk_handle_t rh, int i, mem_info_t *resp) { + uint32_t deviceCount = rh->num_devices; + VkInstance instance = rh->oh; + + VkPhysicalDevice* devices = malloc(deviceCount * sizeof(VkPhysicalDevice)); + result = vkEnumeratePhysicalDevices(instance, &deviceCount, devices); + if (result != VK_SUCCESS) { + resp.err = sprintf("Failed to enumerate physical devices: %d", result); + return; + } + + VkPhysicalDeviceProperties properties; + vkGetPhysicalDeviceProperties(devices[i], &properties); + LOG(h.verbose, "Vulkan device %d: %s\n", i, properties.deviceName); + int supports_budget = support_memory_budget(devices[i]); + if (!supports_budget) { + resp.err = sprintf("Device %d does not support memory budget\n", i); + return; + } + if (properties.deviceType == VK_PHYSICAL_DEVICE_TYPE_CPU) { + resp.err = sprintf("Device %d is a CPU, skipped\n", i); + return; + } + + VkPhysicalDeviceMemoryBudgetPropertiesEXT physical_device_memory_budget_properties; + physical_device_memory_budget_properties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MEMORY_BUDGET_PROPERTIES_EXT; + physical_device_memory_budget_properties.pNext = NULL; + + VkPhysicalDeviceMemoryProperties2 device_memory_properties; + device_memory_properties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MEMORY_PROPERTIES_2; + device_memory_properties.pNext = &physical_device_memory_budget_properties; + + vkGetPhysicalDeviceMemoryProperties2(devices[i], &device_memory_properties); + + VkDeviceSize device_memory_total_usage = 0; + VkDeviceSize device_memory_heap_budget = 0; + + for (uint32_t j = 0; j < device_memory_properties.memoryProperties.memoryHeapCount; j++) { + VkMemoryHeap heap = device_memory_properties.memoryProperties.memoryHeaps[j]; + if (heap.flags & VK_MEMORY_HEAP_DEVICE_LOCAL_BIT) { + device_memory_total_usage += physical_device_memory_budget_properties.heapUsage[j]; + device_memory_heap_budget += physical_device_memory_budget_properties.heapBudget[j]; + } + } + + resp->err = NULL; + snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", i); + snprintf(&resp->gpu_name[0], GPU_NAME_LEN, "%s", properties.deviceName); + resp->total = (uint64_t) device_memory_total_usage; + resp->free = (uint64_t) device_memory_total_usage; + resp->major = VK_API_VERSION_MAJOR(properties.apiVersion); + resp->minor = VK_API_VERSION_MINOR(properties.apiVersion); + resp->patch = VK_API_VERSION_PATCH(properties.apiVersion); +} diff --git a/gpu/gpu_vulkan.h b/gpu/gpu_vulkan.h new file mode 100644 index 000000000..61ebb1a57 --- /dev/null +++ b/gpu/gpu_vulkan.h @@ -0,0 +1,17 @@ +#include "gpu_info.h" + +#ifdef __linux__ +#include +#endif + +typedef VkInstance vk_handle_t; + +typedef struct vk_init_resp +{ + char *err; // If err is non-null handle is invalid + int num_devices; + vk_handle_t oh; +} vk_init_resp_t; + +void vk_init(vk_init_resp_t *resp); +void vk_check_vram(vk_handle_t rh, int i, mem_info_t *resp); From 9c6b0495678f66f5b6b50fdb05c7efd99f5a208f Mon Sep 17 00:00:00 2001 From: pufferffish Date: Sat, 15 Jun 2024 05:27:14 +0100 Subject: [PATCH 002/172] add support in gpu.go --- gpu/gpu.go | 93 +++++++++++++++++++++++++++++++++++++++++++- gpu/gpu_vulkan.c | 24 ++++++++++-- gpu/gpu_vulkan.h | 3 +- gpu/vulkan_common.go | 19 +++++++++ 4 files changed, 133 insertions(+), 6 deletions(-) create mode 100644 gpu/vulkan_common.go diff --git a/gpu/gpu.go b/gpu/gpu.go index a55903c51..359c6b5a7 100644 --- a/gpu/gpu.go +++ b/gpu/gpu.go @@ -29,6 +29,7 @@ type handles struct { cudart *C.cudart_handle_t nvcuda *C.nvcuda_handle_t oneapi *C.oneapi_handle_t + vulkan *C.vk_handle_t } const ( @@ -90,6 +91,16 @@ var OneapiLinuxGlobs = []string{ "/usr/lib*/libze_intel_gpu.so*", } +var VulkanLinuxGlobs = []string{ + "/usr/lib/x86_64-linux-gnu/libvulkan.so*", + "/usr/lib*/libvulkan.so*", +} + +var CapLinuxGlobs = []string{ + "/usr/lib/x86_64-linux-gnu/libcap.so*", + "/usr/lib*/libcap.so*", +} + // Jetson devices have JETSON_JETPACK="x.y.z" factory set to the Jetpack version installed. // Included to drive logic for reducing Ollama-allocated overhead on L4T/Jetson devices. var CudaTegra string = os.Getenv("JETSON_JETPACK") @@ -104,6 +115,10 @@ func initGPUHandles() *handles { var cudartMgmtPatterns []string var nvcudaMgmtName string var nvcudaMgmtPatterns []string + var vulkanMgmtName string + var vulkanMgmtPatterns []string + var libcapMgmtName string + var libcapMgmtPatterns []string tmpDir, _ := PayloadsDir() switch runtime.GOOS { @@ -125,6 +140,12 @@ func initGPUHandles() *handles { // Aligned with driver, we can't carry as payloads nvcudaMgmtName = "libcuda.so*" nvcudaMgmtPatterns = NvcudaLinuxGlobs + + // Vulkan also needs libcap + vulkanMgmtName = "libvulkan.so*" + vulkanMgmtPatterns = VulkanLinuxGlobs + libcapMgmtName = "libcap.so*" + libcapMgmtPatterns = CapLinuxGlobs default: return gpuHandles } @@ -152,6 +173,25 @@ func initGPUHandles() *handles { } } + vulkanLibPaths := FindGPULibs(vulkanMgmtName, vulkanMgmtPatterns) + + var libcapLibPaths []string + if runtime.GOOS == "linux" { + libcapLibPaths = FindGPULibs(libcapMgmtName, libcapMgmtPatterns) + } else { + libcapLibPaths = []string{""} + } + + if len(vulkanLibPaths) > 0 && len(libcapLibPaths) > 0 { + deviceCount, vulkan, vkLibPath, capLibPath := LoadVulkanMgmt(vulkanLibPaths, libcapLibPaths) + if vulkan != nil { + slog.Debug("detected GPUs", "library", vkLibPath, capLibPath, "count", deviceCount) + gpuHandles.vulkan = vulkan + gpuHandles.deviceCount = deviceCount + return gpuHandles + } + } + return gpuHandles } @@ -186,7 +226,7 @@ func GetGPUInfo() GpuInfoList { var memInfo C.mem_info_t resp := []GpuInfo{} - // NVIDIA first + // NVIDIA and Vulkan first for i := range gpuHandles.deviceCount { // TODO once we support CPU compilation variants of GPU libraries refine this... if cpuVariant == "" && runtime.GOARCH == "amd64" { @@ -227,6 +267,32 @@ func GetGPUInfo() GpuInfoList { // TODO potentially sort on our own algorithm instead of what the underlying GPU library does... resp = append(resp, gpuInfo) } + + if gpuHandles.vulkan != nil { + gpuInfo := GpuInfo{ + Library: "vulkan", + } + + C.vk_check_vram(*gpuHandles.vulkan, C.int(i), &memInfo) + if memInfo.err != nil { + slog.Info("error looking up vulkan GPU memory", "error", C.GoString(memInfo.err)) + C.free(unsafe.Pointer(memInfo.err)) + continue + } + + gpuInfo.TotalMemory = uint64(memInfo.total) + gpuInfo.FreeMemory = uint64(memInfo.free) + gpuInfo.ID = C.GoString(&memInfo.gpu_id[0]) + gpuInfo.Compute = fmt.Sprintf("%d.%d", memInfo.major, memInfo.minor) + gpuInfo.MinimumMemory = 0 + gpuInfo.DependencyPath = depPath + gpuInfo.Name = C.GoString(&memInfo.gpu_name[0]) + gpuInfo.DriverMajor = int(memInfo.major) + gpuInfo.DriverMinor = int(memInfo.minor) + + // TODO potentially sort on our own algorithm instead of what the underlying GPU library does... + resp = append(resp, gpuInfo) + } } // Then AMD @@ -379,6 +445,29 @@ func LoadOneapiMgmt(oneapiLibPaths []string) (int, *C.oneapi_handle_t, string) { return 0, nil, "" } +func LoadVulkanMgmt(vulkanLibPaths []string, capLibPaths []string) (int, *C.vk_handle_t, string, string) { + var resp C.vk_init_resp_t + for _, vkLibPath := range vulkanLibPaths { + for _, capLibPath := range capLibPaths { + vkLib := C.CString(vkLibPath) + capLib := C.CString(capLibPath) + defer C.free(unsafe.Pointer(vkLib)) + defer C.free(unsafe.Pointer(capLib)) + + C.vk_init(vkLib, capLib, &resp) + if resp.err != nil { + slog.Debug("Unable to load vulkan", "library", vkLibPath, "error", C.GoString(resp.err)) + slog.Debug("Unable to load libcap", "library", capLibPath, "error", C.GoString(resp.err)) + C.free(unsafe.Pointer(resp.err)) + } else { + return int(resp.num_devices), &resp.vk, vkLibPath, capLibPath + } + } + } + + return 0, nil, "", "" +} + func getVerboseState() C.uint16_t { if envconfig.Debug { return C.uint16_t(1) @@ -401,6 +490,8 @@ func (l GpuInfoList) GetVisibleDevicesEnv() (string, string) { return rocmGetVisibleDevicesEnv(l) case "oneapi": return oneapiGetVisibleDevicesEnv(l) + case "vulkan": + return vkGetVisibleDevicesEnv(l) default: slog.Debug("no filter required for library " + l[0].Library) return "", "" diff --git a/gpu/gpu_vulkan.c b/gpu/gpu_vulkan.c index 39058cd75..bb45bdf21 100644 --- a/gpu/gpu_vulkan.c +++ b/gpu/gpu_vulkan.c @@ -22,18 +22,28 @@ int check_perfmon() { if (cap_free(caps) == -1) return -1; +#endif return 0; -#else - return 0; -#endif } -void vk_init(vk_init_resp_t *resp) { +void vk_init(char* vk_lib_path, char* cap_lib_path, vk_init_resp_t *resp) { + if (!LOAD_LIBRARY(vk_lib_path, RTLD_LAZY)) { + resp->err = "Failed to load Vulkan library"; + return; + } + +#ifdef __linux__ + if (!LOAD_LIBRARY(cap_lib_path, RTLD_LAZY)) { + resp->err = "Failed to load libcap library"; + return; + } + if (check_perfmon() != 0) { resp->err = "Performance monitoring is not allowed. Please enable CAP_PERFMON or run as root to use Vulkan."; return; } +#endif VkInstance instance; VkApplicationInfo appInfo = {}; @@ -123,4 +133,10 @@ void vk_check_vram(vk_handle_t rh, int i, mem_info_t *resp) { resp->major = VK_API_VERSION_MAJOR(properties.apiVersion); resp->minor = VK_API_VERSION_MINOR(properties.apiVersion); resp->patch = VK_API_VERSION_PATCH(properties.apiVersion); + } + +void vk_free(vk_handle_t rh) { + vkDestroyInstance(rh->oh, NULL); + free(rh); +} \ No newline at end of file diff --git a/gpu/gpu_vulkan.h b/gpu/gpu_vulkan.h index 61ebb1a57..e77ce554e 100644 --- a/gpu/gpu_vulkan.h +++ b/gpu/gpu_vulkan.h @@ -13,5 +13,6 @@ typedef struct vk_init_resp vk_handle_t oh; } vk_init_resp_t; -void vk_init(vk_init_resp_t *resp); +void vk_init(char* vk_lib_path, char* cap_lib_path, vk_init_resp_t *resp); void vk_check_vram(vk_handle_t rh, int i, mem_info_t *resp); +void vk_free(vk_handle_t rh); diff --git a/gpu/vulkan_common.go b/gpu/vulkan_common.go new file mode 100644 index 000000000..8d3d15d06 --- /dev/null +++ b/gpu/vulkan_common.go @@ -0,0 +1,19 @@ +package gpu + +import ( + "log/slog" + "strings" +) + +func vkGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) { + ids := []string{} + for _, info := range gpuInfo { + if info.Library != "vulkan" { + // TODO shouldn't happen if things are wired correctly... + slog.Debug("vkGetVisibleDevicesEnv skipping over non-vulkan device", "library", info.Library) + continue + } + ids = append(ids, info.ID) + } + return "GGML_VK_VISIBLE_DEVICES", strings.Join(ids, ",") +} From 93c4d69daa02be2c4407c73d30c8fe72961de61b Mon Sep 17 00:00:00 2001 From: pufferffish Date: Sat, 15 Jun 2024 05:42:59 +0100 Subject: [PATCH 003/172] add support in gen_linux.sh --- llm/generate/gen_linux.sh | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/llm/generate/gen_linux.sh b/llm/generate/gen_linux.sh index a9df6ff86..2190fb93e 100755 --- a/llm/generate/gen_linux.sh +++ b/llm/generate/gen_linux.sh @@ -206,6 +206,34 @@ if [ -z "${OLLAMA_SKIP_CUDA_GENERATE}" -a -d "${CUDA_LIB_DIR}" ]; then fi +if [ -z "${VULKAN_ROOT}" ]; then + # Try the default location in case it exists + VULKAN_ROOT=/usr/lib/ +fi + +if [ -z "${CAP_ROOT}" ]; then + # Try the default location in case it exists + CAP_ROOT=/usr/lib/ +fi + +if [ -z "${OLLAMA_SKIP_VULKAN_GENERATE}" -a -d "${VULKAN_ROOT}" ] && [ -z "${OLLAMA_SKIP_VULKAN_GENERATE}" -a -d "${CAP_ROOT}" ]; then + echo "Vulkan and capabilities libraries detected - building dynamic Vulkan library" + init_vars + + CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} -DLLAMA_VULKAN=1" + BUILD_DIR="../build/linux/${ARCH}/vulkan" + EXTRA_LIBS="-L${VULKAN_ROOT} -L${CAP_ROOT} -lvulkan -lcap" + build + + # copy oneAPI dependencies + for dep in $(ldd "${BUILD_DIR}/bin/ollama_llama_server" | grep "=>" | cut -f2 -d= | cut -f2 -d' ' | grep -e vulkan -e cap); do + cp "${dep}" "${BUILD_DIR}/bin/" + done + cp "${VULKAN_ROOT}/libvulkan.so" "${BUILD_DIR}/bin/" + cp "${CAP_ROOT}/libcap.so" "${BUILD_DIR}/bin/" + compress +fi + if [ -z "${ONEAPI_ROOT}" ]; then # Try the default location in case it exists ONEAPI_ROOT=/opt/intel/oneapi From 24c8840037a9edd48fafd31f113916cb4105c922 Mon Sep 17 00:00:00 2001 From: pufferffish Date: Sat, 15 Jun 2024 07:49:28 +0100 Subject: [PATCH 004/172] it builds --- gpu/gpu.go | 6 +- gpu/gpu_info.h | 1 + gpu/gpu_info_vulkan.c | 222 ++++++++++++++++++++++++++++++++++++++++++ gpu/gpu_info_vulkan.h | 66 +++++++++++++ gpu/gpu_vulkan.c | 142 --------------------------- gpu/gpu_vulkan.h | 18 ---- 6 files changed, 294 insertions(+), 161 deletions(-) create mode 100644 gpu/gpu_info_vulkan.c create mode 100644 gpu/gpu_info_vulkan.h delete mode 100644 gpu/gpu_vulkan.c delete mode 100644 gpu/gpu_vulkan.h diff --git a/gpu/gpu.go b/gpu/gpu.go index 359c6b5a7..0b19e0aba 100644 --- a/gpu/gpu.go +++ b/gpu/gpu.go @@ -209,6 +209,9 @@ func GetGPUInfo() GpuInfoList { if gpuHandles.nvcuda != nil { C.nvcuda_release(*gpuHandles.nvcuda) } + if gpuHandles.vulkan != nil { + C.vk_release(*gpuHandles.vulkan) + } }() // All our GPU builds on x86 have AVX enabled, so fallback to CPU if we don't detect at least AVX @@ -447,6 +450,7 @@ func LoadOneapiMgmt(oneapiLibPaths []string) (int, *C.oneapi_handle_t, string) { func LoadVulkanMgmt(vulkanLibPaths []string, capLibPaths []string) (int, *C.vk_handle_t, string, string) { var resp C.vk_init_resp_t + resp.ch.verbose = getVerboseState() for _, vkLibPath := range vulkanLibPaths { for _, capLibPath := range capLibPaths { vkLib := C.CString(vkLibPath) @@ -460,7 +464,7 @@ func LoadVulkanMgmt(vulkanLibPaths []string, capLibPaths []string) (int, *C.vk_h slog.Debug("Unable to load libcap", "library", capLibPath, "error", C.GoString(resp.err)) C.free(unsafe.Pointer(resp.err)) } else { - return int(resp.num_devices), &resp.vk, vkLibPath, capLibPath + return int(resp.num_devices), &resp.ch, vkLibPath, capLibPath } } } diff --git a/gpu/gpu_info.h b/gpu/gpu_info.h index 482b81a6f..afc80dca4 100644 --- a/gpu/gpu_info.h +++ b/gpu/gpu_info.h @@ -63,6 +63,7 @@ void cpu_check_ram(mem_info_t *resp); #include "gpu_info_cudart.h" #include "gpu_info_nvcuda.h" #include "gpu_info_oneapi.h" +#include "gpu_info_vulkan.h" #endif // __GPU_INFO_H__ #endif // __APPLE__ \ No newline at end of file diff --git a/gpu/gpu_info_vulkan.c b/gpu/gpu_info_vulkan.c new file mode 100644 index 000000000..8b0370d2c --- /dev/null +++ b/gpu/gpu_info_vulkan.c @@ -0,0 +1,222 @@ +#include "gpu_info_vulkan.h" + +#include + +int check_perfmon(vk_handle_t* rh) { +#ifdef __linux__ + cap_t caps; + const cap_value_t cap_list[2] = {CAP_PERFMON}; + + if ((*rh->cap_get_bound)(CAP_SETFCAP) < 0) + return -1; + + caps = (*rh->cap_get_proc)(); + if (caps == NULL) + return -1; + + if ((*rh->cap_set_flag)(caps, CAP_EFFECTIVE, 2, cap_list, CAP_SET) == -1) + return -1; + + if ((*rh->cap_set_proc)(caps) == -1) + return -1; + + if ((*rh->cap_free)(caps) == -1) + return -1; +#endif + + return 0; +} + +int support_memory_budget(vk_handle_t* rh, VkPhysicalDevice device) { + VkPhysicalDeviceProperties properties; + (*rh->vkGetPhysicalDeviceProperties)(device, &properties); + uint32_t extensionCount; + (*rh->vkEnumerateDeviceExtensionProperties)(device, NULL, &extensionCount, NULL); + VkExtensionProperties* extensions = malloc(extensionCount * sizeof(VkExtensionProperties)); + (*rh->vkEnumerateDeviceExtensionProperties)(device, NULL, &extensionCount, extensions); + for (int j = 0; j < extensionCount; j++) { + if (strcmp(extensions[j].extensionName, VK_EXT_MEMORY_BUDGET_EXTENSION_NAME) == 0) { + return 1; + } + } + return 0; +} + +void vk_init(char* vk_lib_path, char* cap_lib_path, vk_init_resp_t *resp) { + const int buflen = 256; + char buf[buflen + 1]; + int i; + + struct lookup { + int is_cap; + char *s; + void **p; + } l[] = { +#ifdef __linux__ + {1, "cap_get_proc", (void *)&resp->ch.cap_get_proc}, + {1, "cap_get_bound", (void *)&resp->ch.cap_get_bound}, + {1, "cap_set_flag", (void *)&resp->ch.cap_set_flag}, + {1, "cap_set_proc", (void *)&resp->ch.cap_set_proc}, + {1, "cap_free", (void *)&resp->ch.cap_free}, +#endif + {0, "vkGetPhysicalDeviceProperties", (void *)&resp->ch.vkGetPhysicalDeviceProperties}, + {0, "vkEnumerateDeviceExtensionProperties", (void *)&resp->ch.vkEnumerateDeviceExtensionProperties}, + {0, "vkCreateInstance", (void *)&resp->ch.vkCreateInstance}, + {0, "vkEnumeratePhysicalDevices", (void *)&resp->ch.vkEnumeratePhysicalDevices}, + {0, "vkGetPhysicalDeviceMemoryProperties2", (void *)&resp->ch.vkGetPhysicalDeviceMemoryProperties2}, + {0, "vkDestroyInstance", (void *)&resp->ch.vkDestroyInstance}, + {0, NULL, NULL}, + }; + + resp->ch.vk_handle = LOAD_LIBRARY(vk_lib_path, RTLD_LAZY); + if (!resp->ch.vk_handle) { + char *msg = LOAD_ERR(); + LOG(resp->ch.verbose, "library %s load err: %s\n", vk_lib_path, msg); + snprintf(buf, buflen, + "Unable to load %s library to query for Vulkan GPUs: %s", + vk_lib_path, msg); + free(msg); + resp->err = strdup(buf); + return; + } + +#ifdef __linux__ + resp->ch.cap_handle = LOAD_LIBRARY(cap_lib_path, RTLD_LAZY); + if (!resp->ch.cap_handle) { + char *msg = LOAD_ERR(); + LOG(resp->ch.verbose, "library %s load err: %s\n", cap_lib_path, msg); + snprintf(buf, buflen, + "Unable to load %s library to query for Vulkan GPUs: %s", + cap_lib_path, msg); + free(msg); + resp->err = strdup(buf); + return; + } +#endif + + for (i = 0; l[i].s != NULL; i++) { + if (l[i].is_cap) +#ifdef __linux__ + *l[i].p = LOAD_SYMBOL(resp->ch.cap_handle, l[i].s); +#else + continue; +#endif + else + *l[i].p = LOAD_SYMBOL(resp->ch.vk_handle, l[i].s); + if (!*l[i].p) { + char *msg = LOAD_ERR(); + LOG(resp->ch.verbose, "dlerr: %s\n", msg); + if (l[i].is_cap) { + UNLOAD_LIBRARY(resp->ch.cap_handle); + resp->ch.cap_handle = NULL; + } else { + UNLOAD_LIBRARY(resp->ch.vk_handle); + resp->ch.vk_handle = NULL; + } + snprintf(buf, buflen, "symbol lookup for %s failed: %s", l[i].s, + msg); + free(msg); + resp->err = strdup(buf); + return; + } + } + + if (check_perfmon(&resp->ch) != 0) { + resp->err = "Performance monitoring is not allowed. Please enable CAP_PERFMON or run as root to use Vulkan."; + return; + } + + VkInstance instance; + VkApplicationInfo appInfo = {}; + appInfo.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO; + appInfo.pNext = NULL; + appInfo.pApplicationName = "Ollama"; + appInfo.applicationVersion = VK_MAKE_VERSION(1, 0, 0); + appInfo.pEngineName = "No Engine"; + appInfo.engineVersion = VK_MAKE_VERSION(1, 0, 0); + appInfo.apiVersion = VK_API_VERSION_1_2; + VkInstanceCreateInfo createInfo = {}; + createInfo.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO; + createInfo.pNext = NULL; + createInfo.flags = 0; + createInfo.enabledExtensionCount = 1; + const char* extensions[] = { VK_KHR_GET_PHYSICAL_DEVICE_PROPERTIES_2_EXTENSION_NAME }; + createInfo.ppEnabledExtensionNames = extensions; + createInfo.pApplicationInfo = &appInfo; + VkResult result = (*resp->ch.vkCreateInstance)(&createInfo, NULL, &instance); + if (result != VK_SUCCESS) { + resp->err = strdup("failed to create instance"); + return; + } + + uint32_t deviceCount; + result = (*resp->ch.vkEnumeratePhysicalDevices)(instance, &deviceCount, NULL); + if (result != VK_SUCCESS) { + resp->err = strdup("failed to enumerate physical devices"); + return; + } + + resp->err = NULL; + resp->ch.vk = instance; + resp->ch.num_devices = deviceCount; + resp->num_devices = deviceCount; +} + +void vk_check_vram(vk_handle_t rh, int i, mem_info_t *resp) { + VkInstance instance = rh.vk; + uint32_t deviceCount = rh.num_devices; + + VkPhysicalDevice* devices = malloc(deviceCount * sizeof(VkPhysicalDevice)); + VkResult result = (*rh.vkEnumeratePhysicalDevices)(instance, &deviceCount, devices); + if (result != VK_SUCCESS) { + resp->err = strdup("failed to enumerate physical devices"); + return; + } + + VkPhysicalDeviceProperties properties; + (*rh.vkGetPhysicalDeviceProperties)(devices[i], &properties); + int supports_budget = support_memory_budget(&rh, devices[i]); + if (!supports_budget) { + resp->err = strdup("device does not support memory budget"); + return; + } + if (properties.deviceType == VK_PHYSICAL_DEVICE_TYPE_CPU) { + resp->err = strdup("device is a CPU"); + return; + } + + VkPhysicalDeviceMemoryBudgetPropertiesEXT physical_device_memory_budget_properties; + physical_device_memory_budget_properties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MEMORY_BUDGET_PROPERTIES_EXT; + physical_device_memory_budget_properties.pNext = NULL; + + VkPhysicalDeviceMemoryProperties2 device_memory_properties; + device_memory_properties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MEMORY_PROPERTIES_2; + device_memory_properties.pNext = &physical_device_memory_budget_properties; + + (*rh.vkGetPhysicalDeviceMemoryProperties2)(devices[i], &device_memory_properties); + + VkDeviceSize device_memory_total_usage = 0; + VkDeviceSize device_memory_heap_budget = 0; + + for (uint32_t j = 0; j < device_memory_properties.memoryProperties.memoryHeapCount; j++) { + VkMemoryHeap heap = device_memory_properties.memoryProperties.memoryHeaps[j]; + if (heap.flags & VK_MEMORY_HEAP_DEVICE_LOCAL_BIT) { + device_memory_total_usage += physical_device_memory_budget_properties.heapUsage[j]; + device_memory_heap_budget += physical_device_memory_budget_properties.heapBudget[j]; + } + } + + resp->err = NULL; + snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", i); + snprintf(&resp->gpu_name[0], GPU_NAME_LEN, "%s", properties.deviceName); + resp->total = (uint64_t) device_memory_total_usage; + resp->free = (uint64_t) device_memory_total_usage; + resp->major = VK_API_VERSION_MAJOR(properties.apiVersion); + resp->minor = VK_API_VERSION_MINOR(properties.apiVersion); + resp->patch = VK_API_VERSION_PATCH(properties.apiVersion); + +} + +void vk_release(vk_handle_t rh) { + (*rh.vkDestroyInstance)(rh.vk, NULL); +} \ No newline at end of file diff --git a/gpu/gpu_info_vulkan.h b/gpu/gpu_info_vulkan.h new file mode 100644 index 000000000..6025f3e09 --- /dev/null +++ b/gpu/gpu_info_vulkan.h @@ -0,0 +1,66 @@ +#ifndef __APPLE__ +#ifndef __GPU_INFO_VULKAN_H__ +#define __GPU_INFO_VULKAN_H__ + +#include "gpu_info.h" + +#ifdef __linux__ +#include +#endif + +#include + +typedef struct { + void* vk_handle; + void* cap_handle; + uint16_t verbose; + + VkInstance vk; + int num_devices; + +#ifdef __linux__ + cap_t (*cap_get_proc)(void); + + int (*cap_get_bound)(cap_value_t); + int (*cap_set_flag)(cap_t, cap_flag_t, int, const cap_value_t *, cap_flag_value_t); + int (*cap_set_proc)(cap_t); + int (*cap_free)(cap_t); +#endif + + void (*vkGetPhysicalDeviceProperties)( + VkPhysicalDevice physicalDevice, + VkPhysicalDeviceProperties* pProperties); + VkResult (*vkEnumerateDeviceExtensionProperties)( + VkPhysicalDevice physicalDevice, + const char* pLayerName, + uint32_t* pPropertyCount, + VkExtensionProperties* pProperties); + VkResult (*vkCreateInstance)( + const VkInstanceCreateInfo* pCreateInfo, + const VkAllocationCallbacks* pAllocator, + VkInstance* pInstance); + VkResult (*vkEnumeratePhysicalDevices)( + VkInstance instance, + uint32_t* pPhysicalDeviceCount, + VkPhysicalDevice* pPhysicalDevices); + void (*vkGetPhysicalDeviceMemoryProperties2)( + VkPhysicalDevice physicalDevice, + VkPhysicalDeviceMemoryProperties2* pMemoryProperties); + void (*vkDestroyInstance)( + VkInstance instance, + const VkAllocationCallbacks* pAllocator); +} vk_handle_t; + +typedef struct vk_init_resp +{ + char *err; // If err is non-null handle is invalid + int num_devices; + vk_handle_t ch; +} vk_init_resp_t; + +void vk_init(char* vk_lib_path, char* cap_lib_path, vk_init_resp_t *resp); +void vk_check_vram(vk_handle_t rh, int i, mem_info_t *resp); +void vk_release(vk_handle_t rh); + +#endif +#endif \ No newline at end of file diff --git a/gpu/gpu_vulkan.c b/gpu/gpu_vulkan.c deleted file mode 100644 index bb45bdf21..000000000 --- a/gpu/gpu_vulkan.c +++ /dev/null @@ -1,142 +0,0 @@ -#include "gpu_vulkan.h" - -#include - -int check_perfmon() { -#ifdef __linux__ - cap_t caps; - const cap_value_t cap_list[2] = {CAP_PERFMON}; - - if (!CAP_IS_SUPPORTED(CAP_SETFCAP)) - return -1; - - caps = cap_get_proc(); - if (caps == NULL) - return -1; - - if (cap_set_flag(caps, CAP_EFFECTIVE, 2, cap_list, CAP_SET) == -1) - return -1; - - if (cap_set_proc(caps) == -1) - return -1; - - if (cap_free(caps) == -1) - return -1; -#endif - - return 0; -} - -void vk_init(char* vk_lib_path, char* cap_lib_path, vk_init_resp_t *resp) { - if (!LOAD_LIBRARY(vk_lib_path, RTLD_LAZY)) { - resp->err = "Failed to load Vulkan library"; - return; - } - -#ifdef __linux__ - if (!LOAD_LIBRARY(cap_lib_path, RTLD_LAZY)) { - resp->err = "Failed to load libcap library"; - return; - } - - if (check_perfmon() != 0) { - resp->err = "Performance monitoring is not allowed. Please enable CAP_PERFMON or run as root to use Vulkan."; - return; - } -#endif - - VkInstance instance; - VkApplicationInfo appInfo = {}; - appInfo.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO; - appInfo.pNext = NULL; - appInfo.pApplicationName = "Ollama"; - appInfo.applicationVersion = VK_MAKE_VERSION(1, 0, 0); - appInfo.pEngineName = "No Engine"; - appInfo.engineVersion = VK_MAKE_VERSION(1, 0, 0); - appInfo.apiVersion = VK_API_VERSION_1_2; - VkInstanceCreateInfo createInfo = {}; - createInfo.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO; - createInfo.pNext = NULL; - createInfo.flags = 0; - createInfo.enabledExtensionCount = 1; - const char* extensions[] = { VK_KHR_GET_PHYSICAL_DEVICE_PROPERTIES_2_EXTENSION_NAME }; - createInfo.ppEnabledExtensionNames = extensions; - createInfo.pApplicationInfo = &appInfo; - VkResult result = vkCreateInstance(&createInfo, NULL, &instance); - if (result != VK_SUCCESS) { - resp.err = sprintf("Failed to create instance: %d", result); - return; - } - - uint32_t deviceCount; - result = vkEnumeratePhysicalDevices(instance, &deviceCount, NULL); - if (result != VK_SUCCESS) { - resp.err = sprintf("Failed to enumerate physical devices: %d", result); - return; - } - - resp.err = NULL; - resp.oh = instance; - resp.num_devices = deviceCount; -} - -void vk_check_vram(vk_handle_t rh, int i, mem_info_t *resp) { - uint32_t deviceCount = rh->num_devices; - VkInstance instance = rh->oh; - - VkPhysicalDevice* devices = malloc(deviceCount * sizeof(VkPhysicalDevice)); - result = vkEnumeratePhysicalDevices(instance, &deviceCount, devices); - if (result != VK_SUCCESS) { - resp.err = sprintf("Failed to enumerate physical devices: %d", result); - return; - } - - VkPhysicalDeviceProperties properties; - vkGetPhysicalDeviceProperties(devices[i], &properties); - LOG(h.verbose, "Vulkan device %d: %s\n", i, properties.deviceName); - int supports_budget = support_memory_budget(devices[i]); - if (!supports_budget) { - resp.err = sprintf("Device %d does not support memory budget\n", i); - return; - } - if (properties.deviceType == VK_PHYSICAL_DEVICE_TYPE_CPU) { - resp.err = sprintf("Device %d is a CPU, skipped\n", i); - return; - } - - VkPhysicalDeviceMemoryBudgetPropertiesEXT physical_device_memory_budget_properties; - physical_device_memory_budget_properties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MEMORY_BUDGET_PROPERTIES_EXT; - physical_device_memory_budget_properties.pNext = NULL; - - VkPhysicalDeviceMemoryProperties2 device_memory_properties; - device_memory_properties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MEMORY_PROPERTIES_2; - device_memory_properties.pNext = &physical_device_memory_budget_properties; - - vkGetPhysicalDeviceMemoryProperties2(devices[i], &device_memory_properties); - - VkDeviceSize device_memory_total_usage = 0; - VkDeviceSize device_memory_heap_budget = 0; - - for (uint32_t j = 0; j < device_memory_properties.memoryProperties.memoryHeapCount; j++) { - VkMemoryHeap heap = device_memory_properties.memoryProperties.memoryHeaps[j]; - if (heap.flags & VK_MEMORY_HEAP_DEVICE_LOCAL_BIT) { - device_memory_total_usage += physical_device_memory_budget_properties.heapUsage[j]; - device_memory_heap_budget += physical_device_memory_budget_properties.heapBudget[j]; - } - } - - resp->err = NULL; - snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", i); - snprintf(&resp->gpu_name[0], GPU_NAME_LEN, "%s", properties.deviceName); - resp->total = (uint64_t) device_memory_total_usage; - resp->free = (uint64_t) device_memory_total_usage; - resp->major = VK_API_VERSION_MAJOR(properties.apiVersion); - resp->minor = VK_API_VERSION_MINOR(properties.apiVersion); - resp->patch = VK_API_VERSION_PATCH(properties.apiVersion); - -} - -void vk_free(vk_handle_t rh) { - vkDestroyInstance(rh->oh, NULL); - free(rh); -} \ No newline at end of file diff --git a/gpu/gpu_vulkan.h b/gpu/gpu_vulkan.h deleted file mode 100644 index e77ce554e..000000000 --- a/gpu/gpu_vulkan.h +++ /dev/null @@ -1,18 +0,0 @@ -#include "gpu_info.h" - -#ifdef __linux__ -#include -#endif - -typedef VkInstance vk_handle_t; - -typedef struct vk_init_resp -{ - char *err; // If err is non-null handle is invalid - int num_devices; - vk_handle_t oh; -} vk_init_resp_t; - -void vk_init(char* vk_lib_path, char* cap_lib_path, vk_init_resp_t *resp); -void vk_check_vram(vk_handle_t rh, int i, mem_info_t *resp); -void vk_free(vk_handle_t rh); From 724fac470f0df86e8d0d24e209bea34f31a4ec84 Mon Sep 17 00:00:00 2001 From: pufferffish Date: Sat, 15 Jun 2024 08:05:48 +0100 Subject: [PATCH 005/172] fix segfault --- gpu/gpu.go | 3 +-- gpu/gpu_info_vulkan.c | 3 ++- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/gpu/gpu.go b/gpu/gpu.go index 0b19e0aba..46359e340 100644 --- a/gpu/gpu.go +++ b/gpu/gpu.go @@ -460,8 +460,7 @@ func LoadVulkanMgmt(vulkanLibPaths []string, capLibPaths []string) (int, *C.vk_h C.vk_init(vkLib, capLib, &resp) if resp.err != nil { - slog.Debug("Unable to load vulkan", "library", vkLibPath, "error", C.GoString(resp.err)) - slog.Debug("Unable to load libcap", "library", capLibPath, "error", C.GoString(resp.err)) + slog.Debug("Unable to load vulkan", "library", vkLibPath, capLibPath, "error", C.GoString(resp.err)) C.free(unsafe.Pointer(resp.err)) } else { return int(resp.num_devices), &resp.ch, vkLibPath, capLibPath diff --git a/gpu/gpu_info_vulkan.c b/gpu/gpu_info_vulkan.c index 8b0370d2c..cb2e8f67e 100644 --- a/gpu/gpu_info_vulkan.c +++ b/gpu/gpu_info_vulkan.c @@ -122,7 +122,8 @@ void vk_init(char* vk_lib_path, char* cap_lib_path, vk_init_resp_t *resp) { } if (check_perfmon(&resp->ch) != 0) { - resp->err = "Performance monitoring is not allowed. Please enable CAP_PERFMON or run as root to use Vulkan."; + resp->err = strdup("performance monitoring is not allowed. Please enable CAP_PERFMON or run as root to use Vulkan."); + LOG(resp->ch.verbose, resp->err); return; } From e4e8a5d25a375c9df03ad122211237798e4ca743 Mon Sep 17 00:00:00 2001 From: pufferffish Date: Sat, 15 Jun 2024 09:44:10 +0100 Subject: [PATCH 006/172] fix compilation --- gpu/gpu_info_vulkan.c | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/gpu/gpu_info_vulkan.c b/gpu/gpu_info_vulkan.c index cb2e8f67e..9822a63f9 100644 --- a/gpu/gpu_info_vulkan.c +++ b/gpu/gpu_info_vulkan.c @@ -123,7 +123,7 @@ void vk_init(char* vk_lib_path, char* cap_lib_path, vk_init_resp_t *resp) { if (check_perfmon(&resp->ch) != 0) { resp->err = strdup("performance monitoring is not allowed. Please enable CAP_PERFMON or run as root to use Vulkan."); - LOG(resp->ch.verbose, resp->err); + LOG(resp->ch.verbose, "vulkan: %s", resp->err); return; } @@ -209,7 +209,8 @@ void vk_check_vram(vk_handle_t rh, int i, mem_info_t *resp) { resp->err = NULL; snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", i); - snprintf(&resp->gpu_name[0], GPU_NAME_LEN, "%s", properties.deviceName); + resp->gpu_name[GPU_NAME_LEN - 1] = '\0'; + strncpy(&resp->gpu_name[0], properties.deviceName, GPU_NAME_LEN - 1); resp->total = (uint64_t) device_memory_total_usage; resp->free = (uint64_t) device_memory_total_usage; resp->major = VK_API_VERSION_MAJOR(properties.apiVersion); From 257364cb3c47a5e392bfb1772ecf6709dc0a7c83 Mon Sep 17 00:00:00 2001 From: pufferffish Date: Sat, 15 Jun 2024 10:52:34 +0100 Subject: [PATCH 007/172] fix free memory monitor --- gpu/gpu_info_vulkan.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gpu/gpu_info_vulkan.c b/gpu/gpu_info_vulkan.c index 9822a63f9..b4b7f26fd 100644 --- a/gpu/gpu_info_vulkan.c +++ b/gpu/gpu_info_vulkan.c @@ -212,7 +212,7 @@ void vk_check_vram(vk_handle_t rh, int i, mem_info_t *resp) { resp->gpu_name[GPU_NAME_LEN - 1] = '\0'; strncpy(&resp->gpu_name[0], properties.deviceName, GPU_NAME_LEN - 1); resp->total = (uint64_t) device_memory_total_usage; - resp->free = (uint64_t) device_memory_total_usage; + resp->free = (uint64_t) device_memory_heap_budget; resp->major = VK_API_VERSION_MAJOR(properties.apiVersion); resp->minor = VK_API_VERSION_MINOR(properties.apiVersion); resp->patch = VK_API_VERSION_PATCH(properties.apiVersion); From 11c55fab8113a02fbd77968c99856c22fb89c880 Mon Sep 17 00:00:00 2001 From: pufferffish Date: Sat, 15 Jun 2024 10:58:12 +0100 Subject: [PATCH 008/172] fix total memory monitor --- gpu/gpu_info_vulkan.c | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/gpu/gpu_info_vulkan.c b/gpu/gpu_info_vulkan.c index b4b7f26fd..fbe7a5885 100644 --- a/gpu/gpu_info_vulkan.c +++ b/gpu/gpu_info_vulkan.c @@ -196,13 +196,13 @@ void vk_check_vram(vk_handle_t rh, int i, mem_info_t *resp) { (*rh.vkGetPhysicalDeviceMemoryProperties2)(devices[i], &device_memory_properties); - VkDeviceSize device_memory_total_usage = 0; + VkDeviceSize device_memory_total_size = 0; VkDeviceSize device_memory_heap_budget = 0; for (uint32_t j = 0; j < device_memory_properties.memoryProperties.memoryHeapCount; j++) { VkMemoryHeap heap = device_memory_properties.memoryProperties.memoryHeaps[j]; if (heap.flags & VK_MEMORY_HEAP_DEVICE_LOCAL_BIT) { - device_memory_total_usage += physical_device_memory_budget_properties.heapUsage[j]; + device_memory_total_size += heap.size; device_memory_heap_budget += physical_device_memory_budget_properties.heapBudget[j]; } } @@ -211,7 +211,7 @@ void vk_check_vram(vk_handle_t rh, int i, mem_info_t *resp) { snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", i); resp->gpu_name[GPU_NAME_LEN - 1] = '\0'; strncpy(&resp->gpu_name[0], properties.deviceName, GPU_NAME_LEN - 1); - resp->total = (uint64_t) device_memory_total_usage; + resp->total = (uint64_t) device_memory_total_size; resp->free = (uint64_t) device_memory_heap_budget; resp->major = VK_API_VERSION_MAJOR(properties.apiVersion); resp->minor = VK_API_VERSION_MINOR(properties.apiVersion); From 18f3f960b01e1dd18a43fbcddbc0dc9de1ae2cbd Mon Sep 17 00:00:00 2001 From: pufferffish Date: Sat, 15 Jun 2024 12:05:01 +0100 Subject: [PATCH 009/172] update gpu.go --- gpu/gpu.go | 642 +++++++++++++++++++++++++++++---------------- gpu/gpu_linux.go | 16 ++ gpu/gpu_windows.go | 4 + gpu/types.go | 7 + 4 files changed, 441 insertions(+), 228 deletions(-) diff --git a/gpu/gpu.go b/gpu/gpu.go index 46359e340..11c72e151 100644 --- a/gpu/gpu.go +++ b/gpu/gpu.go @@ -24,20 +24,45 @@ import ( "github.com/ollama/ollama/format" ) -type handles struct { +type cudaHandles struct { deviceCount int cudart *C.cudart_handle_t nvcuda *C.nvcuda_handle_t + nvml *C.nvml_handle_t +} + +type oneapiHandles struct { oneapi *C.oneapi_handle_t - vulkan *C.vk_handle_t + deviceCount int +} + +type vulkanHandles struct { + vulkan *C.vulkan_handle_t + deviceCount int } const ( cudaMinimumMemory = 457 * format.MebiByte rocmMinimumMemory = 457 * format.MebiByte + // TODO OneAPI minimum memory ) -var gpuMutex sync.Mutex +var ( + gpuMutex sync.Mutex + bootstrapped bool + cpuCapability CPUCapability + cpus []CPUInfo + cudaGPUs []CudaGPUInfo + nvcudaLibPath string + cudartLibPath string + oneapiLibPath string + vulkanLibPath string + libcapLibPath string + nvmlLibPath string + rocmGPUs []RocmGPUInfo + oneapiGPUs []OneapiGPUInfo + vulkanGPUs []VulkanGPUInfo +) // With our current CUDA compile flags, older than 5.0 will not work properly var CudaComputeMin = [2]C.int{5, 0} @@ -47,152 +72,133 @@ var RocmComputeMin = 9 // TODO find a better way to detect iGPU instead of minimum memory const IGPUMemLimit = 1 * format.GibiByte // 512G is what they typically report, so anything less than 1G must be iGPU -var CudartLinuxGlobs = []string{ - "/usr/local/cuda/lib64/libcudart.so*", - "/usr/lib/x86_64-linux-gnu/nvidia/current/libcudart.so*", - "/usr/lib/x86_64-linux-gnu/libcudart.so*", - "/usr/lib/wsl/lib/libcudart.so*", - "/usr/lib/wsl/drivers/*/libcudart.so*", - "/opt/cuda/lib64/libcudart.so*", - "/usr/local/cuda*/targets/aarch64-linux/lib/libcudart.so*", - "/usr/lib/aarch64-linux-gnu/nvidia/current/libcudart.so*", - "/usr/lib/aarch64-linux-gnu/libcudart.so*", - "/usr/local/cuda/lib*/libcudart.so*", - "/usr/lib*/libcudart.so*", - "/usr/local/lib*/libcudart.so*", -} - -var CudartWindowsGlobs = []string{ - "c:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v*\\bin\\cudart64_*.dll", -} - -var NvcudaLinuxGlobs = []string{ - "/usr/local/cuda*/targets/*/lib/libcuda.so*", - "/usr/lib/*-linux-gnu/nvidia/current/libcuda.so*", - "/usr/lib/*-linux-gnu/libcuda.so*", - "/usr/lib/wsl/lib/libcuda.so*", - "/usr/lib/wsl/drivers/*/libcuda.so*", - "/opt/cuda/lib*/libcuda.so*", - "/usr/local/cuda/lib*/libcuda.so*", - "/usr/lib*/libcuda.so*", - "/usr/local/lib*/libcuda.so*", -} - -var NvcudaWindowsGlobs = []string{ - "c:\\windows\\system*\\nvcuda.dll", -} - -var OneapiWindowsGlobs = []string{ - "c:\\Windows\\System32\\DriverStore\\FileRepository\\*\\ze_intel_gpu64.dll", -} - -var OneapiLinuxGlobs = []string{ - "/usr/lib/x86_64-linux-gnu/libze_intel_gpu.so*", - "/usr/lib*/libze_intel_gpu.so*", -} - -var VulkanLinuxGlobs = []string{ - "/usr/lib/x86_64-linux-gnu/libvulkan.so*", - "/usr/lib*/libvulkan.so*", -} - -var CapLinuxGlobs = []string{ - "/usr/lib/x86_64-linux-gnu/libcap.so*", - "/usr/lib*/libcap.so*", -} - // Jetson devices have JETSON_JETPACK="x.y.z" factory set to the Jetpack version installed. // Included to drive logic for reducing Ollama-allocated overhead on L4T/Jetson devices. var CudaTegra string = os.Getenv("JETSON_JETPACK") // Note: gpuMutex must already be held -func initGPUHandles() *handles { +func initCudaHandles() *cudaHandles { // TODO - if the ollama build is CPU only, don't do these checks as they're irrelevant and confusing - gpuHandles := &handles{} - var cudartMgmtName string - var cudartMgmtPatterns []string - var nvcudaMgmtName string - var nvcudaMgmtPatterns []string - var vulkanMgmtName string - var vulkanMgmtPatterns []string - var libcapMgmtName string - var libcapMgmtPatterns []string - - tmpDir, _ := PayloadsDir() - switch runtime.GOOS { - case "windows": - cudartMgmtName = "cudart64_*.dll" - localAppData := os.Getenv("LOCALAPPDATA") - cudartMgmtPatterns = []string{filepath.Join(localAppData, "Programs", "Ollama", cudartMgmtName)} - cudartMgmtPatterns = append(cudartMgmtPatterns, CudartWindowsGlobs...) - // Aligned with driver, we can't carry as payloads - nvcudaMgmtName = "nvcuda.dll" - nvcudaMgmtPatterns = NvcudaWindowsGlobs - case "linux": - cudartMgmtName = "libcudart.so*" - if tmpDir != "" { - // TODO - add "payloads" for subprocess - cudartMgmtPatterns = []string{filepath.Join(tmpDir, "cuda*", cudartMgmtName)} - } - cudartMgmtPatterns = append(cudartMgmtPatterns, CudartLinuxGlobs...) - // Aligned with driver, we can't carry as payloads - nvcudaMgmtName = "libcuda.so*" - nvcudaMgmtPatterns = NvcudaLinuxGlobs - - // Vulkan also needs libcap - vulkanMgmtName = "libvulkan.so*" - vulkanMgmtPatterns = VulkanLinuxGlobs - libcapMgmtName = "libcap.so*" - libcapMgmtPatterns = CapLinuxGlobs - default: - return gpuHandles + cHandles := &cudaHandles{} + // Short Circuit if we already know which library to use + if nvmlLibPath != "" { + cHandles.nvml, _ = LoadNVMLMgmt([]string{nvmlLibPath}) + return cHandles + } + if nvcudaLibPath != "" { + cHandles.deviceCount, cHandles.nvcuda, _ = LoadNVCUDAMgmt([]string{nvcudaLibPath}) + return cHandles + } + if cudartLibPath != "" { + cHandles.deviceCount, cHandles.cudart, _ = LoadCUDARTMgmt([]string{cudartLibPath}) + return cHandles } - slog.Debug("Detecting GPUs") - nvcudaLibPaths := FindGPULibs(nvcudaMgmtName, nvcudaMgmtPatterns) + slog.Debug("searching for GPU discovery libraries for NVIDIA") + var cudartMgmtPatterns []string + + // Aligned with driver, we can't carry as payloads + nvcudaMgmtPatterns := NvcudaGlobs + + if runtime.GOOS == "windows" { + localAppData := os.Getenv("LOCALAPPDATA") + cudartMgmtPatterns = []string{filepath.Join(localAppData, "Programs", "Ollama", CudartMgmtName)} + } + tmpDir, _ := PayloadsDir() + if tmpDir != "" { + // TODO - add "payloads" for subprocess + cudartMgmtPatterns = []string{filepath.Join(tmpDir, "cuda*", CudartMgmtName)} + } + cudartMgmtPatterns = append(cudartMgmtPatterns, CudartGlobs...) + + if len(NvmlGlobs) > 0 { + nvmlLibPaths := FindGPULibs(NvmlMgmtName, NvmlGlobs) + if len(nvmlLibPaths) > 0 { + nvml, libPath := LoadNVMLMgmt(nvmlLibPaths) + if nvml != nil { + slog.Debug("nvidia-ml loaded", "library", libPath) + cHandles.nvml = nvml + nvmlLibPath = libPath + } + } + } + + nvcudaLibPaths := FindGPULibs(NvcudaMgmtName, nvcudaMgmtPatterns) if len(nvcudaLibPaths) > 0 { deviceCount, nvcuda, libPath := LoadNVCUDAMgmt(nvcudaLibPaths) if nvcuda != nil { slog.Debug("detected GPUs", "count", deviceCount, "library", libPath) - gpuHandles.nvcuda = nvcuda - gpuHandles.deviceCount = deviceCount - return gpuHandles + cHandles.nvcuda = nvcuda + cHandles.deviceCount = deviceCount + nvcudaLibPath = libPath + return cHandles } } - cudartLibPaths := FindGPULibs(cudartMgmtName, cudartMgmtPatterns) + cudartLibPaths := FindGPULibs(CudartMgmtName, cudartMgmtPatterns) if len(cudartLibPaths) > 0 { deviceCount, cudart, libPath := LoadCUDARTMgmt(cudartLibPaths) if cudart != nil { slog.Debug("detected GPUs", "library", libPath, "count", deviceCount) - gpuHandles.cudart = cudart - gpuHandles.deviceCount = deviceCount - return gpuHandles + cHandles.cudart = cudart + cHandles.deviceCount = deviceCount + cudartLibPath = libPath + return cHandles } } - vulkanLibPaths := FindGPULibs(vulkanMgmtName, vulkanMgmtPatterns) + return cHandles +} - var libcapLibPaths []string - if runtime.GOOS == "linux" { - libcapLibPaths = FindGPULibs(libcapMgmtName, libcapMgmtPatterns) +// Note: gpuMutex must already be held +func initOneAPIHandles() *oneapiHandles { + oHandles := &oneapiHandles{} + + // Short Circuit if we already know which library to use + if oneapiLibPath != "" { + oHandles.deviceCount, oHandles.oneapi, _ = LoadOneapiMgmt([]string{oneapiLibPath}) + return oHandles + } + + oneapiLibPaths := FindGPULibs(OneapiMgmtName, OneapiGlobs) + if len(oneapiLibPaths) > 0 { + oHandles.deviceCount, oHandles.oneapi, oneapiLibPath = LoadOneapiMgmt(oneapiLibPaths) + } + + return oHandles +} + +// Note: gpuMutex must already be held +func initVulkanHandles() *vulkanHandles { + vHandles := &vulkanHandles{} + + // Short Circuit if we already know which library to use + if vulkanLibPath != "" && libcapLibPath != "" { + vHandles.deviceCount, vHandles.vulkan, _, _ = LoadVulkanMgmt([]string{vulkanLibPath}, []string{libcapLibPath}) + return vHandles + } + + vulkanPaths := FindGPULibs(VulkanMgmtName, VulkanGlobs) + libcapPaths := FindLibCapLibs() + + if len(vulkanPaths) > 0 && len(libcapPaths) > 0 { + vHandles.deviceCount, vHandles.vulkan, vulkanLibPath, libcapLibPath = LoadVulkanMgmt(vulkanPaths, libcapPaths) + } + + return vHandles +} + +func GetCPUInfo() GpuInfoList { + gpuMutex.Lock() + if !bootstrapped { + gpuMutex.Unlock() + GetGPUInfo() } else { - libcapLibPaths = []string{""} + gpuMutex.Unlock() } - - if len(vulkanLibPaths) > 0 && len(libcapLibPaths) > 0 { - deviceCount, vulkan, vkLibPath, capLibPath := LoadVulkanMgmt(vulkanLibPaths, libcapLibPaths) - if vulkan != nil { - slog.Debug("detected GPUs", "library", vkLibPath, capLibPath, "count", deviceCount) - gpuHandles.vulkan = vulkan - gpuHandles.deviceCount = deviceCount - return gpuHandles - } - } - - return gpuHandles + return GpuInfoList{cpus[0].GpuInfo} } func GetGPUInfo() GpuInfoList { @@ -200,141 +206,300 @@ func GetGPUInfo() GpuInfoList { // GPUs so we can report warnings if we see Nvidia/AMD but fail to load the libraries gpuMutex.Lock() defer gpuMutex.Unlock() - - gpuHandles := initGPUHandles() + needRefresh := true + var cHandles *cudaHandles + var oHandles *oneapiHandles + var vHandles *vulkanHandles defer func() { - if gpuHandles.cudart != nil { - C.cudart_release(*gpuHandles.cudart) + if cHandles != nil { + if cHandles.cudart != nil { + C.cudart_release(*cHandles.cudart) + } + if cHandles.nvcuda != nil { + C.nvcuda_release(*cHandles.nvcuda) + } + if cHandles.nvml != nil { + C.nvml_release(*cHandles.nvml) + } } - if gpuHandles.nvcuda != nil { - C.nvcuda_release(*gpuHandles.nvcuda) - } - if gpuHandles.vulkan != nil { - C.vk_release(*gpuHandles.vulkan) + if oHandles != nil { + if oHandles.oneapi != nil { + // TODO - is this needed? + C.oneapi_release(*oHandles.oneapi) + } } }() - // All our GPU builds on x86 have AVX enabled, so fallback to CPU if we don't detect at least AVX - cpuVariant := GetCPUVariant() - if cpuVariant == "" && runtime.GOARCH == "amd64" { - slog.Warn("CPU does not have AVX or AVX2, disabling GPU support.") - } + if !bootstrapped { + slog.Debug("Detecting GPUs") + needRefresh = false + cpuCapability = GetCPUCapability() + var memInfo C.mem_info_t - // On windows we bundle the nvidia library one level above the runner dir - depPath := "" - if runtime.GOOS == "windows" && envconfig.RunnersDir != "" { - depPath = filepath.Dir(envconfig.RunnersDir) - } - - var memInfo C.mem_info_t - resp := []GpuInfo{} - - // NVIDIA and Vulkan first - for i := range gpuHandles.deviceCount { - // TODO once we support CPU compilation variants of GPU libraries refine this... - if cpuVariant == "" && runtime.GOARCH == "amd64" { - continue + mem, err := GetCPUMem() + if err != nil { + slog.Warn("error looking up system memory", "error", err) } - if gpuHandles.cudart != nil || gpuHandles.nvcuda != nil { - gpuInfo := GpuInfo{ - Library: "cuda", + cpus = []CPUInfo{CPUInfo{ + GpuInfo: GpuInfo{ + memInfo: mem, + Library: "cpu", + Variant: cpuCapability, + ID: "0", + }, + }} + + // Fallback to CPU mode if we're lacking required vector extensions on x86 + if cpuCapability < GPURunnerCPUCapability && runtime.GOARCH == "amd64" { + slog.Warn("CPU does not have minimum vector extensions, GPU inference disabled", "required", GPURunnerCPUCapability, "detected", cpuCapability) + bootstrapped = true + // No need to do any GPU discovery, since we can't run on them + return GpuInfoList{cpus[0].GpuInfo} + } + + // On windows we bundle the nvidia library one level above the runner dir + depPath := "" + if runtime.GOOS == "windows" && envconfig.RunnersDir != "" { + depPath = filepath.Dir(envconfig.RunnersDir) + } + + // Load ALL libraries + cHandles = initCudaHandles() + + // NVIDIA + for i := range cHandles.deviceCount { + if cHandles.cudart != nil || cHandles.nvcuda != nil { + gpuInfo := CudaGPUInfo{ + GpuInfo: GpuInfo{ + Library: "cuda", + }, + index: i, + } + var driverMajor int + var driverMinor int + if cHandles.cudart != nil { + C.cudart_bootstrap(*cHandles.cudart, C.int(i), &memInfo) + } else { + C.nvcuda_bootstrap(*cHandles.nvcuda, C.int(i), &memInfo) + driverMajor = int(cHandles.nvcuda.driver_major) + driverMinor = int(cHandles.nvcuda.driver_minor) + } + if memInfo.err != nil { + slog.Info("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err)) + C.free(unsafe.Pointer(memInfo.err)) + continue + } + if memInfo.major < CudaComputeMin[0] || (memInfo.major == CudaComputeMin[0] && memInfo.minor < CudaComputeMin[1]) { + slog.Info(fmt.Sprintf("[%d] CUDA GPU is too old. Compute Capability detected: %d.%d", i, memInfo.major, memInfo.minor)) + continue + } + gpuInfo.TotalMemory = uint64(memInfo.total) + gpuInfo.FreeMemory = uint64(memInfo.free) + gpuInfo.ID = C.GoString(&memInfo.gpu_id[0]) + gpuInfo.Compute = fmt.Sprintf("%d.%d", memInfo.major, memInfo.minor) + gpuInfo.MinimumMemory = cudaMinimumMemory + gpuInfo.DependencyPath = depPath + gpuInfo.Name = C.GoString(&memInfo.gpu_name[0]) + gpuInfo.DriverMajor = driverMajor + gpuInfo.DriverMinor = driverMinor + + // TODO potentially sort on our own algorithm instead of what the underlying GPU library does... + cudaGPUs = append(cudaGPUs, gpuInfo) } - var driverMajor int - var driverMinor int - if gpuHandles.cudart != nil { - C.cudart_check_vram(*gpuHandles.cudart, C.int(i), &memInfo) + } + + // Intel + oHandles = initOneAPIHandles() + for d := 0; oHandles.oneapi != nil && d < int(oHandles.oneapi.num_drivers); d++ { + if oHandles.oneapi == nil { + // shouldn't happen + slog.Warn("nil oneapi handle with driver count", "count", int(oHandles.oneapi.num_drivers)) + continue + } + devCount := C.oneapi_get_device_count(*oHandles.oneapi, C.int(d)) + for i := range devCount { + gpuInfo := OneapiGPUInfo{ + GpuInfo: GpuInfo{ + Library: "oneapi", + }, + driverIndex: d, + gpuIndex: int(i), + } + // TODO - split bootstrapping from updating free memory + C.oneapi_check_vram(*oHandles.oneapi, C.int(d), i, &memInfo) + // TODO - convert this to MinimumMemory based on testing... + var totalFreeMem float64 = float64(memInfo.free) * 0.95 // work-around: leave some reserve vram for mkl lib used in ggml-sycl backend. + memInfo.free = C.uint64_t(totalFreeMem) + gpuInfo.TotalMemory = uint64(memInfo.total) + gpuInfo.FreeMemory = uint64(memInfo.free) + gpuInfo.ID = C.GoString(&memInfo.gpu_id[0]) + gpuInfo.Name = C.GoString(&memInfo.gpu_name[0]) + // TODO dependency path? + oneapiGPUs = append(oneapiGPUs, gpuInfo) + } + } + + // Vulkan + vHandles = initVulkanHandles() + for i := range vHandles.deviceCount { + if vHandles.vulkan != nil { + gpuInfo := VulkanGPUInfo{ + GpuInfo: GpuInfo{ + Library: "vulkan", + }, + index: i, + } + + C.vk_check_vram(*vHandles.vulkan, C.int(i), &memInfo) + if memInfo.err != nil { + slog.Info("error looking up vulkan GPU memory", "error", C.GoString(memInfo.err)) + C.free(unsafe.Pointer(memInfo.err)) + continue + } + + gpuInfo.TotalMemory = uint64(memInfo.total) + gpuInfo.FreeMemory = uint64(memInfo.free) + gpuInfo.ID = C.GoString(&memInfo.gpu_id[0]) + gpuInfo.Compute = fmt.Sprintf("%d.%d", memInfo.major, memInfo.minor) + gpuInfo.MinimumMemory = 0 + gpuInfo.DependencyPath = depPath + gpuInfo.Name = C.GoString(&memInfo.gpu_name[0]) + gpuInfo.DriverMajor = int(memInfo.major) + gpuInfo.DriverMinor = int(memInfo.minor) + + // TODO potentially sort on our own algorithm instead of what the underlying GPU library does... + vulkanGPUs = append(vulkanGPUs, gpuInfo) + } + } + + rocmGPUs = AMDGetGPUInfo() + bootstrapped = true + } + + // For detected GPUs, load library if not loaded + + // Refresh free memory usage + if needRefresh { + mem, err := GetCPUMem() + if err != nil { + slog.Warn("error looking up system memory", "error", err) + } else { + slog.Debug("updating system memory data", + slog.Group( + "before", + "total", format.HumanBytes2(cpus[0].TotalMemory), + "free", format.HumanBytes2(cpus[0].FreeMemory), + ), + slog.Group( + "now", + "total", format.HumanBytes2(mem.TotalMemory), + "free", format.HumanBytes2(mem.FreeMemory), + ), + ) + cpus[0].FreeMemory = mem.FreeMemory + } + + var memInfo C.mem_info_t + if cHandles == nil && len(cudaGPUs) > 0 { + cHandles = initCudaHandles() + } + for i, gpu := range cudaGPUs { + if cHandles.nvml != nil { + C.nvml_get_free(*cHandles.nvml, C.int(gpu.index), &memInfo.free, &memInfo.total, &memInfo.used) + } else if cHandles.cudart != nil { + C.cudart_bootstrap(*cHandles.cudart, C.int(gpu.index), &memInfo) + } else if cHandles.nvcuda != nil { + C.nvcuda_get_free(*cHandles.nvcuda, C.int(gpu.index), &memInfo.free, &memInfo.total) + memInfo.used = memInfo.total - memInfo.free } else { - C.nvcuda_check_vram(*gpuHandles.nvcuda, C.int(i), &memInfo) - driverMajor = int(gpuHandles.nvcuda.driver_major) - driverMinor = int(gpuHandles.nvcuda.driver_minor) + // shouldn't happen + slog.Warn("no valid cuda library loaded to refresh vram usage") + break } if memInfo.err != nil { - slog.Info("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err)) + slog.Warn("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err)) C.free(unsafe.Pointer(memInfo.err)) continue } - if memInfo.major < CudaComputeMin[0] || (memInfo.major == CudaComputeMin[0] && memInfo.minor < CudaComputeMin[1]) { - slog.Info(fmt.Sprintf("[%d] CUDA GPU is too old. Compute Capability detected: %d.%d", i, memInfo.major, memInfo.minor)) + if memInfo.free == 0 { + slog.Warn("error looking up nvidia GPU memory") continue } - gpuInfo.TotalMemory = uint64(memInfo.total) - gpuInfo.FreeMemory = uint64(memInfo.free) - gpuInfo.ID = C.GoString(&memInfo.gpu_id[0]) - gpuInfo.Compute = fmt.Sprintf("%d.%d", memInfo.major, memInfo.minor) - gpuInfo.MinimumMemory = cudaMinimumMemory - gpuInfo.DependencyPath = depPath - gpuInfo.Name = C.GoString(&memInfo.gpu_name[0]) - gpuInfo.DriverMajor = driverMajor - gpuInfo.DriverMinor = driverMinor - - // TODO potentially sort on our own algorithm instead of what the underlying GPU library does... - resp = append(resp, gpuInfo) + slog.Debug("updating cuda memory data", + "gpu", gpu.ID, + "name", gpu.Name, + slog.Group( + "before", + "total", format.HumanBytes2(gpu.TotalMemory), + "free", format.HumanBytes2(gpu.FreeMemory), + ), + slog.Group( + "now", + "total", format.HumanBytes2(uint64(memInfo.total)), + "free", format.HumanBytes2(uint64(memInfo.free)), + "used", format.HumanBytes2(uint64(memInfo.used)), + ), + ) + cudaGPUs[i].FreeMemory = uint64(memInfo.free) } - if gpuHandles.vulkan != nil { - gpuInfo := GpuInfo{ - Library: "vulkan", - } - - C.vk_check_vram(*gpuHandles.vulkan, C.int(i), &memInfo) - if memInfo.err != nil { - slog.Info("error looking up vulkan GPU memory", "error", C.GoString(memInfo.err)) - C.free(unsafe.Pointer(memInfo.err)) + if oHandles == nil && len(oneapiGPUs) > 0 { + oHandles = initOneAPIHandles() + } + for i, gpu := range oneapiGPUs { + if oHandles.oneapi == nil { + // shouldn't happen + slog.Warn("nil oneapi handle with device count", "count", oHandles.deviceCount) continue } + C.oneapi_check_vram(*oHandles.oneapi, C.int(gpu.driverIndex), C.int(gpu.gpuIndex), &memInfo) + // TODO - convert this to MinimumMemory based on testing... + var totalFreeMem float64 = float64(memInfo.free) * 0.95 // work-around: leave some reserve vram for mkl lib used in ggml-sycl backend. + memInfo.free = C.uint64_t(totalFreeMem) + oneapiGPUs[i].FreeMemory = uint64(memInfo.free) + } - gpuInfo.TotalMemory = uint64(memInfo.total) - gpuInfo.FreeMemory = uint64(memInfo.free) - gpuInfo.ID = C.GoString(&memInfo.gpu_id[0]) - gpuInfo.Compute = fmt.Sprintf("%d.%d", memInfo.major, memInfo.minor) - gpuInfo.MinimumMemory = 0 - gpuInfo.DependencyPath = depPath - gpuInfo.Name = C.GoString(&memInfo.gpu_name[0]) - gpuInfo.DriverMajor = int(memInfo.major) - gpuInfo.DriverMinor = int(memInfo.minor) + if vHandles == nil && len(vulkanGPUs) > 0 { + vHandles = initVulkanHandles() + } + for i, gpu := range vulkanGPUs { + if vHandles.vulkan == nil { + // shouldn't happen + slog.Warn("nil vulkan handle with device count", "count", oHandles.deviceCount) + continue + } + C.vk_check_vram(*vHandles.vulkan, C.int(gpu.index), &memInfo) + // TODO - convert this to MinimumMemory based on testing... + var totalFreeMem float64 = float64(memInfo.free) * 0.95 // work-around: leave some reserve vram for mkl lib used in ggml-sycl backend. + memInfo.free = C.uint64_t(totalFreeMem) + vulkanGPUs[i].FreeMemory = uint64(memInfo.free) + } - // TODO potentially sort on our own algorithm instead of what the underlying GPU library does... - resp = append(resp, gpuInfo) + err = RocmGPUInfoList(rocmGPUs).RefreshFreeMemory() + if err != nil { + slog.Debug("problem refreshing ROCm free memory", "error", err) } } - // Then AMD - resp = append(resp, AMDGetGPUInfo()...) - + resp := []GpuInfo{} + for _, gpu := range cudaGPUs { + resp = append(resp, gpu.GpuInfo) + } + for _, gpu := range rocmGPUs { + resp = append(resp, gpu.GpuInfo) + } + for _, gpu := range oneapiGPUs { + resp = append(resp, gpu.GpuInfo) + } + for _, gpu := range vulkanGPUs { + resp = append(resp, gpu.GpuInfo) + } if len(resp) == 0 { - C.cpu_check_ram(&memInfo) - if memInfo.err != nil { - slog.Info("error looking up CPU memory", "error", C.GoString(memInfo.err)) - C.free(unsafe.Pointer(memInfo.err)) - return resp - } - gpuInfo := GpuInfo{ - Library: "cpu", - Variant: cpuVariant, - } - gpuInfo.TotalMemory = uint64(memInfo.total) - gpuInfo.FreeMemory = uint64(memInfo.free) - gpuInfo.ID = C.GoString(&memInfo.gpu_id[0]) - - resp = append(resp, gpuInfo) + resp = append(resp, cpus[0].GpuInfo) } - return resp } -func GetCPUMem() (memInfo, error) { - var ret memInfo - var info C.mem_info_t - C.cpu_check_ram(&info) - if info.err != nil { - defer C.free(unsafe.Pointer(info.err)) - return ret, fmt.Errorf(C.GoString(info.err)) - } - ret.FreeMemory = uint64(info.free) - ret.TotalMemory = uint64(info.total) - return ret, nil -} - func FindGPULibs(baseLibName string, defaultPatterns []string) []string { // Multiple GPU libraries may exist, and some may not work, so keep trying until we exhaust them var ldPaths []string @@ -431,8 +596,26 @@ func LoadNVCUDAMgmt(nvcudaLibPaths []string) (int, *C.nvcuda_handle_t, string) { return 0, nil, "" } +func LoadNVMLMgmt(nvmlLibPaths []string) (*C.nvml_handle_t, string) { + var resp C.nvml_init_resp_t + resp.ch.verbose = getVerboseState() + for _, libPath := range nvmlLibPaths { + lib := C.CString(libPath) + defer C.free(unsafe.Pointer(lib)) + C.nvml_init(lib, &resp) + if resp.err != nil { + slog.Info(fmt.Sprintf("Unable to load NVML management library %s: %s", libPath, C.GoString(resp.err))) + C.free(unsafe.Pointer(resp.err)) + } else { + return &resp.ch, libPath + } + } + return nil, "" +} + func LoadOneapiMgmt(oneapiLibPaths []string) (int, *C.oneapi_handle_t, string) { var resp C.oneapi_init_resp_t + num_devices := 0 resp.oh.verbose = getVerboseState() for _, libPath := range oneapiLibPaths { lib := C.CString(libPath) @@ -442,7 +625,10 @@ func LoadOneapiMgmt(oneapiLibPaths []string) (int, *C.oneapi_handle_t, string) { slog.Debug("Unable to load oneAPI management library", "library", libPath, "error", C.GoString(resp.err)) C.free(unsafe.Pointer(resp.err)) } else { - return int(resp.num_devices), &resp.oh, libPath + for i := range resp.oh.num_drivers { + num_devices += int(C.oneapi_get_device_count(resp.oh, C.int(i))) + } + return num_devices, &resp.oh, libPath } } return 0, nil, "" diff --git a/gpu/gpu_linux.go b/gpu/gpu_linux.go index a099bf822..2e723c4da 100644 --- a/gpu/gpu_linux.go +++ b/gpu/gpu_linux.go @@ -43,10 +43,26 @@ var OneapiGlobs = []string{ "/usr/lib*/libze_intel_gpu.so*", } +var VulkanGlobs = []string{ + "/usr/lib/x86_64-linux-gnu/libvulkan.so*", + "/usr/lib*/libvulkan.so*", +} + +var capLinuxGlobs = []string{ + "/usr/lib/x86_64-linux-gnu/libcap.so*", + "/usr/lib*/libcap.so*", +} + var CudartMgmtName = "libcudart.so*" var NvcudaMgmtName = "libcuda.so*" var NvmlMgmtName = "" // not currently wired on linux var OneapiMgmtName = "libze_intel_gpu.so" +var VulkanMgmtName = "libvulkan.so*" +var libcapMgmtName = "libcap.so*" + +func FindLibCapLibs() []string { + return FindGPULibs(libcapMgmtName, capLinuxGlobs) +} func GetCPUMem() (memInfo, error) { var mem memInfo diff --git a/gpu/gpu_windows.go b/gpu/gpu_windows.go index f8c2e76fe..328477440 100644 --- a/gpu/gpu_windows.go +++ b/gpu/gpu_windows.go @@ -45,6 +45,10 @@ var NvcudaMgmtName = "nvcuda.dll" var NvmlMgmtName = "nvml.dll" var OneapiMgmtName = "ze_intel_gpu64.dll" +func FindLibCapLibs() []string { + return []string{""} +} + func GetCPUMem() (memInfo, error) { memStatus := MEMORYSTATUSEX{length: sizeofMemoryStatusEx} r1, _, err := globalMemoryStatusExProc.Call(uintptr(unsafe.Pointer(&memStatus))) diff --git a/gpu/types.go b/gpu/types.go index 47355959c..b451c0f38 100644 --- a/gpu/types.go +++ b/gpu/types.go @@ -62,6 +62,13 @@ type OneapiGPUInfo struct { } type OneapiGPUInfoList []OneapiGPUInfo +type VulkanGPUInfo struct { + GpuInfo + index int +} + +type VulkanGPUInfoList []VulkanGPUInfo + type GpuInfoList []GpuInfo // Split up the set of gpu info's by Library and variant From 38466f18213ae6f6879e5639d315e8c8e377b602 Mon Sep 17 00:00:00 2001 From: pufferffish Date: Sat, 15 Jun 2024 12:06:43 +0100 Subject: [PATCH 010/172] fix build --- gpu/gpu.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gpu/gpu.go b/gpu/gpu.go index 11c72e151..4eb82f2cf 100644 --- a/gpu/gpu.go +++ b/gpu/gpu.go @@ -37,7 +37,7 @@ type oneapiHandles struct { } type vulkanHandles struct { - vulkan *C.vulkan_handle_t + vulkan *C.vk_handle_t deviceCount int } From e3f9ca4009afe2620b27b61b3e0f37053b1d4354 Mon Sep 17 00:00:00 2001 From: KOISHI KOMEIJI FROM TOUHOU 11 Date: Sat, 15 Jun 2024 20:13:15 +0800 Subject: [PATCH 011/172] fix check_perfmon len --- gpu/gpu_info_vulkan.c | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/gpu/gpu_info_vulkan.c b/gpu/gpu_info_vulkan.c index fbe7a5885..17ee43003 100644 --- a/gpu/gpu_info_vulkan.c +++ b/gpu/gpu_info_vulkan.c @@ -5,7 +5,7 @@ int check_perfmon(vk_handle_t* rh) { #ifdef __linux__ cap_t caps; - const cap_value_t cap_list[2] = {CAP_PERFMON}; + const cap_value_t cap_list[1] = {CAP_PERFMON}; if ((*rh->cap_get_bound)(CAP_SETFCAP) < 0) return -1; @@ -14,7 +14,7 @@ int check_perfmon(vk_handle_t* rh) { if (caps == NULL) return -1; - if ((*rh->cap_set_flag)(caps, CAP_EFFECTIVE, 2, cap_list, CAP_SET) == -1) + if ((*rh->cap_set_flag)(caps, CAP_EFFECTIVE, 1, cap_list, CAP_SET) == -1) return -1; if ((*rh->cap_set_proc)(caps) == -1) @@ -221,4 +221,4 @@ void vk_check_vram(vk_handle_t rh, int i, mem_info_t *resp) { void vk_release(vk_handle_t rh) { (*rh.vkDestroyInstance)(rh.vk, NULL); -} \ No newline at end of file +} From b958cd2848773e1e37fe0cd1e000aa0ee65f5fff Mon Sep 17 00:00:00 2001 From: DSLstandard Date: Sat, 15 Jun 2024 20:19:19 +0800 Subject: [PATCH 012/172] remove cap_get_bound check --- gpu/gpu_info_vulkan.c | 3 --- 1 file changed, 3 deletions(-) diff --git a/gpu/gpu_info_vulkan.c b/gpu/gpu_info_vulkan.c index 17ee43003..c4cdaa543 100644 --- a/gpu/gpu_info_vulkan.c +++ b/gpu/gpu_info_vulkan.c @@ -7,9 +7,6 @@ int check_perfmon(vk_handle_t* rh) { cap_t caps; const cap_value_t cap_list[1] = {CAP_PERFMON}; - if ((*rh->cap_get_bound)(CAP_SETFCAP) < 0) - return -1; - caps = (*rh->cap_get_proc)(); if (caps == NULL) return -1; From b6554e9b8c6502e2dcebba17bab75bd5235adfff Mon Sep 17 00:00:00 2001 From: pufferffish Date: Sat, 15 Jun 2024 21:11:07 +0100 Subject: [PATCH 013/172] fix vulkan handle releasing --- gpu/gpu.go | 5 +++++ gpu/gpu_info_vulkan.c | 9 ++++++++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/gpu/gpu.go b/gpu/gpu.go index 4eb82f2cf..6cebbd2b9 100644 --- a/gpu/gpu.go +++ b/gpu/gpu.go @@ -228,6 +228,11 @@ func GetGPUInfo() GpuInfoList { C.oneapi_release(*oHandles.oneapi) } } + if vHandles != nil { + if vHandles.vulkan != nil { + C.vk_release(*vHandles.vulkan) + } + } }() if !bootstrapped { diff --git a/gpu/gpu_info_vulkan.c b/gpu/gpu_info_vulkan.c index c4cdaa543..e868dcc1b 100644 --- a/gpu/gpu_info_vulkan.c +++ b/gpu/gpu_info_vulkan.c @@ -213,9 +213,16 @@ void vk_check_vram(vk_handle_t rh, int i, mem_info_t *resp) { resp->major = VK_API_VERSION_MAJOR(properties.apiVersion); resp->minor = VK_API_VERSION_MINOR(properties.apiVersion); resp->patch = VK_API_VERSION_PATCH(properties.apiVersion); - } void vk_release(vk_handle_t rh) { + LOG(rh.verbose, "releasing vulkan library\n"); (*rh.vkDestroyInstance)(rh.vk, NULL); + UNLOAD_LIBRARY(rh.vk_handle); + rh.vk_handle = NULL; +#ifdef __linux__ + LOG(rh.verbose, "releasing libcap library\n"); + UNLOAD_LIBRARY(rh.cap_handle); + rh.cap_handle = NULL; +#endif } From ace3d104683748e627b284c21bdfe387536f3b59 Mon Sep 17 00:00:00 2001 From: pufferffish Date: Mon, 23 Sep 2024 18:38:42 +0800 Subject: [PATCH 014/172] fix build on federa 40 --- llm/generate/gen_linux.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llm/generate/gen_linux.sh b/llm/generate/gen_linux.sh index 7e5f531e1..79c449018 100755 --- a/llm/generate/gen_linux.sh +++ b/llm/generate/gen_linux.sh @@ -234,8 +234,8 @@ if [ -z "${OLLAMA_SKIP_VULKAN_GENERATE}" -a -d "${VULKAN_ROOT}" ] && [ -z "${OLL for dep in $(ldd "${BUILD_DIR}/bin/ollama_llama_server" | grep "=>" | cut -f2 -d= | cut -f2 -d' ' | grep -e vulkan -e cap); do cp "${dep}" "${BUILD_DIR}/bin/" done - cp "${VULKAN_ROOT}/libvulkan.so" "${BUILD_DIR}/bin/" - cp "${CAP_ROOT}/libcap.so" "${BUILD_DIR}/bin/" + cp "${VULKAN_ROOT}/libvulkan.so*" "${BUILD_DIR}/bin/" + cp "${CAP_ROOT}/libcap.so*" "${BUILD_DIR}/bin/" compress fi From e61c32943556142a646e5812654768d9c2ba29c8 Mon Sep 17 00:00:00 2001 From: pufferffish Date: Mon, 23 Sep 2024 18:43:16 +0800 Subject: [PATCH 015/172] fix vulkan on windows --- gpu/gpu_windows.go | 5 +++++ llm/generate/gen_windows.ps1 | 16 ++++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/gpu/gpu_windows.go b/gpu/gpu_windows.go index 5491da963..f9afb6ed8 100644 --- a/gpu/gpu_windows.go +++ b/gpu/gpu_windows.go @@ -40,11 +40,16 @@ var OneapiGlobs = []string{ "c:\\Windows\\System32\\DriverStore\\FileRepository\\*\\ze_intel_gpu64.dll", } +var VulkanGlobs = []string{ + "c:\\Windows\\System32\\vulkan-1.dll", +} + var ( CudartMgmtName = "cudart64_*.dll" NvcudaMgmtName = "nvcuda.dll" NvmlMgmtName = "nvml.dll" OneapiMgmtName = "ze_intel_gpu64.dll" + VulkanMgmtName = "vulkan-1.dll" ) func FindLibCapLibs() []string { diff --git a/llm/generate/gen_windows.ps1 b/llm/generate/gen_windows.ps1 index 29ff5ff62..bb92c5121 100644 --- a/llm/generate/gen_windows.ps1 +++ b/llm/generate/gen_windows.ps1 @@ -412,6 +412,21 @@ function build_rocm() { } } +function build_vulkan() { + if (-not "${env:OLLAMA_SKIP_VULKAN_GENERATE}") { + init_vars + $script:buildDir="../build/windows/${script:ARCH}/vulkan" + $script:distDir="$script:DIST_BASE\vulkan" + $script:cmakeDefs += @("-A", "x64", "-DLLAMA_VULKAN=1") + write-host "Building Vulkan" + build + sign + install + } else { + write-host "Skipping Vulkan generation step" + } +} + init_vars if ($($args.count) -eq 0) { git_module_setup @@ -426,6 +441,7 @@ if ($($args.count) -eq 0) { build_cuda build_oneapi build_rocm + build_vulkan } cleanup From 4b74cee096f9bed2d5b64a983575c9d9d4c6fe7d Mon Sep 17 00:00:00 2001 From: yeongbba Date: Sun, 19 Jan 2025 01:30:34 +0900 Subject: [PATCH 016/172] making amdgpu work on arm achitecutre with vulkan --- envconfig/config.go | 2 ++ gpu/gpu.go | 2 +- gpu/gpu_linux.go | 4 ++-- llm/generate/gen_common.sh | 4 ++-- llm/generate/gen_linux.sh | 12 +++++++----- llm/generate/gen_windows.ps1 | 2 +- 6 files changed, 15 insertions(+), 11 deletions(-) diff --git a/envconfig/config.go b/envconfig/config.go index 9c1490a93..239c49fe6 100644 --- a/envconfig/config.go +++ b/envconfig/config.go @@ -175,6 +175,7 @@ var ( CudaVisibleDevices = String("CUDA_VISIBLE_DEVICES") HipVisibleDevices = String("HIP_VISIBLE_DEVICES") RocrVisibleDevices = String("ROCR_VISIBLE_DEVICES") + VkVisibleDevices = String("GGML_VK_VISIBLE_DEVICES") GpuDeviceOrdinal = String("GPU_DEVICE_ORDINAL") HsaOverrideGfxVersion = String("HSA_OVERRIDE_GFX_VERSION") ) @@ -263,6 +264,7 @@ func AsMap() map[string]EnvVar { ret["CUDA_VISIBLE_DEVICES"] = EnvVar{"CUDA_VISIBLE_DEVICES", CudaVisibleDevices(), "Set which NVIDIA devices are visible"} ret["HIP_VISIBLE_DEVICES"] = EnvVar{"HIP_VISIBLE_DEVICES", HipVisibleDevices(), "Set which AMD devices are visible"} ret["ROCR_VISIBLE_DEVICES"] = EnvVar{"ROCR_VISIBLE_DEVICES", RocrVisibleDevices(), "Set which AMD devices are visible"} + ret["GGML_VK_VISIBLE_DEVICES"] = EnvVar{"GGML_VK_VISIBLE_DEVICES", VkVisibleDevices(), "Set which VK AMD devices are visible"} ret["GPU_DEVICE_ORDINAL"] = EnvVar{"GPU_DEVICE_ORDINAL", GpuDeviceOrdinal(), "Set which AMD devices are visible"} ret["HSA_OVERRIDE_GFX_VERSION"] = EnvVar{"HSA_OVERRIDE_GFX_VERSION", HsaOverrideGfxVersion(), "Override the gfx used for all detected AMD GPUs"} ret["OLLAMA_INTEL_GPU"] = EnvVar{"OLLAMA_INTEL_GPU", IntelGPU(), "Enable experimental Intel GPU detection"} diff --git a/gpu/gpu.go b/gpu/gpu.go index 69279867e..1bd337f19 100644 --- a/gpu/gpu.go +++ b/gpu/gpu.go @@ -410,7 +410,7 @@ func GetGPUInfo() GpuInfoList { rocmGPUs = AMDGetGPUInfo() bootstrapped = true - if len(cudaGPUs) == 0 && len(rocmGPUs) == 0 && len(oneapiGPUs) == 0 { + if len(cudaGPUs) == 0 && len(rocmGPUs) == 0 && len(oneapiGPUs) == 0 && len(vulkanGPUs) == 0 { slog.Info("no compatible GPUs were discovered") } } diff --git a/gpu/gpu_linux.go b/gpu/gpu_linux.go index 76df63268..d6f882efb 100644 --- a/gpu/gpu_linux.go +++ b/gpu/gpu_linux.go @@ -53,12 +53,12 @@ var ( ) var VulkanGlobs = []string{ - "/usr/lib/x86_64-linux-gnu/libvulkan.so*", + "/usr/lib/aarch64-linux-gnu/libvulkan.so*", "/usr/lib*/libvulkan.so*", } var capLinuxGlobs = []string{ - "/usr/lib/x86_64-linux-gnu/libcap.so*", + "/usr/lib/aarch64-linux-gnu/libcap.so*", "/usr/lib*/libcap.so*", } diff --git a/llm/generate/gen_common.sh b/llm/generate/gen_common.sh index 3825c155a..2b01e149c 100644 --- a/llm/generate/gen_common.sh +++ b/llm/generate/gen_common.sh @@ -30,7 +30,7 @@ init_vars() { WHOLE_ARCHIVE="-Wl,-force_load" NO_WHOLE_ARCHIVE="" GCC_ARCH="-arch ${ARCH}" - DIST_BASE=../../dist/darwin-${GOARCH}/ + DIST_BASE=../../dist/darwin-${GOARCH} PAYLOAD_BASE=../../build/darwin/${GOARCH} ;; "Linux") @@ -40,7 +40,7 @@ init_vars() { # Cross compiling not supported on linux - Use docker GCC_ARCH="" - DIST_BASE=../../dist/linux-${GOARCH}/ + DIST_BASE=../../dist/linux-${GOARCH} PAYLOAD_BASE=../../build/linux/${GOARCH} ;; *) diff --git a/llm/generate/gen_linux.sh b/llm/generate/gen_linux.sh index 79c449018..17981d543 100755 --- a/llm/generate/gen_linux.sh +++ b/llm/generate/gen_linux.sh @@ -224,9 +224,9 @@ fi if [ -z "${OLLAMA_SKIP_VULKAN_GENERATE}" -a -d "${VULKAN_ROOT}" ] && [ -z "${OLLAMA_SKIP_VULKAN_GENERATE}" -a -d "${CAP_ROOT}" ]; then echo "Vulkan and capabilities libraries detected - building dynamic Vulkan library" init_vars - - CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} -DLLAMA_VULKAN=1" - BUILD_DIR="../build/linux/${ARCH}/vulkan" + RUNNER=vulkan + CMAKE_DEFS="-DCMAKE_SKIP_RPATH=on -DBUILD_SHARED_LIBS=on -DCMAKE_POSITION_INDEPENDENT_CODE=on -DGGML_NATIVE=off -DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_OPENMP=off" + BUILD_DIR="../build/linux/${ARCH}/${RUNNER}" EXTRA_LIBS="-L${VULKAN_ROOT} -L${CAP_ROOT} -lvulkan -lcap" build @@ -234,8 +234,10 @@ if [ -z "${OLLAMA_SKIP_VULKAN_GENERATE}" -a -d "${VULKAN_ROOT}" ] && [ -z "${OLL for dep in $(ldd "${BUILD_DIR}/bin/ollama_llama_server" | grep "=>" | cut -f2 -d= | cut -f2 -d' ' | grep -e vulkan -e cap); do cp "${dep}" "${BUILD_DIR}/bin/" done - cp "${VULKAN_ROOT}/libvulkan.so*" "${BUILD_DIR}/bin/" - cp "${CAP_ROOT}/libcap.so*" "${BUILD_DIR}/bin/" + cp ${VULKAN_ROOT}/libvulkan.so* "${BUILD_DIR}/bin/" + cp ${CAP_ROOT}/libcap.so* "${BUILD_DIR}/bin/" + install + dist compress fi diff --git a/llm/generate/gen_windows.ps1 b/llm/generate/gen_windows.ps1 index bb92c5121..e8bba29a5 100644 --- a/llm/generate/gen_windows.ps1 +++ b/llm/generate/gen_windows.ps1 @@ -417,7 +417,7 @@ function build_vulkan() { init_vars $script:buildDir="../build/windows/${script:ARCH}/vulkan" $script:distDir="$script:DIST_BASE\vulkan" - $script:cmakeDefs += @("-A", "x64", "-DLLAMA_VULKAN=1") + $script:cmakeDefs += @("-A", "x64", "-DDGGML_VULKAN=1") write-host "Building Vulkan" build sign From 6d7579b567cdb03154521e43103e9113a8cbd336 Mon Sep 17 00:00:00 2001 From: yeongbba Date: Sun, 19 Jan 2025 12:41:08 +0900 Subject: [PATCH 017/172] add x86_64 lines in VulkanGlobs and capLinuxGlobs --- gpu/gpu_linux.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/gpu/gpu_linux.go b/gpu/gpu_linux.go index d6f882efb..1251c6e8e 100644 --- a/gpu/gpu_linux.go +++ b/gpu/gpu_linux.go @@ -53,11 +53,13 @@ var ( ) var VulkanGlobs = []string{ + "/usr/lib/x86_64-linux-gnu/libvulkan.so*", "/usr/lib/aarch64-linux-gnu/libvulkan.so*", "/usr/lib*/libvulkan.so*", } var capLinuxGlobs = []string{ + "/usr/lib/x86_64-linux-gnu/libvulkan.so*", "/usr/lib/aarch64-linux-gnu/libcap.so*", "/usr/lib*/libcap.so*", } From 2bf59a512b938739f3bb8c3cdd034d817692ed45 Mon Sep 17 00:00:00 2001 From: yeongbba Date: Sun, 19 Jan 2025 12:51:10 +0900 Subject: [PATCH 018/172] add aarch64 lines in vulkanGlobs and capLinuxGlobs --- discover/gpu_linux.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/discover/gpu_linux.go b/discover/gpu_linux.go index 840ea435a..a2c1d8715 100644 --- a/discover/gpu_linux.go +++ b/discover/gpu_linux.go @@ -58,11 +58,13 @@ var ( var VulkanGlobs = []string{ "/usr/lib/x86_64-linux-gnu/libvulkan.so*", + "/usr/lib/aarch64-linux-gnu/libvulkan.so*", "/usr/lib*/libvulkan.so*", } var capLinuxGlobs = []string{ "/usr/lib/x86_64-linux-gnu/libcap.so*", + "/usr/lib/aarch64-linux-gnu/libvulkan.so*", "/usr/lib*/libcap.so*", } From 0d277d32db8a56aa42e2fdeea2c81202a364c8da Mon Sep 17 00:00:00 2001 From: tomaThomas Date: Sat, 25 Jan 2025 11:23:25 +0100 Subject: [PATCH 019/172] Fix variable name --- discover/gpu.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/discover/gpu.go b/discover/gpu.go index 15d0b99a4..1b6a3f075 100644 --- a/discover/gpu.go +++ b/discover/gpu.go @@ -440,7 +440,7 @@ func GetGPUInfo() GpuInfoList { gpuInfo.ID = C.GoString(&memInfo.gpu_id[0]) gpuInfo.Compute = fmt.Sprintf("%d.%d", memInfo.major, memInfo.minor) gpuInfo.MinimumMemory = 0 - gpuInfo.DependencyPath = depPath + gpuInfo.DependencyPath = depPaths gpuInfo.Name = C.GoString(&memInfo.gpu_name[0]) gpuInfo.DriverMajor = int(memInfo.major) gpuInfo.DriverMinor = int(memInfo.minor) From 2d443b3dd660a1fd2760d64538512df93648b4bb Mon Sep 17 00:00:00 2001 From: pufferffish Date: Mon, 3 Feb 2025 14:46:59 +0000 Subject: [PATCH 020/172] Add vulkan build patch from @jmorganca --- CMakeLists.txt | 13 +++++++++++++ ml/backend/ggml/ggml/.rsync-filter | 3 +++ 2 files changed, 16 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 19d9bd8f9..05f8e2c47 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -110,3 +110,16 @@ if(CMAKE_HIP_COMPILER) endforeach() endif() endif() + +find_package(Vulkan) +if(Vulkan_FOUND) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan) + set(OLLAMA_VULKAN_INSTALL_DIR ${OLLAMA_INSTALL_DIR}/vulkan) + install(TARGETS ggml-vulkan + RUNTIME_DEPENDENCIES + PRE_INCLUDE_REGEXES vulkan + PRE_EXCLUDE_REGEXES ".*" + RUNTIME DESTINATION ${OLLAMA_VULKAN_INSTALL_DIR} COMPONENT Vulkan + LIBRARY DESTINATION ${OLLAMA_VULKAN_INSTALL_DIR} COMPONENT Vulkan + ) +endif() diff --git a/ml/backend/ggml/ggml/.rsync-filter b/ml/backend/ggml/ggml/.rsync-filter index c5acbe490..09d67f270 100644 --- a/ml/backend/ggml/ggml/.rsync-filter +++ b/ml/backend/ggml/ggml/.rsync-filter @@ -12,6 +12,8 @@ include src/ggml-cuda/ include src/ggml-cuda/template-instances/ include src/ggml-hip/ include src/ggml-metal/ +include src/ggml-vulkan/ +include src/ggml-vulkan/vulkan-shaders include *.c include *.h include *.cpp @@ -19,4 +21,5 @@ include *.cu include *.cuh include *.m include *.metal +include *.comp exclude * From 449e5c07aeb9c034f530869a55f0fe3f44ee88dc Mon Sep 17 00:00:00 2001 From: Antoine Viallon Date: Tue, 4 Feb 2025 11:51:17 +0100 Subject: [PATCH 021/172] Sync vendored ggml to add Vulkan support --- Makefile.sync | 2 +- .../ggml/ggml/src/ggml-vulkan/CMakeLists.txt | 92 + .../ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp | 8745 +++++++++++++++++ .../ggml-vulkan/vulkan-shaders/CMakeLists.txt | 9 + .../src/ggml-vulkan/vulkan-shaders/acc.comp | 29 + .../src/ggml-vulkan/vulkan-shaders/add.comp | 29 + .../ggml-vulkan/vulkan-shaders/argsort.comp | 69 + .../src/ggml-vulkan/vulkan-shaders/clamp.comp | 17 + .../ggml-vulkan/vulkan-shaders/concat.comp | 41 + .../vulkan-shaders/contig_copy.comp | 42 + .../src/ggml-vulkan/vulkan-shaders/copy.comp | 20 + .../src/ggml-vulkan/vulkan-shaders/cos.comp | 17 + .../vulkan-shaders/dequant_f32.comp | 20 + .../vulkan-shaders/dequant_funcs.comp | 118 + .../vulkan-shaders/dequant_funcs_cm2.comp | 325 + .../vulkan-shaders/dequant_head.comp | 13 + .../vulkan-shaders/dequant_iq4_nl.comp | 32 + .../vulkan-shaders/dequant_q2_k.comp | 34 + .../vulkan-shaders/dequant_q3_k.comp | 42 + .../vulkan-shaders/dequant_q4_0.comp | 30 + .../vulkan-shaders/dequant_q4_1.comp | 32 + .../vulkan-shaders/dequant_q4_k.comp | 68 + .../vulkan-shaders/dequant_q5_0.comp | 34 + .../vulkan-shaders/dequant_q5_1.comp | 35 + .../vulkan-shaders/dequant_q5_k.comp | 70 + .../vulkan-shaders/dequant_q6_k.comp | 33 + .../vulkan-shaders/dequant_q8_0.comp | 31 + .../vulkan-shaders/diag_mask_inf.comp | 34 + .../src/ggml-vulkan/vulkan-shaders/div.comp | 27 + .../vulkan-shaders/flash_attn_cm2.comp | 289 + .../src/ggml-vulkan/vulkan-shaders/gelu.comp | 25 + .../vulkan-shaders/gelu_quick.comp | 23 + .../vulkan-shaders/generic_binary_head.comp | 64 + .../vulkan-shaders/generic_head.comp | 9 + .../vulkan-shaders/generic_unary_head.comp | 56 + .../ggml-vulkan/vulkan-shaders/get_rows.comp | 28 + .../vulkan-shaders/get_rows_quant.comp | 39 + .../vulkan-shaders/group_norm.comp | 66 + .../ggml-vulkan/vulkan-shaders/im2col.comp | 87 + .../vulkan-shaders/leaky_relu.comp | 22 + .../src/ggml-vulkan/vulkan-shaders/mul.comp | 27 + .../mul_mat_split_k_reduce.comp | 48 + .../vulkan-shaders/mul_mat_vec.comp | 152 + .../vulkan-shaders/mul_mat_vec_base.comp | 118 + .../vulkan-shaders/mul_mat_vec_nc.comp | 71 + .../vulkan-shaders/mul_mat_vec_p021.comp | 73 + .../vulkan-shaders/mul_mat_vec_q2_k.comp | 115 + .../vulkan-shaders/mul_mat_vec_q3_k.comp | 103 + .../vulkan-shaders/mul_mat_vec_q4_k.comp | 133 + .../vulkan-shaders/mul_mat_vec_q5_k.comp | 162 + .../vulkan-shaders/mul_mat_vec_q6_k.comp | 112 + .../ggml-vulkan/vulkan-shaders/mul_mm.comp | 631 ++ .../vulkan-shaders/mul_mm_cm2.comp | 328 + .../src/ggml-vulkan/vulkan-shaders/norm.comp | 44 + .../src/ggml-vulkan/vulkan-shaders/pad.comp | 28 + .../ggml-vulkan/vulkan-shaders/pool2d.comp | 74 + .../src/ggml-vulkan/vulkan-shaders/relu.comp | 21 + .../ggml-vulkan/vulkan-shaders/repeat.comp | 26 + .../ggml-vulkan/vulkan-shaders/rms_norm.comp | 42 + .../ggml-vulkan/vulkan-shaders/rope_head.comp | 49 + .../ggml-vulkan/vulkan-shaders/rope_neox.comp | 37 + .../ggml-vulkan/vulkan-shaders/rope_norm.comp | 37 + .../src/ggml-vulkan/vulkan-shaders/scale.comp | 24 + .../src/ggml-vulkan/vulkan-shaders/silu.comp | 22 + .../src/ggml-vulkan/vulkan-shaders/sin.comp | 17 + .../ggml-vulkan/vulkan-shaders/soft_max.comp | 174 + .../ggml-vulkan/vulkan-shaders/square.comp | 17 + .../ggml-vulkan/vulkan-shaders/sum_rows.comp | 37 + .../src/ggml-vulkan/vulkan-shaders/tanh.comp | 20 + .../vulkan-shaders/test_coopmat2_support.comp | 7 + .../vulkan-shaders/timestep_embedding.comp | 41 + .../src/ggml-vulkan/vulkan-shaders/types.comp | 323 + .../ggml-vulkan/vulkan-shaders/upscale.comp | 36 + .../vulkan-shaders/vulkan-shaders-gen.cpp | 594 ++ .../src/ggml-vulkan/vulkan-shaders/wkv6.comp | 87 + 75 files changed, 14627 insertions(+), 1 deletion(-) create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/CMakeLists.txt create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/add.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/div.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_split_k_reduce.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/square.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat2_support.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/types.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp diff --git a/Makefile.sync b/Makefile.sync index 3001487de..78333fd62 100644 --- a/Makefile.sync +++ b/Makefile.sync @@ -32,7 +32,7 @@ PATCHES=$(wildcard llama/patches/*.patch) apply-patches: $(addsuffix ed, $(PATCHES)) %.patched: %.patch - @if git -c user.name=nobody -c 'user.email=<>' -C $(WORKDIR) am -3 $(realpath $<); then touch $@; else git -C $(WORKDIR) am --abort; exit 1; fi + @if git -c commit.gpgSign=false -c user.name=nobody -c 'user.email=<>' -C $(WORKDIR) am -3 $(realpath $<); then touch $@; else git -C $(WORKDIR) am --abort; exit 1; fi .PHONY: checkout checkout: $(WORKDIR) diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/CMakeLists.txt b/ml/backend/ggml/ggml/src/ggml-vulkan/CMakeLists.txt new file mode 100644 index 000000000..9501de736 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/CMakeLists.txt @@ -0,0 +1,92 @@ +find_package(Vulkan COMPONENTS glslc REQUIRED) + +if (Vulkan_FOUND) + message(STATUS "Vulkan found") + + ggml_add_backend_library(ggml-vulkan + ggml-vulkan.cpp + ../../include/ggml-vulkan.h + ) + + # Compile a test shader to determine whether GL_NV_cooperative_matrix2 is supported. + # If it's not, there will be an error to stderr. + # If it's supported, set a define to indicate that we should compile those shaders + execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat2_support.comp" + OUTPUT_VARIABLE glslc_output + ERROR_VARIABLE glslc_error) + + if (${glslc_error} MATCHES ".*extension not supported: GL_NV_cooperative_matrix2.*") + message(STATUS "GL_NV_cooperative_matrix2 not supported by glslc") + else() + message(STATUS "GL_NV_cooperative_matrix2 supported by glslc") + add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + endif() + + target_link_libraries(ggml-vulkan PRIVATE Vulkan::Vulkan) + target_include_directories(ggml-vulkan PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) + + # Workaround to the "can't dereference invalidated vector iterator" bug in clang-cl debug build + # Posssibly relevant: https://stackoverflow.com/questions/74748276/visual-studio-no-displays-the-correct-length-of-stdvector + if (MSVC AND CMAKE_CXX_COMPILER_ID STREQUAL "Clang") + add_compile_definitions(_ITERATOR_DEBUG_LEVEL=0) + endif() + + if (GGML_VULKAN_CHECK_RESULTS) + add_compile_definitions(GGML_VULKAN_CHECK_RESULTS) + endif() + + if (GGML_VULKAN_DEBUG) + add_compile_definitions(GGML_VULKAN_DEBUG) + endif() + + if (GGML_VULKAN_MEMORY_DEBUG) + add_compile_definitions(GGML_VULKAN_MEMORY_DEBUG) + endif() + + if (GGML_VULKAN_SHADER_DEBUG_INFO) + add_compile_definitions(GGML_VULKAN_SHADER_DEBUG_INFO) + endif() + + if (GGML_VULKAN_PERF) + add_compile_definitions(GGML_VULKAN_PERF) + endif() + + if (GGML_VULKAN_VALIDATE) + add_compile_definitions(GGML_VULKAN_VALIDATE) + endif() + + if (GGML_VULKAN_RUN_TESTS) + add_compile_definitions(GGML_VULKAN_RUN_TESTS) + endif() + + add_subdirectory(vulkan-shaders) + + set (_ggml_vk_genshaders_cmd vulkan-shaders-gen) + set (_ggml_vk_header ${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.hpp) + set (_ggml_vk_source ${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.cpp) + set (_ggml_vk_input_dir ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders) + set (_ggml_vk_output_dir ${CMAKE_CURRENT_BINARY_DIR}/vulkan-shaders.spv) + + file(GLOB _ggml_vk_shader_deps "${_ggml_vk_input_dir}/*.comp") + + add_custom_command( + OUTPUT ${_ggml_vk_header} + ${_ggml_vk_source} + + COMMAND "$/${_ggml_vk_genshaders_cmd}" + --glslc ${Vulkan_GLSLC_EXECUTABLE} + --input-dir ${_ggml_vk_input_dir} + --output-dir ${_ggml_vk_output_dir} + --target-hpp ${_ggml_vk_header} + --target-cpp ${_ggml_vk_source} + --no-clean + + DEPENDS ${_ggml_vk_shader_deps} ${_ggml_vk_genshaders_cmd} + COMMENT "Generate vulkan shaders" + ) + + target_sources(ggml-vulkan PRIVATE ${_ggml_vk_source} ${_ggml_vk_header}) + +else() + message(WARNING "Vulkan not found") +endif() diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp new file mode 100644 index 000000000..d75cd6d61 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -0,0 +1,8745 @@ +#include "ggml-vulkan.h" +#include +#if defined(GGML_VULKAN_RUN_TESTS) || defined(GGML_VULKAN_PERF) || defined(GGML_VULKAN_CHECK_RESULTS) +#include +#include "ggml-cpu.h" +#endif + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ggml-impl.h" +#include "ggml-backend-impl.h" + +#include "ggml-vulkan-shaders.hpp" + +#define VK_API_VERSION VK_API_VERSION_1_2 + +#define CEIL_DIV(M, N) (((M) + (N)-1) / (N)) + +#define VK_VENDOR_ID_AMD 0x1002 +#define VK_VENDOR_ID_APPLE 0x106b +#define VK_VENDOR_ID_INTEL 0x8086 +#define VK_VENDOR_ID_NVIDIA 0x10de + +#define VK_DEVICE_DESCRIPTOR_POOL_SIZE 32 + +#define GGML_VK_MAX_NODES 8192 + +#define MAX_VK_BUFFERS 256 + +#define VK_CHECK(err, msg) \ + do { \ + vk::Result err_ = (err); \ + if (err_ != vk::Result::eSuccess) { \ + fprintf(stderr, "ggml_vulkan: %s error %s at %s:%d\n", \ + #err, to_string(err_).c_str(), __FILE__, __LINE__); \ + exit(1); \ + } \ + } while (0) + +#ifdef GGML_VULKAN_DEBUG +#define VK_LOG_DEBUG(msg) std::cerr << msg << std::endl +#else +#define VK_LOG_DEBUG(msg) ((void) 0) +#endif // GGML_VULKAN_DEBUG + +struct ggml_backend_vk_context; + +struct vk_queue { + uint32_t queue_family_index; + vk::Queue queue; + vk::CommandPool pool; + uint32_t cmd_buffer_idx; + std::vector cmd_buffers; + + vk::PipelineStageFlags stage_flags; + + bool transfer_only; +}; + +struct vk_pipeline_struct { + std::string name; + vk::ShaderModule shader_module; + vk::DescriptorSetLayout dsl; + std::vector descriptor_pools; + std::vector descriptor_sets; + uint32_t descriptor_set_idx; + vk::PipelineLayout layout; + vk::Pipeline pipeline; + uint32_t push_constant_size; + uint32_t parameter_count; + std::array wg_denoms; + uint32_t align; +}; + +typedef std::shared_ptr vk_pipeline; +typedef std::weak_ptr vk_pipeline_ref; + +static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline); + +struct vk_matmul_pipeline_struct { + vk_pipeline l, m, s; + vk_pipeline a_l, a_m, a_s; +}; + +typedef std::shared_ptr vk_matmul_pipeline; + +struct vk_matmul_pipeline2 { + vk_matmul_pipeline2() { + f16acc = std::make_shared(); + f32acc = std::make_shared(); + } + vk_matmul_pipeline f32acc; + vk_matmul_pipeline f16acc; +}; + +struct vk_device_struct; +typedef std::shared_ptr vk_device; +typedef std::weak_ptr vk_device_ref; + +struct vk_buffer_struct; +typedef std::shared_ptr vk_buffer; +typedef std::weak_ptr vk_buffer_ref; + +struct ggml_backend_vk_buffer_type_context { + std::string name; + vk_device device; +}; + +static const char * ggml_backend_vk_buffer_type_name(ggml_backend_buffer_type_t buft); +static ggml_backend_buffer_t ggml_backend_vk_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size); +static size_t ggml_backend_vk_buffer_type_get_alignment(ggml_backend_buffer_type_t buft); +static size_t ggml_backend_vk_buffer_type_get_max_size(ggml_backend_buffer_type_t buft); +static size_t ggml_backend_vk_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor); +static ggml_backend_buffer_type_i ggml_backend_vk_buffer_type_interface = { + /* .get_name = */ ggml_backend_vk_buffer_type_name, + /* .alloc_buffer = */ ggml_backend_vk_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_vk_buffer_type_get_alignment, + /* .get_max_size = */ ggml_backend_vk_buffer_type_get_max_size, + /* .get_alloc_size = */ ggml_backend_vk_buffer_type_get_alloc_size, + /* .is_host = */ NULL, +}; + +#ifdef GGML_VULKAN_MEMORY_DEBUG +class vk_memory_logger; +#endif +#ifdef GGML_VULKAN_PERF +class vk_perf_logger; +#endif +static void ggml_vk_destroy_buffer(vk_buffer& buf); + +static constexpr uint32_t mul_mat_vec_max_cols = 8; + +struct vk_device_struct { + std::mutex mutex; + + vk::PhysicalDevice physical_device; + vk::PhysicalDeviceProperties properties; + std::string name; + uint64_t max_memory_allocation_size; + bool fp16; + bool pipeline_robustness; + vk::Device device; + uint32_t vendor_id; + vk_queue compute_queue; + vk_queue transfer_queue; + bool single_queue; + uint32_t subgroup_size; + uint32_t shader_core_count; + bool uma; + bool float_controls_rte_fp16; + + bool subgroup_size_control; + uint32_t subgroup_min_size; + uint32_t subgroup_max_size; + bool subgroup_require_full_support; + + bool coopmat_support; + bool coopmat_acc_f32_support; + bool coopmat_acc_f16_support; + uint32_t coopmat_m; + uint32_t coopmat_n; + uint32_t coopmat_k; + bool coopmat2; + + size_t idx; + + bool mul_mat_l; + bool mul_mat_m; + bool mul_mat_s; + bool mul_mat_id_l; + bool mul_mat_id_m; + bool mul_mat_id_s; + + vk_matmul_pipeline pipeline_matmul_f32; + vk_matmul_pipeline pipeline_matmul_f32_f16; + vk_matmul_pipeline2 pipeline_matmul_f16; + vk_matmul_pipeline2 pipeline_matmul_f16_f32; + vk_pipeline pipeline_matmul_split_k_reduce; + + vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_COUNT]; + vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat[GGML_TYPE_COUNT]; + + vk_matmul_pipeline pipeline_matmul_id_f32; + vk_matmul_pipeline2 pipeline_matmul_id_f16; + vk_matmul_pipeline2 pipeline_matmul_id_f16_f32; + + vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT]; + + vk_pipeline pipeline_dequant[GGML_TYPE_COUNT]; + vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols]; + vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols]; + vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_COUNT]; + + vk_pipeline pipeline_mul_mat_vec_p021_f16_f32; + vk_pipeline pipeline_mul_mat_vec_nc_f16_f32; + vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT]; + vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT]; + vk_pipeline pipeline_acc_f32; + vk_pipeline pipeline_add_f32, pipeline_add_f32_norepeat; + vk_pipeline pipeline_add_f16_f32_f16, pipeline_add_f16_f32_f16_norepeat; + vk_pipeline pipeline_mul_f32, pipeline_mul_f32_norepeat; + vk_pipeline pipeline_div_f32, pipeline_div_f32_norepeat; + vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32; + vk_pipeline pipeline_upscale_f32; + vk_pipeline pipeline_scale_f32; + vk_pipeline pipeline_sqr_f32; + vk_pipeline pipeline_sin_f32; + vk_pipeline pipeline_cos_f32; + vk_pipeline pipeline_clamp_f32; + vk_pipeline pipeline_pad_f32; + vk_pipeline pipeline_repeat_f32; + vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16; + vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16; + vk_pipeline pipeline_norm_f32; + vk_pipeline pipeline_group_norm_f32; + vk_pipeline pipeline_rms_norm_f32; + vk_pipeline pipeline_gelu_f32; + vk_pipeline pipeline_gelu_quick_f32; + vk_pipeline pipeline_silu_f32; + vk_pipeline pipeline_relu_f32; + vk_pipeline pipeline_leaky_relu_f32; + vk_pipeline pipeline_tanh_f32; + vk_pipeline pipeline_diag_mask_inf_f32; + vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16; + vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512; + vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16; + vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16; + vk_pipeline pipeline_argsort_f32; + vk_pipeline pipeline_sum_rows_f32; + vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16; + vk_pipeline pipeline_timestep_embedding_f32; + vk_pipeline pipeline_pool2d_f32; + vk_pipeline pipeline_rwkv_wkv6_f32; + + // [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned} + vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D80[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D96[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D112[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D128[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D256[GGML_TYPE_COUNT][2][2][2]; + + std::unordered_map pipelines; + std::unordered_map pipeline_descriptor_set_requirements; + + std::vector> pinned_memory; + + vk::Fence fence; + vk_buffer sync_staging; + + ggml_backend_buffer_type buffer_type; + +#ifdef GGML_VULKAN_MEMORY_DEBUG + std::unique_ptr memory_logger; +#endif +#ifdef GGML_VULKAN_PERF + std::unique_ptr perf_logger; +#endif + + ~vk_device_struct() { + VK_LOG_DEBUG("destroy device " << name); + + device.destroyFence(fence); + + ggml_vk_destroy_buffer(sync_staging); + + device.destroyCommandPool(compute_queue.pool); + if (!single_queue) { + device.destroyCommandPool(transfer_queue.pool); + } + + for (auto& pipeline : pipelines) { + if (pipeline.second.expired()) { + continue; + } + + vk_pipeline pl = pipeline.second.lock(); + ggml_vk_destroy_pipeline(device, pl); + } + pipelines.clear(); + + device.destroy(); + } +}; + +struct vk_buffer_struct { + vk::Buffer buffer = VK_NULL_HANDLE; + vk::DeviceMemory device_memory = VK_NULL_HANDLE; + vk::MemoryPropertyFlags memory_property_flags; + void * ptr; + size_t size = 0; + + vk_device device; + + ~vk_buffer_struct() { + if (size == 0) { + return; + } + VK_LOG_DEBUG("~vk_buffer_struct(" << buffer << ", " << size << ")"); + + device->device.freeMemory(device_memory); + device->device.destroyBuffer(buffer); + } +}; + +struct vk_subbuffer { + vk_buffer buffer; + uint64_t offset; + uint64_t size; + + operator vk::DescriptorBufferInfo() const { + return { buffer->buffer, offset, size }; + } +}; + +struct vk_semaphore { + vk::Semaphore s; + uint64_t value; +}; + +struct vk_submission { + vk::CommandBuffer buffer; + std::vector wait_semaphores; + std::vector signal_semaphores; +}; + +typedef std::vector vk_sequence; + +struct vk_mat_mat_push_constants { + uint32_t M; uint32_t N; uint32_t K; + uint32_t stride_a; uint32_t stride_b; uint32_t stride_d; + uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d; + uint32_t k_split; + uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3; +}; +struct vk_mat_vec_push_constants { + uint32_t ncols; uint32_t stride_a; uint32_t stride_b; uint32_t stride_d; + uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d; + uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3; +}; + +struct vk_mat_mat_id_push_constants { + uint32_t M; uint32_t N; uint32_t K; + uint32_t stride_a; uint32_t stride_b; uint32_t stride_d; + uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d; + uint32_t nei0; uint32_t nei1; uint32_t nbi1; uint32_t ne11; +}; +struct vk_mat_vec_id_push_constants { + uint32_t ncols; uint32_t stride_a; uint32_t stride_b; uint32_t stride_d; + uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d; + uint32_t nei0; uint32_t ne11; +}; + +struct vk_flash_attn_push_constants { + uint32_t N; + uint32_t KV; + + uint32_t ne1; + uint32_t ne2; + uint32_t ne3; + + uint32_t neq2; + uint32_t neq3; + uint32_t nek2; + uint32_t nek3; + uint32_t nev2; + uint32_t nev3; + uint32_t nem1; + + uint32_t nb02; + uint32_t nb03; + uint32_t nb12; + uint32_t nb13; + uint32_t nb22; + uint32_t nb23; + uint32_t nb31; + + float scale; + float max_bias; + float logit_softcap; + + uint32_t mask; + uint32_t n_head_log2; + float m0; + float m1; +}; + +struct vk_op_push_constants { + uint32_t KX; + uint32_t KY; + float param1; + float param2; +}; + +struct vk_op_unary_push_constants { + uint32_t ne; + uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03; + uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13; + uint32_t misalign_offsets; + float param1; float param2; + uint32_t ne0_012mp; uint32_t ne0_012L; + uint32_t ne0_01mp; uint32_t ne0_01L; + uint32_t ne0_0mp; uint32_t ne0_0L; + uint32_t ne1_012mp; uint32_t ne1_012L; + uint32_t ne1_01mp; uint32_t ne1_01L; + uint32_t ne1_0mp; uint32_t ne1_0L; +}; +static_assert(sizeof(vk_op_unary_push_constants) <= 128, "sizeof(vk_op_unary_push_constants) must be <= 128"); + +// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1. +// Precompute mp (m' in the paper) and L such that division +// can be computed using a multiply (high 32b of 64b result) +// and a shift: +// +// n/d = (mulhi(n, mp) + n) >> L; +static void init_fastdiv_values(uint32_t d, uint32_t &mp, uint32_t &L) +{ + // compute L = ceil(log2(d)); + L = 0; + while (L < 32 && (uint32_t{1} << L) < d) { + L++; + } + + mp = (uint32_t)((uint64_t{1} << 32) * ((uint64_t{1} << L) - d) / d + 1); +} + +template void init_pushconst_fastdiv(T &p) { + GGML_UNUSED(p); + static_assert(!std::is_const::value, "unexpected type"); +} + +template <> void init_pushconst_fastdiv(vk_op_unary_push_constants &p) { + // Compute magic values to divide by these six numbers. + init_fastdiv_values(p.ne02*p.ne01*p.ne00, p.ne0_012mp, p.ne0_012L); + init_fastdiv_values(p.ne01*p.ne00, p.ne0_01mp, p.ne0_01L); + init_fastdiv_values(p.ne00, p.ne0_0mp, p.ne0_0L); + init_fastdiv_values(p.ne12*p.ne11*p.ne10, p.ne1_012mp, p.ne1_012L); + init_fastdiv_values(p.ne11*p.ne10, p.ne1_01mp, p.ne1_01L); + init_fastdiv_values(p.ne10, p.ne1_0mp, p.ne1_0L); +} + +struct vk_op_binary_push_constants { + uint32_t ne; + uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03; + uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13; + uint32_t ne20; uint32_t ne21; uint32_t ne22; uint32_t ne23; uint32_t nb20; uint32_t nb21; uint32_t nb22; uint32_t nb23; + uint32_t misalign_offsets; + float param1; float param2; int32_t param3; +}; + +struct vk_op_diag_mask_push_constants { + uint32_t ncols; + uint32_t rows_per_channel; + int32_t n_past; +}; + +struct vk_op_rope_push_constants { + uint32_t ncols; + uint32_t n_dims; + float freq_scale; + uint32_t p_delta_rows; + float freq_base; + float ext_factor; + float attn_factor; + float corr_dims[2]; + float theta_scale; + uint32_t has_ff; +}; + +struct vk_op_soft_max_push_constants { + uint32_t KX; + uint32_t KY; + float scale; + float max_bias; + float m0; + float m1; + uint32_t n_head_log2; + uint32_t nrows_x; +}; + +struct vk_op_argsort_push_constants { + uint32_t ncols; + uint32_t ncols_pad; + int32_t order; +}; + +struct vk_op_im2col_push_constants { + uint32_t batch_offset; uint32_t offset_delta; + uint32_t IC; + uint32_t IW; uint32_t IH; + uint32_t OW; uint32_t OH; + uint32_t KW; uint32_t KH; + uint32_t pelements; + uint32_t CHW; + int32_t s0; int32_t s1; + int32_t p0; int32_t p1; + int32_t d0; int32_t d1; +}; + +struct vk_op_timestep_embedding_push_constants { + uint32_t nb1; + uint32_t dim; + uint32_t max_period; +}; + +struct vk_op_pool2d_push_constants { + uint32_t IW; uint32_t IH; + uint32_t OW; uint32_t OH; + uint32_t OC; + uint32_t pelements; + uint32_t op; + int32_t k0; int32_t k1; + int32_t s0; int32_t s1; + int32_t p0; int32_t p1; +}; + +struct vk_op_rwkv_wkv6_push_constants { + uint32_t B; + uint32_t T; + uint32_t C; + uint32_t H; +}; + +// Allow pre-recording command buffers +struct vk_staging_memcpy { + vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {} + + void * dst; + const void * src; + size_t n; +}; + +struct vk_op_upscale_push_constants { + uint32_t ne; uint32_t a_offset; uint32_t d_offset; + uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03; + uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; + float sf0; float sf1; float sf2; float sf3; +}; + +struct vk_context_struct { + vk_submission * s; + std::vector seqs; + + int exit_tensor_idx; + + std::vector in_memcpys; + std::vector out_memcpys; + + vk_queue * q; +}; +typedef std::shared_ptr vk_context; +typedef std::weak_ptr vk_context_ref; + +struct ggml_vk_garbage_collector { + std::vector tl_semaphores; + std::vector semaphores; + std::vector events; + std::vector temp_buffers; + std::vector contexts; +}; + +#if defined(GGML_VULKAN_MEMORY_DEBUG) || defined(GGML_VULKAN_DEBUG) +#define VK_LOG_MEMORY(msg) std::cerr << "ggml_vulkan memory: " << msg << std::endl + +static std::string format_size(size_t size) { + const size_t kib = 1024; + const size_t mib = kib * 1024; + const size_t gib = mib * 1024; + + std::ostringstream oss; + oss << std::fixed << std::setprecision(2); + + if (size >= gib) { + oss << static_cast(size) / gib << " GiB"; + } else if (size >= mib) { + oss << static_cast(size) / mib << " MiB"; + } else if (size >= kib) { + oss << static_cast(size) / kib << " KiB"; + } else { + oss << size << " B"; + } + + return oss.str(); +} + +static std::mutex log_mutex; + +class vk_memory_logger { +public: + vk_memory_logger(): total_device(0), total_host(0) {} + void log_allocation(vk_buffer_ref buf_ref, size_t size); + void log_deallocation(vk_buffer_ref buf_ref); + +private: + std::map allocations; // Track allocations + size_t total_device; + size_t total_host; +}; +#else +#define VK_LOG_MEMORY(msg) ((void) 0) +#endif // GGML_VULKAN_MEMORY_DEBUG + +#if defined(GGML_VULKAN_PERF) + +class vk_perf_logger { +public: + void print_timings() { + std::cerr << "----------------\nVulkan Timings:" << std::endl; + for (const auto& t : timings) { + uint64_t total = 0; + for (const auto& time : t.second) { + total += time; + } + std::cerr << t.first << ": " << t.second.size() << " x " << (total / t.second.size() / 1000.0) << " ms" << std::endl; + } + + timings.clear(); + } + + void log_timing(const ggml_tensor * node, uint64_t time) { + if (node->op == GGML_OP_UNARY) { + timings[ggml_unary_op_name(ggml_get_unary_op(node))].push_back(time); + return; + } + if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) { + const uint64_t m = node->src[0]->ne[1]; + const uint64_t n = node->src[1]->ne[1]; + const uint64_t k = node->src[1]->ne[0]; + std::string name = ggml_op_name(node->op); + if (n == 1) { + name += "_VEC m=" + std::to_string(m) + " k=" + std::to_string(k); + } else { + name += " m=" + std::to_string(m) + " n=" + std::to_string(n) + " k=" + std::to_string(k); + } + timings[name].push_back(time); + return; + } + timings[ggml_op_name(node->op)].push_back(time); + } +private: + std::map> timings; +}; +#endif // GGML_VULKAN_PERF + +struct ggml_backend_vk_context { + std::string name; + + vk_device device; + + size_t semaphore_idx, event_idx; + ggml_vk_garbage_collector gc; + size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k; + vk_buffer prealloc_x, prealloc_y, prealloc_split_k; + vk::Fence fence; + + vk_buffer buffer_pool[MAX_VK_BUFFERS]; + + vk_context_ref compute_ctx; + vk_context_ref transfer_ctx; + + std::vector tensor_ctxs; +}; + +static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT + +static uint64_t vk_tensor_offset(const ggml_tensor * tensor) { + if (tensor->view_src) { + return (uint8_t *) tensor->view_src->data - (uint8_t *) vk_ptr_base; + } + return (uint8_t *) tensor->data - (uint8_t *) vk_ptr_base; +} + +struct ggml_backend_vk_buffer_context { + vk_device_ref device; + vk_buffer dev_buffer; + std::string name; + + ggml_backend_vk_buffer_context(vk_device_ref device, vk_buffer&& dev_buffer, std::string& name) : + device(device), + dev_buffer(dev_buffer), + name(name) { + } + + ~ggml_backend_vk_buffer_context() { + ggml_vk_destroy_buffer(dev_buffer); + } +}; + +#ifdef GGML_VULKAN_MEMORY_DEBUG +void vk_memory_logger::log_allocation(vk_buffer_ref buf_ref, size_t size) { + std::lock_guard guard(log_mutex); + vk_buffer buf = buf_ref.lock(); + const bool device = bool(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eDeviceLocal); + const std::string type = device ? "device" : "host"; + allocations[buf->buffer] = size; + total_device += device ? size : 0; + total_host += device ? 0 : size; + VK_LOG_MEMORY(buf->device->name << ": +" << format_size(size) << " " << type << " at " << buf->buffer << ". Total device: " << format_size(total_device) << ", total host: " << format_size(total_host)); +} + +void vk_memory_logger::log_deallocation(vk_buffer_ref buf_ref) { + if (buf_ref.expired() || buf_ref.lock()->size == 0) { + return; + } + + std::lock_guard guard(log_mutex); + vk_buffer buf = buf_ref.lock(); + const bool device = bool(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eDeviceLocal); + std::string type = device ? "device" : "host"; + auto it = allocations.find(buf->buffer); + total_device -= device ? it->second : 0; + total_host -= device ? 0 : it->second; + if (it != allocations.end()) { + VK_LOG_MEMORY(buf->device->name << ": -" << format_size(it->second) << " " << type << " at " << buf->buffer << ". Total device: " << format_size(total_device) << ", total host: " << format_size(total_host)); + allocations.erase(it); + } else { + VK_LOG_MEMORY("ERROR " << buf->device->name << ": Attempted to deallocate unknown " << type << " memory at " << buf->buffer); + } +} +#endif // GGML_VULKAN_MEMORY_DEBUG + +struct vk_instance_t { + vk::Instance instance; + + std::vector device_indices; + vk_device devices[GGML_VK_MAX_DEVICES]; +}; + +static bool vk_instance_initialized = false; +static vk_instance_t vk_instance; + +#ifdef GGML_VULKAN_CHECK_RESULTS +static size_t vk_skip_checks; +static size_t vk_output_tensor; + +static void ggml_vk_print_tensor(const ggml_tensor * tensor, const char * name); +static void ggml_vk_check_results_0(ggml_tensor * tensor); +static void ggml_vk_check_results_1(ggml_tensor * tensor); +#endif + +typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); + +static void ggml_backend_vk_free(ggml_backend_t backend); + +// variables to track number of compiles in progress +static uint32_t compile_count = 0; +static std::mutex compile_count_mutex; +static std::condition_variable compile_count_cond; + +static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, const std::string name, size_t spv_size, const void* spv_data, const std::string entrypoint, + uint32_t parameter_count, uint32_t push_constant_size, std::array wg_denoms, std::vector specialization_constants, + uint32_t align, bool disable_robustness, bool require_full_subgroups, uint32_t required_subgroup_size) { + VK_LOG_DEBUG("ggml_vk_create_pipeline(" << device->name << ", " << name << ", " << entrypoint << ", " << parameter_count << ", " << push_constant_size << + ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << align << + ", " << disable_robustness << ", " << require_full_subgroups << ", " << required_subgroup_size << ")"); + GGML_ASSERT(parameter_count > 0); + GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT + + pipeline = std::make_shared(); + pipeline->name = name; + pipeline->parameter_count = parameter_count; + pipeline->push_constant_size = push_constant_size; + pipeline->wg_denoms = wg_denoms; + pipeline->align = align; + + vk::ShaderModuleCreateInfo shader_module_create_info({}, spv_size, reinterpret_cast(spv_data)); + pipeline->shader_module = device->device.createShaderModule(shader_module_create_info); + + std::vector dsl_binding; + std::vector dsl_binding_flags; + for (uint32_t i = 0; i < parameter_count; i++) { + dsl_binding.push_back({i, vk::DescriptorType::eStorageBuffer, 1, vk::ShaderStageFlagBits::eCompute}); + dsl_binding_flags.push_back({}); + } + + vk::DescriptorSetLayoutBindingFlagsCreateInfo dslbfci = { dsl_binding_flags }; + + vk::PushConstantRange pcr( + vk::ShaderStageFlagBits::eCompute, + 0, + pipeline->push_constant_size + ); + + vk::DescriptorSetLayoutCreateInfo descriptor_set_layout_create_info( + {}, + dsl_binding); + descriptor_set_layout_create_info.setPNext(&dslbfci); + pipeline->dsl = device->device.createDescriptorSetLayout(descriptor_set_layout_create_info); + + vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline->parameter_count * VK_DEVICE_DESCRIPTOR_POOL_SIZE); + vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, VK_DEVICE_DESCRIPTOR_POOL_SIZE, descriptor_pool_size); + pipeline->descriptor_pools.push_back(device->device.createDescriptorPool(descriptor_pool_create_info)); + + pipeline->descriptor_set_idx = 0; + + vk::PipelineLayoutCreateInfo pipeline_layout_create_info(vk::PipelineLayoutCreateFlags(), pipeline->dsl, pcr); + pipeline->layout = device->device.createPipelineLayout(pipeline_layout_create_info); + + std::vector specialization_entries(specialization_constants.size()); + + for (size_t i = 0; i < specialization_constants.size(); i++) { + specialization_entries[i].constantID = i; + specialization_entries[i].offset = i * sizeof(uint32_t); + specialization_entries[i].size = sizeof(uint32_t); + } + + vk::SpecializationInfo specialization_info( + specialization_entries.size(), + specialization_entries.data(), + specialization_constants.size() * sizeof(uint32_t), + specialization_constants.data() + ); + + vk::PipelineShaderStageCreateFlags pipeline_shader_stage_create_flags{}; + + if (device->subgroup_require_full_support && require_full_subgroups) { + pipeline_shader_stage_create_flags |= vk::PipelineShaderStageCreateFlagBits::eRequireFullSubgroupsEXT; + } + + vk::PipelineShaderStageCreateInfo pipeline_shader_create_info( + pipeline_shader_stage_create_flags, + vk::ShaderStageFlagBits::eCompute, + pipeline->shader_module, + entrypoint.c_str(), + &specialization_info); + + vk::PipelineShaderStageRequiredSubgroupSizeCreateInfoEXT pipeline_shader_stage_required_subgroup_size_create_info; + pipeline_shader_stage_required_subgroup_size_create_info.requiredSubgroupSize = required_subgroup_size; + if (device->subgroup_size_control && required_subgroup_size > 0) { + GGML_ASSERT(device->subgroup_min_size <= required_subgroup_size && required_subgroup_size <= device->subgroup_max_size); + pipeline_shader_create_info.setPNext(&pipeline_shader_stage_required_subgroup_size_create_info); + } + + vk::ComputePipelineCreateInfo compute_pipeline_create_info( + vk::PipelineCreateFlags{}, + pipeline_shader_create_info, + pipeline->layout); + + vk::PipelineRobustnessCreateInfoEXT rci; + + if (device->pipeline_robustness && disable_robustness) { + rci.storageBuffers = vk::PipelineRobustnessBufferBehaviorEXT::eDisabled; + rci.uniformBuffers = vk::PipelineRobustnessBufferBehaviorEXT::eDisabled; + compute_pipeline_create_info.setPNext(&rci); + } + + pipeline->pipeline = device->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value; + + { + std::lock_guard guard(device->mutex); + device->pipelines.insert({ pipeline->name, pipeline }); + } + + { + std::lock_guard guard(compile_count_mutex); + assert(compile_count > 0); + compile_count--; + + // "Progress bar" for shader compiles + static uint32_t total_compile_count = 0; + if ((total_compile_count++ % 10) == 0) { + std::cerr << "."; + } + } + compile_count_cond.notify_all(); +} + +static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline) { + VK_LOG_DEBUG("ggml_pipeline_destroy_pipeline(" << pipeline->name << ")"); + for (auto& pool : pipeline->descriptor_pools) { + device.destroyDescriptorPool(pool); + } + pipeline->descriptor_pools.clear(); + pipeline->descriptor_sets.clear(); + pipeline->descriptor_set_idx = 0; + + device.destroyDescriptorSetLayout(pipeline->dsl); + + device.destroyPipelineLayout(pipeline->layout); + + device.destroyShaderModule(pipeline->shader_module); + + device.destroyPipeline(pipeline->pipeline); +} + +static void ggml_pipeline_request_descriptor_sets(vk_device& device, vk_pipeline& pipeline, uint32_t n) { + VK_LOG_DEBUG("ggml_pipeline_request_descriptor_sets(" << pipeline->name << ", " << n << ")"); + device->pipeline_descriptor_set_requirements[pipeline->name] += n; +} + +static void ggml_pipeline_allocate_descriptor_sets(vk_device& device) { + std::lock_guard guard(device->mutex); + + for (auto& pair : device->pipeline_descriptor_set_requirements) { + vk_pipeline pipeline = device->pipelines.at(pair.first).lock(); + const uint64_t n = pair.second; + + VK_LOG_DEBUG("ggml_pipeline_allocate_descriptor_sets(" << pipeline->name << ", " << n << ")"); + + if (pipeline->descriptor_sets.size() >= pipeline->descriptor_set_idx + n) { + // Enough descriptors are available + continue; + } + + uint32_t to_alloc = pipeline->descriptor_set_idx + n - pipeline->descriptor_sets.size(); + uint32_t pool_remaining = VK_DEVICE_DESCRIPTOR_POOL_SIZE - pipeline->descriptor_sets.size() % VK_DEVICE_DESCRIPTOR_POOL_SIZE; + uint32_t pool_idx = pipeline->descriptor_sets.size() / VK_DEVICE_DESCRIPTOR_POOL_SIZE; + + while (to_alloc > 0) { + const uint32_t alloc_count = std::min(pool_remaining, to_alloc); + to_alloc -= alloc_count; + pool_remaining = VK_DEVICE_DESCRIPTOR_POOL_SIZE; + + if (pool_idx >= pipeline->descriptor_pools.size()) { + vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline->parameter_count * VK_DEVICE_DESCRIPTOR_POOL_SIZE); + vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, VK_DEVICE_DESCRIPTOR_POOL_SIZE, descriptor_pool_size); + pipeline->descriptor_pools.push_back(device->device.createDescriptorPool(descriptor_pool_create_info)); + } + + std::vector layouts(alloc_count); + for (uint32_t i = 0; i < alloc_count; i++) { + layouts[i] = pipeline->dsl; + } + vk::DescriptorSetAllocateInfo descriptor_set_alloc_info(pipeline->descriptor_pools[pool_idx], alloc_count, layouts.data()); + std::vector sets = device->device.allocateDescriptorSets(descriptor_set_alloc_info); + pipeline->descriptor_sets.insert(pipeline->descriptor_sets.end(), sets.begin(), sets.end()); + + pool_idx++; + } + } +} + +static void ggml_pipeline_cleanup(vk_pipeline& pipeline) { + VK_LOG_DEBUG("ggml_pipeline_cleanup(" << pipeline->name << ")"); + pipeline->descriptor_set_idx = 0; +} + +static vk::CommandBuffer ggml_vk_create_cmd_buffer(vk_device& device, vk_queue& q) { + VK_LOG_DEBUG("ggml_vk_create_cmd_buffer()"); + std::lock_guard guard(device->mutex); + + if (q.cmd_buffers.size() > q.cmd_buffer_idx) { + // Reuse command buffer + return q.cmd_buffers[q.cmd_buffer_idx++]; + } + + vk::CommandBufferAllocateInfo command_buffer_alloc_info( + q.pool, + vk::CommandBufferLevel::ePrimary, + 1); + const std::vector cmd_buffers = device->device.allocateCommandBuffers(command_buffer_alloc_info); + auto buf = cmd_buffers.front(); + + q.cmd_buffers.push_back(buf); + q.cmd_buffer_idx++; + + return buf; +} + +static vk_submission ggml_vk_create_submission(vk_device& device, vk_queue& q, std::vector wait_semaphores, std::vector signal_semaphores) { + VK_LOG_DEBUG("ggml_vk_create_submission()"); + vk_submission s; + s.buffer = ggml_vk_create_cmd_buffer(device, q); + s.wait_semaphores = std::move(wait_semaphores); + s.signal_semaphores = std::move(signal_semaphores); + return s; +} + +static void ggml_vk_submit(vk_context& ctx, vk::Fence fence) { + if (ctx->seqs.empty()) { + if (fence) { + ctx->q->queue.submit({}, fence); + } + return; + } + VK_LOG_DEBUG("ggml_vk_submit(" << ctx << ", " << fence << ")"); + + std::vector> tl_wait_vals; + std::vector> tl_signal_vals; + std::vector> tl_wait_semaphores; + std::vector> tl_signal_semaphores; + std::vector tl_submit_infos; + std::vector submit_infos; + int idx = -1; + std::vector> stage_flags; + + size_t reserve = 0; + + for (const auto& sequence : ctx->seqs) { + reserve += sequence.size(); + } + + // Pre-reserve vectors to prevent reallocation, which invalidates pointers + tl_wait_semaphores.reserve(reserve); + tl_wait_vals.reserve(reserve); + tl_signal_semaphores.reserve(reserve); + tl_signal_vals.reserve(reserve); + tl_submit_infos.reserve(reserve); + submit_infos.reserve(reserve); + stage_flags.reserve(reserve); + + for (const auto& sequence : ctx->seqs) { + for (const auto& submission : sequence) { + stage_flags.push_back({}); + idx++; + tl_wait_vals.push_back({}); + tl_wait_semaphores.push_back({}); + tl_signal_vals.push_back({}); + tl_signal_semaphores.push_back({}); + for (size_t i = 0; i < submission.wait_semaphores.size(); i++) { + stage_flags[idx].push_back(ctx->q->stage_flags); + tl_wait_vals[idx].push_back(submission.wait_semaphores[i].value); + tl_wait_semaphores[idx].push_back(submission.wait_semaphores[i].s); + } + for (size_t i = 0; i < submission.signal_semaphores.size(); i++) { + tl_signal_vals[idx].push_back(submission.signal_semaphores[i].value); + tl_signal_semaphores[idx].push_back(submission.signal_semaphores[i].s); + } + tl_submit_infos.push_back({ + (uint32_t) submission.wait_semaphores.size(), + tl_wait_vals[idx].data(), + (uint32_t) submission.signal_semaphores.size(), + tl_signal_vals[idx].data(), + }); + tl_submit_infos[idx].sType = vk::StructureType::eTimelineSemaphoreSubmitInfo; + tl_submit_infos[idx].pNext = nullptr; + vk::SubmitInfo si{ + (uint32_t) submission.wait_semaphores.size(), + tl_wait_semaphores[idx].data(), + stage_flags[idx].data(), + 1, + &submission.buffer, + (uint32_t) submission.signal_semaphores.size(), + tl_signal_semaphores[idx].data(), + }; + si.setPNext(&tl_submit_infos[idx]); + submit_infos.push_back(si); + } + } + + ctx->q->queue.submit(submit_infos, fence); + + ctx->seqs.clear(); +} + +static uint32_t ggml_vk_find_queue_family_index(std::vector& queue_family_props, const vk::QueueFlags& required, const vk::QueueFlags& avoid, int32_t compute_index, uint32_t min_num_queues) { + VK_LOG_DEBUG("ggml_vk_find_queue_family_index()"); + const uint32_t qfsize = queue_family_props.size(); + + // Try with avoid preferences first + for (uint32_t i = 0; i < qfsize; i++) { + if (queue_family_props[i].queueCount >= min_num_queues && (compute_index < 0 || i != (uint32_t) compute_index) && queue_family_props[i].queueFlags & required && !(queue_family_props[i].queueFlags & avoid)) { + return i; + } + } + + // Fall back to only required + for (size_t i = 0; i < qfsize; i++) { + if (queue_family_props[i].queueCount >= min_num_queues && (compute_index < 0 || i != (uint32_t) compute_index) && queue_family_props[i].queueFlags & required) { + return i; + } + } + + // Fall back to reusing compute queue + for (size_t i = 0; i < qfsize; i++) { + if (queue_family_props[i].queueCount >= min_num_queues && queue_family_props[i].queueFlags & required) { + return i; + } + } + + // Fall back to ignoring min_num_queries + for (size_t i = 0; i < qfsize; i++) { + if (queue_family_props[i].queueFlags & required) { + return i; + } + } + + // All commands that are allowed on a queue that supports transfer operations are also allowed on a queue that supports either graphics or compute operations. + // Thus, if the capabilities of a queue family include VK_QUEUE_GRAPHICS_BIT or VK_QUEUE_COMPUTE_BIT, then reporting the VK_QUEUE_TRANSFER_BIT capability separately for that queue family is optional. + if (compute_index >= 0) { + return compute_index; + } + + std::cerr << "ggml_vulkan: No suitable queue family index found." << std::endl; + + for(auto &q_family : queue_family_props) { + std::cerr << "Queue number: " + std::to_string(q_family.queueCount) << " flags: " + to_string(q_family.queueFlags) << std::endl; + } + abort(); +} + +static void ggml_vk_create_queue(vk_device& device, vk_queue& q, uint32_t queue_family_index, uint32_t queue_index, vk::PipelineStageFlags&& stage_flags, bool transfer_only) { + VK_LOG_DEBUG("ggml_vk_create_queue()"); + std::lock_guard guard(device->mutex); + + q.queue_family_index = queue_family_index; + q.transfer_only = transfer_only; + + vk::CommandPoolCreateInfo command_pool_create_info_compute(vk::CommandPoolCreateFlags(VK_COMMAND_POOL_CREATE_TRANSIENT_BIT), queue_family_index); + q.pool = device->device.createCommandPool(command_pool_create_info_compute); + + q.cmd_buffer_idx = 0; + + q.queue = device->device.getQueue(queue_family_index, queue_index); + + q.stage_flags = stage_flags; +} + +static vk_context ggml_vk_create_context(ggml_backend_vk_context * ctx, vk_queue& q) { + vk_context result = std::make_shared(); + VK_LOG_DEBUG("ggml_vk_create_context(" << result << ")"); + ctx->gc.contexts.emplace_back(result); + result->q = &q; + return result; +} + +static vk_context ggml_vk_create_temporary_context(vk_queue& q) { + vk_context result = std::make_shared(); + VK_LOG_DEBUG("ggml_vk_create_temporary_context(" << result << ")"); + result->q = &q; + return result; +} + +static vk_semaphore * ggml_vk_create_binary_semaphore(ggml_backend_vk_context * ctx) { + VK_LOG_DEBUG("ggml_vk_create_timeline_semaphore()"); + vk::SemaphoreTypeCreateInfo tci{ vk::SemaphoreType::eBinary, 0 }; + vk::SemaphoreCreateInfo ci{}; + ci.setPNext(&tci); + vk::Semaphore semaphore = ctx->device->device.createSemaphore(ci); + ctx->gc.semaphores.push_back({ semaphore, 0 }); + return &ctx->gc.semaphores[ctx->gc.semaphores.size() - 1]; +} + +static vk_semaphore * ggml_vk_create_timeline_semaphore(ggml_backend_vk_context * ctx) { + VK_LOG_DEBUG("ggml_vk_create_timeline_semaphore()"); + if (ctx->semaphore_idx >= ctx->gc.tl_semaphores.size()) { + vk::SemaphoreTypeCreateInfo tci{ vk::SemaphoreType::eTimeline, 0 }; + vk::SemaphoreCreateInfo ci{}; + ci.setPNext(&tci); + vk::Semaphore semaphore = ctx->device->device.createSemaphore(ci); + ctx->gc.tl_semaphores.push_back({ semaphore, 0 }); + } + return &ctx->gc.tl_semaphores[ctx->semaphore_idx++]; +} + +static vk::Event ggml_vk_create_event(ggml_backend_vk_context * ctx) { + if (ctx->event_idx >= ctx->gc.events.size()) { + ctx->gc.events.push_back(ctx->device->device.createEvent({})); + } + return ctx->gc.events[ctx->event_idx++]; +} + +static void ggml_vk_queue_cleanup(vk_device& device, vk_queue& q) { + VK_LOG_DEBUG("ggml_vk_queue_cleanup()"); + std::lock_guard guard(device->mutex); + + // Requires command buffers to be done + device->device.resetCommandPool(q.pool); + q.cmd_buffer_idx = 0; +} + +static uint32_t find_properties(const vk::PhysicalDeviceMemoryProperties* mem_props, vk::MemoryRequirements* mem_req, vk::MemoryPropertyFlags flags) { + for (uint32_t i = 0; i < mem_props->memoryTypeCount; ++i) { + vk::MemoryType memory_type = mem_props->memoryTypes[i]; + if ((mem_req->memoryTypeBits & ((uint64_t)1 << i)) && + (flags & memory_type.propertyFlags) == flags && + mem_props->memoryHeaps[memory_type.heapIndex].size >= mem_req->size) { + return static_cast(i); + } + } + return UINT32_MAX; +} + +static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, vk::MemoryPropertyFlags req_flags, vk::MemoryPropertyFlags fallback_flags = vk::MemoryPropertyFlags(0)) { + VK_LOG_DEBUG("ggml_vk_create_buffer(" << device->name << ", " << size << ", " << to_string(req_flags) << ", " << to_string(fallback_flags) << ")"); + if (size > device->max_memory_allocation_size) { + throw vk::OutOfDeviceMemoryError("Requested buffer size exceeds device memory allocation limit"); + } + + std::lock_guard guard(device->mutex); + + vk_buffer buf = std::make_shared(); + + if (size == 0) { + buf->size = 0; + return buf; + } + + vk::BufferCreateInfo buffer_create_info{ + vk::BufferCreateFlags(), + size, + vk::BufferUsageFlagBits::eStorageBuffer | vk::BufferUsageFlagBits::eTransferSrc | vk::BufferUsageFlagBits::eTransferDst, + vk::SharingMode::eExclusive, + 0, + nullptr, + }; + + buf->buffer = device->device.createBuffer(buffer_create_info); + + vk::MemoryRequirements mem_req = device->device.getBufferMemoryRequirements(buf->buffer); + + vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties(); + + uint32_t memory_type_index = UINT32_MAX; + + memory_type_index = find_properties(&mem_props, &mem_req, req_flags); + buf->memory_property_flags = req_flags; + + if (memory_type_index == UINT32_MAX && fallback_flags) { + memory_type_index = find_properties(&mem_props, &mem_req, fallback_flags); + buf->memory_property_flags = fallback_flags; + } + + if (memory_type_index == UINT32_MAX) { + device->device.destroyBuffer(buf->buffer); + throw vk::OutOfDeviceMemoryError("No suitable memory type found"); + } + + try { + buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index }); + } catch (const vk::SystemError& e) { + if (buf->memory_property_flags != fallback_flags) { + // Try again with fallback flags + memory_type_index = find_properties(&mem_props, &mem_req, fallback_flags); + buf->memory_property_flags = fallback_flags; + + try { + buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index }); + } + catch (const vk::SystemError& e) { + device->device.destroyBuffer(buf->buffer); + throw e; + } + } else { + // Out of Host/Device memory, clean up buffer + device->device.destroyBuffer(buf->buffer); + throw e; + } + } + buf->ptr = nullptr; + + if (buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) { + buf->ptr = device->device.mapMemory(buf->device_memory, 0, VK_WHOLE_SIZE); + } + + device->device.bindBufferMemory(buf->buffer, buf->device_memory, 0); + + buf->device = device; + buf->size = size; + +#ifdef GGML_VULKAN_MEMORY_DEBUG + device->memory_logger->log_allocation(buf, size); +#endif + + return buf; +} + +static vk_buffer ggml_vk_create_buffer_check(vk_device& device, size_t size, vk::MemoryPropertyFlags req_flags, vk::MemoryPropertyFlags fallback_flags = vk::MemoryPropertyFlags(0)) { + try { + return ggml_vk_create_buffer(device, size, req_flags, fallback_flags); + } catch (const vk::SystemError& e) { + std::cerr << "ggml_vulkan: Memory allocation of size " << size << " failed." << std::endl; + std::cerr << "ggml_vulkan: " << e.what() << std::endl; + throw e; + } +} + +static vk_buffer ggml_vk_create_buffer_device(vk_device& device, size_t size) { + vk_buffer buf; + try { + if (device->uma) { + // Fall back to host memory type + buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent); + } else { + // use rebar if available, otherwise fallback to device only visible memory + buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal | vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent, vk::MemoryPropertyFlagBits::eDeviceLocal); + } + } catch (const vk::SystemError& e) { + std::cerr << "ggml_vulkan: Device memory allocation of size " << size << " failed." << std::endl; + std::cerr << "ggml_vulkan: " << e.what() << std::endl; + throw e; + } + + return buf; +} + +static void ggml_vk_destroy_buffer(vk_buffer& buf) { + if (buf == nullptr) { + return; + } + +#ifdef GGML_VULKAN_MEMORY_DEBUG + if (buf->device != nullptr) { + buf->device->memory_logger->log_deallocation(buf); + } +#endif + + buf.reset(); +} + +static vk_subbuffer ggml_vk_subbuffer(vk_buffer& buf) { + return { buf, 0, VK_WHOLE_SIZE }; +} + +static void ggml_vk_sync_buffers(vk_context& ctx) { + VK_LOG_DEBUG("ggml_vk_sync_buffers()"); + + const bool transfer_queue = ctx->q->transfer_only; + + ctx->s->buffer.pipelineBarrier( + ctx->q->stage_flags, + ctx->q->stage_flags, + {}, + { { + { !transfer_queue ? (vk::AccessFlagBits::eShaderRead | vk::AccessFlagBits::eShaderWrite | vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) : (vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) }, + { !transfer_queue ? (vk::AccessFlagBits::eShaderRead | vk::AccessFlagBits::eShaderWrite | vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) : (vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) } + } }, + {}, + {} + ); +} + +static void ggml_vk_wait_events(vk_context& ctx, std::vector&& events) { + VK_LOG_DEBUG("ggml_vk_wait_events()"); + if (events.empty()) { + return; + } + + ctx->s->buffer.waitEvents( + events, + ctx->q->stage_flags, + ctx->q->stage_flags, + {}, + {}, + {} + ); +} + +// number of rows/cols for flash attention shader +static constexpr uint32_t flash_attention_num_small_rows = 32; +static std::array fa_rows_cols(uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) { + GGML_UNUSED(clamp); + + // small rows, large cols + if (small_rows) { + return {flash_attention_num_small_rows, 128}; + } + // small cols to reduce register count + if (ggml_is_quantized(type) || D == 256) { + return {64, 32}; + } + return {64, 64}; +}; + +static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector& warptile, bool mul_mat_id) { + // Needs to be kept up to date on shader changes + const uint32_t bank_conflict_offset = device->coopmat_support ? 8 : 1; + const uint32_t type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float); + const uint32_t warps = warptile[0] / warptile[10]; + + const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size; + const uint32_t mmid_row_ids = mul_mat_id ? 3072 * sizeof(uint32_t) : 0; + const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0; + + return (load_bufs + mmid_row_ids + coopmat_stage) <= device->properties.limits.maxComputeSharedMemorySize; +} + +static void ggml_vk_load_shaders(vk_device& device) { + VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")"); + + std::cerr << "ggml_vulkan: Compiling shaders"; + + // some shaders have a minimum subgroup size + const uint32_t subgroup_size_16 = std::max(device->subgroup_size, 16u); + const uint32_t subgroup_size_32 = std::max(device->subgroup_size, 32u); + + // mulmat + std::vector l_warptile, m_warptile, s_warptile, + l_warptile_mmq, m_warptile_mmq, s_warptile_mmq, + l_warptile_mmq_k, m_warptile_mmq_k, s_warptile_mmq_k, + l_warptile_mmqid, m_warptile_mmqid, s_warptile_mmqid; + std::array l_wg_denoms, m_wg_denoms, s_wg_denoms, + l_mmq_wg_denoms, m_mmq_wg_denoms, s_mmq_wg_denoms, + l_mmq_wg_denoms_k, m_mmq_wg_denoms_k, s_mmq_wg_denoms_k, + l_mmqid_wg_denoms, m_mmqid_wg_denoms, s_mmqid_wg_denoms; + + uint32_t l_align, m_align, s_align; + if (device->coopmat2) { + // spec constants and tile sizes for non-quant matmul/matmul_id + l_warptile = { 256, 128, 256, 64 }; + m_warptile = { 256, 128, 128, 64 }; + s_warptile = { 128, 64, 64, 64 }; + l_wg_denoms = {128, 256, 1 }; + m_wg_denoms = {128, 128, 1 }; + s_wg_denoms = { 64, 64, 1 }; + + // spec constants and tile sizes for quant matmul (non-Qi_K) + l_warptile_mmq = { 256, 128, 256, 64 }; + m_warptile_mmq = { 256, 128, 128, 64 }; + s_warptile_mmq = { 256, 128, 128, 64 }; + l_mmq_wg_denoms = { 128, 256, 1 }; + m_mmq_wg_denoms = { 128, 128, 1 }; + s_mmq_wg_denoms = { 128, 128, 1 }; + + // spec constants and tile sizes for quant matmul (Qi_K) + l_warptile_mmq_k = { 256, 128, 512, 16 }; + m_warptile_mmq_k = { 256, 128, 256, 16 }; + s_warptile_mmq_k = { 256, 32, 128, 64 }; + l_mmq_wg_denoms_k = { 128, 512, 1 }; + m_mmq_wg_denoms_k = { 128, 256, 1 }; + s_mmq_wg_denoms_k = { 32, 128, 1 }; + + // spec constants and tile sizes for quant matmul_id + l_warptile_mmqid = { 256, 128, 128, 16 }; + m_warptile_mmqid = { 256, 128, 64, 16 }; + s_warptile_mmqid = { 256, 64, 64, 16 }; + l_mmqid_wg_denoms = { 128, 128, 1 }; + m_mmqid_wg_denoms = { 128, 64, 1 }; + s_mmqid_wg_denoms = { 64, 64, 1 }; + + l_align = 128; + m_align = 64; + s_align = 32; + } else { + // Matrix cores require different warp group sizes + const uint32_t tm_l = device->coopmat_support ? device->coopmat_m : 4; + const uint32_t tm_m = device->coopmat_support ? device->coopmat_m : 4; + const uint32_t tm_s = device->coopmat_support ? device->coopmat_m : 2; + const uint32_t tn_l = device->coopmat_support ? device->coopmat_n : 4; + const uint32_t tn_m = device->coopmat_support ? device->coopmat_n : 2; + const uint32_t tn_s = device->coopmat_support ? device->coopmat_n : 2; + const uint32_t tk_l = device->coopmat_support ? device->coopmat_k : 1; + const uint32_t tk_m = device->coopmat_support ? device->coopmat_k : 1; + const uint32_t tk_s = device->coopmat_support ? device->coopmat_k : 1; + + l_warptile = { 128, 128, 128, 16, device->subgroup_size * 2, 64, 2, tm_l, tn_l, tk_l, device->subgroup_size }; + m_warptile = { 128, 64, 64, 16, device->subgroup_size, 32, 2, tm_m, tn_m, tk_m, device->subgroup_size }; + s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, device->subgroup_size }; + + l_warptile_mmq = { 128, 128, 128, 32, device->subgroup_size * 2, 64, 2, tm_l, tn_l, tk_l, device->subgroup_size }; + m_warptile_mmq = { 128, 64, 64, 32, device->subgroup_size, 32, 2, tm_m, tn_m, tk_m, device->subgroup_size }; + s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, device->subgroup_size }; + + l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 }; + m_mmq_wg_denoms = m_wg_denoms = { 64, 64, 1 }; + s_mmq_wg_denoms = s_wg_denoms = { 32, 32, 1 }; + l_align = 128; + m_align = 64; + s_align = 32; + + // Fallback to smaller sizes if there's not enough shared memory. Given the current shaders + // and tile sizes, this should handle 16KB, 32KB, and 48KB+. + // This logic doesn't explicitly account for the 12KB row_ids in the mul_mat_mat_id shaders. + // But the numbers happen to work out for 32KB shared memory size that when using the medium + // size there's enough room for everything, and we assert for this. + uint32_t shmem_needed = (l_warptile[1] + l_warptile[2]) * (l_warptile[3] + 1) * sizeof(float); + if (shmem_needed > device->properties.limits.maxComputeSharedMemorySize) { + l_warptile = m_warptile; + l_wg_denoms = m_wg_denoms; + shmem_needed = (l_warptile[1] + l_warptile[2]) * (l_warptile[3] + 1) * sizeof(float); + GGML_ASSERT(shmem_needed <= device->properties.limits.maxComputeSharedMemorySize); + } + if (device->properties.limits.maxComputeSharedMemorySize >= 32768) { + // assert mul_mat_mat_id shaders will fit. + GGML_ASSERT(shmem_needed + 3072*4 <= device->properties.limits.maxComputeSharedMemorySize); + } + + shmem_needed = (l_warptile_mmq[1] + l_warptile_mmq[2]) * (l_warptile_mmq[3] + 1) * sizeof(float); + if (shmem_needed > device->properties.limits.maxComputeSharedMemorySize) { + if (device->properties.limits.maxComputeSharedMemorySize == 32768) { + l_warptile_mmq = m_warptile_mmq; + l_mmq_wg_denoms = m_mmq_wg_denoms; + } else { + l_warptile_mmq = s_warptile_mmq; + l_mmq_wg_denoms = s_mmq_wg_denoms; + } + shmem_needed = (l_warptile_mmq[1] + l_warptile_mmq[2]) * (l_warptile_mmq[3] + 1) * sizeof(float); + GGML_ASSERT(shmem_needed <= device->properties.limits.maxComputeSharedMemorySize); + } + if (device->properties.limits.maxComputeSharedMemorySize >= 32768) { + // assert mul_mat_mat_id shaders will fit. + GGML_ASSERT(shmem_needed + 3072*4 <= device->properties.limits.maxComputeSharedMemorySize); + } + // Disable medium and large matrix multiplication if not enough shared memory is available + // Check mmq warptiles as the largest configuration + // Throw an error if not enough for any matrix multiplication is available + if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, false)) { + std::cerr << "ggml_vulkan: Error: Shared memory size too small for matrix multiplication." << std::endl; + throw std::runtime_error("Shared memory size too small for matrix multiplication."); + } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, false)) { + device->mul_mat_m = false; + device->mul_mat_l = false; + } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, false)) { + device->mul_mat_l = false; + } + + // Disable mul_mat_id if not enough shared memory is available + if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, true)) { + device->mul_mat_id_s = false; + device->mul_mat_id_m = false; + device->mul_mat_id_l = false; + } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, true)) { + device->mul_mat_id_m = false; + device->mul_mat_id_l = false; + } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, true)) { + device->mul_mat_id_l = false; + } + } + + device->pipeline_matmul_f32 = std::make_shared(); + device->pipeline_matmul_f32_f16 = std::make_shared(); + + device->pipeline_matmul_id_f32 = std::make_shared(); + + std::vector> compiles; + auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint, + uint32_t parameter_count, uint32_t push_constant_size, std::array wg_denoms, const std::vector& specialization_constants, + uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) { + { + // wait until fewer than N compiles are in progress + uint32_t N = std::max(1u, std::thread::hardware_concurrency()); + std::unique_lock guard(compile_count_mutex); + while (compile_count >= N) { + compile_count_cond.wait(guard); + } + compile_count++; + } + compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), name, spv_size, spv_data, entrypoint, + parameter_count, push_constant_size, wg_denoms, specialization_constants, align, disable_robustness, require_full_subgroups, required_subgroup_size)); + }; + +#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + if (device->coopmat2) { + + auto const &fa_wg_denoms = [&](uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array { + return {fa_rows_cols(D, clamp, type, small_rows)[0], 1, 1}; + }; + + auto const &fa_spec_constants = [&](uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector { + // For large number of rows, 128 invocations seems to work best. + // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we + // can't use 256 for D==80. + uint32_t wg_size = (small_rows && (D % 32) == 0) ? 256 : 128; + auto rows_cols = fa_rows_cols(D, clamp, type, small_rows); + return {wg_size, rows_cols[0], rows_cols[1], (D), clamp}; + }; + +#define CREATE_FA2(TYPE, NAMELC, D) \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,false), fa_spec_constants(D,1,TYPE,false), 1); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,false), fa_spec_constants(D,0,TYPE,false), fa_rows_cols(D,0,TYPE,false)[1]); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,false), fa_spec_constants(D,1,TYPE,false), 1); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,false), fa_spec_constants(D,0,TYPE,false), fa_rows_cols(D,0,TYPE,false)[1]); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,true), fa_spec_constants(D,1,TYPE,true), 1); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,true), fa_spec_constants(D,0,TYPE,true), fa_rows_cols(D,0,TYPE,true)[1]); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,true), fa_spec_constants(D,1,TYPE,true), 1); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,true), fa_spec_constants(D,0,TYPE,true), fa_rows_cols(D,0,TYPE,true)[1]); \ + +#define CREATE_FA(TYPE, NAMELC) \ + CREATE_FA2(TYPE, NAMELC, 64) \ + CREATE_FA2(TYPE, NAMELC, 80) \ + CREATE_FA2(TYPE, NAMELC, 96) \ + CREATE_FA2(TYPE, NAMELC, 112) \ + CREATE_FA2(TYPE, NAMELC, 128) \ + CREATE_FA2(TYPE, NAMELC, 256) + + CREATE_FA(GGML_TYPE_F16, f16) + CREATE_FA(GGML_TYPE_Q4_0, q4_0) + CREATE_FA(GGML_TYPE_Q4_1, q4_1) + CREATE_FA(GGML_TYPE_Q5_0, q5_0) + CREATE_FA(GGML_TYPE_Q5_1, q5_1) + CREATE_FA(GGML_TYPE_Q8_0, q8_0) + // K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently + //CREATE_FA(GGML_TYPE_Q2_K, q2_k) + //CREATE_FA(GGML_TYPE_Q3_K, q3_k) + //CREATE_FA(GGML_TYPE_Q4_K, q4_k) + //CREATE_FA(GGML_TYPE_Q5_K, q5_k) + //CREATE_FA(GGML_TYPE_Q6_K, q6_k) + CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl) +#undef CREATE_FA + + // Create 6 variants, {s,m,l}x{unaligned,aligned} +#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \ + + // Create 2 variants, {f16,f32} accumulator +#define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ + CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ + CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ + + CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3) + + CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + + CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4) + + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) +#undef CREATE_MM +#undef CREATE_MM2 + } else +#endif // defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + if (device->coopmat_support) { + // Create 6 variants, {s,m,l}x{unaligned,aligned} +#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + if (device->mul_mat ## ID ## _l) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \ + if (device->mul_mat ## ID ## _m) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \ + if (device->mul_mat ## ID ## _s) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \ + if (device->mul_mat ## ID ## _l) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \ + if (device->mul_mat ## ID ## _m) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \ + if (device->mul_mat ## ID ## _s) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \ + + // Create 2 variants, {f16,f32} accumulator +#define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + if (device->coopmat_acc_f16_support) { \ + CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + } \ + if (device->coopmat_acc_f32_support) { \ + CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + } \ + + CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + + if (device->coopmat_acc_f16_support) { + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + } else { + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + } + + // If there's not enough shared memory for row_ids and the result tile, don't create these pipelines. + if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) { + CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + + if (device->coopmat_acc_f16_support) { + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + } else { + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + } + } +#undef CREATE_MM2 +#undef CREATE_MM + } else if (device->fp16) { + // Create 6 variants, {s,m,l}x{unaligned,aligned} +#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + if (device->mul_mat ## ID ## _l) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ + if (device->mul_mat ## ID ## _m) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ + if (device->mul_mat ## ID ## _s) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ + if (device->mul_mat ## ID ## _l) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \ + if (device->mul_mat ## ID ## _m) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \ + if (device->mul_mat ## ID ## _s) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \ + + // Create 2 variants, {f16,f32} accumulator +#define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + + CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + + // If there's not enough shared memory for row_ids and the result tile, don't create these pipelines. + if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) { + CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + } +#undef CREATE_MM2 +#undef CREATE_MM + } else { + // Create 6 variants, {s,m,l}x{unaligned,aligned} +#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + if (device->mul_mat ## ID ## _l) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ + if (device->mul_mat ## ID ## _m) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ + if (device->mul_mat ## ID ## _s) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ + if (device->mul_mat ## ID ## _l) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \ + if (device->mul_mat ## ID ## _m) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \ + if (device->mul_mat ## ID ## _s) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \ + + CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + + // If there's not enough shared memory for row_ids and the result tile, don't create these pipelines. + if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) { + CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + CREATE_MM(pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + CREATE_MM(pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + } +#undef CREATE_MM + } + + // mul mat vec + + // the number of rows computed per shader depends on GPU model and quant + uint32_t rm_stdq = 1; + uint32_t rm_kq = 2; + if (device->vendor_id == VK_VENDOR_ID_AMD) { + if (device->subgroup_min_size == 64 && device->subgroup_max_size == 64) { // GCN + rm_stdq = 2; + rm_kq = 4; + } + } else if (device->vendor_id == VK_VENDOR_ID_INTEL) + rm_stdq = 2; + + for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) { + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32_"+std::to_string(i+1), mul_mat_vec_f32_f32_f32_len, mul_mat_vec_f32_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f32_f32_"+std::to_string(i+1), mul_mat_vec_f16_f32_f32_len, mul_mat_vec_f16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_1_f32_f32_len, mul_mat_vec_q5_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q8_0_f32_f32_len, mul_mat_vec_q8_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q2_k_f32_f32_len, mul_mat_vec_q2_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q3_k_f32_f32_len, mul_mat_vec_q3_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq, i+1}, 1, true); + + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32_"+std::to_string(i+1), mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32_"+std::to_string(i+1), mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_1_f16_f32_len, mul_mat_vec_q5_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q8_0_f16_f32_len, mul_mat_vec_q8_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q2_k_f16_f32_len, mul_mat_vec_q2_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q3_k_f16_f32_len, mul_mat_vec_q3_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq, i+1}, 1, true); + } + + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32", mul_mat_vec_id_q5_1_f32_len, mul_mat_vec_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32", mul_mat_vec_id_q8_0_f32_len, mul_mat_vec_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", mul_mat_vec_id_q2_k_f32_len, mul_mat_vec_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", mul_mat_vec_id_q3_k_f32_len, mul_mat_vec_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq}, 1, true); + + // dequant shaders + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_0], "dequant_q4_0", dequant_q4_0_len, dequant_q4_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_1], "dequant_q4_1", dequant_q4_1_len, dequant_q4_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_0], "dequant_q5_0", dequant_q5_0_len, dequant_q5_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_1], "dequant_q5_1", dequant_q5_1_len, dequant_q5_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q8_0], "dequant_q8_0", dequant_q8_0_len, dequant_q8_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q2_K], "dequant_q2_k", dequant_q2_k_len, dequant_q2_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q3_K], "dequant_q3_k", dequant_q3_k_len, dequant_q3_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_K], "dequant_q4_k", dequant_q4_k_len, dequant_q4_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_K], "dequant_q5_k", dequant_q5_k_len, dequant_q5_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q6_K], "dequant_q6_k", dequant_q6_k_len, dequant_q6_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_NL], "dequant_iq4_nl", dequant_iq4_nl_len, dequant_iq4_nl_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); + + // get_rows + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F16 ], "get_rows_f16", get_rows_f16_len, get_rows_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_0], "get_rows_q4_0", get_rows_q4_0_len, get_rows_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_1], "get_rows_q4_1", get_rows_q4_1_len, get_rows_q4_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_0], "get_rows_q5_0", get_rows_q5_0_len, get_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_1], "get_rows_q5_1", get_rows_q5_1_len, get_rows_q5_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q8_0], "get_rows_q8_0", get_rows_q8_0_len, get_rows_q8_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl", get_rows_iq4_nl_len, get_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_0], "get_rows_q4_0_f32", get_rows_q4_0_f32_len, get_rows_q4_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_1], "get_rows_q4_1_f32", get_rows_q4_1_f32_len, get_rows_q4_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_0], "get_rows_q5_0_f32", get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_1], "get_rows_q5_1_f32", get_rows_q5_1_f32_len, get_rows_q5_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q8_0], "get_rows_q8_0_f32", get_rows_q8_0_f32_len, get_rows_q8_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32, "mul_mat_vec_p021_f16_f32", mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 7 * sizeof(uint32_t), {1, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f32, "contig_cpy_f32_f32", contig_cpy_f32_f32_len, contig_cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_add_f32, "add_f32", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1); + ggml_vk_create_pipeline(device, device->pipeline_add_f32_norepeat, "add_f32_norepeat", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1); + ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16, "add_f16_f32_f16", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1); + ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16_norepeat, "add_f16_f32_f16_norepeat", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1); + ggml_vk_create_pipeline(device, device->pipeline_mul_f32_norepeat, "mul_f32_norepeat", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1); + ggml_vk_create_pipeline(device, device->pipeline_div_f32, "div_f32", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1); + ggml_vk_create_pipeline(device, device->pipeline_div_f32_norepeat, "div_f32_norepeat", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_concat_i32, "concat_i32", concat_i32_len, concat_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_upscale_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_sqr_f32, "sqr_f32", sqr_f32_len, sqr_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_sin_f32, "sin_f32", sin_f32_len, sin_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cos_f32, "cos_f32", cos_f32_len, cos_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32", repeat_f32_len, repeat_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_gelu_f32, "gelu_f32", gelu_f32_len, gelu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_gelu_quick_f32, "gelu_quick_f32", gelu_quick_f32_len, gelu_quick_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_silu_f32, "silu_f32", silu_f32_len, silu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_relu_f32, "relu_f32", relu_f32_len, relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_tanh_f32, "tanh_f32", tanh_f32_len, tanh_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1); + + ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + + if (device->float_controls_rte_fp16) { + ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_rte_len, rope_norm_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_rte_len, rope_neox_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + } else { + ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + } + + ggml_vk_create_pipeline(device, device->pipeline_argsort_f32, "argsort_f32", argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1024, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + + ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); + if (device->float_controls_rte_fp16) { + ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); + } else { + ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); + } + + ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_pool2d_f32, "pool2d_f32", pool2d_f32_len, pool2d_f32_data, "main", 2, sizeof(vk_op_pool2d_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); + + for (auto &c : compiles) { + c.wait(); + } + std::cerr << "Done!" << std::endl; +} + +static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props); + +static vk_device ggml_vk_get_device(size_t idx) { + VK_LOG_DEBUG("ggml_vk_get_device(" << idx << ")"); + + if (vk_instance.devices[idx] == nullptr) { + VK_LOG_DEBUG("Initializing new vk_device"); + vk_device device = std::make_shared(); + vk_instance.devices[idx] = device; + +#ifdef GGML_VULKAN_MEMORY_DEBUG + device->memory_logger = std::unique_ptr(new vk_memory_logger()); +#endif +#ifdef GGML_VULKAN_PERF + device->perf_logger = std::unique_ptr(new vk_perf_logger()); +#endif + + size_t dev_num = vk_instance.device_indices[idx]; + + std::vector physical_devices = vk_instance.instance.enumeratePhysicalDevices(); + + if (dev_num >= physical_devices.size()) { + std::cerr << "ggml_vulkan: Device with index " << dev_num << " does not exist." << std::endl; + throw std::runtime_error("Device not found"); + } + + device->physical_device = physical_devices[dev_num]; + const std::vector ext_props = device->physical_device.enumerateDeviceExtensionProperties(); + + bool fp16_storage = false; + bool fp16_compute = false; + bool maintenance4_support = false; + bool sm_builtins = false; + bool amd_shader_core_properties2 = false; + bool pipeline_robustness = false; + bool coopmat2_support = false; + device->coopmat_support = false; + + // Check if maintenance4 is supported + for (const auto& properties : ext_props) { + if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) { + maintenance4_support = true; + } else if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) { + fp16_storage = true; + } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) { + fp16_compute = true; + } else if (strcmp("VK_NV_shader_sm_builtins", properties.extensionName) == 0) { + sm_builtins = true; + } else if (strcmp("VK_AMD_shader_core_properties2", properties.extensionName) == 0) { + amd_shader_core_properties2 = true; + } else if (strcmp("VK_EXT_pipeline_robustness", properties.extensionName) == 0) { + pipeline_robustness = true; + } else if (strcmp("VK_EXT_subgroup_size_control", properties.extensionName) == 0) { + device->subgroup_size_control = true; + } else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0 && + !getenv("GGML_VK_DISABLE_COOPMAT")) { + device->coopmat_support = true; + device->coopmat_m = 0; + device->coopmat_n = 0; + device->coopmat_k = 0; + } else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 && + !getenv("GGML_VK_DISABLE_COOPMAT2")) { + coopmat2_support = true; + } + } + + vk::PhysicalDeviceProperties2 props2; + vk::PhysicalDeviceMaintenance3Properties props3; + vk::PhysicalDeviceMaintenance4Properties props4; + vk::PhysicalDeviceSubgroupProperties subgroup_props; + vk::PhysicalDeviceDriverProperties driver_props; + vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props; + vk::PhysicalDeviceShaderCoreProperties2AMD amd_shader_core_properties2_props; + vk::PhysicalDeviceVulkan12Properties vk12_props; + vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props; + + props2.pNext = &props3; + props3.pNext = &subgroup_props; + subgroup_props.pNext = &driver_props; + driver_props.pNext = &vk12_props; + + VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_props; + + if (maintenance4_support) { + last_struct->pNext = (VkBaseOutStructure *)&props4; + last_struct = (VkBaseOutStructure *)&props4; + } + if (sm_builtins) { + last_struct->pNext = (VkBaseOutStructure *)&sm_props; + last_struct = (VkBaseOutStructure *)&sm_props; + } + if (amd_shader_core_properties2) { + last_struct->pNext = (VkBaseOutStructure *)&amd_shader_core_properties2_props; + last_struct = (VkBaseOutStructure *)&amd_shader_core_properties2_props; + } + if (device->subgroup_size_control) { + last_struct->pNext = (VkBaseOutStructure *)&subgroup_size_control_props; + last_struct = (VkBaseOutStructure *)&subgroup_size_control_props; + } + +#if defined(VK_NV_cooperative_matrix2) + vk::PhysicalDeviceCooperativeMatrix2PropertiesNV coopmat2_props; + if (coopmat2_support) { + last_struct->pNext = (VkBaseOutStructure *)&coopmat2_props; + last_struct = (VkBaseOutStructure *)&coopmat2_props; + } +#endif + + device->physical_device.getProperties2(&props2); + device->properties = props2.properties; + + const char* GGML_VK_FORCE_MAX_ALLOCATION_SIZE = getenv("GGML_VK_FORCE_MAX_ALLOCATION_SIZE"); + + if (GGML_VK_FORCE_MAX_ALLOCATION_SIZE != nullptr) { + device->max_memory_allocation_size = std::stoul(GGML_VK_FORCE_MAX_ALLOCATION_SIZE); + } else if (maintenance4_support) { + device->max_memory_allocation_size = std::min(props3.maxMemoryAllocationSize, props4.maxBufferSize); + } else { + device->max_memory_allocation_size = props3.maxMemoryAllocationSize; + } + + device->vendor_id = device->properties.vendorID; + device->subgroup_size = subgroup_props.subgroupSize; + device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu; + if (sm_builtins) { + device->shader_core_count = sm_props.shaderSMCount; + } else if (amd_shader_core_properties2) { + device->shader_core_count = amd_shader_core_properties2_props.activeComputeUnitCount; + } else { + device->shader_core_count = 0; + } + device->float_controls_rte_fp16 = vk12_props.shaderRoundingModeRTEFloat16; + + const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr; + + device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute; + + if (!ggml_vk_khr_cooperative_matrix_support(device->properties, driver_props)) { + device->coopmat_support = false; + } + + std::vector queue_family_props = device->physical_device.getQueueFamilyProperties(); + + // Try to find a non-graphics compute queue and transfer-focused queues + const uint32_t compute_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eCompute, vk::QueueFlagBits::eGraphics, -1, 1); + const uint32_t transfer_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eTransfer, vk::QueueFlagBits::eCompute | vk::QueueFlagBits::eGraphics, compute_queue_family_index, 1); + + const float priorities[] = { 1.0f, 1.0f }; + device->single_queue = compute_queue_family_index == transfer_queue_family_index && queue_family_props[compute_queue_family_index].queueCount == 1; + + std::vector device_queue_create_infos; + if (compute_queue_family_index != transfer_queue_family_index) { + device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities}); + device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), transfer_queue_family_index, 1, priorities + 1}); + } else if(!device->single_queue) { + device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 2, priorities}); + } else { + device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities}); + } + vk::DeviceCreateInfo device_create_info; + std::vector device_extensions; + vk::PhysicalDeviceFeatures device_features = device->physical_device.getFeatures(); + + VkPhysicalDeviceFeatures2 device_features2; + device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2; + device_features2.pNext = nullptr; + device_features2.features = (VkPhysicalDeviceFeatures)device_features; + + VkPhysicalDeviceVulkan11Features vk11_features; + vk11_features.pNext = nullptr; + vk11_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_1_FEATURES; + device_features2.pNext = &vk11_features; + + VkPhysicalDeviceVulkan12Features vk12_features; + vk12_features.pNext = nullptr; + vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES; + vk11_features.pNext = &vk12_features; + + last_struct = (VkBaseOutStructure *)&vk12_features; + + VkPhysicalDevicePipelineRobustnessFeaturesEXT pl_robustness_features; + pl_robustness_features.pNext = nullptr; + pl_robustness_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PIPELINE_ROBUSTNESS_FEATURES_EXT; + pl_robustness_features.pipelineRobustness = VK_FALSE; + + if (pipeline_robustness) { + last_struct->pNext = (VkBaseOutStructure *)&pl_robustness_features; + last_struct = (VkBaseOutStructure *)&pl_robustness_features; + device_extensions.push_back("VK_EXT_pipeline_robustness"); + } + + VkPhysicalDeviceSubgroupSizeControlFeaturesEXT subgroup_size_control_features; + subgroup_size_control_features.pNext = nullptr; + subgroup_size_control_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_SIZE_CONTROL_FEATURES_EXT; + subgroup_size_control_features.computeFullSubgroups = false; + subgroup_size_control_features.subgroupSizeControl = false; + + if (device->subgroup_size_control) { + last_struct->pNext = (VkBaseOutStructure *)&subgroup_size_control_features; + last_struct = (VkBaseOutStructure *)&subgroup_size_control_features; + } + + VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features; + coopmat_features.pNext = nullptr; + coopmat_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR; + coopmat_features.cooperativeMatrix = VK_FALSE; + + if (device->coopmat_support) { + last_struct->pNext = (VkBaseOutStructure *)&coopmat_features; + last_struct = (VkBaseOutStructure *)&coopmat_features; + } + +#if defined(VK_NV_cooperative_matrix2) + VkPhysicalDeviceCooperativeMatrix2FeaturesNV coopmat2_features {}; + coopmat2_features.pNext = nullptr; + coopmat2_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_2_FEATURES_NV; + if (coopmat2_support) { + last_struct->pNext = (VkBaseOutStructure *)&coopmat2_features; + last_struct = (VkBaseOutStructure *)&coopmat2_features; + device_extensions.push_back("VK_NV_cooperative_matrix2"); + } +#endif + + vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2); + + device->fp16 = device->fp16 && vk12_features.shaderFloat16; + + device->pipeline_robustness = pl_robustness_features.pipelineRobustness; + + if (device->subgroup_size_control) { + device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize; + device->subgroup_max_size = subgroup_size_control_props.maxSubgroupSize; + } + + device->subgroup_size_control = device->subgroup_size_control && + (subgroup_size_control_props.requiredSubgroupSizeStages & vk::ShaderStageFlagBits::eCompute) && + subgroup_size_control_features.subgroupSizeControl; + + if (device->subgroup_size_control) { + device->subgroup_require_full_support = subgroup_size_control_features.computeFullSubgroups; + device_extensions.push_back("VK_EXT_subgroup_size_control"); + } + + device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix; + + if (coopmat2_support) { +#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + if (coopmat2_features.cooperativeMatrixWorkgroupScope && + coopmat2_features.cooperativeMatrixFlexibleDimensions && + coopmat2_features.cooperativeMatrixReductions && + coopmat2_features.cooperativeMatrixConversions && + coopmat2_features.cooperativeMatrixPerElementOperations && + coopmat2_features.cooperativeMatrixTensorAddressing && + coopmat2_features.cooperativeMatrixBlockLoads && + vk12_features.bufferDeviceAddress) { + + std::vector flexible_dimensions; + uint32_t count = 0; + + PFN_vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV + _vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV = + (PFN_vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV) + vk_instance.instance.getProcAddr("vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV"); + + _vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV(device->physical_device, &count, nullptr); + + VkCooperativeMatrixFlexibleDimensionsPropertiesNV empty_prop {}; + empty_prop.sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_FLEXIBLE_DIMENSIONS_PROPERTIES_NV; + flexible_dimensions.resize(count, empty_prop); + + _vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV(device->physical_device, &count, flexible_dimensions.data()); + + bool found_fp16_128 = false, + found_fp16_256 = false, + found_fp32_128 = false, + found_fp32_256 = false; + // need to support fp16*fp16 with fp16/fp32 accumulator, for workgroupsize 128 + // with 32x16x16 and 256 with 32x32x16. + for (auto &prop : flexible_dimensions) { + if (prop.saturatingAccumulation == VK_FALSE && + prop.scope == VK_SCOPE_WORKGROUP_KHR && + prop.AType == VK_COMPONENT_TYPE_FLOAT16_KHR && + prop.BType == VK_COMPONENT_TYPE_FLOAT16_KHR) { + + if (prop.workgroupInvocations == 128 && + prop.MGranularity <= 32 && + prop.NGranularity <= 16 && + prop.KGranularity <= 16) { + if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR && + prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) { + found_fp16_128 = true; + } + if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR && + prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) { + found_fp32_128 = true; + } + } + if (prop.workgroupInvocations == 256 && + prop.MGranularity <= 32 && + prop.NGranularity <= 32 && + prop.KGranularity <= 16) { + if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR && + prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) { + found_fp16_256 = true; + } + if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR && + prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) { + found_fp32_256 = true; + } + } + } + } + if (found_fp16_128 && found_fp16_256 && + found_fp32_128 && found_fp32_256 && + coopmat2_props.cooperativeMatrixFlexibleDimensionsMaxDimension >= 512) { + device->coopmat2 = true; + } + } +#endif + } + + if (!vk11_features.storageBuffer16BitAccess) { + std::cerr << "ggml_vulkan: device " << GGML_VK_NAME << idx << " does not support 16-bit storage." << std::endl; + throw std::runtime_error("Unsupported device"); + } + + device_extensions.push_back("VK_KHR_16bit_storage"); + +#ifdef GGML_VULKAN_VALIDATE + device_extensions.push_back("VK_KHR_shader_non_semantic_info"); +#endif + + if (device->fp16) { + device_extensions.push_back("VK_KHR_shader_float16_int8"); + } + + if (device->coopmat_support) { + // Query supported shapes + std::vector cm_props; + + PFN_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR pfn_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR = + (PFN_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR)vkGetInstanceProcAddr(vk_instance.instance, "vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR"); + + uint32_t cm_props_num; + + pfn_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR(device->physical_device, &cm_props_num, nullptr); + + cm_props.resize(cm_props_num); + + for (auto& prop : cm_props) { + prop.sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_PROPERTIES_KHR; + } + + pfn_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR(device->physical_device, &cm_props_num, cm_props.data()); + + VK_LOG_DEBUG("ggml_vulkan: Cooperative Matrix Shapes: " << cm_props.size()); + + for (auto& prop : cm_props) { + VK_LOG_DEBUG("ggml_vulkan: M: " << prop.MSize << " N: " << prop.NSize << " K: " << prop.KSize << " A: " << vk::to_string((vk::ComponentTypeKHR)prop.AType) << " B: " << vk::to_string((vk::ComponentTypeKHR)prop.BType) << " C: " << vk::to_string((vk::ComponentTypeKHR)prop.CType) << " Result: " << vk::to_string((vk::ComponentTypeKHR)prop.ResultType) << " saturatingAccumulation: " << prop.saturatingAccumulation << " scope: " << vk::to_string((vk::ScopeKHR)prop.scope)); + + if ((vk::ComponentTypeKHR)prop.AType == vk::ComponentTypeKHR::eFloat16 && + (vk::ComponentTypeKHR)prop.BType == vk::ComponentTypeKHR::eFloat16 && + (vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup + ) { + if ((vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat32 && + (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat32) { + // coopmat sizes not set yet + if (device->coopmat_m == 0) { + device->coopmat_acc_f32_support = true; + device->coopmat_m = prop.MSize; + device->coopmat_n = prop.NSize; + device->coopmat_k = prop.KSize; + } else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) { + // Only enable if shape is identical + device->coopmat_acc_f32_support = true; + } + } else if ((vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat16 && + (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat16) { + // coopmat sizes not set yet + if (device->coopmat_m == 0) { + device->coopmat_acc_f16_support = true; + device->coopmat_m = prop.MSize; + device->coopmat_n = prop.NSize; + device->coopmat_k = prop.KSize; + } else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) { + // Only enable if shape is identical + device->coopmat_acc_f16_support = true; + } + } + } + } + + if (device->coopmat_m == 0 || !device->coopmat_acc_f32_support) { + // No suitable matmul mode found + GGML_LOG_DEBUG("ggml_vulkan: WARNING: No suitable matrix core mode found. Disabling matrix cores.\n"); + device->coopmat_support = false; + } + } + + if (device->coopmat_support) { + device_extensions.push_back("VK_KHR_cooperative_matrix"); + } + + device->name = GGML_VK_NAME + std::to_string(idx); + + device_create_info = { + vk::DeviceCreateFlags(), + device_queue_create_infos, + {}, + device_extensions + }; + device_create_info.setPNext(&device_features2); + device->device = device->physical_device.createDevice(device_create_info); + + // Queues + ggml_vk_create_queue(device, device->compute_queue, compute_queue_family_index, 0, { vk::PipelineStageFlagBits::eComputeShader | vk::PipelineStageFlagBits::eTransfer }, false); + + // Shaders + // Disable matmul tile sizes early if performance low or not supported + switch (device->vendor_id) { +#ifndef GGML_VULKAN_RUN_TESTS + case VK_VENDOR_ID_AMD: + case VK_VENDOR_ID_INTEL: + device->mul_mat_l = false; + device->mul_mat_m = true; + device->mul_mat_s = true; + device->mul_mat_id_l = false; + device->mul_mat_id_m = true; + device->mul_mat_id_s = true; + break; + case VK_VENDOR_ID_APPLE: + device->mul_mat_l = false; + device->mul_mat_m = true; + device->mul_mat_s = false; + device->mul_mat_id_l = false; + device->mul_mat_id_m = true; + device->mul_mat_id_s = false; + break; +#endif + default: + device->mul_mat_l = true; + device->mul_mat_m = true; + device->mul_mat_s = true; + device->mul_mat_id_l = true; + device->mul_mat_id_m = true; + device->mul_mat_id_s = true; + break; + } + + ggml_vk_load_shaders(device); + + if (!device->single_queue) { + const uint32_t transfer_queue_index = compute_queue_family_index == transfer_queue_family_index ? 1 : 0; + ggml_vk_create_queue(device, device->transfer_queue, transfer_queue_family_index, transfer_queue_index, { vk::PipelineStageFlagBits::eTransfer }, true); + } else { + // TODO: Use pointer or reference to avoid copy + device->transfer_queue = device->compute_queue; + } + + device->buffer_type = { + /* .iface = */ ggml_backend_vk_buffer_type_interface, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), idx), + /* .context = */ new ggml_backend_vk_buffer_type_context{ device->name, device }, + }; + + device->fence = device->device.createFence({}); + + device->idx = idx; + + return device; + } + + return vk_instance.devices[idx]; +} + +static void ggml_vk_print_gpu_info(size_t idx) { + GGML_ASSERT(idx < vk_instance.device_indices.size()); + size_t dev_num = vk_instance.device_indices[idx]; + VK_LOG_DEBUG("ggml_vk_print_gpu_info(" << dev_num << ")"); + GGML_ASSERT(vk_instance_initialized); + + std::vector devices = vk_instance.instance.enumeratePhysicalDevices(); + + if (dev_num >= devices.size()) { + std::cerr << "ggml_vulkan: Device with index " << dev_num << " does not exist." << std::endl; + throw std::runtime_error("Device not found"); + } + + vk::PhysicalDevice physical_device = devices[dev_num]; + std::vector ext_props = physical_device.enumerateDeviceExtensionProperties(); + + vk::PhysicalDeviceProperties2 props2; + vk::PhysicalDeviceMaintenance3Properties props3; + vk::PhysicalDeviceSubgroupProperties subgroup_props; + vk::PhysicalDeviceDriverProperties driver_props; + props2.pNext = &props3; + props3.pNext = &subgroup_props; + subgroup_props.pNext = &driver_props; + physical_device.getProperties2(&props2); + + const size_t subgroup_size = subgroup_props.subgroupSize; + const bool uma = props2.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu; + + bool fp16_storage = false; + bool fp16_compute = false; + bool coopmat_support = false; + bool coopmat2_support = false; + + for (auto properties : ext_props) { + if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) { + fp16_storage = true; + } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) { + fp16_compute = true; + } else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0 && + !getenv("GGML_VK_DISABLE_COOPMAT")) { + coopmat_support = true; +#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + } else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 && + !getenv("GGML_VK_DISABLE_COOPMAT2")) { + coopmat2_support = true; +#endif + } + } + + if (!ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props)) { + coopmat_support = false; + } + + const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16"); + bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr; + + bool fp16 = !force_disable_f16 && fp16_storage && fp16_compute; + + vk::PhysicalDeviceFeatures device_features = physical_device.getFeatures(); + + VkPhysicalDeviceFeatures2 device_features2; + device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2; + device_features2.pNext = nullptr; + device_features2.features = (VkPhysicalDeviceFeatures)device_features; + + VkPhysicalDeviceVulkan11Features vk11_features; + vk11_features.pNext = nullptr; + vk11_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_1_FEATURES; + device_features2.pNext = &vk11_features; + + VkPhysicalDeviceVulkan12Features vk12_features; + vk12_features.pNext = nullptr; + vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES; + vk11_features.pNext = &vk12_features; + + // Pointer to the last chain element + VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_features; + + VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features; + coopmat_features.pNext = nullptr; + coopmat_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR; + coopmat_features.cooperativeMatrix = VK_FALSE; + + if (coopmat_support) { + last_struct->pNext = (VkBaseOutStructure *)&coopmat_features; + last_struct = (VkBaseOutStructure *)&coopmat_features; + } + + vkGetPhysicalDeviceFeatures2(physical_device, &device_features2); + + fp16 = fp16 && vk12_features.shaderFloat16; + + coopmat_support = coopmat_support && coopmat_features.cooperativeMatrix; + + std::string matrix_cores = coopmat2_support ? "NV_coopmat2" : coopmat_support ? "KHR_coopmat" : "none"; + + std::string device_name = props2.properties.deviceName.data(); + GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | matrix cores: %s\n", + idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, subgroup_size, matrix_cores.c_str()); + + if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) { + GGML_LOG_DEBUG("ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want.\n"); + } +} + +static bool ggml_vk_instance_validation_ext_available(const std::vector& instance_extensions); +static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector& instance_extensions); + +void ggml_vk_instance_init() { + if (vk_instance_initialized) { + return; + } + VK_LOG_DEBUG("ggml_vk_instance_init()"); + + vk_instance_initialized = true; + + vk::ApplicationInfo app_info{ "ggml-vulkan", 1, nullptr, 0, VK_API_VERSION }; + + const std::vector instance_extensions = vk::enumerateInstanceExtensionProperties(); + const bool validation_ext = ggml_vk_instance_validation_ext_available(instance_extensions); +#ifdef __APPLE__ + const bool portability_enumeration_ext = ggml_vk_instance_portability_enumeration_ext_available(instance_extensions); +#endif + + std::vector layers; + + if (validation_ext) { + layers.push_back("VK_LAYER_KHRONOS_validation"); + } + std::vector extensions; + if (validation_ext) { + extensions.push_back("VK_EXT_validation_features"); + } +#ifdef __APPLE__ + if (portability_enumeration_ext) { + extensions.push_back("VK_KHR_portability_enumeration"); + } +#endif + vk::InstanceCreateInfo instance_create_info(vk::InstanceCreateFlags{}, &app_info, layers, extensions); +#ifdef __APPLE__ + if (portability_enumeration_ext) { + instance_create_info.flags |= vk::InstanceCreateFlagBits::eEnumeratePortabilityKHR; + } +#endif + + std::vector features_enable; + vk::ValidationFeaturesEXT validation_features; + + if (validation_ext) { + features_enable = { vk::ValidationFeatureEnableEXT::eBestPractices }; + validation_features = { + features_enable, + {}, + }; + validation_features.setPNext(nullptr); + instance_create_info.setPNext(&validation_features); + GGML_LOG_DEBUG("ggml_vulkan: Validation layers enabled\n"); + } + vk_instance.instance = vk::createInstance(instance_create_info); + + size_t num_available_devices = vk_instance.instance.enumeratePhysicalDevices().size(); + + // Emulate behavior of CUDA_VISIBLE_DEVICES for Vulkan + char * devices_env = getenv("GGML_VK_VISIBLE_DEVICES"); + if (devices_env != nullptr) { + std::string devices(devices_env); + std::replace(devices.begin(), devices.end(), ',', ' '); + + std::stringstream ss(devices); + size_t tmp; + while (ss >> tmp) { + if(tmp >= num_available_devices) { + std::cerr << "ggml_vulkan: Invalid device index " << tmp << " in GGML_VK_VISIBLE_DEVICES." << std::endl; + throw std::runtime_error("Invalid Vulkan device index"); + } + vk_instance.device_indices.push_back(tmp); + } + } else { + std::vector devices = vk_instance.instance.enumeratePhysicalDevices(); + + // Make sure at least one device exists + if (devices.empty()) { + std::cerr << "ggml_vulkan: Error: No devices found." << std::endl; + GGML_ABORT("fatal error"); + } + + // Default to using all dedicated GPUs + for (size_t i = 0; i < devices.size(); i++) { + vk::PhysicalDeviceProperties2 new_props; + vk::PhysicalDeviceDriverProperties new_driver; + vk::PhysicalDeviceIDProperties new_id; + new_props.pNext = &new_driver; + new_driver.pNext = &new_id; + devices[i].getProperties2(&new_props); + + if (new_props.properties.deviceType == vk::PhysicalDeviceType::eDiscreteGpu) { + // Check if there are two physical devices corresponding to the same GPU + auto old_device = std::find_if( + vk_instance.device_indices.begin(), + vk_instance.device_indices.end(), + [&devices, &new_id](const size_t k){ + vk::PhysicalDeviceProperties2 old_props; + vk::PhysicalDeviceIDProperties old_id; + old_props.pNext = &old_id; + devices[k].getProperties2(&old_props); + return std::equal(std::begin(old_id.deviceUUID), std::end(old_id.deviceUUID), std::begin(new_id.deviceUUID)); + } + ); + if (old_device == vk_instance.device_indices.end()) { + vk_instance.device_indices.push_back(i); + } else { + // There can be two physical devices corresponding to the same GPU if there are 2 different drivers + // This can cause error when splitting layers aross the devices, need to keep only 1 + VK_LOG_DEBUG("Device " << i << " and device " << *old_device << " have the same deviceUUID"); + + vk::PhysicalDeviceProperties2 old_props; + vk::PhysicalDeviceDriverProperties old_driver; + old_props.pNext = &old_driver; + devices[*old_device].getProperties2(&old_props); + + std::map driver_priorities {}; + int old_priority = std::numeric_limits::max(); + int new_priority = std::numeric_limits::max(); + + // Check https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VkDriverId.html for the list of driver id + // Smaller number -> higher priority + switch (old_props.properties.vendorID) { + case VK_VENDOR_ID_AMD: + driver_priorities[vk::DriverId::eMesaRadv] = 1; + driver_priorities[vk::DriverId::eAmdOpenSource] = 2; + driver_priorities[vk::DriverId::eAmdProprietary] = 3; + break; + case VK_VENDOR_ID_INTEL: + driver_priorities[vk::DriverId::eIntelOpenSourceMESA] = 1; + driver_priorities[vk::DriverId::eIntelProprietaryWindows] = 2; + break; + case VK_VENDOR_ID_NVIDIA: + driver_priorities[vk::DriverId::eNvidiaProprietary] = 1; +#if defined(VK_API_VERSION_1_3) && VK_HEADER_VERSION >= 235 + driver_priorities[vk::DriverId::eMesaNvk] = 2; +#endif + break; + } + + if (driver_priorities.count(old_driver.driverID)) { + old_priority = driver_priorities[old_driver.driverID]; + } + if (driver_priorities.count(new_driver.driverID)) { + new_priority = driver_priorities[new_driver.driverID]; + } + + if (new_priority < old_priority) { + auto r = std::remove(vk_instance.device_indices.begin(), vk_instance.device_indices.end(), *old_device); + vk_instance.device_indices.erase(r, vk_instance.device_indices.end()); + vk_instance.device_indices.push_back(i); + + VK_LOG_DEBUG("Prioritize device " << i << " driver " << new_driver.driverName << " over device " << *old_device << " driver " << old_driver.driverName); + } + else { + VK_LOG_DEBUG("Prioritize device " << *old_device << " driver " << old_driver.driverName << " over device " << i << " driver " << new_driver.driverName << std::endl); + } + } + } + } + + // If no dedicated GPUs found, fall back to GPU 0 + if (vk_instance.device_indices.empty()) { + vk_instance.device_indices.push_back(0); + } + } + GGML_LOG_DEBUG("ggml_vulkan: Found %zu Vulkan devices:\n", vk_instance.device_indices.size()); + + for (size_t i = 0; i < vk_instance.device_indices.size(); i++) { + ggml_vk_print_gpu_info(i); + } +} + +static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) { + VK_LOG_DEBUG("ggml_vk_init(" << ctx->name << ", " << idx << ")"); + ggml_vk_instance_init(); + GGML_ASSERT(idx < vk_instance.device_indices.size()); + + ctx->name = GGML_VK_NAME + std::to_string(idx); + + ctx->device = ggml_vk_get_device(idx); + + ctx->semaphore_idx = 0; + ctx->event_idx = 0; + + ctx->prealloc_size_x = 0; + ctx->prealloc_size_y = 0; + ctx->prealloc_size_split_k = 0; + + ctx->fence = ctx->device->device.createFence({}); + +#ifdef GGML_VULKAN_CHECK_RESULTS + const char* skip_checks = getenv("GGML_VULKAN_SKIP_CHECKS"); + vk_skip_checks = (skip_checks == NULL ? 0 : atoi(skip_checks)); + const char* output_tensor = getenv("GGML_VULKAN_OUTPUT_TENSOR"); + vk_output_tensor = (output_tensor == NULL ? 0 : atoi(output_tensor)); +#endif +} + +static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type type) { + VK_LOG_DEBUG("ggml_vk_get_to_fp16()"); + switch (type) { + case GGML_TYPE_F32: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ4_NL: + break; + default: + return nullptr; + } + + return ctx->device->pipeline_dequant[type]; +} + +static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type, ggml_prec prec) { + VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_pipeline(" << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")"); + if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) { + return ctx->device->pipeline_matmul_f32; + } + if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) { + return ctx->device->pipeline_matmul_f32_f16; + } + if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) { + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) { + return ctx->device->pipeline_matmul_f16_f32.f16acc; + } + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { + return ctx->device->pipeline_matmul_f16.f16acc; + } + } else { + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) { + return ctx->device->pipeline_matmul_f16_f32.f32acc; + } + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { + return ctx->device->pipeline_matmul_f16.f32acc; + } + } + + if (src1_type != GGML_TYPE_F32 && !ctx->device->coopmat2) { + return nullptr; + } + + switch (src0_type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ4_NL: + break; + default: + return nullptr; + } + + if (ctx->device->coopmat2) { + assert(src1_type == GGML_TYPE_F16); + return ctx->device->pipeline_dequant_mul_mat_mat_f16[src0_type].f16acc; + } + return ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc; +} + +static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type, uint32_t num_cols) { + VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec()"); + GGML_ASSERT(b_type == GGML_TYPE_F32 || b_type == GGML_TYPE_F16); + GGML_ASSERT(num_cols >= 1 && num_cols <= mul_mat_vec_max_cols); + + switch (a_type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ4_NL: + break; + default: + return nullptr; + } + + return b_type == GGML_TYPE_F32 ? ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[a_type][num_cols-1] : ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[a_type][num_cols-1]; +} + +static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type, ggml_prec prec) { + VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_id_pipeline()"); + if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) { + return ctx->device->pipeline_matmul_id_f32; + } + if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) { + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) { + return ctx->device->pipeline_matmul_id_f16_f32.f16acc; + } + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { + return ctx->device->pipeline_matmul_id_f16.f16acc; + } + } else { + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) { + return ctx->device->pipeline_matmul_id_f16_f32.f32acc; + } + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { + return ctx->device->pipeline_matmul_id_f16.f32acc; + } + } + + GGML_ASSERT(src1_type == GGML_TYPE_F32); + + switch (src0_type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ4_NL: + break; + default: + return nullptr; + } + + return ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f32acc; +} + +static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type) { + VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec()"); + GGML_ASSERT(b_type == GGML_TYPE_F32); + + switch (a_type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ4_NL: + break; + default: + return nullptr; + } + + return ctx->device->pipeline_dequant_mul_mat_vec_id_f32[a_type]; +} + +static vk_buffer ggml_vk_pool_malloc(ggml_backend_vk_context * ctx, size_t size) { + VK_LOG_DEBUG("ggml_vk_pool_malloc(" << size << ")"); + VK_LOG_MEMORY("ggml_vk_pool_malloc"); + + int best_i = -1; + size_t best_size = std::numeric_limits::max(); //smallest unused buffer that fits our needs + int worst_i = -1; + size_t worst_size = 0; //largest unused buffer seen so far + for (int i = 0; i < MAX_VK_BUFFERS; ++i) { + vk_buffer &b = ctx->buffer_pool[i]; + if (b != nullptr && b->size >= size && b->size < best_size) { + best_i = i; + best_size = b->size; + } + if (b != nullptr && b->size > worst_size) { + worst_i = i; + worst_size = b->size; + } + } + if(best_i != -1) { + //found the smallest buffer that fits our needs + vk_buffer b = ctx->buffer_pool[best_i]; + ctx->buffer_pool[best_i].reset(); + return b; + } + if(worst_i != -1) { + //no buffer that fits our needs, resize largest one to save memory + vk_buffer& b = ctx->buffer_pool[worst_i]; + ggml_vk_destroy_buffer(b); + } + + return ggml_vk_create_buffer_device(ctx->device, size); +} + +static void ggml_vk_pool_free(ggml_backend_vk_context * ctx, vk_buffer& buffer) { + VK_LOG_DEBUG("ggml_vk_pool_free(" << buffer->size << ")"); + for (int i = 0; i < MAX_VK_BUFFERS; ++i) { + vk_buffer& b = ctx->buffer_pool[i]; + if (b == nullptr) { + b = buffer; + return; + } + } + std::cerr << "ggml_vulkan: WARNING: vk buffer pool full, increase MAX_VK_BUFFERS" << std::endl; + ggml_vk_destroy_buffer(buffer); +} + +// Returns an available temporary buffer that may only be used temporarily, it will be reused +static vk_buffer ggml_vk_create_buffer_temp(ggml_backend_vk_context * ctx, size_t size) { + // Try to find existing temp buffer with enough capacity + for (auto& buffer : ctx->gc.temp_buffers) { + if (buffer->size >= size) { + return buffer; + } + } + + VK_LOG_MEMORY("ggml_vk_create_buffer_temp(" << size << ")"); + + // Otherwise create new buffer + vk_buffer buf = ggml_vk_pool_malloc(ctx, size); + ctx->gc.temp_buffers.push_back(buf); + + return buf; +} + +static void * ggml_vk_host_malloc(vk_device& device, size_t size) { + VK_LOG_MEMORY("ggml_vk_host_malloc(" << size << ")"); + vk_buffer buf = ggml_vk_create_buffer(device, size, + vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached, + vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent); + + if(!(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible)) { + fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory\n", + size/1024.0/1024.0); + device->device.freeMemory(buf->device_memory); + device->device.destroyBuffer(buf->buffer); + return nullptr; + } + + device->pinned_memory.push_back(std::make_tuple(buf->ptr, size, buf)); + + return buf->ptr; +} + +static void ggml_vk_host_free(vk_device& device, void* ptr) { + if (ptr == nullptr) { + return; + } + VK_LOG_MEMORY("ggml_vk_host_free(" << ptr << ")"); + vk_buffer buf; + size_t index; + for (size_t i = 0; i < device->pinned_memory.size(); i++) { + const uint8_t* addr = (const uint8_t*) std::get<0>(device->pinned_memory[i]); + const uint8_t* endr = addr + std::get<1>(device->pinned_memory[i]); + if (ptr >= addr && ptr < endr) { + buf = std::get<2>(device->pinned_memory[i]); + index = i; + break; + } + } + if (buf == nullptr) { + fprintf(stderr, "WARNING: failed to free pinned memory: memory not in map\n"); + return; + } + + ggml_vk_destroy_buffer(buf); + + device->pinned_memory.erase(device->pinned_memory.begin() + index); +} + +static void ggml_vk_host_get(vk_device& device, const void * ptr, vk_buffer& buf, size_t& buf_offset) { + buf = nullptr; + buf_offset = 0; + for (size_t i = 0; i < device->pinned_memory.size(); i++) { + const uint8_t* addr = (const uint8_t*) std::get<0>(device->pinned_memory[i]); + const uint8_t* endr = addr + std::get<1>(device->pinned_memory[i]); + if (ptr >= addr && ptr < endr) { + buf = std::get<2>(device->pinned_memory[i]); + buf_offset = ((const uint8_t *)ptr) - addr; + break; + } + } +} + +static vk_submission ggml_vk_begin_submission(vk_device& device, vk_queue& q, bool one_time = true) { + vk_submission s; + s.buffer = ggml_vk_create_cmd_buffer(device, q); + if (one_time) { + s.buffer.begin({ vk::CommandBufferUsageFlagBits::eOneTimeSubmit }); + } else { + s.buffer.begin({ vk::CommandBufferUsageFlags{} }); + } + + return s; +} + + + +static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& subctx, vk_pipeline& pipeline, std::initializer_list const& descriptor_buffer_infos, size_t push_constant_size, const void* push_constants, std::array elements) { + const uint32_t wg0 = CEIL_DIV(elements[0], pipeline->wg_denoms[0]); + const uint32_t wg1 = CEIL_DIV(elements[1], pipeline->wg_denoms[1]); + const uint32_t wg2 = CEIL_DIV(elements[2], pipeline->wg_denoms[2]); + VK_LOG_DEBUG("ggml_vk_dispatch_pipeline(" << pipeline->name << ", {"; + for (auto& buffer : descriptor_buffer_infos) { + std::cerr << "(" << buffer.buffer << ", " << buffer.offset << ", " << buffer.range << "), "; + } + std::cerr << "}, (" << wg0 << "," << wg1 << "," << wg2 << "))"); + GGML_ASSERT(pipeline->descriptor_set_idx < pipeline->descriptor_sets.size()); + GGML_ASSERT(descriptor_buffer_infos.size() == pipeline->parameter_count); + + vk::DescriptorSet& descriptor_set = pipeline->descriptor_sets[pipeline->descriptor_set_idx++]; + vk::WriteDescriptorSet write_descriptor_set{ descriptor_set, 0, 0, pipeline->parameter_count, vk::DescriptorType::eStorageBuffer, nullptr, descriptor_buffer_infos.begin() }; + ctx->device->device.updateDescriptorSets({ write_descriptor_set }, {}); + + subctx->s->buffer.pushConstants(pipeline->layout, vk::ShaderStageFlagBits::eCompute, 0, push_constant_size, push_constants); + subctx->s->buffer.bindPipeline(vk::PipelineBindPoint::eCompute, pipeline->pipeline); + subctx->s->buffer.bindDescriptorSets(vk::PipelineBindPoint::eCompute, + pipeline->layout, + 0, + { descriptor_set }, + {}); + subctx->s->buffer.dispatch(wg0, wg1, wg2); +} + +static void ggml_vk_end_submission(vk_submission& s, std::vector wait_semaphores, std::vector signal_semaphores) { + s.buffer.end(); + + s.wait_semaphores = std::move(wait_semaphores); + s.signal_semaphores = std::move(signal_semaphores); +} + +static void ggml_vk_ctx_end(vk_context& ctx) { + VK_LOG_DEBUG("ggml_vk_ctx_end(" << ctx << ", " << ctx->seqs.size() << ")"); + if (ctx->s == nullptr) { + return; + } + + ctx->s->buffer.end(); + ctx->s = nullptr; +} + +static void ggml_vk_ctx_begin(vk_device& device, vk_context& subctx) { + VK_LOG_DEBUG("ggml_vk_ctx_begin(" << device->name << ")"); + if (subctx->s != nullptr) { + ggml_vk_ctx_end(subctx); + } + + subctx->seqs.push_back({ ggml_vk_begin_submission(device, *subctx->q) }); + subctx->s = subctx->seqs[subctx->seqs.size() - 1].data(); +} + +static size_t ggml_vk_align_size(size_t width, size_t align) { + VK_LOG_DEBUG("ggml_vk_align_size(" << width << ", " << align << ")"); + return CEIL_DIV(width, align) * align; +} + +static void deferred_memcpy(void * dst, const void * src, size_t size, std::vector* memcpys = nullptr) { + if (memcpys == nullptr) { + memcpy(dst, src, size); + } else { + memcpys->emplace_back(dst, src, size); + } +} + +static void ggml_vk_ensure_sync_staging_buffer(vk_device& device, size_t size) { + if (device->sync_staging == nullptr || device->sync_staging->size < size) { + VK_LOG_MEMORY("ggml_vk_ensure_sync_staging_buffer(" << size << ")"); + ggml_vk_destroy_buffer(device->sync_staging); + device->sync_staging = ggml_vk_create_buffer_check(device, size, + vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached, + vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent); + } +} + +static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_context& subctx, vk_buffer& dst, size_t offset, const ggml_tensor * tensor, bool sync_staging = false) { + VK_LOG_DEBUG("ggml_vk_buffer_write_nc_async(" << tensor << ")"); + GGML_ASSERT(!ggml_is_contiguous(tensor)); + // Buffer is already mapped + if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) { + std::cerr << "ggml_vulkan: buffer_write_nc_async dst buffer is host_visible. Use synchronous write." << std::endl; + GGML_ABORT("fatal error"); + } + // Check if src is pinned memory + vk_buffer buf = nullptr; + size_t buf_offset = 0; + ggml_vk_host_get(ctx->device, tensor->data, buf, buf_offset); + + const uint64_t ne0 = tensor->ne[0]; + const uint64_t ne1 = tensor->ne[1]; + const uint64_t ne2 = tensor->ne[2]; + const uint64_t ne3 = tensor->ne[3]; + const uint64_t nb0 = tensor->nb[0]; + const uint64_t nb1 = tensor->nb[1]; + const uint64_t nb2 = tensor->nb[2]; + const uint64_t nb3 = tensor->nb[3]; + const ggml_type type = tensor->type; + const uint64_t ts = ggml_type_size(type); + const uint64_t bs = ggml_blck_size(type); + + const uint64_t dstnb0 = ts; + const uint64_t dstnb1 = dstnb0*(ne0/bs); + const uint64_t dstnb2 = dstnb1*ne1; + const uint64_t dstnb3 = dstnb2*ne2; + + const uint64_t ne = ggml_nelements(tensor); + + if (buf != nullptr) { + // Memory is pinned, use as staging buffer + std::vector slices; + + for (uint64_t i3 = 0; i3 < ne3; i3++) { + for (uint64_t i2 = 0; i2 < ne2; i2++) { + // Find longest contiguous slice + if (ne1*nb1 == dstnb2) { + slices.push_back({ buf_offset + i3*nb3 + i2*nb2, offset + i3*dstnb3 + i2*dstnb2, dstnb2 }); + } else { + for (uint64_t i1 = 0; i1 < ne1; i1++) { + if (ne0*nb0/bs == dstnb1) { + slices.push_back({ buf_offset + i3*nb3 + i2*nb2 + i1*nb1, offset + i3*dstnb3 + i2*dstnb2 + i1*dstnb1, dstnb1 }); + } else { + const uint64_t s_off = buf_offset + i3*nb3 + i2*nb2 + i1*nb1; + const uint64_t d_off = offset + i3*dstnb3 + i2*dstnb2 + i1*dstnb1; + for (uint64_t i0 = 0; i0 < ne0; i0++) { + slices.push_back({ s_off + i1*nb0, d_off + i0*dstnb0, dstnb0 }); + } + } + } + } + } + } + + ggml_vk_sync_buffers(subctx); + subctx->s->buffer.copyBuffer(buf->buffer, dst->buffer, slices); + return; + } + + if (!sync_staging) { + GGML_ABORT("Asynchronous write to non-pinned memory not supported"); + } + + // Staging buffer required + vk_buffer& staging = ctx->device->sync_staging; + const uint64_t copy_size = ts*ne/bs; + ggml_vk_ensure_sync_staging_buffer(ctx->device, copy_size); + VkBufferCopy buf_copy{ 0, offset, copy_size }; + + ggml_vk_sync_buffers(subctx); + vkCmdCopyBuffer(subctx->s->buffer, (VkBuffer)staging->buffer, (VkBuffer)dst->buffer, 1, &buf_copy); + + for (uint64_t i3 = 0; i3 < ne3; i3++) { + for (uint64_t i2 = 0; i2 < ne2; i2++) { + // Find longest contiguous slice + if (ne1*nb1 == dstnb2) { + deferred_memcpy((uint8_t *)staging->ptr + i3*dstnb3 + i2*dstnb2, (const uint8_t *) tensor->data + buf_offset + i3*nb3 + i2*nb2, dstnb2, &subctx->in_memcpys); + } else { + for (uint64_t i1 = 0; i1 < ne1; i1++) { + if (ne0*nb0/bs == dstnb1) { + deferred_memcpy((uint8_t *)staging->ptr + i3*dstnb3 + i2*dstnb2 + i1*dstnb1, (const uint8_t *) tensor->data + buf_offset + i3*nb3 + i2*nb2 + i1*nb1, dstnb1, &subctx->in_memcpys); + } else { + const uint64_t s_off = buf_offset + i3*nb3 + i2*nb2 + i1*nb1; + const uint64_t d_off = i3*dstnb3 + i2*dstnb2 + i1*dstnb1; + for (uint64_t i0 = 0; i0 < ne0; i0++) { + deferred_memcpy((uint8_t *)staging->ptr + d_off + i0*dstnb0, (const uint8_t *) tensor->data + s_off + i0*nb0, dstnb0, &subctx->in_memcpys); + } + } + } + } + } + } +} + +static void ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height, bool sync_staging = false) { + VK_LOG_DEBUG("ggml_vk_buffer_write_2d_async(" << width << ", " << height << ")"); + // Buffer is already mapped + if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) { + std::cerr << "ggml_vulkan: buffer_write_async dst buffer is host_visible. Use synchronous write." << std::endl; + GGML_ABORT("fatal error"); + } + // Check if src is pinned memory + vk_buffer buf = nullptr; + size_t buf_offset = 0; + ggml_vk_host_get(dst->device, src, buf, buf_offset); + + if (buf != nullptr) { + // Memory is pinned, use as staging buffer + std::vector slices(1); + if (width == spitch) { + // Only do single write if stride is equal + slices[0].srcOffset = buf_offset; + slices[0].dstOffset = offset; + slices[0].size = width * height; + } else { + slices.resize(height); + for (size_t i = 0; i < height; i++) { + slices[i].srcOffset = buf_offset + i * spitch; + slices[i].dstOffset = offset + i * width; + slices[i].size = width; + } + } + + ggml_vk_sync_buffers(subctx); + subctx->s->buffer.copyBuffer(buf->buffer, dst->buffer, slices); + return; + } + VK_LOG_DEBUG("STAGING"); + + if (!sync_staging) { + GGML_ABORT("Asynchronous write to non-pinned memory not supported"); + } + + // Staging buffer required + const size_t copy_size = width*height; + ggml_vk_ensure_sync_staging_buffer(dst->device, copy_size); + + vk_buffer& staging_buffer = dst->device->sync_staging; + + VkBufferCopy buf_copy = { + 0, + offset, + copy_size}; + + ggml_vk_sync_buffers(subctx); + vkCmdCopyBuffer(subctx->s->buffer, (VkBuffer)staging_buffer->buffer, (VkBuffer)dst->buffer, 1, &buf_copy); + + if (width == spitch) { + deferred_memcpy((uint8_t *)staging_buffer->ptr, src, width * height, &subctx->in_memcpys); + } else { + for (size_t i = 0; i < height; i++) { + deferred_memcpy((uint8_t *)staging_buffer->ptr + i * width, (const uint8_t *) src + i * spitch, width, &subctx->in_memcpys); + } + } +} + +static void ggml_vk_buffer_write_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t size, bool sync_staging = false) { + VK_LOG_DEBUG("ggml_vk_buffer_write_async(" << size << ")"); + return ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, size, size, 1, sync_staging); +} + +static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height) { + VK_LOG_DEBUG("ggml_vk_buffer_write_2d(" << width << ", " << height << ")"); + // Buffer is already mapped + if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) { + GGML_ASSERT(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostCoherent); + + for (size_t i = 0; i < height; i++) { + memcpy((uint8_t *)dst->ptr + offset + i * width, (const uint8_t *) src + i * spitch, width); + } + } else { + vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue); + ggml_vk_ctx_begin(dst->device, subctx); + ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, spitch, width, height, true); + ggml_vk_ctx_end(subctx); + + for (auto& cpy : subctx->in_memcpys) { + memcpy(cpy.dst, cpy.src, cpy.n); + } + + ggml_vk_submit(subctx, dst->device->fence); + VK_CHECK(dst->device->device.waitForFences({ dst->device->fence }, true, UINT64_MAX), "vk_buffer_write_2d waitForFences"); + dst->device->device.resetFences({ dst->device->fence }); + } +} + +static void ggml_vk_buffer_write(vk_buffer& dst, size_t offset, const void * src, size_t size) { + VK_LOG_DEBUG("ggml_vk_buffer_write(" << size << ")"); + ggml_vk_buffer_write_2d(dst, offset, src, 0, size, 1); +} + +static void ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size_t offset, void * dst, size_t spitch, size_t dpitch, size_t width, size_t height, bool sync_staging = false) { + VK_LOG_DEBUG("ggml_vk_buffer_read_2d_async(offset=" << offset << ", width=" << width << ", height=" << height << ")"); + GGML_ASSERT(width > 0); + GGML_ASSERT(height > 0); + GGML_ASSERT(src != nullptr); + + // TODO: staging_offset is not used + + // Check if dst is pinned memory + vk_buffer buf = nullptr; + size_t buf_offset = 0; + ggml_vk_host_get(src->device, dst, buf, buf_offset); + + std::vector slices(1); + if (width == spitch && width == dpitch) { + // Only do single write if stride is equal + slices[0].srcOffset = offset; + slices[0].dstOffset = buf_offset; + slices[0].size = width * height; + } else { + slices.resize(height); + for (size_t i = 0; i < height; i++) { + slices[i].srcOffset = offset + i * spitch; + slices[i].dstOffset = buf_offset + i * dpitch; + slices[i].size = width; + } + } + + if (buf != nullptr) { + // Memory is pinned, use as staging buffer + ggml_vk_sync_buffers(subctx); + subctx->s->buffer.copyBuffer(src->buffer, buf->buffer, slices); + + return; + } + VK_LOG_DEBUG("STAGING"); + + if (!sync_staging) { + GGML_ABORT("Asynchronous read from non-pinned memory not supported"); + } + + // Fall back to staging buffer + const size_t copy_size = dpitch * height; + ggml_vk_ensure_sync_staging_buffer(src->device, copy_size); + + vk_buffer& staging_buffer = src->device->sync_staging; + + ggml_vk_sync_buffers(subctx); + subctx->s->buffer.copyBuffer(src->buffer, staging_buffer->buffer, slices); + + deferred_memcpy(dst, staging_buffer->ptr, copy_size, &subctx->out_memcpys); +} + +static void ggml_vk_buffer_read_async(vk_context subctx, vk_buffer& src, size_t offset, void * dst, size_t size, bool sync_staging = false) { + return ggml_vk_buffer_read_2d_async(subctx, src, offset, dst, size, size, size, 1, sync_staging); +} + +static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_t size) { + VK_LOG_DEBUG("ggml_vk_buffer_read(" << src->buffer << ", " << offset << ", " << size << ")"); + + // If the device is not an UMA device the memory is host-accessible through rebar. While writing + // through PCIe is sufficient fast reading back data from PCIe is slower than going through + // the HW device to host copy path. + if(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible && src->device->uma) { + GGML_ASSERT(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostCoherent); + + memcpy(dst, (uint8_t *) src->ptr + offset, size); + } else { + vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue); + ggml_vk_ctx_begin(src->device, subctx); + ggml_vk_buffer_read_async(subctx, src, offset, dst, size, true); + ggml_vk_ctx_end(subctx); + + ggml_vk_submit(subctx, src->device->fence); + VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX), "vk_buffer_read waitForFences"); + src->device->device.resetFences({ src->device->fence }); + + for (auto& cpy : subctx->out_memcpys) { + memcpy(cpy.dst, cpy.src, cpy.n); + } + } +} + +static void ggml_vk_buffer_copy_async(vk_context& ctx, vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) { + VK_LOG_DEBUG("ggml_vk_buffer_copy_async(" << size << ")"); + // Make sure both buffers are on same device + GGML_ASSERT(src->device == dst->device); + + VkBufferCopy bc{ src_offset, dst_offset, size }; + + vkCmdCopyBuffer(ctx->s->buffer, (VkBuffer)src->buffer, (VkBuffer)dst->buffer, 1, &bc); +} + +static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) { + if (src->device == dst->device) { + VK_LOG_DEBUG("ggml_vk_buffer_copy(SINGLE_DEVICE, " << size << ")"); + // Copy within the device + vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue); + ggml_vk_ctx_begin(src->device, subctx); + ggml_vk_buffer_copy_async(subctx, dst, dst_offset, src, src_offset, size); + ggml_vk_ctx_end(subctx); + ggml_vk_submit(subctx, src->device->fence); + VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX), "vk_buffer_copy waitForFences"); + src->device->device.resetFences({ src->device->fence }); + } else { + VK_LOG_DEBUG("ggml_vk_buffer_copy(MULTI_DEVICE, " << size << ")"); + // Copy device to device + ggml_vk_ensure_sync_staging_buffer(src->device, size); + ggml_vk_ensure_sync_staging_buffer(dst->device, size); + + // Copy to src staging buffer + ggml_vk_buffer_copy(src->device->sync_staging, 0, src, src_offset, size); + // memcpy to dst staging buffer + memcpy(dst->device->sync_staging->ptr, src->device->sync_staging->ptr, size); + // Copy to dst buffer + ggml_vk_buffer_copy(dst, dst_offset, dst->device->sync_staging, 0, size); + } +} + +static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, size_t size) { + VK_LOG_DEBUG("ggml_vk_buffer_memset(" << offset << ", " << c << ", " << size << ")"); + + vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue); + ggml_vk_ctx_begin(dst->device, subctx); + subctx->s->buffer.fillBuffer(dst->buffer, offset, size, c); + ggml_vk_ctx_end(subctx); + + ggml_vk_submit(subctx, dst->device->fence); + VK_CHECK(dst->device->device.waitForFences({ dst->device->fence }, true, UINT64_MAX), "vk_memset waitForFences"); + dst->device->device.resetFences({ dst->device->fence }); +} + +static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, int m, int n, int k, const vk_pipeline& pipeline) { + VK_LOG_DEBUG("ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ")"); + + uint32_t split_k = 1; + if (ctx->device->shader_core_count != 0 && m >= (int)pipeline->wg_denoms[0] && n >= (int)pipeline->wg_denoms[1]) { + // If k is 'large' and the SMs will fill less than halfway, use split_k. + uint32_t m_tiles = CEIL_DIV(m, pipeline->wg_denoms[0]); + uint32_t n_tiles = CEIL_DIV(n, pipeline->wg_denoms[1]); + if (k >= 2048 && m_tiles * n_tiles < ctx->device->shader_core_count / 2) { + split_k = ctx->device->shader_core_count / (m_tiles * n_tiles); + // Clamp to 2 or 4 + split_k = std::min(split_k, 4u); + if (split_k == 3) { + split_k = 2; + } + } + } + + return split_k; +} + +static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) { + VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ")"); + + if (ctx->device->coopmat2) { + if ((ctx->device->mul_mat_l && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_m && !ctx->device->mul_mat_s)) { + return aligned ? mmp->a_l : mmp->l; + } + if ((ctx->device->mul_mat_m && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_s) { + return aligned ? mmp->a_m : mmp->m; + } + return aligned ? mmp->a_s : mmp->s; + } + + if ((ctx->device->mul_mat_s && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_m && !ctx->device->mul_mat_l)) { + return aligned ? mmp->a_s : mmp->s; + } + if ((ctx->device->mul_mat_m && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_l) { + return aligned ? mmp->a_m : mmp->m; + } + return aligned ? mmp->a_l : mmp->l; +} + +static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n) { + VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ")"); + return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true)->align; +} + +static void ggml_vk_matmul( + ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline& pipeline, + vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& split_k_buffer, + uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d, + uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d, + uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3) { + VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ")"); + ggml_vk_sync_buffers(subctx); + if (split_k == 1) { + const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3 }; + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, sizeof(vk_mat_mat_push_constants), &pc, { m, n, batch }); + return; + } + + GGML_ASSERT(batch_stride_d == m * n); + + const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, CEIL_DIV(k, split_k), ne02, ne12, broadcast2, broadcast3 }; + // Make sure enough workgroups get assigned for split k to work + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, sizeof(vk_mat_mat_push_constants), &pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch }); + ggml_vk_sync_buffers(subctx); + const std::array pc2 = { (uint32_t)(m * n * batch), split_k }; + ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 }); +} + +static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) { + VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ")"); + + if (ctx->device->coopmat2) { + if ((ctx->device->mul_mat_id_l && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_id_m && !ctx->device->mul_mat_id_s)) { + return aligned ? mmp->a_l : mmp->l; + } + if ((ctx->device->mul_mat_id_m && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_id_s) { + return aligned ? mmp->a_m : mmp->m; + } + return aligned ? mmp->a_s : mmp->s; + } + + if ((ctx->device->mul_mat_id_s && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_id_m && !ctx->device->mul_mat_id_l)) { + return aligned ? mmp->a_s : mmp->s; + } + if ((ctx->device->mul_mat_id_m && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_id_l) { + return aligned ? mmp->a_m : mmp->m; + } + return aligned ? mmp->a_l : mmp->l; +} + +static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n) { + VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ")"); + return ggml_vk_guess_matmul_id_pipeline(ctx, mmp, m, n, true)->align; +} + +static void ggml_vk_matmul_id( + ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline& pipeline, + vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& ids, + uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d, + uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d, + uint32_t n_as, uint32_t nei0, uint32_t nei1, uint32_t nbi1, uint32_t ne11) { + VK_LOG_DEBUG("ggml_vk_matmul_id(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), ids: (" << ids.buffer->buffer << ", " << ids.offset << ", " << ids.size << "), " << + "m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", " << + "batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", " << + "n_as: " << n_as << ", nei0: " << nei0 << ", nei1: " << nei1 << ", nbi1: " << nbi1 << ", ne11: " << ne11 << ")"); + ggml_vk_sync_buffers(subctx); + const vk_mat_mat_id_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, + nei0, nei1, nbi1, ne11 }; + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids }, sizeof(vk_mat_mat_id_push_constants), &pc, { m, nei1, n_as }); +} + +static bool ggml_vk_dim01_contiguous(const ggml_tensor * tensor) { + return + tensor->nb[0] == ggml_type_size(tensor->type) && + tensor->nb[1] == (tensor->nb[0]*tensor->ne[0])/ggml_blck_size(tensor->type) && + tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; +} + +static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src, const ggml_tensor * dst, ggml_type to) { + + // Choose "contiguous copy" shader if src/dst are contiguous + bool contig = ggml_is_contiguous(src) && (!dst || ggml_is_contiguous(dst)); + + if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_F32) { + if (contig) { + return ctx->device->pipeline_contig_cpy_f32_f32; + } else { + return ctx->device->pipeline_cpy_f32_f32; + } + } + if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_F16) { + if (contig) { + return ctx->device->pipeline_contig_cpy_f32_f16; + } else { + return ctx->device->pipeline_cpy_f32_f16; + } + } + if (src->type == GGML_TYPE_F16 && to == GGML_TYPE_F16) { + if (contig) { + return ctx->device->pipeline_contig_cpy_f16_f16; + } else { + return ctx->device->pipeline_cpy_f16_f16; + } + } + + std::cerr << "Missing CPY op for types: " << ggml_type_name(src->type) << " " << ggml_type_name(to) << std::endl; + GGML_ABORT("fatal error"); +} + +static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline pipeline, const ggml_tensor * tensor, vk_subbuffer&& in, vk_subbuffer&& out) { + VK_LOG_DEBUG("ggml_vk_cpy_to_contiguous((" << tensor << ", type=" << tensor->type << ", ne0=" << tensor->ne[0] << ", ne1=" << tensor->ne[1] << ", ne2=" << tensor->ne[2] << ", ne3=" << tensor->ne[3] << ", nb0=" << tensor->nb[0] << ", nb1=" << tensor->nb[1] << ", nb2=" << tensor->nb[2] << ", nb3=" << tensor->nb[3] << "), "; + std::cerr << "buffer in size=" << in.buffer->size << ", buffer out size=" << out.buffer->size << ")"); + const int tensor_type_size = ggml_type_size(tensor->type); + + const uint32_t ne = ggml_nelements(tensor); + std::array elements; + + if (ne > 262144) { + elements = { 512, 512, CEIL_DIV(ne, 262144) }; + } else if (ne > 512) { + elements = { 512, CEIL_DIV(ne, 512), 1 }; + } else { + elements = { ne, 1, 1 }; + } + + vk_op_unary_push_constants pc = { + (uint32_t)ne, + (uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3], (uint32_t)tensor->nb[0] / tensor_type_size, (uint32_t)tensor->nb[1] / tensor_type_size, (uint32_t)tensor->nb[2] / tensor_type_size, (uint32_t)tensor->nb[3] / tensor_type_size, + (uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3], 1 , (uint32_t)tensor->ne[0] , (uint32_t)(tensor->ne[0] * tensor->ne[1]) , (uint32_t)(tensor->ne[0] * tensor->ne[1] * tensor->ne[2]), + 0, + 0.0f, 0.0f, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }; + init_pushconst_fastdiv(pc); + ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(vk_op_unary_push_constants), &pc, elements); +} + +static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + VK_LOG_DEBUG("ggml_vk_mul_mat_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; + std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; + std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; + std::cerr << "), " << (dryrun ? "dryrun" : "") << ")"); + GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); // NOLINT + GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT + + const uint64_t ne00 = src0->ne[0]; + const uint64_t ne01 = src0->ne[1]; + const uint64_t ne02 = src0->ne[2]; + const uint64_t ne03 = src0->ne[3]; + + const uint64_t ne10 = src1->ne[0]; + const uint64_t ne11 = src1->ne[1]; + const uint64_t ne12 = src1->ne[2]; + const uint64_t ne13 = src1->ne[3]; + + const uint64_t ne20 = dst->ne[0]; + const uint64_t ne21 = dst->ne[1]; + + const uint64_t r2 = ne12 / ne02; + const uint64_t r3 = ne13 / ne03; + + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; + ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; + + vk_buffer d_Qx = nullptr; + size_t qx_buf_offset = 0; + vk_buffer d_Qy = nullptr; + size_t qy_buf_offset = 0; + + bool src0_uma = false; + bool src1_uma = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, src0->data, d_Qx, qx_buf_offset); + ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset); + src0_uma = d_Qx != nullptr; + src1_uma = d_Qy != nullptr; + } + + const bool x_non_contig = !ggml_vk_dim01_contiguous(src0); + // Reformat and convert to fp16 if src1 is non-contiguous, or for coopmat2 for better perf + const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) || + !ggml_vk_dim01_contiguous(src1); + + const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig; + + vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type, (ggml_prec)dst->op_params[0]); + + const bool qx_needs_dequant = mmp == nullptr || x_non_contig; + const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig; + + if (qx_needs_dequant) { + // Fall back to dequant + f16 mulmat + mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16, (ggml_prec)dst->op_params[0]); + } + + // Not implemented + GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT + + const int x_ne = ne01 * ne00; + const int y_ne = ne11 * ne10; + const int d_ne = ne11 * ne01; + + const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11)); + const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8; + + vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned); + + const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, pipeline); + + const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type); + const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); + const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne; + const uint64_t y_sz = y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne; + const uint64_t d_sz = sizeof(float) * d_ne; + + vk_pipeline to_fp16_vk_0 = nullptr; + vk_pipeline to_fp16_vk_1 = nullptr; + + if (x_non_contig) { + to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, GGML_TYPE_F16); + } else { + to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type); + } + if (y_non_contig) { + to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, GGML_TYPE_F16); + } else { + to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type); + } + GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT + GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT + + if (dryrun) { + const uint64_t x_sz_upd = x_sz * ne02 * ne03; + const uint64_t y_sz_upd = y_sz * ne12 * ne13; + const uint64_t split_k_size = split_k > 1 ? d_sz * ne12 * ne13 * split_k : 0; + if ( + (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) || + (qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size) || + (split_k > 1 && split_k_size > ctx->device->max_memory_allocation_size)) { + GGML_ABORT("Requested preallocation size is too large"); + } + if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) { + ctx->prealloc_size_x = x_sz_upd; + } + if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) { + ctx->prealloc_size_y = y_sz_upd; + } + if (split_k > 1 && ctx->prealloc_size_split_k < split_k_size) { + ctx->prealloc_size_split_k = split_k_size; + } + + // Request descriptor sets + ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); + if (qx_needs_dequant) { + ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_0, 1); + } + if (qy_needs_dequant) { + ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1); + } + if (split_k > 1) { + ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, 1); + } + return; + } + + vk_buffer d_D = dst_buf_ctx->dev_buffer; + const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; + GGML_ASSERT(d_D != nullptr); + GGML_ASSERT(d_D->size >= d_buf_offset + d_sz * ne02 * ne03); + vk_buffer d_X; + uint64_t x_buf_offset = 0; + vk_buffer d_Y; + uint64_t y_buf_offset = 0; + if (!src0_uma) { + d_Qx = src0_buf_ctx->dev_buffer; + qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs; + GGML_ASSERT(d_Qx != nullptr); + } + if (!src1_uma) { + d_Qy = src1_buf_ctx->dev_buffer; + qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs; + GGML_ASSERT(d_Qy != nullptr); + } + if (qx_needs_dequant) { + d_X = ctx->prealloc_x; + GGML_ASSERT(d_X->size >= x_sz * ne02 * ne03); + } else { + d_X = d_Qx; + x_buf_offset = qx_buf_offset; + GGML_ASSERT(qx_sz == x_sz); + } + if (qy_needs_dequant) { + d_Y = ctx->prealloc_y; + GGML_ASSERT(d_Y->size >= y_sz * ne02 * ne03); + } else { + d_Y = d_Qy; + y_buf_offset = qy_buf_offset; + GGML_ASSERT(qy_sz == y_sz); + } + + if (x_non_contig) { + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE }); + } else if (qx_needs_dequant) { + const std::vector pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) }; + ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, vk_subbuffer{ d_X, 0, x_sz * ne02 * ne03 } }, pc.size() * sizeof(uint32_t), pc.data(), { (uint32_t)(x_ne * ne02 * ne03), 1, 1}); + } + if (y_non_contig) { + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }); + } + + uint32_t stride_batch_x = ne00*ne01; + uint32_t stride_batch_y = ne10*ne11; + + if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) { + stride_batch_x = src0->nb[0] / ggml_type_size(src0->type); + } + + if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) { + stride_batch_y = src1->nb[0] / ggml_type_size(src1->type); + } + + // compute + ggml_vk_matmul( + ctx, subctx, pipeline, + { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 }, + { d_D, d_buf_offset, d_sz * ne12 * ne13 }, { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k }, + ne01, ne11, ne10, + ne10, ne10, ne01, stride_batch_x, stride_batch_y, ne20*ne21, + split_k, ne12*ne13, ne02, ne12, r2, r3 + ); // NOLINT +} + +static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + VK_LOG_DEBUG("ggml_vk_mul_mat_vec_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; + std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; + std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; + std::cerr << "), " << (dryrun ? "dryrun" : "") << "),)"); + GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); // NOLINT + GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT + + const uint64_t ne00 = src0->ne[0]; + const uint64_t ne01 = src0->ne[1]; + const uint64_t ne02 = src0->ne[2]; + const uint64_t ne03 = src0->ne[3]; + + const uint64_t ne10 = src1->ne[0]; + const uint64_t ne11 = src1->ne[1]; + const uint64_t ne12 = src1->ne[2]; + const uint64_t ne13 = src1->ne[3]; + + const uint64_t ne20 = dst->ne[0]; + const uint64_t ne21 = dst->ne[1]; + const uint64_t ne22 = dst->ne[2]; + const uint64_t ne23 = dst->ne[3]; + + const uint64_t r2 = ne12 / ne02; + const uint64_t r3 = ne13 / ne03; + + // batch_n indicates that we need to compute a few vector results, and this assumes + // ne12 and ne13 are 1. It overloads the batch_strides to hold the row strides. + GGML_ASSERT(ne11 == 1 || ne12 * ne13 == 1); + bool batch_n = ne11 > 1; + + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; + ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; + + vk_buffer d_Qx = nullptr; + size_t qx_buf_offset = 0; + vk_buffer d_Qy = nullptr; + size_t qy_buf_offset = 0; + + bool src0_uma = false; + bool src1_uma = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, src0->data, d_Qx, qx_buf_offset); + ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset); + src0_uma = d_Qx != nullptr; + src1_uma = d_Qy != nullptr; + } + + const bool x_non_contig = !ggml_vk_dim01_contiguous(src0); + const bool y_non_contig = !ggml_vk_dim01_contiguous(src1); + + const bool f16_f32_kernel = src1->type == GGML_TYPE_F32; + + const bool qx_needs_dequant = x_non_contig; + const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig; + + // Not implemented + GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT + + const uint64_t x_ne = ne01 * ne00; + const uint64_t y_ne = ne11 * ne10; + const uint64_t d_ne = ne11 * ne01; + + const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment); + const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); + const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz; + const uint64_t y_sz = f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne; + const uint64_t d_sz = sizeof(float) * d_ne; + + vk_pipeline to_fp16_vk_0 = nullptr; + vk_pipeline to_fp16_vk_1 = nullptr; + if (x_non_contig) { + to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, src0->type); + } + if (y_non_contig) { + to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, src1->type); + } else { + to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type); + } + vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, src1->type, ne11); + GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT + GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT + GGML_ASSERT(dmmv != nullptr); + + if (dryrun) { + const uint64_t x_sz_upd = x_sz * ne02 * ne03; + const uint64_t y_sz_upd = y_sz * ne12 * ne13; + if ( + (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) || + (qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size)) { + GGML_ABORT("Requested preallocation size is too large"); + } + if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) { + ctx->prealloc_size_x = x_sz_upd; + } + if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) { + ctx->prealloc_size_y = y_sz_upd; + } + + // Request descriptor sets + if (qx_needs_dequant) { + ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_0, 1); + } + if (qy_needs_dequant) { + ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1); + } + ggml_pipeline_request_descriptor_sets(ctx->device, dmmv, 1); + return; + } + + vk_buffer d_D = dst_buf_ctx->dev_buffer; + const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; + GGML_ASSERT(d_D != nullptr); + vk_buffer d_X; + uint64_t x_buf_offset = 0; + vk_buffer d_Y; + uint64_t y_buf_offset = 0; + if(!src0_uma) { + d_Qx = src0_buf_ctx->dev_buffer; + qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs; + GGML_ASSERT(d_Qx != nullptr); + } + if(!src1_uma) { + d_Qy = src1_buf_ctx->dev_buffer; + qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs; + GGML_ASSERT(d_Qy != nullptr); + } + if (qx_needs_dequant) { + d_X = ctx->prealloc_x; + } else { + d_X = d_Qx; + x_buf_offset = qx_buf_offset; + GGML_ASSERT(qx_sz == x_sz); + } + if (qy_needs_dequant) { + d_Y = ctx->prealloc_y; + } else { + d_Y = d_Qy; + y_buf_offset = qy_buf_offset; + GGML_ASSERT(qy_sz == y_sz); + } + + if (x_non_contig) { + GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment)); + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE }); + } + if (y_non_contig) { + GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne); + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }); + } + + // For batch_n, the A matrix is the same for each batch, and B/D use the row stride as the batch stride + uint32_t stride_batch_x = batch_n ? 0 : ne00*ne01; + uint32_t stride_batch_y = batch_n ? ne10 : (ne10*ne11); + uint32_t stride_batch_d = batch_n ? ne20 : (ne20*ne21); + + if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) { + stride_batch_x = src0->nb[0] / ggml_type_size(src0->type); + } + + if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) { + stride_batch_y = src1->nb[0] / ggml_type_size(src1->type); + } + + const uint32_t max_groups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0]; + + uint32_t groups_x = ne01; + uint32_t groups_z = 1; + + if (ne01 > max_groups_x) { + groups_z = 64; + groups_x = CEIL_DIV(groups_x, groups_z); + } + + // compute + const vk_mat_vec_push_constants pc = { + (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01, + stride_batch_x, stride_batch_y, stride_batch_d, + (uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3, + }; + ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, + { vk_subbuffer{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, vk_subbuffer{ d_Y, y_buf_offset, y_sz * ne12 * ne13 }, vk_subbuffer{ d_D, d_buf_offset, d_sz * ne22 * ne23} }, + sizeof(vk_mat_vec_push_constants), &pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z }); +} + +static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + VK_LOG_DEBUG("ggml_vk_mul_mat_p021_f16_f32(" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; + std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; + std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; + std::cerr << "), " << (dryrun ? "dryrun" : "") << ")"); + GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1)); + GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // NOLINT + GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); // NOLINT + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + const uint64_t ne00 = src0->ne[0]; + const uint64_t ne01 = src0->ne[1]; + const uint64_t ne02 = src0->ne[2]; + // const uint64_t ne03 = src0->ne[3]; + + const uint64_t ne10 = src1->ne[0]; + const uint64_t ne11 = src1->ne[1]; + const uint64_t ne12 = src1->ne[2]; + // const uint64_t ne13 = src1->ne[3]; + + GGML_ASSERT(ne11 == 1); + + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; + ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; + + vk_buffer d_Qy = nullptr; + size_t qy_buf_offset = 0; + + bool src1_uma = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset); + src1_uma = d_Qy != nullptr; + } + + const uint64_t x_ne = ne00 * ne01 * ne02; + const uint64_t y_ne = ne10 * ne11 * ne12; + const uint64_t d_ne = ne01 * ne11 * ne12; + + const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment); + const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); + const uint64_t d_sz = sizeof(float) * d_ne; + + if (dryrun) { + // Request descriptor sets + ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_mul_mat_vec_p021_f16_f32, 1); + return; + } + + vk_buffer d_D = dst_buf_ctx->dev_buffer; + const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; + GGML_ASSERT(d_D != nullptr); + vk_buffer d_Qx = src0_buf_ctx->dev_buffer; + const uint64_t qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs; + GGML_ASSERT(d_Qx != nullptr); + if (!src1_uma) { + d_Qy = src1_buf_ctx->dev_buffer; + qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs; + GGML_ASSERT(d_Qx != nullptr); + } + + const uint64_t qy_buffer_offset = (qy_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment; + const uint64_t qy_shader_offset = qy_buf_offset - qy_buffer_offset; + + const uint64_t d_buffer_offset = (d_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment; + const uint64_t d_shader_offset = d_buf_offset - d_buffer_offset; + + // compute + const std::array pc = { (uint32_t)ne00, (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) }; + ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 6 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 }); +} + +static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + VK_LOG_DEBUG("ggml_vk_mul_mat_nc_f16_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; + std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; + std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; + std::cerr << "), " << (dryrun ? "dryrun" : "") << ")"); + GGML_ASSERT(!ggml_is_transposed(src0)); + GGML_ASSERT(!ggml_is_transposed(src1)); + GGML_ASSERT(!ggml_is_permuted(src0)); + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + const uint64_t ne00 = src0->ne[0]; + const uint64_t ne01 = src0->ne[1]; + const uint64_t ne02 = src0->ne[2]; + // const uint64_t ne03 = src0->ne[3]; + + const uint64_t nb01 = src0->nb[1]; + const uint64_t nb02 = src0->nb[2]; + + // const uint64_t ne10 = src1->ne[0]; + const uint64_t ne11 = src1->ne[1]; + const uint64_t ne12 = src1->ne[2]; + // const uint64_t ne13 = src1->ne[3]; + + GGML_ASSERT(ne11 == 1); + + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; + ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; + + vk_buffer d_Qy = nullptr; + size_t qy_buf_offset = 0; + + bool src1_uma = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset); + src1_uma = d_Qy != nullptr; + } + + const uint64_t d_ne = ne01 * ne11 * ne12; + + const uint32_t row_stride_x = nb01 / sizeof(ggml_fp16_t); + const uint32_t channel_stride_x = nb02 / sizeof(ggml_fp16_t); + + const uint64_t qx_sz = ggml_nbytes(src0); + const uint64_t qy_sz = ggml_nbytes(src1); + const uint64_t d_sz = sizeof(float) * d_ne; + + if (dryrun) { + // Request descriptor sets + ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, 1); + return; + } + + vk_buffer d_D = dst_buf_ctx->dev_buffer; + const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; + GGML_ASSERT(d_D != nullptr); + vk_buffer d_Qx = src0_buf_ctx->dev_buffer; + const uint64_t qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs; + GGML_ASSERT(d_Qx != nullptr); + if (!src1_uma) { + d_Qy = src1_buf_ctx->dev_buffer; + qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs; + GGML_ASSERT(d_Qx != nullptr); + } + + const uint64_t qy_buffer_offset = (qy_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment; + const uint64_t qy_shader_offset = qy_buf_offset - qy_buffer_offset; + + const uint64_t d_buffer_offset = (d_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment; + const uint64_t d_shader_offset = d_buf_offset - d_buffer_offset; + + // compute + const std::array pc = { (uint32_t)ne00, (uint32_t)ne01, row_stride_x, channel_stride_x, (uint32_t)(ne12 / ne02), (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) }; + ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, + { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 7 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 }); +} + +static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + VK_LOG_DEBUG("ggml_vk_mul_mat(" << src0 << ", " << src1 << ", " << dst << ")"); + if (src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && dst->ne[1] == 1 && + // detect 0213 permutation, and batch size of 1 + src0->nb[0] <= src0->nb[2] && + src0->nb[2] <= src0->nb[1] && + src0->nb[1] <= src0->nb[3] && + src1->nb[0] <= src1->nb[2] && + src1->nb[2] <= src1->nb[1] && + src1->nb[1] <= src1->nb[3] && + src0->ne[3] == 1 && + src1->ne[3] == 1) { + ggml_vk_mul_mat_vec_p021_f16_f32(ctx, subctx, src0, src1, dst, dryrun); + } else if (src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && dst->ne[1] == 1 && + !ggml_is_permuted(src0) && !ggml_is_permuted(src1)) { + ggml_vk_mul_mat_vec_nc_f16_f32(ctx, subctx, src0, src1, dst, dryrun); + // mul_mat_vec supports batching ne12*ne13 when ne11==1, or treating ne11 as the batch size (up to four) + // when ne12 and ne13 are one. + } else if ((dst->ne[1] == 1 || (dst->ne[1] <= mul_mat_vec_max_cols && src1->ne[2] * src1->ne[3] == 1)) && + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) { + ggml_vk_mul_mat_vec_q_f16(ctx, subctx, src0, src1, dst, dryrun); + } else { + ggml_vk_mul_mat_q_f16(ctx, subctx, src0, src1, dst, dryrun); + } +} + +static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, bool dryrun = false) { + VK_LOG_DEBUG("ggml_vk_mul_mat_id_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; + std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; + std::cerr << "), (" << ids << ", name=" << ids->name << ", type=" << ids->type << ", ne0=" << ids->ne[0] << ", ne1=" << ids->ne[1] << ", ne2=" << ids->ne[2] << ", ne3=" << ids->ne[3] << ", nb0=" << ids->nb[0] << ", nb1=" << ids->nb[1] << ", nb2=" << ids->nb[2] << ", nb3=" << ids->nb[3]; + std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)"); + GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT + GGML_ASSERT(ids->type == GGML_TYPE_I32); + + const uint64_t ne00 = src0->ne[0]; + const uint64_t ne01 = src0->ne[1]; + const uint64_t ne02 = src0->ne[2]; + const uint64_t ne03 = src0->ne[3]; + + const uint64_t ne10 = src1->ne[0]; + const uint64_t ne11 = src1->ne[1]; + const uint64_t ne12 = src1->ne[2]; + const uint64_t ne13 = src1->ne[3]; + + const uint64_t nei0 = ids->ne[0]; + const uint64_t nei1 = ids->ne[1]; + GGML_ASSERT(nei0 * nei1 <= 3072); + + const uint32_t nbi1 = ids->nb[1]; + const uint32_t nbi2 = ids->nb[2]; + + const uint64_t ne20 = dst->ne[0]; + const uint64_t ne21 = dst->ne[1]; + const uint64_t ne22 = dst->ne[2]; + const uint64_t ne23 = dst->ne[3]; + + const uint64_t n_as = ne02; + + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; + ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; + ggml_backend_vk_buffer_context * ids_buf_ctx = (ggml_backend_vk_buffer_context *)ids->buffer->context; + + vk_buffer d_Qx = nullptr; + size_t qx_buf_offset = 0; + vk_buffer d_Qy = nullptr; + size_t qy_buf_offset = 0; + vk_buffer d_ids = nullptr; + size_t ids_buf_offset = 0; + + bool src0_uma = false; + bool src1_uma = false; + bool ids_uma = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, src0->data, d_Qx, qx_buf_offset); + ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset); + ggml_vk_host_get(ctx->device, ids->data, d_ids, ids_buf_offset); + src0_uma = d_Qx != nullptr; + src1_uma = d_Qy != nullptr; + ids_uma = d_ids != nullptr; + } + + const bool x_non_contig = !ggml_vk_dim01_contiguous(src0); + const bool y_non_contig = !ggml_vk_dim01_contiguous(src1); + + const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig; + + vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type, (ggml_prec)dst->op_params[0]); + + const bool qx_needs_dequant = mmp == nullptr || x_non_contig; + const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig; + + if (qx_needs_dequant) { + GGML_ABORT("fatal error"); + } + + // Not implemented + GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT + + const uint64_t x_ne = ne01 * ne00; + const uint64_t y_ne = ne11 * ne10; + const uint64_t d_ne = ne21 * ne20; + + const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1)); + const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8; + + vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned); + + const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type); + const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); + const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne; + const uint64_t y_sz = y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne; + const uint64_t ids_sz = nbi2; + const uint64_t d_sz = sizeof(float) * d_ne; + + vk_pipeline to_fp16_vk_0 = nullptr; + vk_pipeline to_fp16_vk_1 = nullptr; + + if (x_non_contig) { + to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, GGML_TYPE_F16); + } else { + to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type); + } + if (y_non_contig) { + to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, GGML_TYPE_F16); + } else { + to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type); + } + GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT + GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT + + if (dryrun) { + const uint64_t x_sz_upd = x_sz * ne02 * ne03; + const uint64_t y_sz_upd = y_sz * ne12 * ne13; + if ( + (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) || + (qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size)) { + GGML_ABORT("Requested preallocation size is too large"); + } + if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) { + ctx->prealloc_size_x = x_sz_upd; + } + if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) { + ctx->prealloc_size_y = y_sz_upd; + } + + // Request descriptor sets + ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); + if (qx_needs_dequant) { + ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_0, 1); + } + if (qy_needs_dequant) { + ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1); + } + return; + } + + vk_buffer d_D = dst_buf_ctx->dev_buffer; + const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; + GGML_ASSERT(d_D != nullptr); + vk_buffer d_X; + uint64_t x_buf_offset = 0; + vk_buffer d_Y; + uint64_t y_buf_offset = 0; + if (!src0_uma) { + d_Qx = src0_buf_ctx->dev_buffer; + qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs; + GGML_ASSERT(d_Qx != nullptr); + } + if (!src1_uma) { + d_Qy = src1_buf_ctx->dev_buffer; + qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs; + GGML_ASSERT(d_Qy != nullptr); + } + if (!ids_uma) { + d_ids = ids_buf_ctx->dev_buffer; + ids_buf_offset = vk_tensor_offset(ids) + ids->view_offs; + GGML_ASSERT(d_ids != nullptr); + } + if (qx_needs_dequant) { + d_X = ctx->prealloc_x; + GGML_ASSERT(d_X->size >= x_sz * ne02 * ne03); + } else { + d_X = d_Qx; + x_buf_offset = qx_buf_offset; + GGML_ASSERT(qx_sz == x_sz); + } + if (qy_needs_dequant) { + d_Y = ctx->prealloc_y; + GGML_ASSERT(d_Y->size >= y_sz * ne02 * ne03); + } else { + d_Y = d_Qy; + y_buf_offset = qy_buf_offset; + GGML_ASSERT(qy_sz == y_sz); + } + + if (x_non_contig) { + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE }); + } else if (qx_needs_dequant) { + const std::vector pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) }; + ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0, + { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, vk_subbuffer{ d_X, 0, x_sz * ne02 * ne03 } }, pc.size() * sizeof(uint32_t), pc.data(), { (uint32_t)(x_ne * ne02 * ne03), 1, 1}); + } + if (y_non_contig) { + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }); + } + + uint32_t stride_batch_x = ne00*ne01; + uint32_t stride_batch_y = ne10*ne11; + + if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) { + stride_batch_x = src0->nb[0] / ggml_type_size(src0->type); + } + + if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) { + stride_batch_y = src1->nb[0] / ggml_type_size(src1->type); + } + + // compute + ggml_vk_matmul_id( + ctx, subctx, pipeline, + { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 }, + { d_D, d_buf_offset, d_sz * ne22 * ne23 }, { d_ids, ids_buf_offset, ids_sz }, + ne01, ne21, ne10, ne10, ne10, ne01, + stride_batch_x, stride_batch_y, ne20*ne21, + n_as, nei0, nei1, nbi1 / ggml_type_size(ids->type), ne11 + ); // NOLINT +} + +static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, bool dryrun = false) { + VK_LOG_DEBUG("ggml_vk_mul_mat_vec_id_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; + std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; + std::cerr << "), (" << ids << ", name=" << ids->name << ", type=" << ids->type << ", ne0=" << ids->ne[0] << ", ne1=" << ids->ne[1] << ", ne2=" << ids->ne[2] << ", ne3=" << ids->ne[3] << ", nb0=" << ids->nb[0] << ", nb1=" << ids->nb[1] << ", nb2=" << ids->nb[2] << ", nb3=" << ids->nb[3]; + std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; + std::cerr << "), " << (dryrun ? "dryrun" : "") << ")"); + GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); // NOLINT + GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT + GGML_ASSERT(ids->type == GGML_TYPE_I32); + + const uint64_t ne00 = src0->ne[0]; + const uint64_t ne01 = src0->ne[1]; + const uint64_t ne02 = src0->ne[2]; + const uint64_t ne03 = src0->ne[3]; + + const uint64_t ne10 = src1->ne[0]; + const uint64_t ne11 = src1->ne[1]; + const uint64_t ne12 = src1->ne[2]; + const uint64_t ne13 = src1->ne[3]; + + const uint64_t nei0 = ids->ne[0]; + const uint64_t nei1 = ids->ne[1]; + + const uint64_t nbi2 = ids->nb[2]; + + GGML_ASSERT(nei1 == 1); + + const uint64_t ne20 = dst->ne[0]; + const uint64_t ne21 = dst->ne[1]; + const uint64_t ne22 = dst->ne[2]; + const uint64_t ne23 = dst->ne[3]; + + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; + ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; + ggml_backend_vk_buffer_context * ids_buf_ctx = (ggml_backend_vk_buffer_context *)ids->buffer->context; + + vk_buffer d_Qx = nullptr; + size_t qx_buf_offset = 0; + vk_buffer d_Qy = nullptr; + size_t qy_buf_offset = 0; + vk_buffer d_ids = nullptr; + size_t ids_buf_offset = 0; + + bool src0_uma = false; + bool src1_uma = false; + bool ids_uma = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, src0->data, d_Qx, qx_buf_offset); + ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset); + ggml_vk_host_get(ctx->device, ids->data, d_ids, ids_buf_offset); + src0_uma = d_Qx != nullptr; + src1_uma = d_Qy != nullptr; + ids_uma = d_ids != nullptr; + } + + const bool x_non_contig = !ggml_vk_dim01_contiguous(src0); + const bool y_non_contig = !ggml_vk_dim01_contiguous(src1); + + const bool f16_f32_kernel = src1->type == GGML_TYPE_F32; + + const bool qx_needs_dequant = x_non_contig; + const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig; + + // Not implemented + GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT + + const uint64_t x_ne = ne01 * ne00; + const uint64_t y_ne = ne11 * ne10; + const uint64_t d_ne = ne21 * ne20; + + const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment); + const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); + const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz; + const uint64_t y_sz = f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne; + const uint64_t ids_sz = nbi2; + const uint64_t d_sz = sizeof(float) * d_ne; + + vk_pipeline to_fp16_vk_0 = nullptr; + vk_pipeline to_fp16_vk_1 = nullptr; + if (x_non_contig) { + to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, src0->type); + } + if (y_non_contig) { + to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, src1->type); + } else { + to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type); + } + vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec_id(ctx, src0->type, src1->type); + GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT + GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT + GGML_ASSERT(dmmv != nullptr); + + if (dryrun) { + const uint64_t x_sz_upd = x_sz * ne02 * ne03; + const uint64_t y_sz_upd = y_sz * ne12 * ne13; + if ( + (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) || + (qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size)) { + GGML_ABORT("Requested preallocation size is too large"); + } + if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) { + ctx->prealloc_size_x = x_sz_upd; + } + if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) { + ctx->prealloc_size_y = y_sz_upd; + } + + // Request descriptor sets + if (qx_needs_dequant) { + ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_0, 1); + } + if (qy_needs_dequant) { + ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1); + } + ggml_pipeline_request_descriptor_sets(ctx->device, dmmv, 1); + return; + } + + vk_buffer d_D = dst_buf_ctx->dev_buffer; + const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; + GGML_ASSERT(d_D != nullptr); + vk_buffer d_X; + uint64_t x_buf_offset = 0; + vk_buffer d_Y; + uint64_t y_buf_offset = 0; + if(!src0_uma) { + d_Qx = src0_buf_ctx->dev_buffer; + qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs; + GGML_ASSERT(d_Qx != nullptr); + } + if(!src1_uma) { + d_Qy = src1_buf_ctx->dev_buffer; + qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs; + GGML_ASSERT(d_Qy != nullptr); + } + if(!ids_uma) { + d_ids = ids_buf_ctx->dev_buffer; + ids_buf_offset = vk_tensor_offset(ids) + ids->view_offs; + GGML_ASSERT(d_ids != nullptr); + } + if (qx_needs_dequant) { + d_X = ctx->prealloc_x; + } else { + d_X = d_Qx; + x_buf_offset = qx_buf_offset; + GGML_ASSERT(qx_sz == x_sz); + } + if (qy_needs_dequant) { + d_Y = ctx->prealloc_y; + } else { + d_Y = d_Qy; + y_buf_offset = qy_buf_offset; + GGML_ASSERT(qy_sz == y_sz); + } + + if (x_non_contig) { + GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment)); + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE }); + } + if (y_non_contig) { + GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne); + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }); + } + + uint32_t stride_batch_y = ne10*ne11; + + if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) { + stride_batch_y = src1->nb[0] / ggml_type_size(src1->type); + } + + const uint32_t max_groups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0]; + + uint32_t groups_x = ne01; + uint32_t groups_z = 1; + + if (ne01 > max_groups_x) { + groups_z = 64; + groups_x = CEIL_DIV(groups_x, groups_z); + } + + // compute + const vk_mat_vec_id_push_constants pc = { + (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01, + (uint32_t)x_ne, stride_batch_y, (uint32_t)(ne20*ne21), + (uint32_t)nei0, (uint32_t)ne11, + }; + ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, + { vk_subbuffer{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, + vk_subbuffer{ d_Y, y_buf_offset, y_sz * ne12 * ne13 }, vk_subbuffer{ d_D, d_buf_offset, d_sz * ne22 * ne23}, vk_subbuffer{ d_ids, ids_buf_offset, ids_sz } }, + sizeof(vk_mat_vec_id_push_constants), &pc, { groups_x, (uint32_t)nei0, groups_z }); +} + +static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) { + VK_LOG_DEBUG("ggml_vk_mul_mat_id(" << src0 << ", " << src1 << ", " << src2 << ", " << dst << ")"); + if (src2->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) { + ggml_vk_mul_mat_vec_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun); + } else { + ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun); + } +} + +static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * q, const ggml_tensor * k, const ggml_tensor * v, const ggml_tensor * mask, ggml_tensor * dst, bool dryrun = false) { + VK_LOG_DEBUG("ggml_vk_flash_attn((" << q << ", name=" << q->name << ", type=" << q->type << ", ne0=" << q->ne[0] << ", ne1=" << q->ne[1] << ", ne2=" << q->ne[2] << ", ne3=" << q->ne[3] << ", nb0=" << q->nb[0] << ", nb1=" << q->nb[1] << ", nb2=" << q->nb[2] << ", nb3=" << q->nb[3]; + std::cerr << "), (" << k << ", name=" << k->name << ", type=" << k->type << ", ne0=" << k->ne[0] << ", ne1=" << k->ne[1] << ", ne2=" << k->ne[2] << ", ne3=" << k->ne[3] << ", nb0=" << k->nb[0] << ", nb1=" << k->nb[1] << ", nb2=" << k->nb[2] << ", nb3=" << k->nb[3]; + std::cerr << "), (" << v << ", name=" << v->name << ", type=" << v->type << ", ne0=" << v->ne[0] << ", ne1=" << v->ne[1] << ", ne2=" << v->ne[2] << ", ne3=" << v->ne[3] << ", nb0=" << v->nb[0] << ", nb1=" << v->nb[1] << ", nb2=" << v->nb[2] << ", nb3=" << v->nb[3]; + std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; + std::cerr << "), " << (dryrun ? "dryrun" : "") << ")"); + + GGML_TENSOR_LOCALS(int64_t, neq, q, ne) + GGML_TENSOR_LOCALS(size_t, nbq, q, nb) + GGML_TENSOR_LOCALS(int64_t, nek, k, ne) + GGML_TENSOR_LOCALS(size_t, nbk, k, nb) + GGML_TENSOR_LOCALS(int64_t, nev, v, ne) + GGML_TENSOR_LOCALS(size_t, nbv, v, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + + const uint32_t nem1 = mask ? mask->ne[1] : 0; + const uint32_t nbm1 = mask ? mask->nb[1] : 0; + + const uint32_t D = neq0; + const uint32_t N = neq1; + const uint32_t KV = nek1; + + GGML_ASSERT(ne0 == D); + GGML_ASSERT(ne2 == N); + + // input tensor rows must be contiguous + GGML_ASSERT(nbq0 == ggml_type_size(q->type)); + GGML_ASSERT(nbk0 == ggml_type_size(k->type)); + GGML_ASSERT(nbv0 == ggml_type_size(v->type)); + + GGML_ASSERT(neq0 == D); + GGML_ASSERT(nek0 == D); + GGML_ASSERT(nev0 == D); + + GGML_ASSERT(neq1 == N); + GGML_ASSERT(nev0 == D); + + GGML_ASSERT(nev1 == nek1); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + assert(dst->type == GGML_TYPE_F32); + assert(q->type == GGML_TYPE_F32); + assert(k->type == v->type); + + vk_pipeline *pipelines; + // XXX TODO other backends may be changing accumulator precision to default to f32 soon + bool f32acc = dst->op_params[3] == GGML_PREC_F32; + bool small_rows = N <= flash_attention_num_small_rows; + switch (D) { + case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break; + case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break; + case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96[k->type][f32acc][small_rows][0]; break; + case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112[k->type][f32acc][small_rows][0]; break; + case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128[k->type][f32acc][small_rows][0]; break; + case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256[k->type][f32acc][small_rows][0]; break; + default: + assert(!"unsupported D value"); + return; + } + assert(pipelines); + + bool aligned = (KV % pipelines[1]->align) == 0; + vk_pipeline pipeline = pipelines[aligned]; + assert(pipeline); + + if (dryrun) { + // Request descriptor sets + ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); + return; + } + + float scale = 1.0f; + float max_bias = 0.0f; + float logit_softcap = 0.0f; + + memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float)); + memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float)); + + if (logit_softcap != 0) { + scale /= logit_softcap; + } + + const uint32_t n_head_kv = neq2; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv)); + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + ggml_vk_sync_buffers(subctx); + + vk_buffer d_Q = nullptr, d_K = nullptr, d_V = nullptr, d_D = nullptr, d_M = nullptr; + size_t q_buf_offset = 0, k_buf_offset = 0, v_buf_offset = 0, d_buf_offset = 0, m_buf_offset = 0; + + bool Q_uma = false, K_uma = false, V_uma = false, D_uma = false, M_uma = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, q->data, d_Q, q_buf_offset); + ggml_vk_host_get(ctx->device, k->data, d_K, q_buf_offset); + ggml_vk_host_get(ctx->device, v->data, d_V, q_buf_offset); + ggml_vk_host_get(ctx->device, dst->data, d_D, q_buf_offset); + Q_uma = d_Q != nullptr; + K_uma = d_K != nullptr; + V_uma = d_V != nullptr; + D_uma = d_D != nullptr; + if (mask) { + ggml_vk_host_get(ctx->device, mask->data, d_M, q_buf_offset); + M_uma = d_M != nullptr; + } + } + + + ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_buffer_context * q_buf_ctx = (ggml_backend_vk_buffer_context *)q->buffer->context; + ggml_backend_vk_buffer_context * k_buf_ctx = (ggml_backend_vk_buffer_context *)k->buffer->context; + ggml_backend_vk_buffer_context * v_buf_ctx = (ggml_backend_vk_buffer_context *)v->buffer->context; + + if (!Q_uma) { + d_Q = q_buf_ctx->dev_buffer; + q_buf_offset = vk_tensor_offset(q) + q->view_offs; + } + if (!K_uma) { + d_K = k_buf_ctx->dev_buffer; + k_buf_offset = vk_tensor_offset(k) + k->view_offs; + } + if (!V_uma) { + d_V = v_buf_ctx->dev_buffer; + v_buf_offset = vk_tensor_offset(v) + v->view_offs; + } + if (!D_uma) { + d_D = d_buf_ctx->dev_buffer; + d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; + } + + if (!M_uma) { + d_M = d_Q; + m_buf_offset = q_buf_offset; + if (mask) { + ggml_backend_vk_buffer_context * m_buf_ctx = (ggml_backend_vk_buffer_context*)mask->buffer->context; + d_M = m_buf_ctx->dev_buffer; + m_buf_offset = vk_tensor_offset(mask) + mask->view_offs; + } + } + + const vk_flash_attn_push_constants pc = { N, KV, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, (uint32_t)neq2, (uint32_t)neq3, (uint32_t)nek2, (uint32_t)nek3, (uint32_t)nev2, (uint32_t)nev3, nem1, (uint32_t)nbq2, (uint32_t)nbq3, (uint32_t)nbk2, (uint32_t)nbk3, (uint32_t)nbv2, (uint32_t)nbv3, nbm1, scale, max_bias, logit_softcap, mask != nullptr, n_head_log2, m0, m1 }; + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, + { + vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE}, + vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE}, + vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE}, + vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE}, + vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE}, + }, + sizeof(vk_flash_attn_push_constants), &pc, { (uint32_t)neq1, (uint32_t)neq2, (uint32_t)neq3 }); +} + +static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) { + switch (op) { + case GGML_OP_GET_ROWS: + GGML_ASSERT(src1->type == GGML_TYPE_I32); + if (dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_get_rows[src0->type]; + } + if (dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_get_rows_f32[src0->type]; + } + return nullptr; + case GGML_OP_ACC: + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_acc_f32; + } + return nullptr; + case GGML_OP_ADD: + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_f32_norepeat : ctx->device->pipeline_add_f32; + } + if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) { + return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_f16_f32_f16_norepeat : ctx->device->pipeline_add_f16_f32_f16; + } + return nullptr; + case GGML_OP_MUL: + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_mul_f32_norepeat : ctx->device->pipeline_mul_f32; + } + return nullptr; + case GGML_OP_DIV: + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_div_f32_norepeat : ctx->device->pipeline_div_f32; + } + return nullptr; + case GGML_OP_CONCAT: + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_concat_f32; + } + if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_concat_f16; + } + if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) { + return ctx->device->pipeline_concat_i32; + } + return nullptr; + case GGML_OP_UPSCALE: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_upscale_f32; + } + return nullptr; + case GGML_OP_SCALE: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_scale_f32; + } + return nullptr; + case GGML_OP_SQR: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_sqr_f32; + } + return nullptr; + case GGML_OP_SIN: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_sin_f32; + } + return nullptr; + case GGML_OP_COS: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_cos_f32; + } + return nullptr; + case GGML_OP_CLAMP: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_clamp_f32; + } + return nullptr; + case GGML_OP_PAD: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_pad_f32; + } + return nullptr; + case GGML_OP_REPEAT: + if (ggml_type_size(src0->type) == sizeof(float) && ggml_type_size(dst->type) == sizeof(float)) { + return ctx->device->pipeline_repeat_f32; + } + return nullptr; + case GGML_OP_CPY: + case GGML_OP_CONT: + case GGML_OP_DUP: + return ggml_vk_get_cpy_pipeline(ctx, src0, dst, dst->type); + case GGML_OP_NORM: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_norm_f32; + } + return nullptr; + case GGML_OP_GROUP_NORM: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_group_norm_f32; + } + return nullptr; + case GGML_OP_RMS_NORM: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_rms_norm_f32; + } + return nullptr; + case GGML_OP_UNARY: + switch (ggml_get_unary_op(dst)) { + case GGML_UNARY_OP_SILU: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_silu_f32; + } + break; + case GGML_UNARY_OP_GELU: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_gelu_f32; + } + break; + case GGML_UNARY_OP_GELU_QUICK: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_gelu_quick_f32; + } + break; + case GGML_UNARY_OP_RELU: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_relu_f32; + } + break; + case GGML_UNARY_OP_TANH: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_tanh_f32; + } + break; + default: + break; + } + return nullptr; + case GGML_OP_DIAG_MASK_INF: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_diag_mask_inf_f32; + } + return nullptr; + case GGML_OP_SOFT_MAX: + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); + + if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) { + return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_wg512 : ctx->device->pipeline_soft_max_f32; + } + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { + return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_f16_wg512 : ctx->device->pipeline_soft_max_f32_f16; + } + return nullptr; + case GGML_OP_ROPE: + { + const int mode = ((const int32_t *) dst->op_params)[2]; + const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; + + if (is_neox) { + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_rope_neox_f32; + } + if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_rope_neox_f16; + } + } else { + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_rope_norm_f32; + } + if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_rope_norm_f16; + } + } + return nullptr; + } + case GGML_OP_ARGSORT: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) { + return ctx->device->pipeline_argsort_f32; + } + return nullptr; + case GGML_OP_SUM_ROWS: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_sum_rows_f32; + } + return nullptr; + case GGML_OP_IM2COL: + if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_im2col_f32; + } + if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_im2col_f32_f16; + } + return nullptr; + case GGML_OP_TIMESTEP_EMBEDDING: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_timestep_embedding_f32; + } + return nullptr; + case GGML_OP_POOL_2D: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_pool2d_f32; + } + return nullptr; + case GGML_OP_RWKV_WKV6: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_rwkv_wkv6_f32; + } + return nullptr; + case GGML_OP_LEAKY_RELU: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_leaky_relu_f32; + } + return nullptr; + default: + return nullptr; + } + + GGML_UNUSED(src2); +} + +static bool ggml_vk_op_supports_incontiguous(ggml_op op) { + switch (op) { + case GGML_OP_CPY: + case GGML_OP_GET_ROWS: + case GGML_OP_ADD: + case GGML_OP_MUL: + case GGML_OP_DIV: + case GGML_OP_CONCAT: + case GGML_OP_UPSCALE: + case GGML_OP_SQR: + case GGML_OP_SIN: + case GGML_OP_COS: + case GGML_OP_CLAMP: + case GGML_OP_PAD: + case GGML_OP_REPEAT: + return true; + default: + return false; + } +} + +static uint32_t get_misalign_bytes(ggml_backend_vk_context * ctx, const ggml_tensor * t) +{ + return ((vk_tensor_offset(t) + t->view_offs) & (ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1));; +} + +template void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, T &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { + GGML_UNUSED(p); + GGML_UNUSED(src0); + GGML_UNUSED(src1); + GGML_UNUSED(src2); + GGML_UNUSED(dst); + static_assert(!std::is_const::value, "unexpected type"); + GGML_ASSERT(!src0 || get_misalign_bytes(ctx, src0) == 0); + GGML_ASSERT(!src1 || get_misalign_bytes(ctx, src1) == 0); + GGML_ASSERT(!src2 || get_misalign_bytes(ctx, src2) == 0); + GGML_ASSERT(!dst || get_misalign_bytes(ctx, dst) == 0); +} + +template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_unary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { + const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); + const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); + + p.misalign_offsets = (a_offset << 16) | d_offset; + + GGML_UNUSED(src1); + GGML_UNUSED(src2); +} + +template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_binary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { + const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); + const uint32_t b_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type); + const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); + + GGML_ASSERT(dst->op != GGML_OP_GET_ROWS || (a_offset == 0 && b_offset == 0 && d_offset == 0)); + + p.misalign_offsets = (a_offset << 16) | (b_offset << 8) | d_offset; + + GGML_UNUSED(src2); +} + +template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_upscale_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { + const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); + const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); + + p.a_offset = a_offset; + p.d_offset = d_offset; + + GGML_UNUSED(src1); + GGML_UNUSED(src2); +} + +template +static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op, PC&& pc, bool dryrun = false) { + VK_LOG_DEBUG("ggml_vk_op_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; + if (src1 != nullptr) { + std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; + } + if (src2 != nullptr) { + std::cerr << "), (" << src2 << ", name=" << src2->name << ", type=" << src2->type << ", ne0=" << src2->ne[0] << ", ne1=" << src2->ne[1] << ", ne2=" << src2->ne[2] << ", ne3=" << src2->ne[3] << ", nb0=" << src2->nb[0] << ", nb1=" << src2->nb[1] << ", nb2=" << src2->nb[2] << ", nb3=" << src2->nb[3]; + } + std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; + std::cerr << "), " << ggml_op_name(op) << ", " << (dryrun ? "dryrun" : "") << ")"); + GGML_ASSERT(op == GGML_OP_GET_ROWS || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type)))); // NOLINT + GGML_ASSERT(ggml_vk_op_supports_incontiguous(op) || ggml_vk_dim01_contiguous(src0)); // NOLINT + GGML_ASSERT(dst->buffer != nullptr); + const uint64_t ne00 = src0->ne[0]; + const uint64_t ne01 = src0->ne[1]; + const uint64_t ne02 = src0->ne[2]; + const uint64_t ne03 = src0->ne[3]; + const uint64_t ne0 = ne00 * ne01; + + const bool use_src1 = src1 != nullptr; + const uint64_t ne10 = use_src1 ? src1->ne[0] : 0; + const uint64_t ne11 = use_src1 ? src1->ne[1] : 0; + const uint64_t ne12 = use_src1 ? src1->ne[2] : 0; + const uint64_t ne13 = use_src1 ? src1->ne[3] : 0; + const uint64_t ne1 = ne10 * ne11; + // const uint64_t nb10 = use_src1 ? src1->nb[0] : 0; + + const bool use_src2 = src2 != nullptr; + const uint64_t ne20 = use_src2 ? src2->ne[0] : 0; + const uint64_t ne21 = use_src2 ? src2->ne[1] : 0; + const uint64_t ne22 = use_src2 ? src2->ne[2] : 0; + const uint64_t ne23 = use_src2 ? src2->ne[3] : 0; + const uint64_t ne2 = ne20 * ne21; + + const uint64_t ned0 = dst->ne[0]; + const uint64_t ned1 = dst->ne[1]; + const uint64_t ned2 = dst->ne[2]; + const uint64_t ned3 = dst->ne[3]; + const uint64_t ned = ned0 * ned1; + + init_pushconst_fastdiv(pc); + + vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, dst, op); + + if (pipeline == nullptr) { + std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(op) << " for " << ggml_type_name(src0->type); + if (src1 != nullptr) { + std::cerr << " and " << ggml_type_name(src1->type); + } + std::cerr << " to " << ggml_type_name(dst->type) << std::endl; + GGML_ABORT("fatal error"); + } + + if (dryrun) { + ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); + return; + } + + const bool op_supports_incontiguous = ggml_vk_op_supports_incontiguous(op); + + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; + ggml_backend_vk_buffer_context * src1_buf_ctx = use_src1 ? (ggml_backend_vk_buffer_context *)src1->buffer->context : nullptr; + ggml_backend_vk_buffer_context * src2_buf_ctx = use_src2 ? (ggml_backend_vk_buffer_context *)src2->buffer->context : nullptr; + + vk_buffer d_X = nullptr; + size_t x_buf_offset = 0; + vk_buffer d_Y = nullptr; + size_t y_buf_offset = 0; + vk_buffer d_Z = nullptr; + size_t z_buf_offset = 0; + + bool src0_uma = false; + bool src1_uma = false; + bool src2_uma = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, src0->data, d_X, x_buf_offset); + src0_uma = d_X != nullptr; + if (use_src1) { + ggml_vk_host_get(ctx->device, src1->data, d_Y, y_buf_offset); + src1_uma = d_Y != nullptr; + } + if (use_src2) { + ggml_vk_host_get(ctx->device, src2->data, d_Z, z_buf_offset); + src2_uma = d_Z != nullptr; + } + } + + uint64_t x_sz = ggml_type_size(src0->type)/ggml_blck_size(src0->type) * ne0; + uint64_t y_sz = use_src1 ? ggml_type_size(src1->type) * ne1 : 0; + uint64_t z_sz = use_src2 ? ggml_type_size(src2->type) * ne2 : 0; + uint64_t d_sz = ggml_type_size(dst->type) * ned; + + vk_buffer d_D = dst_buf_ctx->dev_buffer; + + // Workaround for tiny tensor inputs on ROPE + if (op == GGML_OP_ROPE && use_src1 && y_sz > d_D->size) { + y_sz = VK_WHOLE_SIZE; + } + + GGML_ASSERT(d_D != nullptr); + uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; + if(!src0_uma) { + d_X = src0_buf_ctx->dev_buffer; + x_buf_offset = vk_tensor_offset(src0) + src0->view_offs; + GGML_ASSERT(d_X != nullptr); + } + if (use_src1 && !src1_uma) { + d_Y = src1_buf_ctx->dev_buffer; + y_buf_offset = vk_tensor_offset(src1) + src1->view_offs; + GGML_ASSERT(d_Y != nullptr); + } + if (use_src2 && !src2_uma) { + d_Z = src2_buf_ctx->dev_buffer; + z_buf_offset = vk_tensor_offset(src2) + src2->view_offs; + GGML_ASSERT(d_Z != nullptr); + } + // Compute misalignment offset for descriptors and store it in in push constants, then align the descriptor offsets. + init_pushconst_tensor_offsets(ctx, pc, src0, src1, src2, dst); + x_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); + y_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); + z_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); + d_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); + + if (op_supports_incontiguous) { + x_sz = ggml_nbytes(src0); + y_sz = use_src1 ? ggml_nbytes(src1) : 0; + z_sz = use_src2 ? ggml_nbytes(src2) : 0; + d_sz = ggml_nbytes(dst); + + if (x_buf_offset + x_sz >= d_X->size) { + x_sz = VK_WHOLE_SIZE; + } + if (use_src1 && y_buf_offset + y_sz >= d_Y->size) { + y_sz = VK_WHOLE_SIZE; + } + if (use_src2 && z_buf_offset + z_sz >= d_Z->size) { + z_sz = VK_WHOLE_SIZE; + } + if (d_buf_offset + d_sz >= d_D->size) { + d_sz = VK_WHOLE_SIZE; + } + } + + std::array elements; + + // Single call if dimension 2 is contiguous + GGML_ASSERT(op_supports_incontiguous || (ggml_is_contiguous(src0) && (src1 == nullptr || ggml_is_contiguous(src1)))); + + switch (op) { + case GGML_OP_NORM: + case GGML_OP_RMS_NORM: + case GGML_OP_SOFT_MAX: + case GGML_OP_SUM_ROWS: + { + const uint32_t nr = ggml_nrows(src0); + if (nr > 262144) { + elements = { 512, 512, CEIL_DIV(nr, 262144) }; + } else if (nr > 512) { + elements = { 512, CEIL_DIV(nr, 512), 1 }; + } else { + elements = { nr, 1, 1 }; + } + } break; + case GGML_OP_GROUP_NORM: + { + const uint32_t num_groups = dst->op_params[0]; + elements = { num_groups * (uint32_t)src0->ne[3], 1, 1 }; + } break; + case GGML_OP_DIAG_MASK_INF: + case GGML_OP_ROPE: + elements = { (uint32_t)ggml_nrows(src0), (uint32_t)ne00, 1 }; + break; + case GGML_OP_GET_ROWS: + elements = { (uint32_t)ne00, (uint32_t)ne10, (uint32_t)(ne11 * ne12) }; + break; + case GGML_OP_ARGSORT: + elements = { (uint32_t)ne00, (uint32_t)ggml_nrows(src0), 1 }; + break; + case GGML_OP_IM2COL: + { + const bool is_2D = dst->op_params[6] == 1; + + const uint32_t IC = src1->ne[is_2D ? 2 : 1]; + + const uint32_t KH = is_2D ? src0->ne[1] : 1; + const uint32_t KW = src0->ne[0]; + + const uint32_t OH = is_2D ? dst->ne[2] : 1; + const uint32_t OW = dst->ne[1]; + + const uint32_t batch = src1->ne[is_2D ? 3 : 2]; + + elements = { OW * KW * KH, OH, batch * IC }; + } break; + case GGML_OP_TIMESTEP_EMBEDDING: + { + const uint32_t dim = dst->op_params[0]; + uint32_t half_ceil = (dim + 1) / 2; + elements = { half_ceil, (uint32_t)src0->ne[0], 1 }; + } break; + case GGML_OP_POOL_2D: + { + const uint32_t N = dst->ne[3]; + const uint32_t OC = dst->ne[2]; + const uint32_t OH = dst->ne[1]; + const uint32_t OW = dst->ne[0]; + elements = { N * OC * OH * OW, 1, 1}; + } break; + case GGML_OP_ADD: + case GGML_OP_DIV: + case GGML_OP_MUL: + case GGML_OP_SCALE: + case GGML_OP_SQR: + case GGML_OP_SIN: + case GGML_OP_COS: + case GGML_OP_CLAMP: + case GGML_OP_PAD: + case GGML_OP_REPEAT: + case GGML_OP_CPY: + case GGML_OP_CONCAT: + case GGML_OP_UPSCALE: + case GGML_OP_UNARY: + { + const uint32_t ne = ggml_nelements(dst); + if (ne > 262144) { + elements = { 512, 512, CEIL_DIV(ne, 262144) }; + } else if (ne > 512) { + elements = { 512, CEIL_DIV(ne, 512), 1 }; + } else { + elements = { ne, 1, 1 }; + } + } break; + default: + elements = { (uint32_t)ggml_nelements(src0), 1, 1 }; + break; + } + + if (!op_supports_incontiguous) { + if (x_sz != VK_WHOLE_SIZE) { + x_sz *= ne02 * ne03; + } + if (use_src1 && y_sz != VK_WHOLE_SIZE) { + y_sz *= ne12 * ne13; + } + if (use_src2 && z_sz != VK_WHOLE_SIZE) { + z_sz *= ne22 * ne23; + } + if (d_sz != VK_WHOLE_SIZE) { + d_sz *= ned2 * ned3; + } + } + + if (op == GGML_OP_SOFT_MAX) { + // Empty src1 is possible in soft_max, but the shader needs a buffer + vk_subbuffer subbuf_y; + if (use_src1) { + subbuf_y = { d_Y, y_buf_offset, y_sz }; + } else { + subbuf_y = { d_X, 0, x_sz }; + } + + ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); + } else if (op == GGML_OP_ROPE) { + // Empty src2 is possible in rope, but the shader needs a buffer + vk_subbuffer subbuf_z; + if (use_src2) { + subbuf_z = { d_Z, z_buf_offset, z_sz }; + } else { + subbuf_z = { d_X, 0, x_sz }; + } + + ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); + } else if (op == GGML_OP_IM2COL) { + // im2col uses only src1 and dst buffers + ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); + } else if (use_src2) { + ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); + } else if (use_src1) { + ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); + } else { + ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); + } +} + +static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GET_ROWS, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, 0, + }, dryrun); +} + +static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + int nb1 = dst->op_params[0] / 4; // 4 bytes of float32 + int nb2 = dst->op_params[1] / 4; // 4 bytes of float32 + // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused + int offset = dst->op_params[3] / 4; // offset in bytes + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_ACC, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, offset, + }, dryrun); +} + +static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_ADD, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, 0, + }, dryrun); +} + +static void ggml_vk_mul(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_MUL, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, 0, + }, dryrun); +} + +static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_DIV, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, 0, + }, dryrun); +} + +static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, bool dryrun = false) { + const ggml_tensor * k = dst->src[0]; + const ggml_tensor * v = dst->src[1]; + const ggml_tensor * r = dst->src[2]; + const ggml_tensor * tf = dst->src[3]; + const ggml_tensor * td = dst->src[4]; + const ggml_tensor * state = dst->src[5]; + + GGML_ASSERT(!ggml_is_quantized(k->type)); + GGML_ASSERT(!ggml_is_quantized(v->type)); + GGML_ASSERT(!ggml_is_quantized(r->type)); + GGML_ASSERT(!ggml_is_quantized(tf->type)); + GGML_ASSERT(!ggml_is_quantized(td->type)); + GGML_ASSERT(!ggml_is_quantized(state->type)); + GGML_ASSERT(dst->buffer != nullptr); + + vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, k, v, r, dst, GGML_OP_RWKV_WKV6); + GGML_ASSERT(pipeline != nullptr); + + if (dryrun) { + ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); + return; + } + + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_buffer_context * k_buf_ctx = (ggml_backend_vk_buffer_context *)k->buffer->context; + ggml_backend_vk_buffer_context * v_buf_ctx = (ggml_backend_vk_buffer_context *)v->buffer->context; + ggml_backend_vk_buffer_context * r_buf_ctx = (ggml_backend_vk_buffer_context *)r->buffer->context; + ggml_backend_vk_buffer_context * tf_buf_ctx = (ggml_backend_vk_buffer_context *)tf->buffer->context; + ggml_backend_vk_buffer_context * td_buf_ctx = (ggml_backend_vk_buffer_context *)td->buffer->context; + ggml_backend_vk_buffer_context * state_buf_ctx = (ggml_backend_vk_buffer_context *)state->buffer->context; + + ggml_vk_sync_buffers(subctx); + + vk_buffer d_D = nullptr, d_K = nullptr, d_V = nullptr, d_R = nullptr, d_TF = nullptr, d_TD = nullptr, d_State = nullptr; + size_t k_offset = 0, v_offset = 0, r_offset = 0, tf_offset = 0, td_offset = 0, state_offset = 0, dst_offset = 0; + bool K_uma = false, V_uma = false, R_uma = false, TF_uma = false, TD_uma = false, STATE_uma = false, DST_uma = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, k->data, d_K, k_offset); + ggml_vk_host_get(ctx->device, v->data, d_V, v_offset); + ggml_vk_host_get(ctx->device, r->data, d_R, r_offset); + ggml_vk_host_get(ctx->device, tf->data, d_TF, tf_offset); + ggml_vk_host_get(ctx->device, td->data, d_TD, td_offset); + ggml_vk_host_get(ctx->device, state->data, d_State, state_offset); + ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset); + + K_uma = d_K != nullptr; + V_uma = d_V != nullptr; + R_uma = d_R != nullptr; + TF_uma = d_TF != nullptr; + TD_uma = d_TD != nullptr; + STATE_uma = d_State != nullptr; + DST_uma = d_D != nullptr; + } + + if (!K_uma) { + d_K = k_buf_ctx->dev_buffer; + k_offset = vk_tensor_offset(k) + k->view_offs; + } + if (!V_uma) { + d_V = v_buf_ctx->dev_buffer; + v_offset = vk_tensor_offset(v) + v->view_offs; + } + if (!R_uma) { + d_R = r_buf_ctx->dev_buffer; + r_offset = vk_tensor_offset(r) + r->view_offs; + } + if (!TF_uma) { + d_TF = tf_buf_ctx->dev_buffer; + tf_offset = vk_tensor_offset(tf) + tf->view_offs; + } + if (!TD_uma) { + d_TD = td_buf_ctx->dev_buffer; + td_offset = vk_tensor_offset(td) + td->view_offs; + } + if (!STATE_uma) { + d_State = state_buf_ctx->dev_buffer; + state_offset = vk_tensor_offset(state) + state->view_offs; + } + if (!DST_uma) { + d_D = dst_buf_ctx->dev_buffer; + dst_offset = vk_tensor_offset(dst) + dst->view_offs; + } + + const uint64_t k_size = ggml_nbytes(k); + const uint64_t v_size = ggml_nbytes(v); + const uint64_t r_size = ggml_nbytes(r); + const uint64_t tf_size = ggml_nbytes(tf); + const uint64_t td_size = ggml_nbytes(td); + const uint64_t state_size = ggml_nbytes(state); + const uint64_t dst_size = ggml_nbytes(dst); + + std::array elements = { + (uint32_t)(pc.B * pc.H), + 1, + 1 + }; + + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { + vk_subbuffer{ d_K, k_offset, k_size }, + vk_subbuffer{ d_V, v_offset, v_size }, + vk_subbuffer{ d_R, r_offset, r_size }, + vk_subbuffer{ d_TF, tf_offset, tf_size }, + vk_subbuffer{ d_TD, td_offset, td_size }, + vk_subbuffer{ d_State, state_offset, state_size }, + vk_subbuffer{ d_D, dst_offset, dst_size } + }, sizeof(vk_op_rwkv_wkv6_push_constants), &pc, elements); +} + +static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) { + const size_t seq_length = dst->src[0]->ne[3]; + const size_t n_embed = dst->ne[0]; + const size_t n_heads = dst->src[0]->ne[2]; + const size_t n_seqs = dst->src[5]->ne[1]; + + ggml_vk_op_f32_rwkv6( + ctx, subctx, dst, + { + (uint32_t)n_seqs, + (uint32_t)seq_length, + (uint32_t)n_embed, + (uint32_t)n_heads, + }, + dryrun + ); +} + +static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + int * op_params = (int *)dst->op_params; + + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONCAT, { + (uint32_t)ggml_nelements(dst), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, op_params[0], + }, dryrun); +} + +static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + + const float sf0 = (float)dst->ne[0] / src0->ne[0]; + const float sf1 = (float)dst->ne[1] / src0->ne[1]; + const float sf2 = (float)dst->ne[2] / src0->ne[2]; + const float sf3 = (float)dst->ne[3] / src0->ne[3]; + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UPSCALE, { + (uint32_t)ggml_nelements(dst), 0, 0, + (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)dst->ne[0], (uint32_t)dst->ne[1], (uint32_t)dst->ne[2],(uint32_t)dst->ne[3], + sf0, sf1, sf2, sf3, + }, dryrun); +} + +static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + float * op_params = (float *)dst->op_params; + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SCALE, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + op_params[0], 0.0f, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }, dryrun); +} + +static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQR, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }, dryrun); +} + +static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SIN, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }, dryrun); +} + +static void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_COS, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }, dryrun); +} + +static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + float * op_params = (float *)dst->op_params; + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CLAMP, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + op_params[0], op_params[1], + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }, dryrun); +} + +static void ggml_vk_pad(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_PAD, { + (uint32_t)ggml_nelements(dst), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }, dryrun); +} + +static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT, { + (uint32_t)ggml_nelements(dst), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }, dryrun); +} + +static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }, dryrun); +} + +static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + float * op_params = (float *)dst->op_params; + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun); +} + +static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + const int * int_op_params = (const int *)dst->op_params; + const float * float_op_params = (const float *)dst->op_params; + + const uint32_t num_groups = int_op_params[0]; + const float eps = float_op_params[1]; + const uint32_t group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun); +} + +static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + float * op_params = (float *)dst->op_params; + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun); +} + +static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun); +} + +static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + int32_t * op_params = (int32_t *)dst->op_params; + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }, dryrun); +} + +static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + float * op_params = (float *)dst->op_params; + + float scale = op_params[0]; + float max_bias = op_params[1]; + + const uint32_t ncols = (uint32_t)src0->ne[0]; + const uint32_t nrows_x = (uint32_t)ggml_nrows(src0); + const uint32_t nrows_y = (uint32_t)src0->ne[1]; + + const uint32_t n_head_kv = nrows_x/nrows_y; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, { + ncols, + src1 != nullptr ? nrows_y : (uint32_t)0, + scale, max_bias, + m0, m1, + n_head_log2, + nrows_x, + }, dryrun); +} + +static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) { + const int n_dims = ((int32_t *) dst->op_params)[1]; + // const int mode = ((int32_t *) dst->op_params)[2]; + // const int n_ctx = ((int32_t *) dst->op_params)[3]; + const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; + const float freq_base = ((float *) dst->op_params)[5]; + const float freq_scale = ((float *) dst->op_params)[6]; + const float ext_factor = ((float *) dst->op_params)[7]; + const float attn_factor = ((float *) dst->op_params)[8]; + const float beta_fast = ((float *) dst->op_params)[9]; + const float beta_slow = ((float *) dst->op_params)[10]; + + float corr_dims[2]; + ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); + + const float theta_scale = powf(freq_base, -2.0f/n_dims); + + ggml_vk_op_f32(ctx, subctx, src0, src1, src2, dst, GGML_OP_ROPE, { + (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1], + freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale, + src2 != nullptr, + }, dryrun); +} + +static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + int32_t * op_params = (int32_t *)dst->op_params; + + uint32_t ncols = src0->ne[0]; + + uint32_t ncols_pad = 1; + while (ncols_pad < ncols) { + ncols_pad *= 2; + } + + GGML_ASSERT(ncols_pad <= 1024); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ARGSORT, { + ncols, + ncols_pad, + op_params[0], + }, dryrun); +} + +static void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, { (uint32_t)src0->ne[0], 0, 0.0f, 0.0f }, dryrun); +} + +static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + const int32_t s0 = dst->op_params[0]; + const int32_t s1 = dst->op_params[1]; + const int32_t p0 = dst->op_params[2]; + const int32_t p1 = dst->op_params[3]; + const int32_t d0 = dst->op_params[4]; + const int32_t d1 = dst->op_params[5]; + + const bool is_2D = dst->op_params[6] == 1; + + const uint32_t IC = src1->ne[is_2D ? 2 : 1]; + const uint32_t IH = is_2D ? src1->ne[1] : 1; + const uint32_t IW = src1->ne[0]; + + const uint32_t KH = is_2D ? src0->ne[1] : 1; + const uint32_t KW = src0->ne[0]; + + const uint32_t OH = is_2D ? dst->ne[2] : 1; + const uint32_t OW = dst->ne[1]; + + const uint32_t offset_delta = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32 + const uint32_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32 + + const uint32_t pelements = OW * KW * KH; + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_IM2COL, { + batch_offset, offset_delta, + IC, IW, IH, OW, OH, KW, KH, + pelements, + IC * KH * KW, + s0, s1, p0, p1, d0, d1, + }, dryrun); +} + +static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + const uint32_t dim = dst->op_params[0]; + const uint32_t max_period = dst->op_params[1]; + const uint32_t nb1 = dst->nb[1] / ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_TIMESTEP_EMBEDDING, { + nb1, dim, max_period, + }, dryrun); +} + +static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + uint32_t op = static_cast(dst->op_params[0]); + const int32_t k1 = dst->op_params[1]; + const int32_t k0 = dst->op_params[2]; + const int32_t s1 = dst->op_params[3]; + const int32_t s0 = dst->op_params[4]; + const int32_t p1 = dst->op_params[5]; + const int32_t p0 = dst->op_params[6]; + + const uint32_t IH = src0->ne[1]; + const uint32_t IW = src0->ne[0]; + + const uint32_t N = dst->ne[3]; + + const uint32_t OC = dst->ne[2]; + const uint32_t OH = dst->ne[1]; + const uint32_t OW = dst->ne[0]; + + const uint32_t parallel_elements = N * OC * OH * OW; + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_POOL_2D, { + IW, IH, OW, OH, OC, + parallel_elements, + op, + k0, k1, s0, s1, p0, p1, + }, dryrun); +} + +static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + const float * op_params = (const float *)dst->op_params; + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f }, dryrun); +} + +#ifdef GGML_VULKAN_RUN_TESTS +static void ggml_vk_print_matrix_area(const void * data, ggml_type type, int ne0, int ne1, int i0, int i1, int i2) { + if (type != GGML_TYPE_F32 && type != GGML_TYPE_F16) { + return; + } + i0 = std::max(i0, 5); + i1 = std::max(i1, 5); + i2 = std::max(i2, 0); + fprintf(stderr, " "); + for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) { + fprintf(stderr, "%7d ", idx1); + } + fprintf(stderr, "\n"); + for (int idx0 = i0 - 5; idx0 < i0 + 5; idx0++) { + fprintf(stderr, "%7d: ", idx0); + for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) { + if (idx0 >= 0 && idx0 < ne0 && idx1 >= 0 && idx1 < ne1) { + float val; + if (type == GGML_TYPE_F32) { + val = *((const float *) data + i2*ne1*ne0 + idx1*ne0 + idx0); + } else if (type == GGML_TYPE_F16) { + val = ggml_fp16_to_fp32(*((const ggml_fp16_t *) data + i2*ne1*ne0 + idx1*ne0 + idx0)); + } else { + GGML_ABORT("fatal error"); + } + fprintf(stderr, "% 7.2f ", val); + } else { + fprintf(stderr, " "); + } + } + fprintf(stderr, "\n"); + } +} + +template +static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, int split_k, int shader_size) { + VK_LOG_DEBUG("ggml_vk_test_matmul(" << m << ", " << n << ", " << k << ", " << batch << ", " << num_it << ", " << split_k << ", " << shader_size << ")"); + const size_t x_ne = m * k * batch; + const size_t y_ne = k * n * batch; + const size_t d_ne = m * n * batch; + + vk_pipeline p; + std::string shname; + if (shader_size == 0) { + if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f32->a_s; + shname = "F32_ALIGNED_S"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f32_f16->a_s; + shname = "F32_F16_ALIGNED_S"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f16_f32.f32acc->a_s; + shname = "F16_F32_ALIGNED_S"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f16.f32acc->a_s; + shname = "F16_ALIGNED_S"; + } else { + GGML_ABORT("fatal error"); + } + } else if (shader_size == 1) { + if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f32->a_m; + shname = "F32_ALIGNED_M"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f32_f16->a_m; + shname = "F32_F16_ALIGNED_M"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f16_f32.f32acc->a_m; + shname = "F16_F32_ALIGNED_M"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f16.f32acc->a_m; + shname = "F16_ALIGNED_M"; + } else { + GGML_ABORT("fatal error"); + } + } else if (shader_size == 2) { + if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f32->a_l; + shname = "F32_ALIGNED_L"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f32_f16->a_l; + shname = "F32_F16_ALIGNED_L"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f16_f32.f32acc->a_l; + shname = "F16_F32_ALIGNED_L"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f16.f32acc->a_l; + shname = "F16_ALIGNED_L"; + } else { + GGML_ABORT("fatal error"); + } + } else { + GGML_ASSERT(0); + } + + const size_t kpad = ggml_vk_align_size(k, p->align); + + if (k != kpad) { + if (shader_size == 0) { + if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f32->s; + shname = "F32_S"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f32_f16->s; + shname = "F32_F16_S"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f16_f32.f32acc->s; + shname = "F16_F32_S"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f16.f32acc->s; + shname = "F16_S"; + } + } else if (shader_size == 1) { + if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f32->m; + shname = "F32_M"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f32_f16->m; + shname = "F32_F16_M"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f16_f32.f32acc->m; + shname = "F16_F32_M"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f16.f32acc->m; + shname = "F16_M"; + } + } else if (shader_size == 2) { + if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f32->l; + shname = "F32_L"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f32_f16->l; + shname = "F32_F16_L"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f16_f32.f32acc->l; + shname = "F16_F32_L"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f16.f32acc->l; + shname = "F16_L"; + } + } + } + + ggml_pipeline_request_descriptor_sets(ctx->device, p, num_it); + if (split_k > 1) { + ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, num_it); + + if (ctx->prealloc_split_k == nullptr || ctx->prealloc_split_k->size < sizeof(float) * d_ne * split_k) { + // Resize buffer + if (ctx->prealloc_split_k != nullptr) { + ggml_vk_destroy_buffer(ctx->prealloc_split_k); + } + ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, vk::MemoryPropertyFlagBits::eDeviceLocal); + } + } + + ggml_pipeline_allocate_descriptor_sets(ctx->device); + + vk_buffer d_X = ggml_vk_create_buffer_check(ctx->device, sizeof(X_TYPE) * x_ne, vk::MemoryPropertyFlagBits::eDeviceLocal); + vk_buffer d_Y = ggml_vk_create_buffer_check(ctx->device, sizeof(Y_TYPE) * y_ne, vk::MemoryPropertyFlagBits::eDeviceLocal); + vk_buffer d_D = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne, vk::MemoryPropertyFlagBits::eDeviceLocal); + + X_TYPE* x = (X_TYPE *) malloc(sizeof(X_TYPE) * x_ne); + Y_TYPE* y = (Y_TYPE *) malloc(sizeof(Y_TYPE) * y_ne); + float* d = (float *) malloc(sizeof(float) * d_ne); + + for (size_t i = 0; i < x_ne; i++) { + if (std::is_same()) { + x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f; + // x[i] = 1.0f; + // x[i] = i + 1; + // x[i] = (i % k == i / k) ? 1.0f : 0.0f; + } else if (std::is_same()) { + x[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f); + // x[i] = ggml_fp32_to_fp16(1.0f); + // x[i] = ggml_fp32_to_fp16(i + 1); + // x[i] = ggml_fp32_to_fp16((i % k == i / k) ? 1.0f : 0.0f); + } else { + GGML_ABORT("fatal error"); + } + } + for (size_t i = 0; i < y_ne; i++) { + if (std::is_same()) { + y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f; + // y[i] = (i % k == i / k) ? 1.0f : 0.0f; + // y[i] = i + 1; + } else if (std::is_same()) { + y[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f); + // y[i] = ggml_fp32_to_fp16((i % k == i / k) ? 1.0f : 0.0f); + // y[i] = ggml_fp32_to_fp16(i + 1); + } else { + GGML_ABORT("fatal error"); + } + } + + ggml_vk_buffer_write(d_X, 0, x, sizeof(X_TYPE) * k * m * batch); + ggml_vk_buffer_write(d_Y, 0, y, sizeof(Y_TYPE) * k * n * batch); + + vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue); + ggml_vk_ctx_begin(ctx->device, subctx); + for (size_t i = 0; i < num_it; i++) { + ggml_vk_matmul( + ctx, subctx, p, ggml_vk_subbuffer(d_X), ggml_vk_subbuffer(d_Y), ggml_vk_subbuffer(d_D), ggml_vk_subbuffer(ctx->prealloc_split_k), + m, n, k, + k, k, m, k*m, k*n, m*n, + split_k, batch, batch, batch, 1, 1 + ); + } + ggml_vk_ctx_end(subctx); + + auto begin = std::chrono::high_resolution_clock::now(); + ggml_vk_submit(subctx, ctx->fence); + VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_matmul waitForFences"); + ctx->device->device.resetFences({ ctx->fence }); + + auto end = std::chrono::high_resolution_clock::now(); + double time = std::chrono::duration_cast(end-begin).count() / 1000.0; + + // copy dst to host + ggml_vk_buffer_read(d_D, 0, d, sizeof(float) * d_ne); + + float * d_chk = (float *) malloc(sizeof(float) * d_ne); + + ggml_init_params iparams = { + /*.mem_size =*/ 1024*1024*1024, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + ggml_context * ggml_ctx = ggml_init(iparams); + + ggml_type src0_type; + ggml_type src1_type; + + if (std::is_same()) { + src0_type = GGML_TYPE_F32; + } else if (std::is_same()) { + src0_type = GGML_TYPE_F16; + } else { + GGML_ABORT("fatal error"); + } + if (std::is_same()) { + src1_type = GGML_TYPE_F32; + } else if (std::is_same()) { + src1_type = GGML_TYPE_F16; + } else { + GGML_ABORT("fatal error"); + } + + ggml_tensor * src0_ggml = ggml_new_tensor_3d(ggml_ctx, src0_type, k, m, batch); + ggml_tensor * src1_ggml = ggml_new_tensor_3d(ggml_ctx, src1_type, k, n, batch); + ggml_tensor * tensor_ggml = ggml_mul_mat(ggml_ctx, src0_ggml, src1_ggml); + + src0_ggml->data = x; + src1_ggml->data = y; + tensor_ggml->data = d_chk; + + ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx); + ggml_build_forward_expand(cgraph, tensor_ggml); + + ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 1); + + ggml_free(ggml_ctx); + + double avg_err = 0.0; + int first_err_n = -1; + int first_err_m = -1; + int first_err_b = -1; + + for (size_t i = 0; i < m*n*batch; i++) { + double err = std::fabs(d[i] - d_chk[i]); + avg_err += err; + + if ((err > 0.05f || std::isnan(err)) && first_err_n == -1) { + first_err_b = i / (m * n); + first_err_n = (i % (m * n)) / m; + first_err_m = (i % (m * n)) % m; + } + } + + avg_err /= m * n; + + double tflops = 2.0*m*n*k*batch*num_it / (time / 1000.0) / (1000.0*1000.0*1000.0*1000.0); + + std::cerr << "TEST " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl; + + if (avg_err > 0.1 || std::isnan(avg_err)) { + std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl; + std::cerr << "Actual result: " << std::endl << std::endl; + ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); + std::cerr << "Expected result: " << std::endl << std::endl; + ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); + + if (split_k > 1) { + float * split_k_buf = (float *) malloc(sizeof(float) * d_ne * split_k); + ggml_vk_buffer_read(ctx->prealloc_split_k, 0, split_k_buf, sizeof(float) * d_ne * split_k); + + std::cerr << "d_buf0: " << std::endl << std::endl; + ggml_vk_print_matrix_area(split_k_buf, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); + + std::cerr << "d_buf1: " << std::endl << std::endl; + ggml_vk_print_matrix_area(split_k_buf + d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); + + std::cerr << "d_buf2: " << std::endl << std::endl; + ggml_vk_print_matrix_area(split_k_buf + 2 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); + + std::cerr << "d_buf3: " << std::endl << std::endl; + ggml_vk_print_matrix_area(split_k_buf + 3 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); + + free(split_k_buf); + } + } + + free(d_chk); + + ggml_vk_queue_cleanup(ctx->device, ctx->device->transfer_queue); + ggml_vk_queue_cleanup(ctx->device, ctx->device->compute_queue); + + ggml_vk_destroy_buffer(d_X); + ggml_vk_destroy_buffer(d_Y); + ggml_vk_destroy_buffer(d_D); + + ggml_pipeline_cleanup(p); + ggml_pipeline_cleanup(ctx->device->pipeline_matmul_split_k_reduce); + + free(x); + free(y); + free(d); +} + +static void ggml_vk_print_tensor_area(const ggml_tensor * tensor, int i0, int i1, int i2, int i3) { + if (tensor->type != GGML_TYPE_F32 && tensor->type != GGML_TYPE_F16) { + return; + } + i0 = std::max(i0, 5); + i1 = std::max(i1, 5); + i2 = std::max(i2, 0); + i3 = std::max(i3, 0); + fprintf(stderr, " "); + for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) { + fprintf(stderr, "%7d ", idx1); + } + fprintf(stderr, "\n"); + for (int idx0 = i0 - 5; idx0 < i0 + 5; idx0++) { + fprintf(stderr, "%7d: ", idx0); + for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) { + if (idx0 >= 0 && idx0 < tensor->ne[0] && idx1 >= 0 && idx1 < tensor->ne[1] && i2 >= 0 && i2 < tensor->ne[2] && i3 >= 0 && i3 < tensor->ne[3]) { + float val; + if (tensor->type == GGML_TYPE_F32) { + val = *(float *) ((char *) tensor->data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]); + } else if (tensor->type == GGML_TYPE_F16) { + val = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) tensor->data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0])); + } else { + GGML_ABORT("fatal error"); + } + fprintf(stderr, "% 7.2f ", val); + } else { + fprintf(stderr, " "); + } + } + fprintf(stderr, "\n"); + } +} + +static void ggml_vk_quantize_data(const float * from, void * to, size_t ne, ggml_type quant) { + ggml_quantize_chunk(quant, from, to, 0, 1, ne, nullptr); +} + +static void ggml_vk_dequantize_data(const void * from, float * to, size_t ne, ggml_type quant) { + if (quant == GGML_TYPE_F32) { + memcpy(to, from, sizeof(float) * ne); + return; + } + + const auto * tt = ggml_get_type_traits(quant); + + ggml_to_float_t dequant_fn = tt->to_float; + + dequant_fn(from, to, ne); +} + +static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_type quant) { + VK_LOG_DEBUG("ggml_vk_test_dequant(" << ne << ")"); + const size_t x_sz = sizeof(float) * ne; + const size_t x_sz_f16 = sizeof(ggml_fp16_t) * ne; + const size_t qx_sz = ne * ggml_type_size(quant)/ggml_blck_size(quant); + float * x = (float *) malloc(x_sz); + void * qx = malloc(qx_sz); + vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal); + vk_buffer x_buf = ggml_vk_create_buffer_check(ctx->device, x_sz_f16, vk::MemoryPropertyFlagBits::eDeviceLocal); + float * x_ref = (float *) malloc(x_sz); + ggml_fp16_t * x_chk = (ggml_fp16_t *) malloc(x_sz_f16); + + for (size_t i = 0; i < ne; i++) { + x[i] = rand() / (float)RAND_MAX; + } + + vk_pipeline p = ggml_vk_get_to_fp16(ctx, quant); + + ggml_vk_quantize_data(x, qx, ne, quant); + ggml_vk_dequantize_data(qx, x_ref, ne, quant); + + ggml_pipeline_request_descriptor_sets(ctx->device, p, 1); + + ggml_pipeline_allocate_descriptor_sets(ctx->device); + + ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz); + + vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue); + ggml_vk_ctx_begin(ctx->device, subctx); + const std::vector pc = { 1, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne }; + ggml_vk_dispatch_pipeline(ctx, subctx, p, { vk_subbuffer{ qx_buf, 0, qx_sz }, vk_subbuffer{ x_buf, 0, x_sz_f16 } }, pc.size() * sizeof(int), pc.data(), { (uint32_t)ne, 1, 1}); + ggml_vk_ctx_end(subctx); + + auto begin = std::chrono::high_resolution_clock::now(); + + ggml_vk_submit(subctx, ctx->fence); + VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_dequant waitForFences"); + ctx->device->device.resetFences({ ctx->fence }); + + auto end = std::chrono::high_resolution_clock::now(); + + double ms_dequant = std::chrono::duration_cast(end-begin).count() / 1000.0; + ggml_vk_buffer_read(x_buf, 0, x_chk, x_sz_f16); + + int first_err = -1; + + double avg_err = 0.0; + for (size_t i = 0; i < ne; i++) { + double error = std::fabs(x_ref[i] - ggml_fp16_to_fp32(x_chk[i])); + avg_err += error; + + if (first_err < 0 && error > 0.05) { + first_err = i; + } + } + + avg_err /= ne; + + std::cerr << "TEST DEQUANT " << ggml_type_name(quant) << " time=" << ms_dequant << "ms avg_err=" << avg_err << std::endl; + + if (avg_err > 0.1) { + std::cerr << "first_error = " << first_err << std::endl; + std::cerr << "Actual result: " << std::endl << std::endl; + for (int i = std::max(0, first_err - 5); i < std::min((int)ne, first_err + 5); i++) { + std::cerr << ggml_fp16_to_fp32(x_chk[i]) << ", "; + } + std::cerr << std::endl << "Expected result: " << std::endl << std::endl; + for (int i = std::max(0, first_err - 5); i < std::min((int)ne, first_err + 5); i++) { + std::cerr << x_ref[i] << ", "; + } + std::cerr << std::endl; + } + + ggml_vk_destroy_buffer(x_buf); + ggml_vk_destroy_buffer(qx_buf); + + free(x); + free(qx); + free(x_ref); + free(x_chk); +} + +static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, size_t split_k, size_t shader_size, ggml_type quant) { + VK_LOG_DEBUG("ggml_vk_test_dequant_matmul(" << m << ", " << n << ", " << k << ", " << batch << ", " << num_it << ", " << split_k << ", " << ggml_type_name(quant) << ")"); + const size_t x_ne = m * k * batch; + const size_t y_ne = k * n * batch; + const size_t d_ne = m * n * batch; + + vk_pipeline p; + std::string shname; + if (shader_size == 0) { + p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_s : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_s; + shname = std::string(ggml_type_name(quant)) + "_ALIGNED_S"; + } else if (shader_size == 1) { + p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_m : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_m; + shname = std::string(ggml_type_name(quant)) + "_ALIGNED_M"; + } else if (shader_size == 2) { + p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_l : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_l; + shname = std::string(ggml_type_name(quant)) + "_ALIGNED_L"; + } else { + GGML_ASSERT(0); + } + + const size_t kpad = ggml_vk_align_size(k, p->align); + + if (k != kpad) { + if (shader_size == 0) { + p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->s : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->s; + shname = std::string(ggml_type_name(quant)) + "_S"; + } else if (shader_size == 1) { + p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->m : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->m; + shname = std::string(ggml_type_name(quant)) + "_M"; + } else if (shader_size == 2) { + p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->l : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->l; + shname = std::string(ggml_type_name(quant)) + "_L"; + } else { + GGML_ASSERT(0); + } + } + + const size_t x_sz = sizeof(float) * x_ne; + const size_t y_sz = sizeof(float) * y_ne; + const size_t qx_sz = x_ne * ggml_type_size(quant)/ggml_blck_size(quant); + const size_t d_sz = sizeof(float) * d_ne; + float * x = (float *) malloc(x_sz); + float * y = (float *) malloc(y_sz); + void * qx = malloc(qx_sz); + vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal); + vk_buffer y_buf = ggml_vk_create_buffer_check(ctx->device, y_sz, vk::MemoryPropertyFlagBits::eDeviceLocal); + vk_buffer d_buf = ggml_vk_create_buffer_check(ctx->device, d_sz, vk::MemoryPropertyFlagBits::eDeviceLocal); + float * d = (float *) malloc(d_sz); + float * d_chk = (float *) malloc(d_sz); + + for (size_t i = 0; i < x_ne; i++) { + x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f; + } + + ggml_vk_quantize_data(x, qx, x_ne, quant); + + for (size_t i = 0; i < y_ne; i++) { + // y[i] = rand() / (float)RAND_MAX; + y[i] = (i % k == i / k) ? 1.0f : 0.0f; + } + + ggml_pipeline_request_descriptor_sets(ctx->device, p, num_it); + if (split_k > 1) { + ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, num_it); + + if (ctx->prealloc_split_k == nullptr || ctx->prealloc_split_k->size < sizeof(float) * d_ne * split_k) { + // Resize buffer + if (ctx->prealloc_split_k != nullptr) { + ggml_vk_destroy_buffer(ctx->prealloc_split_k); + } + ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, vk::MemoryPropertyFlagBits::eDeviceLocal); + } + } + + ggml_pipeline_allocate_descriptor_sets(ctx->device); + + ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz); + ggml_vk_buffer_write(y_buf, 0, y, y_sz); + + vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue); + ggml_vk_ctx_begin(ctx->device, subctx); + for (size_t i = 0; i < num_it; i++) { + ggml_vk_matmul( + ctx, subctx, p, ggml_vk_subbuffer(qx_buf), ggml_vk_subbuffer(y_buf), ggml_vk_subbuffer(d_buf), ggml_vk_subbuffer(ctx->prealloc_split_k), + m, n, k, + k, k, m, k*m, k*n, m*n, + split_k, batch, batch, batch, 1, 1 + ); + } + ggml_vk_ctx_end(subctx); + + auto begin = std::chrono::high_resolution_clock::now(); + + ggml_vk_submit(subctx, ctx->fence); + VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_dequant waitForFences"); + ctx->device->device.resetFences({ ctx->fence }); + + auto end = std::chrono::high_resolution_clock::now(); + + double time_ms = std::chrono::duration_cast(end-begin).count() / 1000.0; + ggml_vk_buffer_read(d_buf, 0, d, d_sz); + + ggml_init_params iparams = { + /*.mem_size =*/ 1024*1024*1024, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + ggml_context * ggml_ctx = ggml_init(iparams); + + ggml_tensor * src0_ggml = ggml_new_tensor_3d(ggml_ctx, quant, k, m, batch); + ggml_tensor * src1_ggml = ggml_new_tensor_3d(ggml_ctx, GGML_TYPE_F32, k, n, batch); + ggml_tensor * tensor_ggml = ggml_mul_mat(ggml_ctx, src0_ggml, src1_ggml); + + src0_ggml->data = qx; + src1_ggml->data = y; + tensor_ggml->data = d_chk; + + ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx); + ggml_build_forward_expand(cgraph, tensor_ggml); + + ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 1); + + ggml_free(ggml_ctx); + + double avg_err = 0.0; + int first_err_n = -1; + int first_err_m = -1; + int first_err_b = -1; + + for (size_t i = 0; i < m*n*batch; i++) { + double err = std::fabs(d[i] - d_chk[i]); + avg_err += err; + + if ((err > 0.05f || std::isnan(err)) && first_err_n == -1) { + first_err_b = i / (m * n); + first_err_n = (i % (m * n)) / m; + first_err_m = (i % (m * n)) % m; + } + } + + avg_err /= m * n; + + double tflops = 2.0*m*n*k*batch*num_it / (time_ms / 1000.0) / (1000.0*1000.0*1000.0*1000.0); + + std::cerr << "TEST MMQ " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time_ms / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl; + + if (avg_err > 0.01 || std::isnan(avg_err)) { + std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl; + std::cerr << "Actual result: " << std::endl << std::endl; + ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); + std::cerr << std::endl; + std::cerr << "Expected result: " << std::endl << std::endl; + ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); + + if (split_k > 1) { + float * split_k_buf = (float *) malloc(sizeof(float) * d_ne * split_k); + ggml_vk_buffer_read(ctx->prealloc_split_k, 0, split_k_buf, sizeof(float) * d_ne * split_k); + + std::cerr << "d_buf0: " << std::endl << std::endl; + ggml_vk_print_matrix_area(split_k_buf, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); + + std::cerr << "d_buf1: " << std::endl << std::endl; + ggml_vk_print_matrix_area(split_k_buf + d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); + + std::cerr << "d_buf2: " << std::endl << std::endl; + ggml_vk_print_matrix_area(split_k_buf + 2 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); + + std::cerr << "d_buf3: " << std::endl << std::endl; + ggml_vk_print_matrix_area(split_k_buf + 3 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); + + free(split_k_buf); + } + } + + ggml_vk_destroy_buffer(qx_buf); + ggml_vk_destroy_buffer(y_buf); + ggml_vk_destroy_buffer(d_buf); + + free(x); + free(qx); + free(y); + free(d); + free(d_chk); +} +#endif + +static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) { +#if defined(GGML_VULKAN_RUN_TESTS) + const std::vector vals { + 512, 512, 128, + 128, 512, 512, + 4096, 512, 4096, + 11008, 512, 4096, + 4096, 512, 11008, + 32000, 512, 4096, + 8, 8, 8, + 100, 46, 576, + 623, 111, 128, + 100, 46, 558, + 512, 1, 256, + 128, 110, 622, + 511, 511, 127, + 511, 511, 7, + 511, 511, 17, + 49, 49, 128, + 128, 49, 49, + 4096, 49, 4096, + }; + const size_t num_it = 100; + + for (size_t i = 0; i < vals.size(); i += 3) { + ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0); + ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1); + ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2); + std::cerr << '\n'; + ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 0); + ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 1); + ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 2); + std::cerr << '\n'; + ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0); + ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1); + ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2); + std::cerr << '\n' << std::endl; + + if (vals[i + 2] % 32 == 0) { + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0, GGML_TYPE_Q4_0); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1, GGML_TYPE_Q4_0); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2, GGML_TYPE_Q4_0); + std::cerr << '\n'; + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 0, GGML_TYPE_Q4_0); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 1, GGML_TYPE_Q4_0); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 2, GGML_TYPE_Q4_0); + std::cerr << '\n'; + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0, GGML_TYPE_Q4_0); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1, GGML_TYPE_Q4_0); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2, GGML_TYPE_Q4_0); + std::cerr << '\n' << std::endl; + } + + if (vals[i + 2] % 256 == 0) { + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0, GGML_TYPE_Q4_K); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1, GGML_TYPE_Q4_K); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2, GGML_TYPE_Q4_K); + std::cerr << '\n'; + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 0, GGML_TYPE_Q4_K); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 1, GGML_TYPE_Q4_K); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 2, GGML_TYPE_Q4_K); + std::cerr << '\n'; + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0, GGML_TYPE_Q4_K); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1, GGML_TYPE_Q4_K); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2, GGML_TYPE_Q4_K); + std::cerr << '\n' << std::endl; + } + } + + GGML_ABORT("fatal error"); +#endif + + if (ctx->prealloc_x == nullptr || (ctx->prealloc_size_x > 0 && ctx->prealloc_x->size < ctx->prealloc_size_x)) { + VK_LOG_MEMORY("ggml_vk_preallocate_buffers(x_size: " << ctx->prealloc_size_x << ")"); + // Resize buffer + if (ctx->prealloc_x != nullptr) { + ggml_vk_destroy_buffer(ctx->prealloc_x); + } + ctx->prealloc_x = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_x); + } + if (ctx->prealloc_y == nullptr || (ctx->prealloc_size_y > 0 && ctx->prealloc_y->size < ctx->prealloc_size_y)) { + VK_LOG_MEMORY("ggml_vk_preallocate_buffers(y_size: " << ctx->prealloc_size_y << ")"); + // Resize buffer + if (ctx->prealloc_y != nullptr) { + ggml_vk_destroy_buffer(ctx->prealloc_y); + } + ctx->prealloc_y = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_y); + } + if (ctx->prealloc_split_k == nullptr || (ctx->prealloc_size_split_k > 0 && ctx->prealloc_split_k->size < ctx->prealloc_size_split_k)) { + VK_LOG_MEMORY("ggml_vk_preallocate_buffers(split_k_size: " << ctx->prealloc_size_split_k << ")"); + // Resize buffer + if (ctx->prealloc_split_k != nullptr) { + ggml_vk_destroy_buffer(ctx->prealloc_split_k); + } + ctx->prealloc_split_k = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_split_k); + } +} + +static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_tensor* tensor, int tensor_idx, bool use_fence); + +// Returns true if node has enqueued work into the queue, false otherwise +// If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution. +static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * node, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool submit){ + if (ggml_is_empty(node) || !node->buffer) { + return false; + } + + VK_LOG_DEBUG("ggml_vk_build_graph(" << node << ", " << ggml_op_name(node->op) << ")"); + ctx->semaphore_idx = 0; + + const ggml_tensor * src0 = node->src[0]; + const ggml_tensor * src1 = node->src[1]; + const ggml_tensor * src2 = node->src[2]; + const ggml_tensor * src3 = node->src[3]; + + switch (node->op) { + // Return on empty ops to avoid generating a compute_ctx and setting exit_tensor + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_PERMUTE: + case GGML_OP_TRANSPOSE: + case GGML_OP_NONE: + return false; + case GGML_OP_UNARY: + switch (ggml_get_unary_op(node)) { + case GGML_UNARY_OP_SILU: + case GGML_UNARY_OP_GELU: + case GGML_UNARY_OP_GELU_QUICK: + case GGML_UNARY_OP_RELU: + case GGML_UNARY_OP_TANH: + break; + default: + return false; + } + break; + case GGML_OP_REPEAT: + case GGML_OP_GET_ROWS: + case GGML_OP_ADD: + case GGML_OP_ACC: + case GGML_OP_MUL: + case GGML_OP_DIV: + case GGML_OP_CONCAT: + case GGML_OP_UPSCALE: + case GGML_OP_SCALE: + case GGML_OP_SQR: + case GGML_OP_SIN: + case GGML_OP_COS: + case GGML_OP_CLAMP: + case GGML_OP_PAD: + case GGML_OP_CPY: + case GGML_OP_CONT: + case GGML_OP_DUP: + case GGML_OP_NORM: + case GGML_OP_GROUP_NORM: + case GGML_OP_RMS_NORM: + case GGML_OP_DIAG_MASK_INF: + case GGML_OP_SOFT_MAX: + case GGML_OP_ROPE: + case GGML_OP_MUL_MAT: + case GGML_OP_MUL_MAT_ID: + case GGML_OP_ARGSORT: + case GGML_OP_SUM_ROWS: + case GGML_OP_IM2COL: + case GGML_OP_TIMESTEP_EMBEDDING: + case GGML_OP_POOL_2D: + case GGML_OP_RWKV_WKV6: + case GGML_OP_LEAKY_RELU: + case GGML_OP_FLASH_ATTN_EXT: + break; + default: + std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl; + GGML_ABORT("fatal error"); + return false; + } + + vk_context compute_ctx; + + if (!dryrun) { + if (ctx->compute_ctx.expired()) { + compute_ctx = ggml_vk_create_context(ctx, ctx->device->compute_queue); + ctx->compute_ctx = compute_ctx; + ggml_vk_ctx_begin(ctx->device, compute_ctx); + } else { + compute_ctx = ctx->compute_ctx.lock(); + } + } else { + switch (node->op) { + case GGML_OP_REPEAT: + case GGML_OP_ACC: + case GGML_OP_GET_ROWS: + case GGML_OP_ADD: + case GGML_OP_MUL: + case GGML_OP_DIV: + case GGML_OP_CONCAT: + case GGML_OP_UPSCALE: + case GGML_OP_SCALE: + case GGML_OP_SQR: + case GGML_OP_SIN: + case GGML_OP_COS: + case GGML_OP_CLAMP: + case GGML_OP_PAD: + case GGML_OP_CPY: + case GGML_OP_CONT: + case GGML_OP_DUP: + case GGML_OP_NORM: + case GGML_OP_GROUP_NORM: + case GGML_OP_RMS_NORM: + case GGML_OP_UNARY: + case GGML_OP_DIAG_MASK_INF: + case GGML_OP_SOFT_MAX: + case GGML_OP_ROPE: + case GGML_OP_ARGSORT: + case GGML_OP_SUM_ROWS: + case GGML_OP_IM2COL: + case GGML_OP_TIMESTEP_EMBEDDING: + case GGML_OP_POOL_2D: + case GGML_OP_LEAKY_RELU: + { + // These operations all go through ggml_vk_op_f32, so short-circuit and + // do the only thing needed for the dryrun. + vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, node, node->op); + ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); + return false; + } + default: + break; + } + } + + switch (node->op) { + case GGML_OP_REPEAT: + ggml_vk_repeat(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_ACC: + ggml_vk_acc(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_GET_ROWS: + ggml_vk_get_rows(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_ADD: + ggml_vk_add(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_MUL: + ggml_vk_mul(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_DIV: + ggml_vk_div(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_CONCAT: + ggml_vk_concat(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_UPSCALE: + ggml_vk_upscale(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_SCALE: + ggml_vk_scale(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_SQR: + ggml_vk_sqr(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_SIN: + ggml_vk_sin(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_COS: + ggml_vk_cos(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_CLAMP: + ggml_vk_clamp(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_PAD: + ggml_vk_pad(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_CPY: + case GGML_OP_CONT: + case GGML_OP_DUP: + ggml_vk_cpy(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_NORM: + ggml_vk_norm(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_GROUP_NORM: + ggml_vk_group_norm(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_RMS_NORM: + ggml_vk_rms_norm(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_UNARY: + switch (ggml_get_unary_op(node)) { + case GGML_UNARY_OP_SILU: + case GGML_UNARY_OP_GELU: + case GGML_UNARY_OP_GELU_QUICK: + case GGML_UNARY_OP_RELU: + case GGML_UNARY_OP_TANH: + ggml_vk_unary(ctx, compute_ctx, src0, node, dryrun); + break; + default: + return false; + } + break; + case GGML_OP_DIAG_MASK_INF: + ggml_vk_diag_mask_inf(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_SOFT_MAX: + ggml_vk_soft_max(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_ROPE: + ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, dryrun); + + break; + case GGML_OP_ARGSORT: + ggml_vk_argsort(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_SUM_ROWS: + ggml_vk_sum_rows(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_IM2COL: + ggml_vk_im2col(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_TIMESTEP_EMBEDDING: + ggml_vk_timestep_embedding(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_POOL_2D: + ggml_vk_pool_2d(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_LEAKY_RELU: + ggml_vk_leaky_relu(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_MUL_MAT: + ggml_vk_mul_mat(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_MUL_MAT_ID: + ggml_vk_mul_mat_id(ctx, compute_ctx, src0, src1, src2, node, dryrun); + + break; + + case GGML_OP_FLASH_ATTN_EXT: + ggml_vk_flash_attn(ctx, compute_ctx, src0, src1, src2, src3, node, dryrun); + + break; + + case GGML_OP_RWKV_WKV6: + ggml_vk_rwkv_wkv6(ctx, compute_ctx, node, dryrun); + + break; + default: + return false; + } + + if (dryrun) { + return false; + } + + ctx->tensor_ctxs[node_idx] = compute_ctx; + +#if defined(GGML_VULKAN_CHECK_RESULTS) || defined(GGML_VULKAN_PERF) + // Force context reset on each node so that each tensor ends up in its own context + // and can be run and compared to its CPU equivalent separately + last_node = true; +#endif + + if (submit || last_node) { + ggml_vk_ctx_end(compute_ctx); + + // TODO probably it'd be better to pass a exit_node flag to ggml_vk_compute_forward + if (last_node) { + compute_ctx->exit_tensor_idx = node_idx_begin; + } + else { + compute_ctx->exit_tensor_idx = -1; + } + + ctx->compute_ctx.reset(); + + bool ok = ggml_vk_compute_forward(ctx, node_begin, node_idx_begin, false); + if (!ok) { + if (node->op == GGML_OP_UNARY) { + std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast(node->op_params[0])) << ")" << std::endl; + } + else { + std::cerr << __func__ << ": error: op not supported " << node->name << " (" << ggml_op_name(node->op) << ")" << std::endl; + } + } + + } + return true; +} + +static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * tensor, int tensor_idx, bool use_fence = true){ + ggml_backend_buffer * buf = nullptr; + + switch (tensor->op) { + case GGML_OP_ADD: + case GGML_OP_ACC: + case GGML_OP_GET_ROWS: + case GGML_OP_MUL: + case GGML_OP_DIV: + case GGML_OP_CONCAT: + case GGML_OP_UPSCALE: + case GGML_OP_SCALE: + case GGML_OP_SQR: + case GGML_OP_SIN: + case GGML_OP_COS: + case GGML_OP_CLAMP: + case GGML_OP_PAD: + case GGML_OP_CPY: + case GGML_OP_CONT: + case GGML_OP_DUP: + case GGML_OP_NORM: + case GGML_OP_GROUP_NORM: + case GGML_OP_RMS_NORM: + case GGML_OP_DIAG_MASK_INF: + case GGML_OP_SOFT_MAX: + case GGML_OP_ROPE: + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_PERMUTE: + case GGML_OP_TRANSPOSE: + case GGML_OP_NONE: + case GGML_OP_ARGSORT: + case GGML_OP_SUM_ROWS: + case GGML_OP_IM2COL: + case GGML_OP_TIMESTEP_EMBEDDING: + case GGML_OP_POOL_2D: + case GGML_OP_RWKV_WKV6: + case GGML_OP_LEAKY_RELU: + case GGML_OP_REPEAT: + buf = tensor->buffer; + + break; + case GGML_OP_UNARY: + switch (ggml_get_unary_op(tensor)) { + case GGML_UNARY_OP_SILU: + case GGML_UNARY_OP_GELU: + case GGML_UNARY_OP_GELU_QUICK: + case GGML_UNARY_OP_RELU: + case GGML_UNARY_OP_TANH: + buf = tensor->buffer; + break; + default: + return false; + } + break; + case GGML_OP_MUL_MAT: + case GGML_OP_MUL_MAT_ID: + case GGML_OP_FLASH_ATTN_EXT: + buf = tensor->buffer; + + break; + default: + return false; + } + + if (buf == nullptr) { + return false; + } + + VK_LOG_DEBUG("ggml_vk_compute_forward(" << tensor << ", name=" << tensor->name << ", op=" << ggml_op_name(tensor->op) << ", type=" << tensor->type << ", ne0=" << tensor->ne[0] << ", ne1=" << tensor->ne[1] << ", ne2=" << tensor->ne[2] << ", ne3=" << tensor->ne[3] << ", nb0=" << tensor->nb[0] << ", nb1=" << tensor->nb[1] << ", nb2=" << tensor->nb[2] << ", nb3=" << tensor->nb[3] << ", view_src=" << tensor->view_src << ", view_offs=" << tensor->view_offs << ")"); + + vk_context subctx = ctx->tensor_ctxs[tensor_idx].lock(); + + // always wait for the GPU work to be done for the last submit + if (tensor_idx == subctx->exit_tensor_idx) { + use_fence = true; + } + + // Only run if ctx hasn't been submitted yet + if (!subctx->seqs.empty()) { +#ifdef GGML_VULKAN_CHECK_RESULTS + ggml_vk_check_results_0(tensor); + use_fence = true; +#endif + + // Do staging buffer copies + for (auto& cpy : subctx->in_memcpys) { + memcpy(cpy.dst, cpy.src, cpy.n); + } + + ggml_vk_submit(subctx, use_fence ? ctx->fence : vk::Fence{}); + + if (use_fence) { + VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_compute_forward waitForFences"); + + ctx->device->device.resetFences({ ctx->fence }); + } +#ifdef GGML_VULKAN_CHECK_RESULTS + ggml_vk_check_results_1(tensor); +#endif + } + + if (tensor_idx == subctx->exit_tensor_idx) { + // Do staging buffer copies + for (auto& cpy : subctx->out_memcpys) { + memcpy(cpy.dst, cpy.src, cpy.n); + } + subctx->in_memcpys.clear(); + subctx->out_memcpys.clear(); + } + + return true; +} + +// Clean up after graph processing is done +static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) { + VK_LOG_DEBUG("ggml_vk_graph_cleanup()"); + for (auto& buffer : ctx->gc.temp_buffers) { + ggml_vk_pool_free(ctx, buffer); + } + ctx->gc.temp_buffers.clear(); + + for (auto& dsr : ctx->device->pipeline_descriptor_set_requirements) { + vk_pipeline_ref plr = ctx->device->pipelines[dsr.first]; + + if (plr.expired()) { + continue; + } + + vk_pipeline pl = plr.lock(); + ggml_pipeline_cleanup(pl); + } + + ggml_vk_queue_cleanup(ctx->device, ctx->device->compute_queue); + ggml_vk_queue_cleanup(ctx->device, ctx->device->transfer_queue); + + for (size_t i = 0; i < ctx->gc.semaphores.size(); i++) { + ctx->device->device.destroySemaphore({ ctx->gc.semaphores[i].s }); + } + ctx->gc.semaphores.clear(); + + for (size_t i = 0; i < ctx->gc.tl_semaphores.size(); i++) { + ctx->device->device.destroySemaphore({ ctx->gc.tl_semaphores[i].s }); + } + ctx->gc.tl_semaphores.clear(); + ctx->semaphore_idx = 0; + + ctx->event_idx = 0; + + for (auto& event : ctx->gc.events) { + ctx->device->device.resetEvent(event); + } + + ctx->tensor_ctxs.clear(); + ctx->gc.contexts.clear(); + ctx->device->pipeline_descriptor_set_requirements.clear(); +} + +// Clean up on backend free +static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) { + VK_LOG_DEBUG("ggml_vk_cleanup(" << ctx->name << ")"); + ggml_vk_graph_cleanup(ctx); + + ggml_vk_destroy_buffer(ctx->prealloc_x); + ggml_vk_destroy_buffer(ctx->prealloc_y); + ggml_vk_destroy_buffer(ctx->prealloc_split_k); + + for (auto& buffer : ctx->buffer_pool) { + ggml_vk_destroy_buffer(buffer); + } + + ctx->prealloc_size_x = 0; + ctx->prealloc_size_y = 0; + ctx->prealloc_size_split_k = 0; + + for (auto& event : ctx->gc.events) { + ctx->device->device.destroyEvent(event); + } + ctx->gc.events.clear(); + + ctx->device->device.destroyFence(ctx->fence); +} + +static int ggml_vk_get_device_count() { + ggml_vk_instance_init(); + + return vk_instance.device_indices.size(); +} + +static void ggml_vk_get_device_description(int device, char * description, size_t description_size) { + ggml_vk_instance_init(); + + std::vector devices = vk_instance.instance.enumeratePhysicalDevices(); + + vk::PhysicalDeviceProperties props; + devices[device].getProperties(&props); + + snprintf(description, description_size, "%s", props.deviceName.data()); +} + +// backend interface + +#define UNUSED GGML_UNUSED + +// device backend + +static bool ggml_backend_buffer_is_vk(ggml_backend_buffer_t buffer) { + return buffer->buft->iface.get_name == ggml_backend_vk_buffer_type_name; +} + +static void ggml_backend_vk_buffer_free_buffer(ggml_backend_buffer_t buffer) { + VK_LOG_MEMORY("ggml_backend_vk_buffer_free_buffer()"); + ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context; + ggml_vk_destroy_buffer(ctx->dev_buffer); + delete ctx; +} + +static void * ggml_backend_vk_buffer_get_base(ggml_backend_buffer_t buffer) { + return vk_ptr_base; + + UNUSED(buffer); +} + +static void ggml_backend_vk_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { + VK_LOG_DEBUG("ggml_backend_vk_buffer_init_tensor(" << buffer << " (" << buffer->context << "), " << tensor << ")"); + if (tensor->view_src != nullptr) { + GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft); + } +} + +static void ggml_backend_vk_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + VK_LOG_DEBUG("ggml_backend_vk_buffer_set_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")"); + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context; + vk_buffer buf = buf_ctx->dev_buffer; + + ggml_vk_buffer_write(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size); +} + +static void ggml_backend_vk_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { + VK_LOG_DEBUG("ggml_backend_vk_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")"); + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context; + + vk_buffer buf = buf_ctx->dev_buffer; + + ggml_vk_buffer_read(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size); +} + +static bool ggml_backend_vk_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) { + if (ggml_backend_buffer_is_vk(src->buffer)) { + ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context; + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + + vk_buffer src_buf = src_buf_ctx->dev_buffer; + vk_buffer dst_buf = dst_buf_ctx->dev_buffer; + + ggml_vk_buffer_copy(dst_buf, vk_tensor_offset(dst) + dst->view_offs, src_buf, vk_tensor_offset(src) + src->view_offs, ggml_nbytes(src)); + + return true; + } + return false; + + UNUSED(buffer); +} + +static void ggml_backend_vk_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context; + + ggml_vk_buffer_memset(ctx->dev_buffer, 0, value, buffer->size); +} + +static ggml_backend_buffer_i ggml_backend_vk_buffer_interface = { + /* .free_buffer = */ ggml_backend_vk_buffer_free_buffer, + /* .get_base = */ ggml_backend_vk_buffer_get_base, + /* .init_tensor = */ ggml_backend_vk_buffer_init_tensor, + /* .memset_tensor = */ NULL, + /* .set_tensor = */ ggml_backend_vk_buffer_set_tensor, + /* .get_tensor = */ ggml_backend_vk_buffer_get_tensor, + /* .cpy_tensor = */ ggml_backend_vk_buffer_cpy_tensor, + /* .clear = */ ggml_backend_vk_buffer_clear, + /* .reset = */ NULL, +}; + +// vk buffer type +static const char * ggml_backend_vk_buffer_type_name(ggml_backend_buffer_type_t buft) { + ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *)buft->context; + + return ctx->name.c_str(); +} + +static ggml_backend_buffer_t ggml_backend_vk_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + VK_LOG_MEMORY("ggml_backend_vk_buffer_type_alloc_buffer(" << size << ")"); + ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context; + + vk_buffer dev_buffer = nullptr; + try { + dev_buffer = ggml_vk_create_buffer_device(ctx->device, size); + } catch (const vk::SystemError& e) { + return nullptr; + } + + ggml_backend_vk_buffer_context * bufctx = new ggml_backend_vk_buffer_context(ctx->device, std::move(dev_buffer), ctx->name); + + return ggml_backend_buffer_init(buft, ggml_backend_vk_buffer_interface, bufctx, size); +} + +static size_t ggml_backend_vk_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { + ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context; + return ctx->device->properties.limits.minStorageBufferOffsetAlignment; +} + +static size_t ggml_backend_vk_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) { + ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context; + return ctx->device->max_memory_allocation_size; +} + +static size_t ggml_backend_vk_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { + return ggml_nbytes(tensor); + + UNUSED(buft); +} + +ggml_backend_buffer_type_t ggml_backend_vk_buffer_type(size_t dev_num) { + ggml_vk_instance_init(); + + VK_LOG_DEBUG("ggml_backend_vk_buffer_type(" << dev_num << ")"); + + vk_device dev = ggml_vk_get_device(dev_num); + + return &dev->buffer_type; +} + +// host buffer type + +static const char * ggml_backend_vk_host_buffer_type_name(ggml_backend_buffer_type_t buft) { + return GGML_VK_NAME "_Host"; + + UNUSED(buft); +} + +static const char * ggml_backend_vk_host_buffer_name(ggml_backend_buffer_t buffer) { + return GGML_VK_NAME "_Host"; + + UNUSED(buffer); +} + +static void ggml_backend_vk_host_buffer_free_buffer(ggml_backend_buffer_t buffer) { + VK_LOG_MEMORY("ggml_backend_vk_host_buffer_free_buffer()"); + ggml_vk_host_free(vk_instance.devices[0], buffer->context); +} + +static ggml_backend_buffer_t ggml_backend_vk_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + VK_LOG_MEMORY("ggml_backend_vk_host_buffer_type_alloc_buffer(" << size << ")"); + + size += 32; // Behave like the CPU buffer type + void * ptr = nullptr; + try { + ptr = ggml_vk_host_malloc(vk_instance.devices[0], size); + } catch (vk::SystemError& e) { + std::cerr << "ggml_vulkan: Failed to allocate pinned memory." << std::endl; + std::cerr << "ggml_vulkan: " << e.what() << std::endl; + // fallback to cpu buffer + return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size); + } + + ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size); + buffer->buft = buft; + buffer->iface.free_buffer = ggml_backend_vk_host_buffer_free_buffer; + + return buffer; + + UNUSED(buft); +} + +static size_t ggml_backend_vk_host_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { + return vk_instance.devices[0]->properties.limits.minMemoryMapAlignment; + + UNUSED(buft); +} + +// Should be changed to return device-specific host buffer type +// but that probably requires changes in llama.cpp +ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type() { + static struct ggml_backend_buffer_type ggml_backend_vk_buffer_type_host = { + /* .iface = */ { + /* .get_name = */ ggml_backend_vk_host_buffer_type_name, + /* .alloc_buffer = */ ggml_backend_vk_host_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_vk_host_buffer_type_get_alignment, + /* .get_max_size = */ NULL, // defaults to SIZE_MAX + /* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size, + /* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host, + }, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), 0), + /* .context = */ nullptr, + }; + + // Make sure device 0 is initialized + ggml_vk_instance_init(); + ggml_vk_get_device(0); + + return &ggml_backend_vk_buffer_type_host; +} + + +// backend + +static const char * ggml_backend_vk_name(ggml_backend_t backend) { + ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; + + return ctx->name.c_str(); +} + +static void ggml_backend_vk_free(ggml_backend_t backend) { + ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; + VK_LOG_DEBUG("ggml_backend_vk_free(" << ctx->name << ")"); + + ggml_vk_cleanup(ctx); + + delete ctx; + delete backend; +} + +static ggml_backend_buffer_type_t ggml_backend_vk_get_default_buffer_type(ggml_backend_t backend) { + ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; + + return &ctx->device->buffer_type; +} + +static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + VK_LOG_DEBUG("ggml_backend_vk_set_tensor_async(" << size << ")"); + ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; + GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type"); + + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context; + + vk_context transfer_ctx; + + if (ctx->transfer_ctx.expired()) { + // Initialize new transfer context + transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue); + ctx->transfer_ctx = transfer_ctx; + ggml_vk_ctx_begin(ctx->device, transfer_ctx); + } else { + transfer_ctx = ctx->transfer_ctx.lock(); + } + + vk_buffer buf = buf_ctx->dev_buffer; + + ggml_vk_buffer_write_async(transfer_ctx, buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size); +} + +static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { + VK_LOG_DEBUG("ggml_backend_vk_get_tensor_async(" << size << ")"); + ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; + GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type"); + + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context; + + vk_context transfer_ctx; + + if (ctx->transfer_ctx.expired()) { + // Initialize new transfer context + transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue); + ctx->transfer_ctx = transfer_ctx; + ggml_vk_ctx_begin(ctx->device, transfer_ctx); + } else { + transfer_ctx = ctx->transfer_ctx.lock(); + } + + vk_buffer buf = buf_ctx->dev_buffer; + + ggml_vk_buffer_read_async(transfer_ctx, buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size); +} + +static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend, const ggml_tensor * src, ggml_tensor * dst) { + VK_LOG_DEBUG("ggml_backend_vk_cpy_tensor_async()"); + ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; + if ((dst->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || dst->buffer->buft == ggml_backend_vk_host_buffer_type()) && ggml_backend_buffer_is_vk(src->buffer)) { + ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context; + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + + vk_context transfer_ctx; + + if (ctx->transfer_ctx.expired()) { + // Initialize new transfer context + transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue); + ctx->transfer_ctx = transfer_ctx; + ggml_vk_ctx_begin(ctx->device, transfer_ctx); + } else { + transfer_ctx = ctx->transfer_ctx.lock(); + } + + vk_buffer src_buf = src_buf_ctx->dev_buffer; + vk_buffer dst_buf = dst_buf_ctx->dev_buffer; + + ggml_vk_buffer_copy_async(transfer_ctx, dst_buf, vk_tensor_offset(dst) + dst->view_offs, src_buf, vk_tensor_offset(src) + src->view_offs, ggml_nbytes(src)); + return true; + } + + return false; +} + +static void ggml_backend_vk_synchronize(ggml_backend_t backend) { + VK_LOG_DEBUG("ggml_backend_vk_synchronize()"); + ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; + if(ctx->transfer_ctx.expired()) { + return; + } + + vk_context transfer_ctx = ctx->transfer_ctx.lock(); + + ggml_vk_ctx_end(transfer_ctx); + + for (auto& cpy : transfer_ctx->in_memcpys) { + memcpy(cpy.dst, cpy.src, cpy.n); + } + + ggml_vk_submit(transfer_ctx, ctx->fence); + VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_backend_vk_synchronize waitForFences"); + ctx->device->device.resetFences({ ctx->fence }); + + for (auto& cpy : transfer_ctx->out_memcpys) { + memcpy(cpy.dst, cpy.src, cpy.n); + } + + ctx->transfer_ctx.reset(); +} + +static bool ggml_vk_is_empty(ggml_tensor * node) { + return ggml_is_empty(node) || node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE; +} + +static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { + VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)"); + ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; + + for (int i = 0; i < cgraph->n_nodes; i++) { + ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false); + } + ggml_vk_preallocate_buffers(ctx); + ggml_pipeline_allocate_descriptor_sets(ctx->device); + + int last_node = cgraph->n_nodes - 1; + + // If the last op in the cgraph isn't backend GPU, the command buffer doesn't get closed properly + while (last_node > 0 && ggml_vk_is_empty(cgraph->nodes[last_node])) { + last_node -= 1; + } + + // Reserve tensor context space for all nodes + ctx->tensor_ctxs.resize(cgraph->n_nodes); + + bool first_node_in_batch = true; // true if next node will be first node in a batch + int submit_node_idx = 0; // index to first node in a batch + + // Submit work every nodes_per_submit nodes to overlap CPU cmdbuffer generation with GPU execution. + // Start with a smaller count to get work submitted right away, and increase it after each submit. + int nodes_per_submit = 20; + int submitted_nodes = 0; + int submit_count = 0; + for (int i = 0; i < cgraph->n_nodes; i++) { + if (first_node_in_batch) { + submit_node_idx = i; + } + + bool submit = (submitted_nodes >= nodes_per_submit) || (i == last_node); + + bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, submit); + + if (enqueued) { + ++submitted_nodes; + +#ifndef GGML_VULKAN_CHECK_RESULTS + if (first_node_in_batch) { + first_node_in_batch = false; + } +#endif + } + + if (submit) { + first_node_in_batch = true; + submitted_nodes = 0; + switch (submit_count) { + case 0: + nodes_per_submit = 50; + break; + default: + nodes_per_submit = 100; + break; + } + submit_count++; + } + } + +#ifdef GGML_VULKAN_PERF + ctx->device->perf_logger->print_timings(); +#endif + + ggml_vk_graph_cleanup(ctx); + + return GGML_STATUS_SUCCESS; + + UNUSED(backend); +} + +// TODO: enable async and synchronize +static ggml_backend_i ggml_backend_vk_interface = { + /* .get_name = */ ggml_backend_vk_name, + /* .free = */ ggml_backend_vk_free, + /* .set_tensor_async = */ NULL, // ggml_backend_vk_set_tensor_async, + /* .get_tensor_async = */ NULL, // ggml_backend_vk_get_tensor_async, + /* .cpy_tensor_async = */ NULL, // ggml_backend_vk_cpy_tensor_async, + /* .synchronize = */ NULL, // ggml_backend_vk_synchronize, + /* .graph_plan_create = */ NULL, + /* .graph_plan_free = */ NULL, + /* .graph_plan_update = */ NULL, + /* .graph_plan_compute = */ NULL, + /* .graph_compute = */ ggml_backend_vk_graph_compute, + /* .event_record = */ NULL, + /* .event_wait = */ NULL, +}; + +static ggml_guid_t ggml_backend_vk_guid() { + static ggml_guid guid = { 0xb8, 0xf7, 0x4f, 0x86, 0x40, 0x3c, 0xe1, 0x02, 0x91, 0xc8, 0xdd, 0xe9, 0x02, 0x3f, 0xc0, 0x2b }; + return &guid; +} + +ggml_backend_t ggml_backend_vk_init(size_t dev_num) { + VK_LOG_DEBUG("ggml_backend_vk_init(" << dev_num << ")"); + + ggml_backend_vk_context * ctx = new ggml_backend_vk_context; + ggml_vk_init(ctx, dev_num); + + ggml_backend_t vk_backend = new ggml_backend { + /* .guid = */ ggml_backend_vk_guid(), + /* .interface = */ ggml_backend_vk_interface, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), dev_num), + /* .context = */ ctx, + }; + + return vk_backend; +} + +bool ggml_backend_is_vk(ggml_backend_t backend) { + return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_vk_guid()); +} + +int ggml_backend_vk_get_device_count() { + return ggml_vk_get_device_count(); +} + +void ggml_backend_vk_get_device_description(int device, char * description, size_t description_size) { + GGML_ASSERT(device < (int) vk_instance.device_indices.size()); + int dev_idx = vk_instance.device_indices[device]; + ggml_vk_get_device_description(dev_idx, description, description_size); +} + +void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total) { + GGML_ASSERT(device < (int) vk_instance.device_indices.size()); + + vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device]]; + + vk::PhysicalDeviceMemoryProperties memprops = vkdev.getMemoryProperties(); + + for (const vk::MemoryHeap& heap : memprops.memoryHeaps) { + if (heap.flags & vk::MemoryHeapFlagBits::eDeviceLocal) { + *total = heap.size; + *free = heap.size; + break; + } + } +} + +////////////////////////// + +struct ggml_backend_vk_device_context { + size_t device; + std::string name; + std::string description; +}; + +static const char * ggml_backend_vk_device_get_name(ggml_backend_dev_t dev) { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + return ctx->name.c_str(); +} + +static const char * ggml_backend_vk_device_get_description(ggml_backend_dev_t dev) { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + return ctx->description.c_str(); +} + +static void ggml_backend_vk_device_get_memory(ggml_backend_dev_t device, size_t * free, size_t * total) { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)device->context; + ggml_backend_vk_get_device_memory(ctx->device, free, total); +} + +static ggml_backend_buffer_type_t ggml_backend_vk_device_get_buffer_type(ggml_backend_dev_t dev) { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + return ggml_backend_vk_buffer_type(ctx->device); +} + +static ggml_backend_buffer_type_t ggml_backend_vk_device_get_host_buffer_type(ggml_backend_dev_t dev) { + UNUSED(dev); + return ggml_backend_vk_host_buffer_type(); +} + +static enum ggml_backend_dev_type ggml_backend_vk_device_get_type(ggml_backend_dev_t dev) { + UNUSED(dev); + return GGML_BACKEND_DEVICE_TYPE_GPU; +} + +static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) { + props->name = ggml_backend_vk_device_get_name(dev); + props->description = ggml_backend_vk_device_get_description(dev); + props->type = ggml_backend_vk_device_get_type(dev); + ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total); + props->caps = { + /* .async = */ false, + /* .host_buffer = */ true, + /* .buffer_from_host_ptr = */ false, + /* .events = */ false, + }; +} + +static ggml_backend_t ggml_backend_vk_device_init(ggml_backend_dev_t dev, const char * params) { + UNUSED(params); + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + return ggml_backend_vk_init(ctx->device); +} + +static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { + switch (op->op) { + case GGML_OP_UNARY: + switch (ggml_get_unary_op(op)) { + case GGML_UNARY_OP_GELU: + case GGML_UNARY_OP_GELU_QUICK: + case GGML_UNARY_OP_SILU: + case GGML_UNARY_OP_RELU: + case GGML_UNARY_OP_TANH: + return ggml_is_contiguous(op->src[0]); + default: + return false; + } + break; + case GGML_OP_MUL_MAT: + case GGML_OP_MUL_MAT_ID: + { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + const vk_device& device = ggml_vk_get_device(ctx->device); + if (op->op == GGML_OP_MUL_MAT_ID && !device->mul_mat_id_s && !device->mul_mat_id_m && !device->mul_mat_id_l) { + // If there's not enough shared memory for row_ids and the result tile, fallback to CPU + return false; + } + switch (op->src[0]->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ4_NL: + break; + default: + return false; + } + struct ggml_tensor * a; + struct ggml_tensor * b; + if (op->op == GGML_OP_MUL_MAT) { + a = op->src[0]; + b = op->src[1]; + } else { + a = op->src[2]; + b = op->src[1]; + } + if (a->ne[3] != b->ne[3]) { + return false; + } + if (!(ggml_vk_dim01_contiguous(op->src[0]) || op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) || + !(ggml_vk_dim01_contiguous(op->src[1]) || op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16)) { + return false; + } + + return true; + } break; + case GGML_OP_FLASH_ATTN_EXT: + { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + if (!ggml_vk_get_device(ctx->device)->coopmat2) { + return false; + } + switch (op->src[0]->ne[0]) { + case 64: + case 80: + case 96: + case 112: + case 128: + case 256: + break; + default: + return false; + } + if (op->src[0]->type != GGML_TYPE_F32) { + return false; + } + if (op->type != GGML_TYPE_F32) { + return false; + } + if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) { + return false; + } + // It's straightforward to support different K/V dequant, but would + // significantly increase the number of pipelines + if (op->src[1]->type != op->src[2]->type) { + return false; + } + switch (op->src[1]->type) { + case GGML_TYPE_F16: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + // K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently + //case GGML_TYPE_Q2_K: + //case GGML_TYPE_Q3_K: + //case GGML_TYPE_Q4_K: + //case GGML_TYPE_Q5_K: + //case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ4_NL: + break; + default: + return false; + } + return true; + } + case GGML_OP_GET_ROWS: + { + switch (op->src[0]->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_IQ4_NL: + return true; + default: + return false; + } + } break; + case GGML_OP_CONT: + case GGML_OP_CPY: + case GGML_OP_DUP: + { + ggml_type src0_type = op->src[0]->type; + ggml_type src1_type = op->src[1] != nullptr ? op->src[1]->type : src0_type; + if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) { + return true; + } + if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) { + return true; + } + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { + return true; + } + return false; + } break; + case GGML_OP_REPEAT: + return ggml_type_size(op->type) == sizeof(float) && ggml_type_size(op->src[0]->type) == sizeof(float); + case GGML_OP_ROPE: + { + const int mode = ((const int32_t *) op->op_params)[2]; + if (mode & GGML_ROPE_TYPE_MROPE) { + return false; + } + if (mode & GGML_ROPE_TYPE_VISION) { + return false; + } + return ggml_is_contiguous(op->src[0]); + } + case GGML_OP_NONE: + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_PERMUTE: + case GGML_OP_TRANSPOSE: + case GGML_OP_NORM: + case GGML_OP_GROUP_NORM: + case GGML_OP_RMS_NORM: + case GGML_OP_ADD: + case GGML_OP_ACC: + case GGML_OP_MUL: + case GGML_OP_DIV: + case GGML_OP_CONCAT: + case GGML_OP_UPSCALE: + case GGML_OP_SCALE: + case GGML_OP_SQR: + case GGML_OP_SIN: + case GGML_OP_COS: + case GGML_OP_CLAMP: + case GGML_OP_PAD: + case GGML_OP_DIAG_MASK_INF: + case GGML_OP_SOFT_MAX: + case GGML_OP_ARGSORT: + case GGML_OP_SUM_ROWS: + case GGML_OP_IM2COL: + case GGML_OP_TIMESTEP_EMBEDDING: + case GGML_OP_POOL_2D: + case GGML_OP_RWKV_WKV6: + case GGML_OP_LEAKY_RELU: + return true; + default: + return false; + } + + UNUSED(dev); +} + +static bool ggml_backend_vk_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { + if (buft->iface.get_name != ggml_backend_vk_buffer_type_name) { + return false; + } + + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + ggml_backend_vk_buffer_type_context * buft_ctx = (ggml_backend_vk_buffer_type_context *)buft->context; + + return buft_ctx->device->idx == ctx->device; +} + +static bool ggml_backend_vk_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) { + const int min_batch_size = 32; + + return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) || + (op->ne[2] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID); + + UNUSED(dev); +} + +static const struct ggml_backend_device_i ggml_backend_vk_device_i = { + /* .get_name = */ ggml_backend_vk_device_get_name, + /* .get_description = */ ggml_backend_vk_device_get_description, + /* .get_memory = */ ggml_backend_vk_device_get_memory, + /* .get_type = */ ggml_backend_vk_device_get_type, + /* .get_props = */ ggml_backend_vk_device_get_props, + /* .init_backend = */ ggml_backend_vk_device_init, + /* .get_buffer_type = */ ggml_backend_vk_device_get_buffer_type, + /* .get_host_buffer_type = */ ggml_backend_vk_device_get_host_buffer_type, + /* .buffer_from_host_ptr = */ NULL, + /* .supports_op = */ ggml_backend_vk_device_supports_op, + /* .supports_buft = */ ggml_backend_vk_device_supports_buft, + /* .offload_op = */ ggml_backend_vk_device_offload_op, + /* .event_new = */ NULL, + /* .event_free = */ NULL, + /* .event_synchronize = */ NULL, +}; + +static const char * ggml_backend_vk_reg_get_name(ggml_backend_reg_t reg) { + UNUSED(reg); + return GGML_VK_NAME; +} + +static size_t ggml_backend_vk_reg_get_device_count(ggml_backend_reg_t reg) { + UNUSED(reg); + return ggml_backend_vk_get_device_count(); +} + +static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, size_t device) { + static std::vector devices; + + static bool initialized = false; + + { + static std::mutex mutex; + std::lock_guard lock(mutex); + if (!initialized) { + for (int i = 0; i < ggml_backend_vk_get_device_count(); i++) { + ggml_backend_vk_device_context * ctx = new ggml_backend_vk_device_context; + char desc[256]; + ggml_backend_vk_get_device_description(i, desc, sizeof(desc)); + ctx->device = i; + ctx->name = GGML_VK_NAME + std::to_string(i); + ctx->description = desc; + devices.push_back(new ggml_backend_device { + /* .iface = */ ggml_backend_vk_device_i, + /* .reg = */ reg, + /* .context = */ ctx, + }); + } + initialized = true; + } + } + + GGML_ASSERT(device < devices.size()); + return devices[device]; +} + +static const struct ggml_backend_reg_i ggml_backend_vk_reg_i = { + /* .get_name = */ ggml_backend_vk_reg_get_name, + /* .get_device_count = */ ggml_backend_vk_reg_get_device_count, + /* .get_device = */ ggml_backend_vk_reg_get_device, + /* .get_proc_address = */ NULL, +}; + +ggml_backend_reg_t ggml_backend_vk_reg() { + static ggml_backend_reg reg = { + /* .api_version = */ GGML_BACKEND_API_VERSION, + /* .iface = */ ggml_backend_vk_reg_i, + /* .context = */ nullptr, + }; + + return ® +} + +// Extension availability +static bool ggml_vk_instance_validation_ext_available(const std::vector& instance_extensions) { +#ifdef GGML_VULKAN_VALIDATE + bool portability_enumeration_ext = false; + // Check for portability enumeration extension for MoltenVK support + for (const auto& properties : instance_extensions) { + if (strcmp("VK_KHR_portability_enumeration", properties.extensionName) == 0) { + return true; + } + } + if (!portability_enumeration_ext) { + std::cerr << "ggml_vulkan: WARNING: Instance extension VK_KHR_portability_enumeration not found." << std::endl; + } +#endif + return false; + + UNUSED(instance_extensions); +} +static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector& instance_extensions) { +#ifdef __APPLE__ + bool portability_enumeration_ext = false; + // Check for portability enumeration extension for MoltenVK support + for (const auto& properties : instance_extensions) { + if (strcmp("VK_KHR_portability_enumeration", properties.extensionName) == 0) { + return true; + } + } + if (!portability_enumeration_ext) { + std::cerr << "ggml_vulkan: WARNING: Instance extension VK_KHR_portability_enumeration not found." << std::endl; + } +#endif + return false; + + UNUSED(instance_extensions); +} + +static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props) { + switch (props.vendorID) { + case VK_VENDOR_ID_INTEL: + // Intel drivers don't support coopmat properly yet + return false; + case VK_VENDOR_ID_AMD: + if (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource) { + // Workaround for AMD proprietary driver reporting support on all GPUs + const std::string name = props.deviceName; + return name.rfind("AMD Radeon RX 7", 0) == 0 || name.rfind("AMD Radeon(TM) RX 7", 0) == 0 || // RDNA 3 consumer GPUs + name.rfind("AMD Radeon PRO W7", 0) == 0 || name.rfind("AMD Radeon(TM) PRO W7", 0) == 0 || // RDNA 3 workstation GPUs + name.rfind("AMD Radeon 7", 0) == 0 || name.rfind("AMD Radeon(TM) 7", 0) == 0; // RDNA 3 APUs + } + return true; + default: + return true; + } +} + +// checks + +#ifdef GGML_VULKAN_CHECK_RESULTS +static void ggml_vk_print_graph_origin(const ggml_tensor * tensor, std::vector& done, int level = 0) { + if (std::find(done.begin(), done.end(), tensor) != done.end() || level > 10) { + return; + } + for (int j = 0; j < level; j++) { + std::cerr << " "; + } + std::cerr << ggml_op_name(tensor->op) << " gpu=" << (tensor->extra != nullptr) << std::endl; + + done.push_back(tensor); + + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (tensor->src[i] != nullptr) { + ggml_vk_print_graph_origin(tensor->src[i], done, level + 1); + } + } +} + +static void ggml_vk_print_tensor_area(const ggml_tensor * tensor, const void * data, int i0, int i1, int i2, int i3) { + if (tensor->type != GGML_TYPE_F32 && tensor->type != GGML_TYPE_F16 && tensor->type != GGML_TYPE_I32) { + return; + } + i0 = std::max(i0, 5); + i1 = std::max(i1, 5); + i2 = std::max(i2, 0); + i3 = std::max(i3, 0); + fprintf(stderr, " "); + for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) { + fprintf(stderr, "%7d ", idx1); + } + fprintf(stderr, "\n"); + for (int idx0 = i0 - 5; idx0 < i0 + 5; idx0++) { + fprintf(stderr, "%7d: ", idx0); + for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) { + if (idx0 >= 0 && idx0 < tensor->ne[0] && idx1 >= 0 && idx1 < tensor->ne[1] && i2 >= 0 && i2 < tensor->ne[2] && i3 >= 0 && i3 < tensor->ne[3]) { + float val; + if (tensor->type == GGML_TYPE_F32) { + val = *(const float *) ((const char *) data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]); + } else if (tensor->type == GGML_TYPE_F16) { + val = ggml_fp16_to_fp32(*(const ggml_fp16_t *) ((const char *) data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0])); + } else if (tensor->type == GGML_TYPE_I32) { + val = *(const int32_t *) ((const char *) data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]); + } else { + GGML_ABORT("fatal error"); + } + fprintf(stderr, "% 7.2f ", val); + } else { + fprintf(stderr, " "); + } + } + fprintf(stderr, "\n"); + } +} + +static void ggml_vk_print_tensor(const ggml_tensor * tensor, const char * name) { + void * tensor_data = tensor->data; + + const bool is_gpu = tensor->buffer != nullptr && ggml_backend_buffer_is_vk(tensor->buffer); + + if (is_gpu) { + const size_t tensor_size = ggml_nbytes(tensor); + tensor_data = malloc(tensor_size); + + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context; + + vk_buffer buffer_gpu = buf_ctx->dev_buffer; + ggml_vk_buffer_read(buffer_gpu, vk_tensor_offset(tensor) + tensor->view_offs, tensor_data, tensor_size); + } + + std::cerr << "TENSOR CHECK " << name << " (" << tensor->name << "): " << ggml_op_name(tensor->op) << std::endl; + std::cerr << "tensor=" << tensor << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << std::endl; + if (tensor->src[0] != nullptr) { + std::cerr << "tensor->src[0]=" << tensor->src[0] << " name=" << tensor->src[0]->name << " op=" << ggml_op_name(tensor->src[0]->op) << " type=" << ggml_type_name(tensor->src[0]->type) << " ne0=" << tensor->src[0]->ne[0] << " nb0=" << tensor->src[0]->nb[0] << " ne1=" << tensor->src[0]->ne[1] << " nb1=" << tensor->src[0]->nb[1] << " ne2=" << tensor->src[0]->ne[2] << " nb2=" << tensor->src[0]->nb[2] << " ne3=" << tensor->src[0]->ne[3] << " nb3=" << tensor->src[0]->nb[3] << std::endl; + } + if (tensor->src[1] != nullptr) { + std::cerr << "tensor->src[1]=" << tensor->src[1] << " name=" << tensor->src[1]->name << " op=" << ggml_op_name(tensor->src[1]->op) << " type=" << ggml_type_name(tensor->src[1]->type) << " ne0=" << tensor->src[1]->ne[0] << " nb0=" << tensor->src[1]->nb[0] << " ne1=" << tensor->src[1]->ne[1] << " nb1=" << tensor->src[1]->nb[1] << " ne2=" << tensor->src[1]->ne[2] << " nb2=" << tensor->src[1]->nb[2] << " ne3=" << tensor->src[1]->ne[3] << " nb3=" << tensor->src[1]->nb[3] << std::endl; + } + std::cerr << std::endl << "Result:" << std::endl; + ggml_vk_print_tensor_area(tensor, tensor_data, 5, 5, 0, 0); + std::cerr << std::endl; + std::vector done; + ggml_vk_print_graph_origin(tensor, done); + + if (is_gpu) { + free(tensor_data); + } +} + +void * comp_result; +size_t comp_size; +size_t comp_nb[GGML_MAX_DIMS]; +size_t check_counter = 0; +static void ggml_vk_check_results_0(ggml_tensor * tensor) { + if (tensor->op == GGML_OP_TRANSPOSE) { + return; + } + + check_counter++; + if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) { + return; + } + + VK_LOG_DEBUG("ggml_vk_check_results_0(" << tensor->name << ")"); + + ggml_tensor * src0 = tensor->src[0]; + ggml_tensor * src1 = tensor->src[1]; + ggml_tensor * src2 = tensor->src[2]; + ggml_tensor * src3 = tensor->src[3]; + + struct ggml_init_params iparams = { + /*.mem_size =*/ 2ul*1024ul*1024ul*1024ul, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ false, + }; + + struct ggml_context * ggml_ctx = ggml_init(iparams); + + struct ggml_tensor * src0_clone = nullptr; + struct ggml_tensor * src1_clone = nullptr; + struct ggml_tensor * src2_clone = nullptr; + struct ggml_tensor * src3_clone = nullptr; + struct ggml_tensor * tensor_clone = nullptr; + + size_t src0_size; + size_t src1_size; + size_t src2_size; + size_t src3_size; + + void * src0_buffer = nullptr; + void * src1_buffer = nullptr; + void * src2_buffer = nullptr; + void * src3_buffer = nullptr; + + if (src0 != nullptr) { + src0_clone = ggml_dup_tensor(ggml_ctx, src0); + + src0_size = ggml_nbytes(src0); + + src0_buffer = malloc(src0_size); + src0_clone->data = src0_buffer; + if (ggml_backend_buffer_is_host(src0->buffer)) { + memcpy(src0_clone->data, src0->data, src0_size); + memcpy(src0_clone->nb, src0->nb, sizeof(size_t) * GGML_MAX_DIMS); + } else if (ggml_backend_buffer_is_vk(src0->buffer)) { + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; + vk_buffer& buffer_gpu = buf_ctx->dev_buffer; + uint64_t offset = vk_tensor_offset(src0) + src0->view_offs; + if (!ggml_is_contiguous(src0) && ggml_vk_dim01_contiguous(src0)) { + for (int i3 = 0; i3 < src0->ne[3]; i3++) { + for (int i2 = 0; i2 < src0->ne[2]; i2++) { + const int idx = i3*src0->ne[2] + i2; + ggml_vk_buffer_read(buffer_gpu, offset + idx * src0->nb[2], ((char *)src0_clone->data + idx * src0_clone->nb[2]), src0->ne[1] * src0->nb[1]); + } + } + + src0_clone->nb[0] = src0->nb[0]; + src0_clone->nb[1] = src0->nb[1]; + for (int i = 2; i < GGML_MAX_DIMS; i++) { + src0_clone->nb[i] = src0_clone->nb[i - 1]*src0_clone->ne[i - 1]; + } + } else { + if (offset + src0_size >= buffer_gpu->size) { + src0_size = buffer_gpu->size - offset; + } + ggml_vk_buffer_read(buffer_gpu, offset, src0_clone->data, src0_size); + memcpy(src0_clone->nb, src0->nb, sizeof(size_t) * GGML_MAX_DIMS); + } + } else { + GGML_ABORT("fatal error"); + } + + if (vk_output_tensor > 0 && vk_output_tensor == check_counter) { + ggml_vk_print_tensor(src0, "src0"); + } + } + if (src1 != nullptr) { + src1_clone = ggml_dup_tensor(ggml_ctx, src1); + + src1_size = ggml_nbytes(src1); + + src1_buffer = malloc(src1_size); + src1_clone->data = src1_buffer; + if (ggml_backend_buffer_is_host(src1->buffer)) { + memcpy(src1_clone->data, src1->data, src1_size); + memcpy(src1_clone->nb, src1->nb, sizeof(size_t) * GGML_MAX_DIMS); + } else if (ggml_backend_buffer_is_vk(src1->buffer)) { + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; + vk_buffer& buffer_gpu = buf_ctx->dev_buffer; + uint64_t offset = vk_tensor_offset(src1) + src1->view_offs; + if (!ggml_is_contiguous(src1) && ggml_vk_dim01_contiguous(src1)) { + for (int i3 = 0; i3 < src1->ne[3]; i3++) { + for (int i2 = 0; i2 < src1->ne[2]; i2++) { + const int idx = i3*src1->ne[2] + i2; + ggml_vk_buffer_read(buffer_gpu, offset + idx * src1->nb[2], ((char *)src1_clone->data + idx * src1_clone->nb[2]), src1->ne[1] * src1->nb[1]); + } + } + + src1_clone->nb[0] = src1->nb[0]; + src1_clone->nb[1] = src1->nb[1]; + for (int i = 2; i < GGML_MAX_DIMS; i++) { + src1_clone->nb[i] = src1_clone->nb[i - 1]*src1_clone->ne[i - 1]; + } + } else { + if (offset + src1_size >= buffer_gpu->size) { + src1_size = buffer_gpu->size - offset; + } + ggml_vk_buffer_read(buffer_gpu, offset, src1_clone->data, src1_size); + memcpy(src1_clone->nb, src1->nb, sizeof(size_t) * GGML_MAX_DIMS); + } + } else { + GGML_ABORT("fatal error"); + } + + if (vk_output_tensor > 0 && vk_output_tensor == check_counter) { + ggml_vk_print_tensor(src1, "src1"); + } + } + if (src2 != nullptr) { + src2_clone = ggml_dup_tensor(ggml_ctx, src2); + + src2_size = ggml_nbytes(src2); + + src2_buffer = malloc(src2_size); + src2_clone->data = src2_buffer; + if (ggml_backend_buffer_is_host(src2->buffer)) { + memcpy(src2_clone->data, src2->data, src2_size); + memcpy(src2_clone->nb, src2->nb, sizeof(size_t) * GGML_MAX_DIMS); + } else if (ggml_backend_buffer_is_vk(src2->buffer)) { + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src2->buffer->context; + vk_buffer& buffer_gpu = buf_ctx->dev_buffer; + uint64_t offset = vk_tensor_offset(src2) + src2->view_offs; + if (!ggml_is_contiguous(src2) && ggml_vk_dim01_contiguous(src2)) { + for (int i3 = 0; i3 < src2->ne[3]; i3++) { + for (int i2 = 0; i2 < src2->ne[2]; i2++) { + const int idx = i3*src2->ne[2] + i2; + ggml_vk_buffer_read(buffer_gpu, offset + idx * src2->nb[2], ((char *)src2_clone->data + idx * src2_clone->nb[2]), src2->ne[1] * src2->nb[1]); + } + } + + src2_clone->nb[0] = src2->nb[0]; + src2_clone->nb[1] = src2->nb[1]; + for (int i = 2; i < GGML_MAX_DIMS; i++) { + src2_clone->nb[i] = src2_clone->nb[i - 1]*src2_clone->ne[i - 1]; + } + } else { + if (offset + src2_size >= buffer_gpu->size) { + src2_size = buffer_gpu->size - offset; + } + ggml_vk_buffer_read(buffer_gpu, offset, src2_clone->data, src2_size); + memcpy(src2_clone->nb, src2->nb, sizeof(size_t) * GGML_MAX_DIMS); + } + } else { + GGML_ABORT("fatal error"); + } + + if (vk_output_tensor > 0 && vk_output_tensor == check_counter) { + ggml_vk_print_tensor(src2, "src2"); + } + } + if (src3 != nullptr) { + src3_clone = ggml_dup_tensor(ggml_ctx, src3); + + src3_size = ggml_nbytes(src3); + + src3_buffer = malloc(src3_size); + src3_clone->data = src3_buffer; + if (ggml_backend_buffer_is_host(src3->buffer)) { + memcpy(src3_clone->data, src3->data, src3_size); + memcpy(src3_clone->nb, src3->nb, sizeof(size_t) * GGML_MAX_DIMS); + } else if (ggml_backend_buffer_is_vk(src3->buffer)) { + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src3->buffer->context; + vk_buffer& buffer_gpu = buf_ctx->dev_buffer; + uint64_t offset = vk_tensor_offset(src3) + src3->view_offs; + if (!ggml_is_contiguous(src3) && ggml_vk_dim01_contiguous(src3)) { + for (int i3 = 0; i3 < src3->ne[3]; i3++) { + for (int i2 = 0; i2 < src3->ne[2]; i2++) { + const int idx = i3*src3->ne[2] + i2; + ggml_vk_buffer_read(buffer_gpu, offset + idx * src3->nb[2], ((char *)src3_clone->data + idx * src3_clone->nb[2]), src3->ne[1] * src3->nb[1]); + } + } + + src3_clone->nb[0] = src3->nb[0]; + src3_clone->nb[1] = src3->nb[1]; + for (int i = 2; i < GGML_MAX_DIMS; i++) { + src3_clone->nb[i] = src3_clone->nb[i - 1]*src3_clone->ne[i - 1]; + } + } else { + if (offset + src3_size >= buffer_gpu->size) { + src3_size = buffer_gpu->size - offset; + } + ggml_vk_buffer_read(buffer_gpu, offset, src3_clone->data, src3_size); + memcpy(src3_clone->nb, src3->nb, sizeof(size_t) * GGML_MAX_DIMS); + } + } else { + GGML_ABORT("fatal error"); + } + + if (vk_output_tensor > 0 && vk_output_tensor == check_counter) { + ggml_vk_print_tensor(src3, "src3"); + } + } + + if (tensor->op == GGML_OP_FLASH_ATTN_EXT) { + const float *params = (const float *)tensor->op_params; + tensor_clone = ggml_flash_attn_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, src3_clone, params[0], params[1], params[2]); + } else if (tensor->op == GGML_OP_MUL_MAT) { + tensor_clone = ggml_mul_mat(ggml_ctx, src0_clone, src1_clone); + } else if (tensor->op == GGML_OP_MUL_MAT_ID) { + tensor_clone = ggml_mul_mat_id(ggml_ctx, src0_clone, src1_clone, src2_clone); + } else if (tensor->op == GGML_OP_MUL) { + tensor_clone = ggml_mul(ggml_ctx, src0_clone, src1_clone); + } else if (tensor->op == GGML_OP_DIV) { + tensor_clone = ggml_div(ggml_ctx, src0_clone, src1_clone); + } else if (tensor->op == GGML_OP_CONCAT) { + tensor_clone = ggml_concat(ggml_ctx, src0_clone, src1_clone, *(int *)tensor->op_params); + } else if (tensor->op == GGML_OP_UPSCALE) { + tensor_clone = ggml_upscale_ext(ggml_ctx, src0_clone, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); + } else if (tensor->op == GGML_OP_SCALE) { + tensor_clone = ggml_scale(ggml_ctx, src0_clone, ((float *)tensor->op_params)[0]); + } else if (tensor->op == GGML_OP_SQR) { + tensor_clone = ggml_sqr(ggml_ctx, src0_clone); + } else if (tensor->op == GGML_OP_SIN) { + tensor_clone = ggml_sin(ggml_ctx, src0_clone); + } else if (tensor->op == GGML_OP_COS) { + tensor_clone = ggml_cos(ggml_ctx, src0_clone); + } else if (tensor->op == GGML_OP_CLAMP) { + tensor_clone = ggml_clamp(ggml_ctx, src0_clone, ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]); + } else if (tensor->op == GGML_OP_PAD) { + tensor_clone = ggml_pad(ggml_ctx, src0_clone, tensor->ne[0] - src0_clone->ne[0], tensor->ne[1] - src0_clone->ne[1], tensor->ne[2] - src0_clone->ne[2], tensor->ne[3] - src0_clone->ne[3]); + } else if (tensor->op == GGML_OP_REPEAT) { + tensor_clone = ggml_repeat(ggml_ctx, src0_clone, tensor); + } else if (tensor->op == GGML_OP_ADD) { + tensor_clone = ggml_add(ggml_ctx, src0_clone, src1_clone); + } else if (tensor->op == GGML_OP_ACC) { + tensor_clone = ggml_acc(ggml_ctx, src0_clone, src1_clone, tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]); + } else if (tensor->op == GGML_OP_NORM) { + tensor_clone = ggml_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params); + } else if (tensor->op == GGML_OP_GROUP_NORM) { + tensor_clone = ggml_group_norm(ggml_ctx, src0_clone, *(int *)tensor->op_params, ((float *)tensor->op_params)[1]); + } else if (tensor->op == GGML_OP_RMS_NORM) { + tensor_clone = ggml_rms_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params); + } else if (tensor->op == GGML_OP_SOFT_MAX) { + if (src1 != nullptr) { + tensor_clone = ggml_soft_max_ext(ggml_ctx, src0_clone, src1_clone, ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]); + } else { + tensor_clone = ggml_soft_max(ggml_ctx, src0_clone); + } + } else if (tensor->op == GGML_OP_DIAG_MASK_INF) { + tensor_clone = ggml_diag_mask_inf(ggml_ctx, src0_clone, *(int *)tensor->op_params); + } else if (tensor->op == GGML_OP_ROPE) { + const int n_dims = ((int32_t *) tensor->op_params)[1]; + const int mode = ((int32_t *) tensor->op_params)[2]; + //const int n_ctx_ggml = ((int32_t *) tensor->op_params)[3]; + const int n_ctx_orig_ggml = ((int32_t *) tensor->op_params)[4]; + const float freq_base = ((float *) tensor->op_params)[5]; + const float freq_scale = ((float *) tensor->op_params)[6]; + const float ext_factor = ((float *) tensor->op_params)[7]; + const float attn_factor = ((float *) tensor->op_params)[8]; + const float beta_fast = ((float *) tensor->op_params)[9]; + const float beta_slow = ((float *) tensor->op_params)[10]; + tensor_clone = ggml_rope_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + } else if (tensor->op == GGML_OP_UNARY) { + switch (ggml_get_unary_op(tensor)) { + case GGML_UNARY_OP_SILU: + tensor_clone = ggml_silu(ggml_ctx, src0_clone); + break; + case GGML_UNARY_OP_GELU: + tensor_clone = ggml_gelu(ggml_ctx, src0_clone); + break; + case GGML_UNARY_OP_GELU_QUICK: + tensor_clone = ggml_gelu_quick(ggml_ctx, src0_clone); + break; + case GGML_UNARY_OP_RELU: + tensor_clone = ggml_relu(ggml_ctx, src0_clone); + break; + case GGML_UNARY_OP_TANH: + tensor_clone = ggml_tanh(ggml_ctx, src0_clone); + break; + default: + std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl; + GGML_ABORT("fatal error"); + } + } else if (tensor->op == GGML_OP_CPY || tensor->op == GGML_OP_DUP) { + if (src1 == nullptr) { + tensor_clone = ggml_dup(ggml_ctx, src0_clone); + tensor_clone->type = tensor->type; + } else { + tensor_clone = ggml_cpy(ggml_ctx, src0_clone, src1_clone); + } + } else if (tensor->op == GGML_OP_CONT) { + tensor_clone = ggml_cont_4d(ggml_ctx, src0_clone, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); + } else if (tensor->op == GGML_OP_RESHAPE) { + tensor_clone = ggml_reshape_4d(ggml_ctx, src0_clone, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); + } else if (tensor->op == GGML_OP_VIEW) { + tensor_clone = ggml_view_4d(ggml_ctx, src0_clone, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], tensor->nb[1], tensor->nb[2], tensor->nb[3], ((int32_t *) tensor->op_params)[0]); + } else if (tensor->op == GGML_OP_PERMUTE) { + int32_t * params = (int32_t *)tensor->op_params; + tensor_clone = ggml_permute(ggml_ctx, src0_clone, params[0], params[1], params[2], params[3]); + } else if (tensor->op == GGML_OP_TRANSPOSE) { + tensor_clone = ggml_transpose(ggml_ctx, src0_clone); + } else if (tensor->op == GGML_OP_GET_ROWS) { + tensor_clone = ggml_get_rows(ggml_ctx, src0_clone, src1_clone); + } else if (tensor->op == GGML_OP_ARGSORT) { + tensor_clone = ggml_argsort(ggml_ctx, src0_clone, (ggml_sort_order) *(int *)tensor->op_params); + } else if (tensor->op == GGML_OP_SUM_ROWS) { + tensor_clone = ggml_sum_rows(ggml_ctx, src0_clone); + } else if (tensor->op == GGML_OP_IM2COL) { + const int32_t s0 = tensor->op_params[0]; + const int32_t s1 = tensor->op_params[1]; + const int32_t p0 = tensor->op_params[2]; + const int32_t p1 = tensor->op_params[3]; + const int32_t d0 = tensor->op_params[4]; + const int32_t d1 = tensor->op_params[5]; + + const bool is_2D = tensor->op_params[6] == 1; + tensor_clone = ggml_im2col(ggml_ctx, src0_clone, src1_clone, s0, s1, p0, p1, d0, d1, is_2D, tensor->type); + } else if (tensor->op == GGML_OP_TIMESTEP_EMBEDDING) { + const int32_t dim = tensor->op_params[0]; + const int32_t max_period = tensor->op_params[1]; + tensor_clone = ggml_timestep_embedding(ggml_ctx, src0_clone, dim, max_period); + } else if (tensor->op == GGML_OP_POOL_2D) { + enum ggml_op_pool op = static_cast(tensor->op_params[0]); + const int32_t k0 = tensor->op_params[1]; + const int32_t k1 = tensor->op_params[2]; + const int32_t s0 = tensor->op_params[3]; + const int32_t s1 = tensor->op_params[4]; + const int32_t p0 = tensor->op_params[5]; + const int32_t p1 = tensor->op_params[6]; + + tensor_clone = ggml_pool_2d(ggml_ctx, src0_clone, op, k0, k1, s0, s1, p0, p1); + } else if (tensor->op == GGML_OP_LEAKY_RELU) { + const float * op_params = (const float *)tensor->op_params; + tensor_clone = ggml_leaky_relu(ggml_ctx, src0_clone, op_params[0], false); + } else if (tensor->op == GGML_OP_RWKV_WKV6) { + tensor_clone = ggml_rwkv_wkv6(ggml_ctx, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], + tensor->src[4], tensor->src[5]); + } + else { + std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl; + GGML_ABORT("fatal error"); + } + + ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx); + ggml_build_forward_expand(cgraph, tensor_clone); + + ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 8); + + if (vk_output_tensor > 0 && vk_output_tensor == check_counter) { + ggml_vk_print_tensor(tensor_clone, "tensor_clone"); + } + + comp_size = ggml_nbytes(tensor_clone); + + comp_result = malloc(comp_size); + memcpy(comp_result, tensor_clone->data, comp_size); + memcpy(comp_nb, tensor_clone->nb, sizeof(size_t) * GGML_MAX_DIMS); + + if (src0 != nullptr) { + free(src0_buffer); + } + if (src1 != nullptr) { + free(src1_buffer); + } + + ggml_free(ggml_ctx); + + VK_LOG_DEBUG("END ggml_vk_check_results_0(" << tensor->name << ")"); +} + +static void ggml_vk_check_results_1(ggml_tensor * tensor) { + if (tensor->op == GGML_OP_TRANSPOSE) { + return; + } + if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) { + return; + } + + VK_LOG_DEBUG("ggml_vk_check_results_1(" << tensor->name << ")"); + + ggml_tensor * src0 = tensor->src[0]; + ggml_tensor * src1 = tensor->src[1]; + ggml_tensor * src2 = tensor->src[2]; + + void * tensor_data = tensor->data; + + if (ggml_backend_buffer_is_vk(tensor->buffer)) { + size_t tensor_size = ggml_nbytes(tensor); + tensor_data = malloc(tensor_size); + + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context; + + vk_buffer& buffer_gpu = buf_ctx->dev_buffer; + uint64_t offset = vk_tensor_offset(tensor) + tensor->view_offs; + if (offset + tensor_size >= buffer_gpu->size) { + tensor_size = buffer_gpu->size - offset; + } + + ggml_vk_buffer_read(buffer_gpu, offset, tensor_data, tensor_size); + } + + float first_error_result = -1.0f; + float first_error_correct = -1.0f; + std::array first_error = { -1, -1, -1, -1 }; + double avg_err = 0.0; + size_t counter = 0; + + for (int i3 = 0; i3 < tensor->ne[3]; i3++) { + for (int i2 = 0; i2 < tensor->ne[2]; i2++) { + for (int i1 = 0; i1 < tensor->ne[1]; i1++) { + for (int i0 = 0; i0 < tensor->ne[0]; i0++) { + const bool buffer_size_fit = i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0] < comp_size; + float correct = 0.0f; + float result = 0.0f; + + if (buffer_size_fit) { + if (tensor->type == GGML_TYPE_F32) { + correct = *(float *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]); + result = *(float *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]); + } else if (tensor->type == GGML_TYPE_F16) { + correct = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0])); + result = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0])); + } else if (tensor->type == GGML_TYPE_I32) { + correct = *(int32_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]); + result = *(int32_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]); + } else { + std::cerr << "Results check not implemented for type " << ggml_type_name(tensor->type) << std::endl; + } + } else { + std::cerr << "Missing debug code for type " << ggml_type_name(tensor->type) << std::endl; + GGML_ABORT("fatal error"); + } + + if ((std::isnan(correct) != std::isnan(result)) || (std::isinf(correct) != std::isinf(result)) || !buffer_size_fit) { + std::cerr << "ERROR: Invalid value in " << ggml_op_name(tensor->op) << " i3=" << i3 << " i2=" << i2 << " i1=" << i1 << " i0=" << i0 << " result=" << result << " correct=" << correct << " avg_err=" << (avg_err / counter) << std::endl; + std::cerr << "tensor=" << tensor << " tensor->name=" << tensor->name << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << " offset=" << tensor->view_offs << std::endl; + if (src0 != nullptr) { + std::cerr << "src0=" << src0 << " src0->name=" << src0->name << " op=" << ggml_op_name(src0->op) << " type=" << ggml_type_name(src0->type) << " ne0=" << src0->ne[0] << " nb0=" << src0->nb[0] << " ne1=" << src0->ne[1] << " nb1=" << src0->nb[1] << " ne2=" << src0->ne[2] << " nb2=" << src0->nb[2] << " ne3=" << src0->ne[3] << " nb3=" << src0->nb[3] << " offset=" << src0->view_offs << std::endl; + } + if (src1 != nullptr) { + std::cerr << "src1=" << src1 << " src1->name=" << src1->name << " op=" << ggml_op_name(src1->op) << " type=" << ggml_type_name(src1->type) << " ne0=" << src1->ne[0] << " nb0=" << src1->nb[0] << " ne1=" << src1->ne[1] << " nb1=" << src1->nb[1] << " ne2=" << src1->ne[2] << " nb2=" << src1->nb[2] << " ne3=" << src1->ne[3] << " nb3=" << src1->nb[3] << " offset=" << src1->view_offs << std::endl; + } + if (src2 != nullptr) { + std::cerr << "src2=" << src2 << " src2->name=" << src2->name << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl; + } + std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl; + std::cerr << std::endl << "Result:" << std::endl; + ggml_vk_print_tensor_area(tensor, tensor_data, i0, i1, i2, i3); + std::cerr << std::endl << "Correct:" << std::endl; + ggml_vk_print_tensor_area(tensor, comp_result, i0, i1, i2, i3); + std::cerr << std::endl; + std::vector done; + ggml_vk_print_graph_origin(tensor, done); + GGML_ABORT("fatal error"); + } + if (first_error[0] == -1 && std::fabs(correct - result) > 0.1f) { + first_error[0] = i0; + first_error[1] = i1; + first_error[2] = i2; + first_error[3] = i3; + first_error_result = result; + first_error_correct = correct; + } + + // Special case, value is infinite, avoid NaN result in avg_err + // NaN also appears in results, if both are nan error is 0 + if (!std::isinf(correct) && !std::isinf(result) && !std::isnan(correct) && !std::isnan(result)) { + avg_err += std::fabs(correct - result); + } + counter++; + } + } + } + } + + avg_err /= counter; + + if (vk_output_tensor > 0 && vk_output_tensor == check_counter) { + std::cerr << "TENSOR CHECK: avg_err=" << avg_err << " in " << ggml_op_name(tensor->op) << " (check " << check_counter << ")" << std::endl; + std::cerr << "tensor=" << tensor << " tensor->name=" << tensor->name << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << " offset=" << tensor->view_offs << std::endl; + if (src0 != nullptr) { + std::cerr << "src0=" << src0 << " op=" << ggml_op_name(src0->op) << " type=" << ggml_type_name(src0->type) << " ne0=" << src0->ne[0] << " nb0=" << src0->nb[0] << " ne1=" << src0->ne[1] << " nb1=" << src0->nb[1] << " ne2=" << src0->ne[2] << " nb2=" << src0->nb[2] << " ne3=" << src0->ne[3] << " nb3=" << src0->nb[3] << " offset=" << src0->view_offs << std::endl; + } + if (src1 != nullptr) { + std::cerr << "src1=" << src1 << " op=" << ggml_op_name(src1->op) << " type=" << ggml_type_name(src1->type) << " ne0=" << src1->ne[0] << " nb0=" << src1->nb[0] << " ne1=" << src1->ne[1] << " nb1=" << src1->nb[1] << " ne2=" << src1->ne[2] << " nb2=" << src1->nb[2] << " ne3=" << src1->ne[3] << " nb3=" << src1->nb[3] << " offset=" << src1->view_offs << std::endl; + } + if (src2 != nullptr) { + std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl; + } + std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl; + std::cerr << std::endl << "Result:" << std::endl; + ggml_vk_print_tensor_area(tensor, tensor_data, 5, 5, 0, 0); + std::cerr << std::endl << "Correct:" << std::endl; + ggml_vk_print_tensor_area(tensor, comp_result, 5, 5, 0, 0); + std::cerr << std::endl; + std::vector done; + ggml_vk_print_graph_origin(tensor, done); + } + + if (avg_err > 0.05 || std::isnan(avg_err)) { + std::cerr << "ERROR: avg_err=" << avg_err << " in " << ggml_op_name(tensor->op) << " (check " << check_counter << ")" << std::endl; + std::cerr << "tensor=" << tensor << " tensor->name=" << tensor->name << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << " offset=" << tensor->view_offs << std::endl; + if (src0 != nullptr) { + std::cerr << "src0=" << src0 << " op=" << ggml_op_name(src0->op) << " type=" << ggml_type_name(src0->type) << " ne0=" << src0->ne[0] << " nb0=" << src0->nb[0] << " ne1=" << src0->ne[1] << " nb1=" << src0->nb[1] << " ne2=" << src0->ne[2] << " nb2=" << src0->nb[2] << " ne3=" << src0->ne[3] << " nb3=" << src0->nb[3] << " offset=" << src0->view_offs << std::endl; + } + if (src1 != nullptr) { + std::cerr << "src1=" << src1 << " op=" << ggml_op_name(src1->op) << " type=" << ggml_type_name(src1->type) << " ne0=" << src1->ne[0] << " nb0=" << src1->nb[0] << " ne1=" << src1->ne[1] << " nb1=" << src1->nb[1] << " ne2=" << src1->ne[2] << " nb2=" << src1->nb[2] << " ne3=" << src1->ne[3] << " nb3=" << src1->nb[3] << " offset=" << src1->view_offs << std::endl; + } + if (src2 != nullptr) { + std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl; + } + std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl; + std::cerr << std::endl << "Result:" << std::endl; + ggml_vk_print_tensor_area(tensor, tensor_data, first_error[0], first_error[1], first_error[2], first_error[3]); + std::cerr << std::endl << "Correct:" << std::endl; + ggml_vk_print_tensor_area(tensor, comp_result, first_error[0], first_error[1], first_error[2], first_error[3]); + std::cerr << std::endl; + std::vector done; + ggml_vk_print_graph_origin(tensor, done); + GGML_ABORT("fatal error"); + } else { + std::cerr << check_counter << " " << tensor->name << " op=" << ggml_op_name(tensor->op) << " avg_err=" << avg_err << std::endl; + } + + free(comp_result); + comp_result = nullptr; + comp_size = 0; + + if (ggml_backend_buffer_is_vk(tensor->buffer)) { + free(tensor_data); + } + + VK_LOG_DEBUG("END ggml_vk_check_results_1(" << tensor->name << ")"); +} +#endif + +GGML_BACKEND_DL_IMPL(ggml_backend_vk_reg) diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt new file mode 100644 index 000000000..bd0c74cb1 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt @@ -0,0 +1,9 @@ +find_package (Threads REQUIRED) +find_package(Vulkan COMPONENTS glslc REQUIRED) + +set(TARGET vulkan-shaders-gen) +add_executable(${TARGET} vulkan-shaders-gen.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_compile_features(${TARGET} PRIVATE cxx_std_17) +target_link_libraries(vulkan-shaders-gen PUBLIC Threads::Threads) +target_link_libraries(vulkan-shaders-gen PRIVATE Vulkan::Vulkan) diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp new file mode 100644 index 000000000..d896f1ef0 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp @@ -0,0 +1,29 @@ +#version 450 + +#include "types.comp" +#include "generic_binary_head.comp" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint idx = gl_GlobalInvocationID.x; + if (idx >= p.ne) { + return; + } + + const uint offset = p.param3; + const uint src1_i = idx - offset; + const uint oz = src1_i / p.nb02; + const uint oy = (src1_i - (oz * p.nb02)) / p.nb01; + const uint ox = src1_i % p.nb01; + + uint i00, i01, i02, i03; + get_indices(idx, i00, i01, i02, i03); + + if (ox < p.ne10 && oy < p.ne11 && oz < p.ne12) { + data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + ox + oy * p.ne10 + oz * p.ne10 * p.ne11])); + } else { + data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)])); + } +} + diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/add.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/add.comp new file mode 100644 index 000000000..2b4085c4f --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/add.comp @@ -0,0 +1,29 @@ +#version 450 + +#extension GL_EXT_shader_16bit_storage : require + +#include "types.comp" +#include "generic_binary_head.comp" + +const uint num_threads = 256; + +layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; + +void main() { + uint idx = get_idx(); + + // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation + const uint num_iter = 2; + + [[unroll]] for (uint i = 0; i < num_iter; ++i) { + if (idx >= p.ne) { + continue; + } + uint i00, i01, i02, i03; + get_indices(idx, i00, i01, i02, i03); + + data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)])); + + idx += num_threads; + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp new file mode 100644 index 000000000..d4fa45b1e --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp @@ -0,0 +1,69 @@ +#version 450 + +#include "types.comp" + +#define BLOCK_SIZE 1024 +#define ASC 0 + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) buffer D {int data_d[];}; + +layout (push_constant) uniform parameter { + uint ncols; + uint ncols_pad; + uint order; +} p; + +shared int dst_row[BLOCK_SIZE]; + +void swap(uint idx0, uint idx1) { + int tmp = dst_row[idx0]; + dst_row[idx0] = dst_row[idx1]; + dst_row[idx1] = tmp; +} + +void main() { + // bitonic sort + const int col = int(gl_LocalInvocationID.x); + const uint row = gl_WorkGroupID.y; + + const uint row_offset = row * p.ncols; + + // initialize indices + if (col < p.ncols_pad) { + dst_row[col] = col; + } + barrier(); + + for (uint k = 2; k <= p.ncols_pad; k *= 2) { + for (uint j = k / 2; j > 0; j /= 2) { + const uint ixj = col ^ j; + if (col < p.ncols_pad && ixj > col) { + if ((col & k) == 0) { + if (dst_row[col] >= p.ncols || + (dst_row[ixj] < p.ncols && (p.order == ASC ? + data_a[row_offset + dst_row[col]] > data_a[row_offset + dst_row[ixj]] : + data_a[row_offset + dst_row[col]] < data_a[row_offset + dst_row[ixj]])) + ) { + swap(col, ixj); + } + } else { + if (dst_row[ixj] >= p.ncols || + (dst_row[col] < p.ncols && (p.order == ASC ? + data_a[row_offset + dst_row[col]] < data_a[row_offset + dst_row[ixj]] : + data_a[row_offset + dst_row[col]] > data_a[row_offset + dst_row[ixj]])) + ) { + swap(col, ixj); + } + } + } + barrier(); + } + } + + if (col < p.ncols) { + data_d[row_offset + col] = dst_row[col]; + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp new file mode 100644 index 000000000..1e5cb8dae --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp @@ -0,0 +1,17 @@ +#version 450 + +#include "types.comp" +#include "generic_unary_head.comp" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint idx = get_idx(); + + if (idx >= p.ne) { + return; + } + + const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]); + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val < p.param1 ? p.param1 : (val > p.param2 ? p.param2 : val)); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp new file mode 100644 index 000000000..9ee2f1fae --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp @@ -0,0 +1,41 @@ +#version 450 + +#include "types.comp" +#include "generic_binary_head.comp" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + const int dim = p.param3; + + if (idx >= p.ne) { + return; + } + + const uint i3 = idx / (p.ne22*p.ne21*p.ne20); + const uint i3_offset = i3 * p.ne22*p.ne21*p.ne20; + const uint i2 = (idx - i3_offset) / (p.ne21*p.ne20); + const uint i2_offset = i2*p.ne21*p.ne20; + const uint i1 = (idx - i3_offset - i2_offset) / p.ne20; + const uint i0 = idx - i3_offset - i2_offset - i1*p.ne20; + + uint o[4] = {0, 0, 0, 0}; + o[dim] = dim == 0 ? p.ne00 : (dim == 1 ? p.ne01 : (dim == 2 ? p.ne02 : p.ne03)); + + const uint src0_idx = i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0*p.nb00; + const uint src1_idx = (i3 - o[3])*p.nb13 + (i2 - o[2])*p.nb12 + (i1 - o[1])*p.nb11 + (i0 - o[0])*p.nb10; + const uint dst_idx = i3*p.nb23 + i2*p.nb22 + i1*p.nb21 + i0*p.nb20; + + const bool is_src0 = i0 < p.ne00 && i1 < p.ne01 && i2 < p.ne02 && i3 < p.ne03; + +#ifndef OPTIMIZATION_ERROR_WORKAROUND + data_d[get_doffset() + dst_idx] = D_TYPE(is_src0 ? data_a[get_aoffset() + src0_idx] : data_b[get_boffset() + src1_idx]); +#else + if (is_src0) { + data_d[get_doffset() + dst_idx] = data_a[get_aoffset() + src0_idx]; + } else { + data_d[get_doffset() + dst_idx] = data_b[get_boffset() + src1_idx]; + } +#endif +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp new file mode 100644 index 000000000..dd828c232 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp @@ -0,0 +1,42 @@ +#version 450 + +#include "types.comp" +#include "generic_unary_head.comp" + +#extension GL_EXT_control_flow_attributes : require + +const uint num_threads = 128; + +layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; + +void main() { + uint idx = get_idx(); + + // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation + const uint num_iter = 4; + + // fast path for when all four iterations are in-bounds + if (idx + (num_iter-1)*num_threads < p.ne) { + [[unroll]] for (uint i = 0; i < num_iter; ++i) { +#ifndef OPTIMIZATION_ERROR_WORKAROUND + data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]); +#else + data_d[get_doffset() + idx] = data_a[get_aoffset() + idx]; +#endif + idx += num_threads; + } + } else { + [[unroll]] for (uint i = 0; i < num_iter; ++i) { + if (idx >= p.ne) { + continue; + } + +#ifndef OPTIMIZATION_ERROR_WORKAROUND + data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]); +#else + data_d[get_doffset() + idx] = data_a[get_aoffset() + idx]; +#endif + idx += num_threads; + } + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp new file mode 100644 index 000000000..29c906494 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp @@ -0,0 +1,20 @@ +#version 450 + +#include "types.comp" +#include "generic_unary_head.comp" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint idx = get_idx(); + + if (idx >= p.ne) { + return; + } + +#ifndef OPTIMIZATION_ERROR_WORKAROUND + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(data_a[get_aoffset() + src0_idx(idx)]); +#else + data_d[get_doffset() + dst_idx(idx)] = data_a[get_aoffset() + src0_idx(idx)]; +#endif +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp new file mode 100644 index 000000000..0b8d02f58 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp @@ -0,0 +1,17 @@ +#version 450 + +#include "types.comp" +#include "generic_unary_head.comp" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint idx = get_idx(); + + if (idx >= p.ne) { + return; + } + + const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]); + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(cos(val)); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp new file mode 100644 index 000000000..a4d3fca55 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp @@ -0,0 +1,20 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {float data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + const uint i = gl_GlobalInvocationID.x * 16; + + if (i >= p.nel) { + return; + } + + [[unroll]] for (uint l = 0; l < 16; l++) { + data_b[i + l] = D_TYPE(data_a[i + l]); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp new file mode 100644 index 000000000..91bb8f8db --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp @@ -0,0 +1,118 @@ +#if !defined(DATA_A_F32) && !defined(DATA_A_F16) +#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require +#endif + +#include "types.comp" + +#if defined(A_TYPE_PACKED16) +layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];}; +#endif +#if defined(A_TYPE_PACKED32) +layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];}; +#endif + +#if defined(DATA_A_F32) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]); +} +#endif + +#if defined(DATA_A_F16) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]); +} +#endif + +#if defined(DATA_A_Q4_0) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint vui = uint(data_a[a_offset + ib].qs[iqs]); + return (vec2(vui & 0xF, vui >> 4) - 8.0f); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]); + return (vec4(vui & 0xF, (vui >> 4) & 0xF, (vui >> 8) & 0xF, vui >> 12) - 8.0f); +} +#endif + +#if defined(DATA_A_Q4_1) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint vui = uint(data_a[a_offset + ib].qs[iqs]); + return vec2(vui & 0xF, vui >> 4); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]); + return vec4(vui & 0xF, (vui >> 4) & 0xF, (vui >> 8) & 0xF, vui >> 12); +} +#endif + +#if defined(DATA_A_Q5_0) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint uint_qh = uint(data_a[a_offset + ib].qh[1]) << 16 | data_a[a_offset + ib].qh[0]; + const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); + const uint vui = uint(data_a[a_offset + ib].qs[iqs]); + return (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint uint_qh = uint(data_a_packed16[a_offset + ib].qh[1]) << 16 | data_a_packed16[a_offset + ib].qh[0]; + const ivec2 qh0 = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); + const ivec2 qh1 = ivec2(((uint_qh >> (iqs + 1)) << 4) & 0x10, (uint_qh >> (iqs + 13)) & 0x10); + const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]); + return (vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) - 16.0f); +} +#endif + +#if defined(DATA_A_Q5_1) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint uint_qh = data_a[a_offset + ib].qh; + const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); + const uint vui = uint(data_a[a_offset + ib].qs[iqs]); + return vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint uint_qh = data_a_packed16[a_offset + ib].qh; + const ivec2 qh0 = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); + const ivec2 qh1 = ivec2(((uint_qh >> (iqs + 1)) << 4) & 0x10, (uint_qh >> (iqs + 13)) & 0x10); + const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]); + return vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y); +} +#endif + +#if defined(DATA_A_Q8_0) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + return vec2(int(data_a[a_offset + ib].qs[iqs]), int(data_a[a_offset + ib].qs[iqs + 1])); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + uint32_t v0 = data_a_packed16[a_offset + ib].qs[iqs/2]; + uint32_t v1 = data_a_packed16[a_offset + ib].qs[iqs/2 + 1]; + return vec4(int8_t(v0 & 0xFF), int8_t(v0 >> 8), int8_t(v1 & 0xFF), int8_t(v1 >> 8)); +} +#endif + +#if defined(DATA_A_IQ4_NL) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint vui = uint(data_a[a_offset + ib].qs[iqs]); + return vec2(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[vui >> 4]); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]); + return vec4(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[(vui >> 4) & 0xF], kvalues_iq4nl[(vui >> 8) & 0xF], kvalues_iq4nl[vui >> 12]); +} +#endif + +#if defined(DATA_A_F32) || defined(DATA_A_F16) +vec2 get_dm(uint ib, uint a_offset) { + return vec2(0, 0); +} +#endif + +#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ4_NL) +vec2 get_dm(uint ib, uint a_offset) { + return vec2(float(data_a[a_offset + ib].d), 0); +} +#endif + +#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1) +vec2 get_dm(uint ib, uint a_offset) { + return vec2(float(data_a[a_offset + ib].d), float(data_a[a_offset + ib].m)); +} +#endif diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp new file mode 100644 index 000000000..94b78598e --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp @@ -0,0 +1,325 @@ + +#include "types.comp" + +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ4_0 { + block_q4_0_packed16 block; +}; + +float16_t dequantFuncQ4_0(const in decodeBufQ4_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + const uint shift = (idx & 0x10) >> 2; + uint32_t qs = uint32_t(bl.block.qs[(idx & 0xE) >> 1]); + qs >>= shift; + qs &= 0x0F0F; + qs = unpack8(qs)[idx & 1]; + float16_t ret = (float16_t(qs) - float16_t(8)) * d; + return ret; +} + +layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ4_1 { + block_q4_1 block; +}; + +float16_t dequantFuncQ4_1(const in decodeBufQ4_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const float16_t m = bl.block.m; + const uint idx = coordInBlock[1]; + const uint iqs = idx & 0xF; + const uint shift = (idx & 0x10) >> 2; + uint32_t qs = bl.block.qs[iqs]; + qs >>= shift; + qs &= 0xF; + float16_t ret = float16_t(qs) * d + m; + return ret; +} + +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ5_0 { + block_q5_0 block; +}; + +float16_t dequantFuncQ5_0(const in decodeBufQ5_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + const uint iqs = idx & 0xF; + + const uint uint_qh = uint(bl.block.qh[1]) << 16 | bl.block.qh[0]; + const uint qh = ((uint_qh >> idx) << 4) & 0x10; + + const uint shift = (idx & 0x10) >> 2; + uint32_t qs = bl.block.qs[iqs]; + qs >>= shift; + qs &= 0xF; + + float16_t ret = (float16_t(qs | qh) - float16_t(16)) * d; + return ret; +} + +layout(buffer_reference, std430, buffer_reference_align = 8) buffer decodeBufQ5_1 { + block_q5_1 block; +}; + +float16_t dequantFuncQ5_1(const in decodeBufQ5_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const float16_t m = bl.block.m; + const uint idx = coordInBlock[1]; + const uint iqs = idx & 0xF; + + const uint uint_qh = bl.block.qh; + const uint qh = ((uint_qh >> idx) << 4) & 0x10; + + const uint shift = (idx & 0x10) >> 2; + uint32_t qs = bl.block.qs[iqs]; + qs >>= shift; + qs &= 0xF; + + float16_t ret = float16_t(qs | qh) * d + m; + return ret; +} + +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ8_0 { + block_q8_0_packed16 block; +}; + +float16_t dequantFuncQ8_0(const in decodeBufQ8_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + const uint iqs = idx; + + // Load 16b and select the byte for this element + int32_t qs = unpack8(int32_t(bl.block.qs[(iqs & 0x1E) >> 1]))[iqs & 1]; + float16_t ret = float16_t(qs) * d; + return ret; +} + +layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ2_K { + block_q2_K block; +}; + +float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const f16vec2 d = bl.block.d; + const uint idx = coordInBlock[1]; + const uint iqs = idx; + + const uint qsi = (iqs / 128) * 32 + (iqs % 32); // 0..31 + const uint scalesi = iqs / 16; // 0..15 + const uint qsshift = ((iqs % 128) / 32) * 2; // 0,2,4,6 + + uint32_t qs = bl.block.qs[qsi]; + const uint scales = bl.block.scales[scalesi]; + float16_t ret = d.x * float16_t(scales & 0xF) * float16_t((qs >> qsshift) & 3) - d.y * float16_t(scales >> 4); + return ret; +} + +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ3_K { + block_q3_K block; +}; + +float16_t dequantFuncQ3_K(const in decodeBufQ3_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const uint idx = coordInBlock[1]; + const uint iqs = idx; + + const uint n = iqs / 128; // 0,1 + const uint qsi = n * 32 + (iqs % 32); // 0..63 + const uint hmi = (iqs % 32); // 0..31 + const uint j = (iqs % 128) / 8; // 0..15 + const uint is = iqs / 16; // 0..15 + const uint halfsplit = ((iqs % 128) / 32); // 0,1,2,3 + const uint qsshift = halfsplit * 2; // 0,2,4,6 + const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128 + + uint32_t scaleidx0 = (is < 8) ? is : (is-8); + uint32_t scaleidx0shift = (is < 8) ? 0 : 4; + uint32_t scaleidx1 = is + 8 - (is/4)*4; + uint32_t scaleidx1shift = (is/4)*2; + + const int8_t us = int8_t(((bl.block.scales[scaleidx0] >> scaleidx0shift) & 0xF) | (((bl.block.scales[scaleidx1] >> scaleidx1shift) & 3) << 4)); + + const float16_t dl = bl.block.d * float16_t(us - 32); + + float16_t ret = dl * float16_t(int8_t((bl.block.qs[qsi ] >> qsshift) & 3) - (((bl.block.hmask[hmi ] & m) != 0) ? 0 : 4)); + + return ret; +} + +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K { + block_q4_K block; +}; + +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K_packed16 { + block_q4_K_packed16 block; +}; + +float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufQ4_K_packed16 bl16 = decodeBufQ4_K_packed16(bl); + const uint idx = coordInBlock[1]; + + const uint b = (idx & 0x20) >> 5; // 0,1 + const uint is = (idx & 0xE0) >> 5; // 0..7 + + const f16vec2 loadd = bl.block.d; + + uint32_t sc; + uint32_t mbyte; + + uint32_t scidx0 = (is < 4) ? is : (is + 4); + uint32_t scidx1 = (is < 4) ? is : (is - 4); + uint32_t scidxmask1 = (is < 4) ? 0x30 : 0xC0; + uint32_t scidxshift1 = (is < 4) ? 0 : 2; + uint32_t mbidx0 = is + 4; + uint32_t mbidx1 = (is < 4) ? is + 4 : is; + uint32_t mbidxmask0 = (is < 4) ? 0xF : 0xF0; + uint32_t mbidxshift0 = (is < 4) ? 0 : 4; + uint32_t mbidxmask1 = (is < 4) ? 0x30 : 0xC0; + uint32_t mbidxshift1 = (is < 4) ? 0 : 2; + + sc = uint8_t((bl.block.scales[scidx0] & 0xF) | ((bl.block.scales[scidx1] & scidxmask1) >> scidxshift1)); + mbyte = uint8_t(((bl.block.scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((bl.block.scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + + const float16_t d = loadd.x * float16_t(sc); + const float16_t m = loadd.y * float16_t(mbyte); + + uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]); + qs = (qs >> (b * 4)) & 0x0F0F; + qs = unpack8(qs)[idx & 1]; + + float16_t ret = d * float16_t(qs) - m; + + return ret; +} + +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K { + block_q5_K block; +}; + +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K_packed16 { + block_q5_K_packed16 block; +}; + +float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufQ5_K_packed16 bl16 = decodeBufQ5_K_packed16(bl); + const uint idx = coordInBlock[1]; + + const uint b = (idx & 0x20) >> 5; // 0,1 + const uint is = (idx & 0xE0) >> 5; // 0..7 + + const uint32_t hm = 0x0101 << is; + + const f16vec2 loadd = bl.block.d; + + uint32_t sc; + uint32_t mbyte; + + uint32_t scidx0 = (is < 4) ? is : (is + 4); + uint32_t scidx1 = (is < 4) ? is : (is - 4); + uint32_t scidxmask1 = (is < 4) ? 0x30 : 0xC0; + uint32_t scidxshift1 = (is < 4) ? 0 : 2; + uint32_t mbidx0 = is + 4; + uint32_t mbidx1 = (is < 4) ? is + 4 : is; + uint32_t mbidxmask0 = (is < 4) ? 0xF : 0xF0; + uint32_t mbidxshift0 = (is < 4) ? 0 : 4; + uint32_t mbidxmask1 = (is < 4) ? 0x30 : 0xC0; + uint32_t mbidxshift1 = (is < 4) ? 0 : 2; + + sc = uint8_t((bl.block.scales[scidx0] & 0xF) | ((bl.block.scales[scidx1] & scidxmask1) >> scidxshift1)); + mbyte = uint8_t(((bl.block.scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((bl.block.scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + + const float16_t d = loadd.x * float16_t(sc); + const float16_t m = loadd.y * float16_t(mbyte); + + uint qh = uint32_t(bl16.block.qh[(idx & 0x1E) >> 1]); + qh = qh & hm; + qh = unpack8(qh)[idx & 1]; + + uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]); + qs = (qs >> (b * 4)) & 0x0F0F; + qs = unpack8(qs)[idx & 1]; + + float16_t ret = d * (float16_t(qs) + (qh != 0 ? float16_t(16) : float16_t(0))) - m; + + return ret; +} + +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ6_K { + block_q6_K block; +}; + +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ6_K_packed16 { + block_q6_K_packed16 block; +}; + +float16_t dequantFuncQ6_K(const in decodeBufQ6_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufQ6_K_packed16 bl16 = decodeBufQ6_K_packed16(bl); + const uint idx = coordInBlock[1]; + + const uint b = (idx & 0x40) >> 6; // 0,1 + const uint qhshift = (idx & 0x60) >> 4; // 0,2,4,6 + const uint is = (idx & 0xF0) >> 4; // 0..15 + + const float16_t dscale = bl.block.d * float16_t(bl.block.scales[is]); + + uint ql = uint32_t(bl16.block.ql[((idx & 0x80) >> 2) + ((idx & 0x3E) >> 1)]); + ql = (ql >> (b * 4)) & 0x0F0F; + + uint qh = uint32_t(bl16.block.qh[((idx & 0x80) >> 3) + ((idx & 0x1E) >> 1)]); + qh = ((qh >> qhshift) & 0x0303) << 4; + + int q = unpack8(ql | qh)[idx & 1]; + + float16_t ret = dscale * float16_t(q - 32); + + return ret; +} + +#if defined(DATA_A_IQ4_NL) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4_NL { + block_iq4_nl block; +}; + +float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + const uint iqs = idx & 0xF; + const uint shift = (idx & 0x10) >> 2; + uint32_t qs = bl.block.qs[iqs]; + qs >>= shift; + qs &= 0xF; + float16_t ret = float16_t(kvalues_iq4nl[qs]) * d; + return ret; +} +#endif + +#if defined(DATA_A_Q4_0) +#define dequantFuncA dequantFuncQ4_0 +#elif defined(DATA_A_Q4_1) +#define dequantFuncA dequantFuncQ4_1 +#elif defined(DATA_A_Q5_0) +#define dequantFuncA dequantFuncQ5_0 +#elif defined(DATA_A_Q5_1) +#define dequantFuncA dequantFuncQ5_1 +#elif defined(DATA_A_Q8_0) +#define dequantFuncA dequantFuncQ8_0 +#elif defined(DATA_A_Q2_K) +#define dequantFuncA dequantFuncQ2_K +#elif defined(DATA_A_Q3_K) +#define dequantFuncA dequantFuncQ3_K +#elif defined(DATA_A_Q4_K) +#define dequantFuncA dequantFuncQ4_K +#elif defined(DATA_A_Q5_K) +#define dequantFuncA dequantFuncQ5_K +#elif defined(DATA_A_Q6_K) +#define dequantFuncA dequantFuncQ6_K +#elif defined(DATA_A_IQ4_NL) +#define dequantFuncA dequantFuncIQ4_NL +#endif diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.comp new file mode 100644 index 000000000..8d806435b --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.comp @@ -0,0 +1,13 @@ +#extension GL_EXT_control_flow_attributes : require +#extension GL_EXT_shader_16bit_storage : require + +layout (push_constant) uniform parameter +{ + uint M; + uint K; + uint stride_a; + uint stride_b; + uint nel; +} p; + +#include "types.comp" diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp new file mode 100644 index 000000000..8de14fc03 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp @@ -0,0 +1,32 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_iq4_nl data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; + + init_iq4nl_shmem(); + + const uint tid = gl_LocalInvocationID.x % 64; + const uint il = tid/32; + const uint ir = tid%32; + const uint ib = 32*i + ir; + if (ib >= p.nel / 32) { + return; + } + + const uint q_idx = 8*il; + const uint b_idx = 1024*i + 32*ir + q_idx; + + const float d = float(data_a[ib].d); + + [[unroll]] for (uint l = 0; l < 8; ++l) { + data_b[b_idx + l + 0] = D_TYPE(d * kvalues_iq4nl[data_a[ib].qs[q_idx + l] & 0xF]); + data_b[b_idx + l + 16] = D_TYPE(d * kvalues_iq4nl[data_a[ib].qs[q_idx + l] >> 4]); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp new file mode 100644 index 000000000..157154af3 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp @@ -0,0 +1,34 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { + const uint i = gl_WorkGroupID.x * 256 + wgy; + if (i >= p.M * p.K / QUANT_K) { + return; + } + + const uint tid = gl_LocalInvocationID.x; + const uint ip = tid / 32; + const uint il = tid - 32 * ip; + const uint is = 8 * ip + il / 16; + + const uint y_idx = i * QUANT_K + 128 * ip + il; + + const uint ql_idx = 32 * ip + il; + const uint8_t qs = data_a[i].qs[32 * ip + il]; + + FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x); + FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y); + data_b[y_idx + 0] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+0] & 0xF) * ((qs >> 0) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+0] >> 4)); + data_b[y_idx + 32] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+2] & 0xF) * ((qs >> 2) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+2] >> 4)); + data_b[y_idx + 64] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+4] & 0xF) * ((qs >> 4) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+4] >> 4)); + data_b[y_idx + 96] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+6] & 0xF) * ((qs >> 6) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+6] >> 4)); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp new file mode 100644 index 000000000..c17dd0d99 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp @@ -0,0 +1,42 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { + const uint i = uint(gl_WorkGroupID.x * 256 + wgy); + if (i >= p.M * p.K / QUANT_K) { + return; + } + + const uint r = gl_LocalInvocationID.x / 4; + const uint tid = r / 2; + const uint is0 = r % 2; + const uint l0 = 16 * is0 + 4 * (gl_LocalInvocationID.x % 4); + const uint n = tid / 4; + const uint j = tid - 4*n; + + const uint8_t m = uint8_t(1 << (4*n + j)); + const uint is = 8*n + 2*j + is0; + const uint shift = 2*j; + + const int8_t us = int8_t(is < 4 ? (data_a[i].scales[is-0] & 0xF) | (((data_a[i].scales[is+8] >> 0) & 3) << 4) : + is < 8 ? (data_a[i].scales[is-0] & 0xF) | (((data_a[i].scales[is+4] >> 2) & 3) << 4) : + is < 12 ? (data_a[i].scales[is-8] >> 4) | (((data_a[i].scales[is+0] >> 4) & 3) << 4) : + (data_a[i].scales[is-8] >> 4) | (((data_a[i].scales[is-4] >> 6) & 3) << 4)); + const FLOAT_TYPE d_all = FLOAT_TYPE(data_a[i].d); + const FLOAT_TYPE dl = d_all * FLOAT_TYPE(us - 32); + + const uint y_idx = i * QUANT_K + 128 * n + 32 * j; + const uint qs_idx = 32*n; + + for (uint l = l0; l < l0 + 4; ++l) { + data_b[y_idx + l] = D_TYPE(dl * FLOAT_TYPE(int8_t((data_a[i].qs[qs_idx + l] >> shift) & 3) - (((data_a[i].hmask[l] & m) != 0) ? 0 : 4))); + } + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp new file mode 100644 index 000000000..408185327 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp @@ -0,0 +1,30 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_q4_0 data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; + + const uint tid = gl_LocalInvocationID.x % 64; + const uint il = tid/32; + const uint ir = tid%32; + const uint ib = 32*i + ir; + if (ib >= p.nel / 32) { + return; + } + + const uint q_idx = 8*il; + const uint b_idx = 1024*i + 32*ir + q_idx; + + const float d = float(data_a[ib].d); + + [[unroll]] for (uint l = 0; l < 8; ++l) { + data_b[b_idx + l + 0] = D_TYPE(d * ((data_a[ib].qs[q_idx + l] & 0xF) - 8.0f)); + data_b[b_idx + l + 16] = D_TYPE(d * ((data_a[ib].qs[q_idx + l] >> 4) - 8.0f)); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp new file mode 100644 index 000000000..2f27eee68 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp @@ -0,0 +1,32 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_q4_1 data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; + + const uint tid = gl_LocalInvocationID.x % 64; + const uint il = tid/32; + const uint ir = tid%32; + const uint ib = 32*i + ir; + if (ib >= p.nel / 32) { + return; + } + + const uint b_idx = 1024*i + 32*ir + 8*il; + + const float d = float(data_a[ib].d); + const float m = float(data_a[ib].m); + + const uint q_idx = 8*il; + + [[unroll]] for (uint l = 0; l < 8; ++l) { + data_b[b_idx + l + 0] = D_TYPE(d * (data_a[ib].qs[q_idx + l] & 0xF) + m); + data_b[b_idx + l + 16] = D_TYPE(d * (data_a[ib].qs[q_idx + l] >> 4) + m); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp new file mode 100644 index 000000000..987f113a3 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp @@ -0,0 +1,68 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { + const uint ib = gl_WorkGroupID.x * 256 + wgy; + if (ib >= p.M * p.K / QUANT_K) { + return; + } + + const uint tid = gl_LocalInvocationID.x; + const uint il = tid / 8; + const uint ir = tid % 8; + const uint is = 2 * il; + const uint n = 4; + + const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].d.x); + const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].d.y); + + const uint y_idx = ib * QUANT_K + 64 * il + n * ir; + const uint qs_idx = 32*il + n * ir; + + uint scidx0 = (is < 4) ? is : (is + 4); + uint scidx1 = (is < 4) ? is : (is - 4); + uint scidxmask1 = (is < 4) ? 0x30 : 0xC0; + uint scidxshift1 = (is < 4) ? 0 : 2; + uint mbidx0 = is + 4; + uint mbidx1 = (is < 4) ? is + 4 : is; + uint mbidxmask0 = (is < 4) ? 0xF : 0xF0; + uint mbidxshift0 = (is < 4) ? 0 : 4; + uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0; + uint mbidxshift1 = (is < 4) ? 0 : 2; + + uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); + uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + + const FLOAT_TYPE d1 = dall * sc; + const FLOAT_TYPE m1 = dmin * mbyte; + + scidx0 = (is < 4) ? is + 1 : (is + 5); + scidx1 = (is < 4) ? is + 1 : (is - 3); + scidxmask1 = (is < 4) ? 0x30 : 0xC0; + scidxshift1 = (is < 4) ? 0 : 2; + mbidx0 = is + 5; + mbidx1 = (is < 4) ? is + 5 : is + 1; + mbidxmask0 = (is < 4) ? 0xF : 0xF0; + mbidxshift0 = (is < 4) ? 0 : 4; + mbidxmask1 = (is < 4) ? 0x30 : 0xC0; + mbidxshift1 = (is < 4) ? 0 : 2; + + sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); + mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + + const FLOAT_TYPE d2 = dall * sc; + const FLOAT_TYPE m2 = dmin * mbyte; + + [[unroll]] for (uint l = 0; l < n; ++l) { + data_b[y_idx + l ] = D_TYPE(d1 * FLOAT_TYPE(data_a[ib].qs[qs_idx + l] & 0xF) - m1); + data_b[y_idx + l + 32] = D_TYPE(d2 * FLOAT_TYPE(data_a[ib].qs[qs_idx + l] >> 4) - m2); + } + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp new file mode 100644 index 000000000..b20b80529 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp @@ -0,0 +1,34 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_q5_0 data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; + + const uint tid = gl_LocalInvocationID.x % 64; + const uint il = tid/32; + const uint ir = tid%32; + const uint ib = 32*i + ir; + if (ib >= p.nel / 32) { + return; + } + + const uint b_idx = 1024*i + 32*ir + 8*il; + + const float d = float(data_a[ib].d); + const uint qh = uint(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]; + + const uint q_idx = 8*il; + + [[unroll]] for (uint l = 0; l < 8; ++l) { + const uint iqs = q_idx + l; + const uint vui = uint(data_a[ib].qs[iqs]); + data_b[b_idx + l + 0] = D_TYPE(d * (((vui & 0xF) | (((qh >> iqs) << 4) & 0x10)) - 16.0f)); + data_b[b_idx + l + 16] = D_TYPE(d * (((vui >> 4) | ((qh >> (iqs + 12)) & 0x10)) - 16.0f)); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp new file mode 100644 index 000000000..dc59fe3b7 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp @@ -0,0 +1,35 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_q5_1 data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; + + const uint tid = gl_LocalInvocationID.x % 64; + const uint il = tid/32; + const uint ir = tid%32; + const uint ib = 32*i + ir; + if (ib >= p.nel / 32) { + return; + } + + const uint b_idx = 1024*i + 32*ir + 8*il; + + const float d = float(data_a[ib].d); + const float m = float(data_a[ib].m); + const uint qh = data_a[ib].qh; + + const uint q_idx = 8*il; + + [[unroll]] for (uint l = 0; l < 8; ++l) { + const uint iqs = q_idx + l; + const uint vui = uint(data_a[ib].qs[iqs]); + data_b[b_idx + l + 0] = D_TYPE(d * (((vui & 0xF) | (((qh >> iqs) << 4) & 0x10))) + m); + data_b[b_idx + l + 16] = D_TYPE(d * (((vui >> 4) | ((qh >> (iqs + 12)) & 0x10))) + m); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp new file mode 100644 index 000000000..6db5403b6 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp @@ -0,0 +1,70 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { + const uint ib = gl_WorkGroupID.x * 256 + wgy; + if (ib >= p.M * p.K / QUANT_K) { + return; + } + + const uint tid = gl_LocalInvocationID.x; + const uint il = tid / 16; + const uint ir = tid % 16; + const uint is = 2 * il; + + const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].d.x); + const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].d.y); + + const uint y_idx = ib * QUANT_K + 64 * il + 2 * ir; + const uint qs_idx = 32*il + 2 * ir; + const uint qh_idx = 2 * ir; + + uint scidx0 = (is < 4) ? is : (is + 4); + uint scidx1 = (is < 4) ? is : (is - 4); + uint scidxmask1 = (is < 4) ? 0x30 : 0xC0; + uint scidxshift1 = (is < 4) ? 0 : 2; + uint mbidx0 = is + 4; + uint mbidx1 = (is < 4) ? is + 4 : is; + uint mbidxmask0 = (is < 4) ? 0xF : 0xF0; + uint mbidxshift0 = (is < 4) ? 0 : 4; + uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0; + uint mbidxshift1 = (is < 4) ? 0 : 2; + + uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); + uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + + const FLOAT_TYPE d1 = dall * sc; + const FLOAT_TYPE m1 = dmin * mbyte; + + scidx0 = (is < 4) ? is + 1 : (is + 5); + scidx1 = (is < 4) ? is + 1 : (is - 3); + scidxmask1 = (is < 4) ? 0x30 : 0xC0; + scidxshift1 = (is < 4) ? 0 : 2; + mbidx0 = is + 5; + mbidx1 = (is < 4) ? is + 5 : is + 1; + mbidxmask0 = (is < 4) ? 0xF : 0xF0; + mbidxshift0 = (is < 4) ? 0 : 4; + mbidxmask1 = (is < 4) ? 0x30 : 0xC0; + mbidxshift1 = (is < 4) ? 0 : 2; + + sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); + mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + + const FLOAT_TYPE d2 = dall * sc; + const FLOAT_TYPE m2 = dmin * mbyte; + + const uint8_t hm1 = uint8_t(1 << (2 * il )); + const uint8_t hm2 = uint8_t(1 << (2 * il + 1)); + data_b[y_idx ] = D_TYPE(d1 * FLOAT_TYPE((data_a[ib].qs[qs_idx ] & 0xF) + (((data_a[ib].qh[qh_idx ] & hm1) != 0) ? 16 : 0)) - m1); + data_b[y_idx + 1] = D_TYPE(d1 * FLOAT_TYPE((data_a[ib].qs[qs_idx + 1] & 0xF) + (((data_a[ib].qh[qh_idx + 1] & hm1) != 0) ? 16 : 0)) - m1); + data_b[y_idx + 32] = D_TYPE(d2 * FLOAT_TYPE((data_a[ib].qs[qs_idx ] >> 4) + (((data_a[ib].qh[qh_idx ] & hm2) != 0) ? 16 : 0)) - m2); + data_b[y_idx + 33] = D_TYPE(d2 * FLOAT_TYPE((data_a[ib].qs[qs_idx + 1] >> 4) + (((data_a[ib].qh[qh_idx + 1] & hm2) != 0) ? 16 : 0)) - m2); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp new file mode 100644 index 000000000..0b9131755 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp @@ -0,0 +1,33 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { + const uint i = gl_WorkGroupID.x * 256 + wgy; + if (i >= p.M * p.K / QUANT_K) { + return; + } + const uint tid = gl_LocalInvocationID.x; + const uint ip = tid / 32; + const uint il = tid - 32 * ip; + const uint is = 8 * ip + il / 16; + + const uint y_idx = i * QUANT_K + 128 * ip + il; + + const uint ql_idx = 64 * ip + il; + const uint8_t qh = data_a[i].qh[32 * ip + il]; + + const FLOAT_TYPE d = FLOAT_TYPE(data_a[i].d); + + data_b[y_idx + 0] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 0] * (int8_t((data_a[i].ql[ql_idx + 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32))); + data_b[y_idx + 32] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 2] * (int8_t((data_a[i].ql[ql_idx + 32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32))); + data_b[y_idx + 64] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 4] * (int8_t((data_a[i].ql[ql_idx + 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32))); + data_b[y_idx + 96] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 6] * (int8_t((data_a[i].ql[ql_idx + 32] >> 4) | (((qh >> 6) & 3) << 4)) - 32))); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp new file mode 100644 index 000000000..bd1344a88 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp @@ -0,0 +1,31 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_q8_0 data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; + + const uint tid = gl_LocalInvocationID.x % 64; + const uint il = tid/32; + const uint ir = tid%32; + const uint ib = 32*i + ir; + if (ib >= p.nel / 32) { + return; + } + + const uint b_idx = 1024*i + 32*ir + 16*il; + + const float d = float(data_a[ib].d); + + const uint q_idx = 16*il; + + [[unroll]] for (uint l = 0; l < 16; l += 2) { + data_b[b_idx + l ] = D_TYPE(d * data_a[ib].qs[q_idx + l ]); + data_b[b_idx + l + 1] = D_TYPE(d * data_a[ib].qs[q_idx + l + 1]); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp new file mode 100644 index 000000000..4e68742b5 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp @@ -0,0 +1,34 @@ +#version 450 + +#extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_control_flow_attributes : enable + +layout (push_constant) uniform parameter +{ + uint ncols; + uint rows_per_channel; + uint n_past; +} p; + +#include "types.comp" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint col = gl_GlobalInvocationID.y; + const uint row = gl_GlobalInvocationID.x; + + if (col >= p.ncols) { + return; + } + + const uint i = row*p.ncols + col; + if (col > p.n_past + row % p.rows_per_channel) { + data_d[i] = D_TYPE(uintBitsToFloat(0xFF800000)); + } else { + data_d[i] = D_TYPE(data_a[i]); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/div.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/div.comp new file mode 100644 index 000000000..9fb69c6c1 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/div.comp @@ -0,0 +1,27 @@ +#version 450 + +#include "types.comp" +#include "generic_binary_head.comp" + +const uint num_threads = 256; + +layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; + +void main() { + uint idx = get_idx(); + + // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation + const uint num_iter = 2; + + [[unroll]] for (uint i = 0; i < num_iter; ++i) { + if (idx >= p.ne) { + continue; + } + uint i00, i01, i02, i03; + get_indices(idx, i00, i01, i02, i03); + + data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) / FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)])); + + idx += num_threads; + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp new file mode 100644 index 000000000..c5be8131b --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -0,0 +1,289 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : require + +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require + +#extension GL_KHR_memory_scope_semantics : enable +#extension GL_KHR_cooperative_matrix : enable +#extension GL_NV_cooperative_matrix2 : enable +#extension GL_EXT_buffer_reference : enable +#extension GL_KHR_shader_subgroup_ballot : enable +#extension GL_KHR_shader_subgroup_vote : enable +#extension GL_EXT_null_initializer : enable + +#include "types.comp" +#include "dequant_funcs_cm2.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (constant_id = 1) const uint32_t Br = 32; +layout (constant_id = 2) const uint32_t Bc = 32; +layout (constant_id = 3) const uint32_t D = 32; +layout (constant_id = 4) const uint32_t Clamp = gl_CooperativeMatrixClampModeConstantNV; + +layout (push_constant) uniform parameter { + uint32_t N; + uint32_t KV; + + uint32_t ne1; + uint32_t ne2; + uint32_t ne3; + + uint32_t neq2; + uint32_t neq3; + uint32_t nek2; + uint32_t nek3; + uint32_t nev2; + uint32_t nev3; + uint32_t nem1; + + uint32_t nb02; + uint32_t nb03; + uint32_t nb12; + uint32_t nb13; + uint32_t nb22; + uint32_t nb23; + uint32_t nb31; + + float scale; + float max_bias; + float logit_softcap; + + uint32_t mask; + uint32_t n_head_log2; + float m0; + float m1; +} p; + +layout (binding = 0) readonly buffer Q {uint8_t data_q[];}; +layout (binding = 1) readonly buffer K {uint8_t data_k[];}; +layout (binding = 2) readonly buffer V {uint8_t data_v[];}; +layout (binding = 3) readonly buffer M {uint8_t data_m[];}; +layout (binding = 4) writeonly buffer O {D_TYPE data_o[];}; + +#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) + +ACC_TYPE maxReduce(const in ACC_TYPE x, const in ACC_TYPE y) { + return max(x, y); +} + +ACC_TYPE smearReduce(const in ACC_TYPE x, const in ACC_TYPE y) { + return x; +} + +// Replace matrix elements >= numRows or numCols with 'replace' +ACC_TYPE replacePadding(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem, const in ACC_TYPE replace, const in uint32_t numRows, const in uint32_t numCols) { + if (row >= numRows || col >= numCols) { + return replace; + } + return elem; +} + +ACC_TYPE Exp(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem) +{ + return exp(elem); +} + +ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem0, const in ACC_TYPE elem1) +{ + return max(elem0, elem1); +} + +#if defined(BLOCK_SIZE) +#define DECODEFUNC , DEQUANTFUNC +#else +#define DECODEFUNC +#endif + +void main() { +#if defined(DATA_A_IQ4_NL) + init_iq4nl_shmem(); +#endif + + const uint32_t N = p.N; + const uint32_t KV = p.KV; + + const uint32_t Tr = CEIL_DIV(N, Br); + const uint32_t Tc = CEIL_DIV(KV, Bc); + + const uint32_t i = gl_WorkGroupID.x; + + const uint32_t iq2 = gl_WorkGroupID.y; + const uint32_t iq3 = gl_WorkGroupID.z; + + // broadcast factors + const uint32_t rk2 = p.neq2/p.nek2; + const uint32_t rk3 = p.neq3/p.nek3; + + const uint32_t rv2 = p.neq2/p.nev2; + const uint32_t rv3 = p.neq3/p.nev3; + + // k indices + const uint32_t ik3 = iq3 / rk3; + const uint32_t ik2 = iq2 / rk2; + + // v indices + const uint32_t iv3 = iq3 / rv3; + const uint32_t iv2 = iq2 / rv2; + + tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutQ = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); + tensorLayoutNV<2, Clamp> tensorLayoutK = createTensorLayoutNV(2, Clamp); + tensorLayoutNV<2, Clamp> tensorLayoutV = createTensorLayoutNV(2, Clamp); + + tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0); + +#if defined(BLOCK_SIZE) + tensorLayoutK = setTensorLayoutBlockSizeNV(tensorLayoutK, 1, BLOCK_SIZE); + tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, BLOCK_SIZE); +#endif + + tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N, D); + tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D); + tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D); + + coopmat Q; + coopmat Qf16; + + uint32_t q_offset = iq2*p.nb02+iq3*p.nb03; + coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, D)); + + Qf16 = coopmat(Q); + Qf16 *= float16_t(p.scale); + + coopmat O = coopmat(0); + + coopmat L, M; + + L = coopmat(0); + M = coopmat(-1.0/0.0); + + ACC_TYPE slope = ACC_TYPE(1.0); + + // ALiBi + if (p.max_bias > 0.0f) { + const uint32_t h = iq2; + + const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1); + const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1); + + slope = pow(base, ACC_TYPE(exph)); + } + + [[dont_unroll]] + for (uint32_t j = 0; j < Tc; ++j) { + + coopmat S = coopmat(0); + + coopmat K_T; + + uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13; + coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, D), tensorViewTranspose DECODEFUNC); + S = coopMatMulAdd(Qf16, K_T, S); + + if (p.logit_softcap != 0.0f) { + [[unroll]] + for (int k = 0; k < S.length(); ++k) { + S[k] = ACC_TYPE(p.logit_softcap)*tanh(S[k]); + } + } + + if (p.mask != 0) { + tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); + tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV); + + coopmat mv; + + coopMatLoadTensorNV(mv, data_m, 0, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); + + S += slope*coopmat(mv); + } + + // Clear padding elements to -inf, so they don't contribute to rowmax + if (Clamp != 0 && + ((j + 1) * Bc > KV || + (i + 1) * Br > N)) { + + uint R = ((i + 1) * Br > N) ? (N % Br) : Br; + uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc; + + coopMatPerElementNV(S, S, replacePadding, ACC_TYPE(-1.0/0.0), R, C); + } + + coopmat rowmax, P, rowsum, eM; + + coopMatReduceNV(rowmax, S, gl_CooperativeMatrixReduceRowNV, maxReduce); + + coopmat Mold = M; + + // M = max(rowmax, Mold) + // P = e^(S - M) + // eM = e^(Mold - M) + coopMatPerElementNV(M, rowmax, Max, Mold); + coopMatPerElementNV(P, S - M, Exp); + coopMatPerElementNV(eM, Mold - M, Exp); + + // Clear padding elements to 0, so they don't contribute to rowsum + if (Clamp != 0 && + ((j + 1) * Bc > KV || + (i + 1) * Br > N)) { + + uint R = ((i + 1) * Br > N) ? (N % Br) : Br; + uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc; + + coopMatPerElementNV(P, P, replacePadding, ACC_TYPE(0.0), R, C); + } + + coopmat P_A = coopmat(P); + + // compute rowsum by multiplying by matrix of all ones. + coopmat One = coopmat(1.0); + + rowsum = coopmat(0.0); + rowsum = coopMatMulAdd(P_A, One, rowsum); + + coopmat V; + uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23; + coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, D) DECODEFUNC); + + L = eM*L + rowsum; + + // This is the "diagonal" matrix in the paper, but since we do componentwise + // multiply rather than matrix multiply it has the diagonal element smeared + // across the row + coopmat eMdiag; + + // resize eM by using smear/reduce + coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce); + + O = eMdiag * O; + + O = coopMatMulAdd(P_A, V, O); + } + + coopmat Ldiag; + + // resize L by using smear/reduce + coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce); + + [[unroll]] + for (int k = 0; k < Ldiag.length(); ++k) { + Ldiag[k] = ACC_TYPE(1.0) / Ldiag[k]; + } + + O = Ldiag*O; + + tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV); + tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, D); + + // permute dimensions + tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2); + uint32_t o_offset = iq3*p.ne2*p.ne1; + + coopmat O_D = coopmat(O); + coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, 1, 0, D), tensorViewPermute); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp new file mode 100644 index 000000000..4cc7a68ca --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp @@ -0,0 +1,25 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const float GELU_COEF_A = 0.044715f; + const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float xi = float(data_a[i]); + const float val = SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi); + data_d[i] = D_TYPE(0.5f*xi*(2.0f - 2.0f / (exp(2 * val) + 1))); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp new file mode 100644 index 000000000..e6e6fcfd2 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp @@ -0,0 +1,23 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const float GELU_QUICK_COEF = -1.702f; + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float x = float(data_a[i]); + data_d[i] = D_TYPE(x * (1.0f / (1.0f + exp(GELU_QUICK_COEF * x)))); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp new file mode 100644 index 000000000..062e2a4cd --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp @@ -0,0 +1,64 @@ +#extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_control_flow_attributes : require + +layout (push_constant) uniform parameter +{ + uint ne; + uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03; + uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13; + uint ne20; uint ne21; uint ne22; uint ne23; uint nb20; uint nb21; uint nb22; uint nb23; + uint misalign_offsets; + float param1; float param2; int param3; +} p; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; +layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; + +// true if src0/src1 are the same shape and the indices can be reused without additional modulus +layout(constant_id = 0) const bool norepeat = false; + +uint get_idx() { + return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; +} + +uint get_aoffset() { return p.misalign_offsets >> 16; } +uint get_boffset() { return (p.misalign_offsets >> 8) & 0xFF; } +uint get_doffset() { return p.misalign_offsets & 0xFF; } + +// mod and div are expensive and coordinates/dimensions are often power of 2 or equal to 1 +uint fastmod(uint a, uint b) { + if ((b & (b-1)) == 0) { + return a & (b-1); + } + return a % b; +} + +uint fastdiv(uint a, uint b) { + return (a < b) ? 0 : (a / b); +} + +void get_indices(uint idx, out uint i00, out uint i01, out uint i02, out uint i03) { + i03 = fastdiv(idx, (p.ne02*p.ne01*p.ne00)); + const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00; + i02 = fastdiv((idx - i03_offset), (p.ne01*p.ne00)); + const uint i02_offset = i02*p.ne01*p.ne00; + i01 = (idx - i03_offset - i02_offset) / p.ne00; + i00 = idx - i03_offset - i02_offset - i01*p.ne00; +} + +uint src0_idx(uint i00, uint i01, uint i02, uint i03) { + return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00; +} + +uint src1_idx(uint i00, uint i01, uint i02, uint i03) { + if (norepeat) { + return i03*p.nb13 + i02*p.nb12 + i01*p.nb11 + i00*p.nb10; + } else { + return fastmod(i03, p.ne13)*p.nb13 + fastmod(i02, p.ne12)*p.nb12 + fastmod(i01, p.ne11)*p.nb11 + fastmod(i00, p.ne10)*p.nb10; + } +} + +uint dst_idx(uint i00, uint i01, uint i02, uint i03) { + return i03*p.nb23 + i02*p.nb22 + i01*p.nb21 + i00*p.nb20; +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.comp new file mode 100644 index 000000000..66e46ae67 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.comp @@ -0,0 +1,9 @@ +#extension GL_EXT_shader_16bit_storage : require + +layout (push_constant) uniform parameter +{ + uint KX; + uint KY; + float param1; + float param2; +} p; diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp new file mode 100644 index 000000000..68d1bc9f1 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp @@ -0,0 +1,56 @@ +#extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_control_flow_attributes : require + +layout (push_constant) uniform parameter +{ + uint ne; + uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03; + uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13; + uint misalign_offsets; + float param1; float param2; + + uint ne0_012mp; uint ne0_012L; + uint ne0_01mp; uint ne0_01L; + uint ne0_0mp; uint ne0_0L; + uint ne1_012mp; uint ne1_012L; + uint ne1_01mp; uint ne1_01L; + uint ne1_0mp; uint ne1_0L; +} p; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +uint get_idx() { + return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; +} + +uint get_aoffset() { return p.misalign_offsets >> 16; } +uint get_doffset() { return p.misalign_offsets & 0xFFFF; } + +// see init_fastdiv_values in ggml-vulkan.cpp +uint fastdiv(uint n, uint mp, uint L) { + uint msbs, lsbs; + // msbs = mulhi(n, mp) + umulExtended(n, mp, msbs, lsbs); + return (msbs + n) >> L; +} + +uint src0_idx(uint idx) { + const uint i03 = fastdiv(idx, p.ne0_012mp, p.ne0_012L); + const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00; + const uint i02 = fastdiv(idx - i03_offset, p.ne0_01mp, p.ne0_01L); + const uint i02_offset = i02*p.ne01*p.ne00; + const uint i01 = fastdiv(idx - i03_offset - i02_offset, p.ne0_0mp, p.ne0_0L); + const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00; + return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00; +} + +uint dst_idx(uint idx) { + const uint i13 = fastdiv(idx, p.ne1_012mp, p.ne1_012L); + const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10; + const uint i12 = fastdiv(idx - i13_offset, p.ne1_01mp, p.ne1_01L); + const uint i12_offset = i12*p.ne11*p.ne10; + const uint i11 = fastdiv(idx - i13_offset - i12_offset, p.ne1_0mp, p.ne1_0L); + const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10; + return i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + i10*p.nb10; +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp new file mode 100644 index 000000000..e877ed779 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp @@ -0,0 +1,28 @@ +#version 450 + +#include "types.comp" +#include "generic_binary_head.comp" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint i00 = gl_GlobalInvocationID.x; + const uint i10 = gl_GlobalInvocationID.y; + const uint i11 = (gl_GlobalInvocationID.z)/p.ne12; + const uint i12 = (gl_GlobalInvocationID.z)%p.ne12; + + if (i00 >= p.ne00) { + return; + } + + const uint i01 = data_b[get_boffset() + i10*p.nb10 + i11*p.nb11 + i12*p.nb12]; + + const uint a_offset = get_aoffset() + i01*p.nb01 + i11*p.nb02 + i12*p.nb03; + const uint d_offset = get_doffset() + i10*p.nb21 + i11*p.nb22 + i12*p.nb23; + +#ifndef OPTIMIZATION_ERROR_WORKAROUND + data_d[d_offset + i00] = D_TYPE(data_a[a_offset + i00]); +#else + data_d[d_offset + i00] = data_a[a_offset + i00]; +#endif +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp new file mode 100644 index 000000000..1426fde65 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp @@ -0,0 +1,39 @@ +#version 450 + +#include "types.comp" +#include "generic_binary_head.comp" +#include "dequant_funcs.comp" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint i00 = (gl_GlobalInvocationID.x)*2; + const uint i10 = gl_GlobalInvocationID.y; + const uint i11 = (gl_GlobalInvocationID.z)/p.ne12; + const uint i12 = (gl_GlobalInvocationID.z)%p.ne12; + +#if defined(DATA_A_IQ4_NL) + init_iq4nl_shmem(); +#endif + + if (i00 >= p.ne00) { + return; + } + + const uint i01 = data_b[i10*p.nb10 + i11*p.nb11 + i12*p.nb12]; + + const uint a_offset = i01*p.nb01 + i11*p.nb02 + i12*p.nb03; + const uint d_offset = i10*p.nb21 + i11*p.nb22 + i12*p.nb23; + + const uint ib = a_offset + i00/QUANT_K; // block index + const uint iqs = (i00%QUANT_K)/QUANT_R; // quant index + const uint iybs = i00 - i00%QUANT_K; // dst block start index + const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2; + + vec2 v = dequantize(ib, iqs, 0); + const vec2 dm = get_dm(ib, 0); + v = v * dm.x + dm.y; + + data_d[d_offset + iybs + iqs ] = D_TYPE(v.x); + data_d[d_offset + iybs + iqs + y_offset] = D_TYPE(v.y); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp new file mode 100644 index 000000000..b6a0d5645 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp @@ -0,0 +1,66 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable +#define BLOCK_SIZE 512 + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +shared float tmp[BLOCK_SIZE]; + +void main() { + const uint group_size = p.KX; + const float eps = p.param1; + + const uint tid = gl_LocalInvocationID.x; + const uint start = gl_WorkGroupID.x * group_size + tid; + const uint end = (gl_WorkGroupID.x + 1) * group_size; + + tmp[tid] = 0.0f; + + // Calculate mean + [[unroll]] for (uint col = start; col < end; col += BLOCK_SIZE) { + tmp[tid] += float(data_a[col]); + } + + // tmp up partial tmps and write back result + barrier(); + [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + tmp[tid] += tmp[tid + s]; + } + barrier(); + } + + const float mean = tmp[0] / group_size; + barrier(); + tmp[tid] = 0.0f; + + // Calculate variance + [[unroll]] for (uint col = start; col < end; col += BLOCK_SIZE) { + const float xi = float(data_a[col]) - mean; + data_d[col] = D_TYPE(xi); + tmp[tid] += xi * xi; + } + + // sum up partial sums and write back result + barrier(); + [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + tmp[tid] += tmp[tid + s]; + } + barrier(); + } + + const float variance = tmp[0] / group_size; + const float scale = inversesqrt(variance + eps); + + [[unroll]] for (uint col = start; col < end; col += BLOCK_SIZE) { + data_d[col] *= D_TYPE(scale); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp new file mode 100644 index 000000000..122b1e93f --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp @@ -0,0 +1,87 @@ +#version 450 + +#extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_spirv_intrinsics: enable +#extension GL_EXT_control_flow_attributes : require + +#if RTE16 +spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits +#endif + +layout (push_constant) uniform parameter +{ + uint batch_offset; uint offset_delta; + uint IC; + uint IW; uint IH; + uint OW; uint OH; + uint KW; uint KH; + uint pelements; + uint CHW; + int s0; int s1; + int p0; int p1; + int d0; int d1; +} p; + +#include "types.comp" + +layout(constant_id = 0) const uint BLOCK_SIZE = 32; + +const uint NUM_ITER = 512 / BLOCK_SIZE; + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint gidx = gl_GlobalInvocationID.x; + + const uint oh = gl_GlobalInvocationID.y; + const uint batch = gl_GlobalInvocationID.z / p.IC; + const uint ic = gl_GlobalInvocationID.z % p.IC; + + A_TYPE values[NUM_ITER]; + uint offset_dst[NUM_ITER]; + [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { + values[idx] = A_TYPE(0); + } + + [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { + + const uint i = gidx * NUM_ITER + idx; + + const uint ksize = p.OW * (p.KH > 1 ? p.KW : 1); + const uint kx = i / ksize; + const uint kd = kx * ksize; + const uint ky = (i - kd) / p.OW; + const uint ix = i % p.OW; + + const uint iiw = ix * p.s0 + kx * p.d0 - p.p0; + const uint iih = oh * p.s1 + ky * p.d1 - p.p1; + + offset_dst[idx] = + ((batch * p.OH + oh) * p.OW + ix) * p.CHW + + (ic * (p.KW * p.KH) + ky * p.KW + kx); + + if (i >= p.pelements) { + continue; + } + + if (iih < p.IH && iiw < p.IW) { + const uint offset_src = ic * p.offset_delta + batch * p.batch_offset; + values[idx] = data_a[offset_src + iih * p.IW + iiw]; + } + } + + [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { + + const uint i = gidx * NUM_ITER + idx; + + if (i >= p.pelements) { + continue; + } + + data_d[offset_dst[idx]] = D_TYPE(values[idx]); + } + +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp new file mode 100644 index 000000000..d90a99aea --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp @@ -0,0 +1,22 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float val = float(data_a[i]); + data_d[i] = D_TYPE(max(val, 0.0f) + min(val, 0.0f) * p.param1); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp new file mode 100644 index 000000000..43de19df8 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp @@ -0,0 +1,27 @@ +#version 450 + +#include "types.comp" +#include "generic_binary_head.comp" + +const uint num_threads = 256; + +layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; + +void main() { + uint idx = get_idx(); + + // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation + const uint num_iter = 2; + + [[unroll]] for (uint i = 0; i < num_iter; ++i) { + if (idx >= p.ne) { + continue; + } + uint i00, i01, i02, i03; + get_indices(idx, i00, i01, i02, i03); + + data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) * FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)])); + + idx += num_threads; + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_split_k_reduce.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_split_k_reduce.comp new file mode 100644 index 000000000..4c64fd47a --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_split_k_reduce.comp @@ -0,0 +1,48 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {float data_a[];}; +layout (binding = 0) readonly buffer A4 {vec4 data_a4[];}; +layout (binding = 1) writeonly buffer D {float data_d[];}; +layout (binding = 1) writeonly buffer D4 {vec4 data_d4[];}; + +layout (push_constant) uniform parameter { + uint ne; + uint k_num; +} p; + +void main() { + // Each invocation handles four consecutive components + const uint idx = gl_GlobalInvocationID.x * 4; + + if (idx >= p.ne) { + return; + } + + // Check if all four components are in bounds and aligned, + // then use vector loads + if (idx + 3 < p.ne && (p.ne % 4) == 0) { + vec4 result = vec4(0.0f); + + [[unroll]] for (uint i = 0; i < p.k_num; i++) { + result += data_a4[(i * p.ne + idx) / 4]; + } + + data_d4[idx / 4] = result; + } else { + [[unroll]] for (uint j = 0; j < 4; ++j) { + if (idx + j < p.ne) { + float result = 0.0f; + + [[unroll]] for (uint i = 0; i < p.k_num; i++) { + result += data_a[i * p.ne + idx + j]; + } + + data_d[idx + j] = result; + } + } + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp new file mode 100644 index 000000000..24875cdcf --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp @@ -0,0 +1,152 @@ +#version 450 + +#ifdef FLOAT16 +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#endif +#extension GL_EXT_shader_explicit_arithmetic_types : require + +#include "mul_mat_vec_base.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +#if !defined(DATA_A_F32) && !defined(DATA_A_F16) +#define K_PER_ITER 8 +#else +#define K_PER_ITER 2 +#endif + + +uint a_offset, b_offset, d_offset, y_offset; + +void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i, bool lastiter) +{ + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + const uint col = i*BLOCK_SIZE + K_PER_ITER*tid; + const uint iqs = (col%QUANT_K)/QUANT_R; // quant index + const uint iybs = col - col%QUANT_K; // y block start index + +#if K_PER_ITER == 8 +#if QUANT_R == 2 + const B_TYPE_VEC4 bv02 = data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]; + const B_TYPE_VEC4 bv13 = data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs + y_offset) / 4]; + const vec4 bv0 = vec4(bv02.x, bv13.x, bv02.y, bv13.y); + const vec4 bv1 = vec4(bv02.z, bv13.z, bv02.w, bv13.w); +#else + const vec4 bv0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]); + const vec4 bv1 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4 + 1]); +#endif +#else + // Check if the second of the pair of elements is OOB, and don't fetch B or + // accumulate it. We still fetch a pair of elements for A, which is fine for + // quantized formats since they'll be within the same block. We should + // probably skip fetching the second element for F16/F32, but as of now we + // still do. + const bool OOB = lastiter && (iybs + iqs + y_offset >= p.ncols); + + FLOAT_TYPE b0 = 0, b1 = 0; + b0 = FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]); + if (!OOB) { + b1 = FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset]); + } +#endif + uint ibi = first_row*p.ncols; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib = (ibi + col)/QUANT_K; // block index + ibi += p.ncols; + +#if K_PER_ITER == 8 + vec4 v = dequantize4(ib, iqs, a_offset); + vec4 v2 = dequantize4(ib, iqs+(4/QUANT_R), a_offset); + + const vec2 dm = get_dm(ib, a_offset); + if (dm.y != 0) { // quant has min component + v = v * dm.x + dm.y; + v2 = v2 * dm.x + dm.y; + } + + // matrix multiplication + FLOAT_TYPE rowtmp = dot(bv0, v); + rowtmp += dot(bv1, v2); + + if (dm.y == 0) + rowtmp *= dm.x; + + temp[j][n] += rowtmp; +#else + const vec2 v = dequantize(ib, iqs, a_offset); + + // matrix multiplication + temp[j][n] = fma(FLOAT_TYPE(v.x), b0, temp[j][n]); + if (!OOB) { + temp[j][n] = fma(FLOAT_TYPE(v.y), b1, temp[j][n]); + } +#endif + } + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + const uint tid = gl_LocalInvocationID.x; + + get_offsets(a_offset, b_offset, d_offset); + a_offset /= QUANT_K; + + y_offset = QUANT_R == 1 ? 1 : QUANT_K/2; + + FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + uint num_iters = p.ncols / (K_PER_ITER * BLOCK_SIZE); + if (num_iters * K_PER_ITER * BLOCK_SIZE + K_PER_ITER*tid < p.ncols) { + num_iters++; + } + int unroll_count = 4; + uint unrolled_iters = num_iters & ~(unroll_count - 1); + + uint i = 0; + while (i < unrolled_iters) { + // Manually partially unroll the loop + [[unroll]] for (uint k = 0; k < unroll_count; ++k) { + iter(temp, first_row, num_rows, tid, i*K_PER_ITER, false); + i++; + } + } + unroll_count = 2; + unrolled_iters = num_iters & ~(unroll_count - 1); + while (i < unrolled_iters) { + // Manually partially unroll the loop + [[unroll]] for (uint k = 0; k < unroll_count; ++k) { + iter(temp, first_row, num_rows, tid, i*K_PER_ITER, false); + i++; + } + } + while (i < num_iters) { + iter(temp, first_row, num_rows, tid, i*K_PER_ITER, true); + i++; + } + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + +#if defined(DATA_A_IQ4_NL) + init_iq4nl_shmem(); +#endif + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp new file mode 100644 index 000000000..903753c7e --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp @@ -0,0 +1,118 @@ +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_shader_8bit_storage : require + +#ifdef MUL_MAT_ID +#define EXPERT_COUNT 8 +#endif + +#include "types.comp" + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; +layout (binding = 1) readonly buffer BV2 {B_TYPE_VEC2 data_b_v2[];}; +layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];}; + +layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; +#ifdef MUL_MAT_ID +layout (binding = 3) readonly buffer IDS {int data_ids[];}; +#endif + +#include "dequant_funcs.comp" + +layout (push_constant) uniform parameter +{ + uint ncols; + uint stride_a; + uint stride_b; + uint stride_d; + + uint batch_stride_a; + uint batch_stride_b; + uint batch_stride_d; + +#ifdef MUL_MAT_ID + uint nei0; + uint ne11; +#else + uint ne02; + uint ne12; + uint broadcast2; + uint broadcast3; +#endif +} p; + +void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) { +#ifdef MUL_MAT_ID + const uint expert_idx = gl_GlobalInvocationID.y; +#else + const uint batch_idx = gl_GlobalInvocationID.y; +#endif + +#ifndef MUL_MAT_ID + uint batch_idx_a = 0; + if (batch_idx != 0) { + const uint i13 = batch_idx / p.ne12; + const uint i12 = batch_idx % p.ne12; + + const uint i03 = i13 / p.broadcast3; + const uint i02 = i12 / p.broadcast2; + + batch_idx_a = i03 * p.ne02 + i02; + } +#else + const uint expert_id = data_ids[expert_idx]; +#endif + + a_offset = +#ifdef MUL_MAT_ID + expert_id * p.batch_stride_a; +#else + batch_idx_a * p.batch_stride_a; +#endif + b_offset = +#ifdef MUL_MAT_ID + (expert_idx % p.ne11) * p.stride_b; +#else + batch_idx * p.batch_stride_b; +#endif + d_offset = +#ifdef MUL_MAT_ID + expert_idx * p.stride_d; +#else + batch_idx * p.batch_stride_d; +#endif +} + +layout (constant_id = 0) const uint BLOCK_SIZE = 32; +layout (constant_id = 1) const uint NUM_ROWS = 1; +layout (constant_id = 2) const uint NUM_COLS = 1; + +shared FLOAT_TYPE tmpsh[NUM_COLS][NUM_ROWS][BLOCK_SIZE]; + +void reduce_result(const in FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid) { + // sum up partial sums and write back result + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + tmpsh[j][n][tid] = temp[j][n]; + } + } + barrier(); + [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) { + if (tid < s) { + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + tmpsh[j][n][tid] += tmpsh[j][n][tid + s]; + } + } + } + barrier(); + } + if (tid == 0) { + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(tmpsh[j][n][0]); + } + } + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp new file mode 100644 index 000000000..1cc4996d3 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp @@ -0,0 +1,71 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : require + +#define BLOCK_SIZE 32 +#define FLOAT_TYPE float + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; +layout (binding = 2) writeonly buffer D {D_TYPE dst[];}; + +layout (push_constant) uniform parameter +{ + uint ncols_x; + uint nrows_x; + uint row_stride_x; + uint channel_stride_x; + uint channel_x_divisor; + uint b_offset; + uint d_offset; +} p; + +shared FLOAT_TYPE tmp[BLOCK_SIZE]; + +void main() { + const uint tid = gl_LocalInvocationID.x; + const uint row_x = gl_GlobalInvocationID.y; + const uint channel = gl_GlobalInvocationID.z; + const uint channel_x = channel / p.channel_x_divisor; + + const uint nrows_y = p.ncols_x; + const uint nrows_dst = p.nrows_x; + const uint row_dst = row_x; + + const uint idst = channel*nrows_dst + row_dst; + + tmp[tid] = 0.0f; + + for (uint col_x0 = 0; col_x0 < p.ncols_x; col_x0 += BLOCK_SIZE) { + const uint col_x = col_x0 + tid; + + if (col_x >= p.ncols_x) { + break; + } + + const uint row_y = col_x; + + const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x; + const uint iy = channel*nrows_y + row_y; + + const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]); + + tmp[tid] = fma(xi, FLOAT_TYPE(data_b[iy]), tmp[tid]); + } + + // sum up partial sums and write back result + barrier(); + [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + tmp[tid] += tmp[tid + s]; + } + barrier(); + } + + if (tid == 0) { + dst[idst] = tmp[0]; + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp new file mode 100644 index 000000000..9b443807d --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp @@ -0,0 +1,73 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : require + +#define BLOCK_SIZE 32 +#define FLOAT_TYPE float + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; +layout (binding = 2) writeonly buffer D {D_TYPE dst[];}; + +layout (push_constant) uniform parameter +{ + uint ncols_x; + uint nrows_x; + uint nchannels_x; + uint nchannels_y; + uint b_offset; + uint d_offset; +} p; + +shared FLOAT_TYPE tmp[BLOCK_SIZE]; + +void main() { + const uint tid = gl_LocalInvocationID.x; + const uint row_x = gl_GlobalInvocationID.y; + const uint channel = gl_GlobalInvocationID.z; + const uint channel_x = channel / (p.nchannels_y / p.nchannels_x); + + const uint nrows_y = p.ncols_x; + const uint nrows_dst = p.nrows_x; + const uint row_dst = row_x; + + tmp[tid] = FLOAT_TYPE(0.0f); + + for (uint col_x0 = 0; col_x0 < p.ncols_x; col_x0 += BLOCK_SIZE) { + const uint col_x = col_x0 + tid; + + if (col_x >= p.ncols_x) { + break; + } + + // x is transposed and permuted + const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x; + const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]); + + const uint row_y = col_x; + + // y is not transposed but permuted + const uint iy = channel*nrows_y + row_y; + + tmp[tid] = fma(xi, FLOAT_TYPE(data_b[iy]), tmp[tid]); + } + + // dst is not transposed and not permuted + const uint idst = channel*nrows_dst + row_dst; + + // sum up partial sums and write back result + barrier(); + [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + tmp[tid] += tmp[tid + s]; + } + barrier(); + } + + if (tid == 0) { + dst[idst] = tmp[0]; + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp new file mode 100644 index 000000000..934213446 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp @@ -0,0 +1,115 @@ +#version 450 +#extension GL_EXT_shader_explicit_arithmetic_types : require + +#include "mul_mat_vec_base.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 16 threads are used to process each block + const uint it_size = gl_WorkGroupSize.x/16; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid%16; // 0...16 + const uint ix = tid/16; + + const uint step = 8; + + const uint v_im = itid/step; // 0 or 1. 0 computes 0..., 1 computes 128... + const uint v_in = itid - step*v_im; // 0...15 or 0...7 + + const uint l0 = 2*v_in; // 0...15 + const uint q_offset = 32*v_im + l0; + const uint s_offset = 8*v_im; + const uint y_offset = 128*v_im + l0; + + FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { + const uint y_idx = i * QUANT_K + y_offset; + + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + f16vec2 d = data_a[ib0 + i].d; + const FLOAT_TYPE dall = d.x; + const FLOAT_TYPE dmin = d.y; + + uint32_t s0_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 0]; + uint32_t s4_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 1]; + + uint32_t s0_lo4_u32 = s0_u32 & 0x0F0F0F0F; + uint32_t s0_hi4_u32 = (s0_u32 >> 4) & 0x0F0F0F0F; + uint32_t s4_lo4_u32 = s4_u32 & 0x0F0F0F0F; + uint32_t s4_hi4_u32 = (s4_u32 >> 4) & 0x0F0F0F0F; + + uvec4 s0_lo4 = uvec4(unpack8(s0_lo4_u32)); + uvec4 s4_lo4 = uvec4(unpack8(s4_lo4_u32)); + uvec4 s0_hi4 = uvec4(unpack8(s0_hi4_u32)); + uvec4 s4_hi4 = uvec4(unpack8(s4_hi4_u32)); + + uint16_t qs0_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2 + 0]; + uint16_t qs16_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]; + uvec2 qs0 = uvec2(unpack8(qs0_u16)); + uvec2 qs16 = uvec2(unpack8(qs16_u16)); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]; + B_TYPE_VEC2 b16 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]; + B_TYPE_VEC2 b32 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]; + B_TYPE_VEC2 b48 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]; + B_TYPE_VEC2 b64 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]; + B_TYPE_VEC2 b80 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]; + B_TYPE_VEC2 b96 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]; + B_TYPE_VEC2 b112 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]; + + FLOAT_TYPE sum1 = FLOAT_TYPE(0.0); + FLOAT_TYPE sum2 = FLOAT_TYPE(0.0); + [[unroll]] for (int l = 0; l < 2; ++l) { + sum1 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_lo4[0]) * FLOAT_TYPE((qs0[l] >> 0) & 3), + fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_lo4[1]) * FLOAT_TYPE((qs16[l] >> 0) & 3), + fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_lo4[2]) * FLOAT_TYPE((qs0[l] >> 2) & 3), + fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_lo4[3]) * FLOAT_TYPE((qs16[l] >> 2) & 3), + fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_lo4[0]) * FLOAT_TYPE((qs0[l] >> 4) & 3), + fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_lo4[1]) * FLOAT_TYPE((qs16[l] >> 4) & 3), + fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_lo4[2]) * FLOAT_TYPE((qs0[l] >> 6) & 3), + fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_lo4[3]) * FLOAT_TYPE((qs16[l] >> 6) & 3), sum1)))))))); + sum2 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_hi4[0]), + fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_hi4[1]), + fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_hi4[2]), + fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_hi4[3]), + fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_hi4[0]), + fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_hi4[1]), + fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_hi4[2]), + fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_hi4[3]), sum2)))))))); + } + temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n])); + } + } + } + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp new file mode 100644 index 000000000..86b0159d9 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp @@ -0,0 +1,103 @@ +#version 450 +#extension GL_EXT_shader_explicit_arithmetic_types : require + +#include "mul_mat_vec_base.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 16 threads are used to process each block + const uint it_size = gl_WorkGroupSize.x/16; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid%16; // 0...16 + const uint ix = tid/16; + + const uint step = 8; + + const uint v_im = itid/step; // 0 or 1. 0 computes 0..., 1 computes 128... + const uint v_in = itid - step*v_im; // 0...15 or 0...7 + + const uint8_t m = uint8_t(1 << (4 * v_im)); + + const uint l0 = 2*v_in; // 0...15 + const uint q_offset = 32*v_im + l0; + const uint y_offset = 128*v_im + l0; + + FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + const uint s_shift = 4 * v_im; + + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { + const uint y_idx = i * QUANT_K + y_offset; + + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); + + uint16_t s0_16 = data_a_packed16[ib0 + i].scales[0]; + uint16_t s2_16 = data_a_packed16[ib0 + i].scales[1]; + uint16_t s4_16 = data_a_packed16[ib0 + i].scales[2]; + uint16_t s6_16 = data_a_packed16[ib0 + i].scales[3]; + uint16_t s8_16 = data_a_packed16[ib0 + i].scales[4]; + uint16_t s10_16 = data_a_packed16[ib0 + i].scales[5]; + u8vec2 s0 = unpack8(s0_16); + u8vec2 s2 = unpack8(s2_16); + u8vec2 s4 = unpack8(s4_16); + u8vec2 s6 = unpack8(s6_16); + u8vec2 s8 = unpack8(s8_16); + u8vec2 s10 = unpack8(s10_16); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + + B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]; + B_TYPE_VEC2 b16 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]; + B_TYPE_VEC2 b32 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]; + B_TYPE_VEC2 b48 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]; + B_TYPE_VEC2 b64 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]; + B_TYPE_VEC2 b80 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]; + B_TYPE_VEC2 b96 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]; + B_TYPE_VEC2 b112 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]; + + FLOAT_TYPE sum = FLOAT_TYPE(0.0); + [[unroll]] for (int l = 0; l < 2; ++l) { + sum = fma(FLOAT_TYPE(b0[l]) * FLOAT_TYPE(int8_t(((s0[0] >> s_shift) & 0xF) | ((s8[0] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 0)) != 0) ? 0 : 4)), + fma(FLOAT_TYPE(b32[l]) * FLOAT_TYPE(int8_t(((s2[0] >> s_shift) & 0xF) | ((s10[0] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 1)) != 0) ? 0 : 4)), + fma(FLOAT_TYPE(b64[l]) * FLOAT_TYPE(int8_t(((s4[0] >> s_shift) & 0xF) | ((s8[0] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 2)) != 0) ? 0 : 4)), + fma(FLOAT_TYPE(b96[l]) * FLOAT_TYPE(int8_t(((s6[0] >> s_shift) & 0xF) | ((s10[0] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 3)) != 0) ? 0 : 4)), + fma(FLOAT_TYPE(b16[l]) * FLOAT_TYPE(int8_t(((s0[1] >> s_shift) & 0xF) | ((s8[1] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 0)) != 0) ? 0 : 4)), + fma(FLOAT_TYPE(b48[l]) * FLOAT_TYPE(int8_t(((s2[1] >> s_shift) & 0xF) | ((s10[1] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 1)) != 0) ? 0 : 4)), + fma(FLOAT_TYPE(b80[l]) * FLOAT_TYPE(int8_t(((s4[1] >> s_shift) & 0xF) | ((s8[1] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 2)) != 0) ? 0 : 4)), + fma(FLOAT_TYPE(b112[l]) * FLOAT_TYPE(int8_t(((s6[1] >> s_shift) & 0xF) | ((s10[1] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 3)) != 0) ? 0 : 4)), sum)))))))); + } + temp[j][n] = fma(d, sum, temp[j][n]); + } + } + } + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp new file mode 100644 index 000000000..cd1dd8e89 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp @@ -0,0 +1,133 @@ +#version 450 + +#extension GL_EXT_shader_explicit_arithmetic_types : require + +#include "mul_mat_vec_base.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 16 threads are used to process each block + const uint it_size = gl_WorkGroupSize.x/16; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid%16; // 0...16 + const uint ix = tid/16; + + const uint step = 4; + + const uint il = itid/step; // 0...3 + const uint ir = itid - step*il; // 0...7 or 0...3 + const uint n = 4; + + const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 + const uint v_in = il % 2; + + const uint l0 = n * (2 * ir + v_in); // 0...15 + const uint q_offset = 32*v_im + l0; + const uint y_offset = 64*v_im + l0; + + FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { + const uint y1_idx = i * QUANT_K + y_offset; + const uint y2_idx = y1_idx + 128; + + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + f16vec2 d = data_a[ib0 + i].d; + const FLOAT_TYPE dall = FLOAT_TYPE(d.x); + const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); + + uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; + uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; + uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4]; + uvec4 scale0 = uvec4(unpack8(scale0_u32)); + uvec4 scale4 = uvec4(unpack8(scale4_u32)); + uvec4 scale8 = uvec4(unpack8(scale8_u32)); + + const uint32_t sc0 = ( scale0.x & 0x3f); + const uint32_t sc1 = ( scale0.y & 0x3f); + const uint32_t sc2 = ( scale4.x & 0x3f); + const uint32_t sc3 = ( scale4.y & 0x3f); + const uint32_t sc4 = (( scale8.x & 0x0f) | ((scale0.x & 0xc0) >> 2)); + const uint32_t sc5 = (( scale8.y & 0x0f) | ((scale0.y & 0xc0) >> 2)); + const uint32_t sc6 = (((scale8.x >> 4) & 0x0f) | ((scale4.x & 0xc0) >> 2)); + const uint32_t sc7 = (((scale8.y >> 4) & 0x0f) | ((scale4.y & 0xc0) >> 2)); + + uint32_t qs0_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4]; + uint32_t qs64_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4 + 16]; + + uint32_t qs0_u32_lo4 = qs0_u32 & 0x0F0F0F0F; + uint32_t qs0_u32_hi4 = (qs0_u32 >> 4) & 0x0F0F0F0F; + uint32_t qs64_u32_lo4 = qs64_u32 & 0x0F0F0F0F; + uint32_t qs64_u32_hi4 = (qs64_u32 >> 4) & 0x0F0F0F0F; + + uvec4 qs0_lo4 = uvec4(unpack8(qs0_u32_lo4)); + uvec4 qs64_lo4 = uvec4(unpack8(qs64_u32_lo4)); + uvec4 qs0_hi4 = uvec4(unpack8(qs0_u32_hi4)); + uvec4 qs64_hi4 = uvec4(unpack8(qs64_u32_hi4)); + + const uint32_t q4_0 = qs0_lo4.x; + const uint32_t q4_1 = qs0_lo4.y; + const uint32_t q4_2 = qs0_lo4.z; + const uint32_t q4_3 = qs0_lo4.w; + const uint32_t q4_4 = qs0_hi4.x; + const uint32_t q4_5 = qs0_hi4.y; + const uint32_t q4_6 = qs0_hi4.z; + const uint32_t q4_7 = qs0_hi4.w; + const uint32_t q4_8 = qs64_lo4.x; + const uint32_t q4_9 = qs64_lo4.y; + const uint32_t q4_10 = qs64_lo4.z; + const uint32_t q4_11 = qs64_lo4.w; + const uint32_t q4_12 = qs64_hi4.x; + const uint32_t q4_13 = qs64_hi4.y; + const uint32_t q4_14 = qs64_hi4.z; + const uint32_t q4_15 = qs64_hi4.w; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + B_TYPE_VEC4 by10 = data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4]; + B_TYPE_VEC4 by132 = data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 + 8]; + B_TYPE_VEC4 by20 = data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4]; + B_TYPE_VEC4 by232 = data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 + 8]; + + const FLOAT_TYPE sx = fma(FLOAT_TYPE(by10.x), q4_0, fma(FLOAT_TYPE(by10.y), q4_1, fma(FLOAT_TYPE(by10.z), q4_2, FLOAT_TYPE(by10.w) * q4_3))); + const FLOAT_TYPE sy = fma(FLOAT_TYPE(by132.x), q4_4, fma(FLOAT_TYPE(by132.y), q4_5, fma(FLOAT_TYPE(by132.z), q4_6, FLOAT_TYPE(by132.w) * q4_7))); + const FLOAT_TYPE sz = fma(FLOAT_TYPE(by20.x), q4_8, fma(FLOAT_TYPE(by20.y), q4_9, fma(FLOAT_TYPE(by20.z), q4_10, FLOAT_TYPE(by20.w) * q4_11))); + const FLOAT_TYPE sw = fma(FLOAT_TYPE(by232.x), q4_12, fma(FLOAT_TYPE(by232.y), q4_13, fma(FLOAT_TYPE(by232.z), q4_14, FLOAT_TYPE(by232.w) * q4_15))); + const FLOAT_TYPE smin = + fma(FLOAT_TYPE(by10.x), sc2, fma(FLOAT_TYPE(by132.x), sc3, fma(FLOAT_TYPE(by20.x), sc6, fma(FLOAT_TYPE(by232.x), sc7, + fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7, + fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7, + fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6, FLOAT_TYPE(by232.w) * sc7))))))))))))))); + temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n])); + } + } + } + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp new file mode 100644 index 000000000..0a68891c3 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp @@ -0,0 +1,162 @@ +#version 450 + +#extension GL_EXT_shader_explicit_arithmetic_types : require + +#include "mul_mat_vec_base.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 16 threads are used to process each block + const uint it_size = gl_WorkGroupSize.x/16; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid%16; // 0...16 + const uint ix = tid/16; + + const uint il = itid/4; // 0...3 + const uint ir = itid - 4*il; // 0...7 or 0...3 + + const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 + const uint v_in = il % 2; + + const uint l0 = 4*ir + 2*v_in; // 0...15 + const uint q_offset = 32*v_im + l0; + const uint y_offset = 64*v_im + l0; + + FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { + const uint y1_idx = i * QUANT_K + y_offset; + const uint y2_idx = y1_idx + 128; + + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + f16vec2 d = data_a[ib0 + i].d; + const FLOAT_TYPE dall = FLOAT_TYPE(d.x); + const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); + + uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; + uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; + uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4]; + uvec4 scale0 = uvec4(unpack8(scale0_u32)); + uvec4 scale4 = uvec4(unpack8(scale4_u32)); + uvec4 scale8 = uvec4(unpack8(scale8_u32)); + + const uint32_t sc0 = ( scale0.x & 0x3f); + const uint32_t sc1 = ( scale0.y & 0x3f); + const uint32_t sc2 = ( scale4.x & 0x3f); + const uint32_t sc3 = ( scale4.y & 0x3f); + const uint32_t sc4 = (( scale8.x & 0x0f) | ((scale0.x & 0xc0) >> 2)); + const uint32_t sc5 = (( scale8.y & 0x0f) | ((scale0.y & 0xc0) >> 2)); + const uint32_t sc6 = (((scale8.x >> 4) & 0x0f) | ((scale4.x & 0xc0) >> 2)); + const uint32_t sc7 = (((scale8.y >> 4) & 0x0f) | ((scale4.y & 0xc0) >> 2)); + + uint32_t qs0_16_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16); + uint32_t qs64_80_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 32]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 40]) << 16); + + uint32_t qs0_16_u32_lo4 = qs0_16_u32 & 0x0F0F0F0F; + uint32_t qs0_16_u32_hi4 = (qs0_16_u32 >> 4) & 0x0F0F0F0F; + uint32_t qs64_80_u32_lo4 = qs64_80_u32 & 0x0F0F0F0F; + uint32_t qs64_80_u32_hi4 = (qs64_80_u32 >> 4) & 0x0F0F0F0F; + + uint32_t qh = pack32(u16vec2(data_a_packed16[ib0 + i].qh[l0 / 2], data_a_packed16[ib0 + i].qh[l0 / 2 + 8])); + + uint32_t qs0_16_lo4_offset16 = ((qh >> (2*v_im)) & 0x01010101) << 4; + uint32_t qs0_16_hi4_offset16 = ((qh >> (2*v_im)) & 0x02020202) << 3; + uint32_t qs64_80_lo4_offset16 = ((qh >> (2*v_im)) & 0x10101010) << 0; + uint32_t qs64_80_hi4_offset16 = ((qh >> (2*v_im)) & 0x20202020) >> 1; + + qs0_16_u32_lo4 += qs0_16_lo4_offset16; + qs0_16_u32_hi4 += qs0_16_hi4_offset16; + qs64_80_u32_lo4 += qs64_80_lo4_offset16; + qs64_80_u32_hi4 += qs64_80_hi4_offset16; + + uvec4 qs0_16_lo4 = uvec4(unpack8(qs0_16_u32_lo4)); + uvec4 qs64_80_lo4 = uvec4(unpack8(qs64_80_u32_lo4)); + uvec4 qs0_16_hi4 = uvec4(unpack8(qs0_16_u32_hi4)); + uvec4 qs64_80_hi4 = uvec4(unpack8(qs64_80_u32_hi4)); + + const uint32_t q4_0 = qs0_16_lo4.x; + const uint32_t q4_1 = qs0_16_lo4.y; + const uint32_t q4_2 = qs0_16_lo4.z; + const uint32_t q4_3 = qs0_16_lo4.w; + const uint32_t q4_4 = qs0_16_hi4.x; + const uint32_t q4_5 = qs0_16_hi4.y; + const uint32_t q4_6 = qs0_16_hi4.z; + const uint32_t q4_7 = qs0_16_hi4.w; + const uint32_t q4_8 = qs64_80_lo4.x; + const uint32_t q4_9 = qs64_80_lo4.y; + const uint32_t q4_10 = qs64_80_lo4.z; + const uint32_t q4_11 = qs64_80_lo4.w; + const uint32_t q4_12 = qs64_80_hi4.x; + const uint32_t q4_13 = qs64_80_hi4.y; + const uint32_t q4_14 = qs64_80_hi4.z; + const uint32_t q4_15 = qs64_80_hi4.w; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + B_TYPE_VEC2 by10 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2]; + B_TYPE_VEC2 by116 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 8]; + B_TYPE_VEC2 by132 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 16]; + B_TYPE_VEC2 by148 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 24]; + B_TYPE_VEC2 by20 = data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2]; + B_TYPE_VEC2 by216 = data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 8]; + B_TYPE_VEC2 by232 = data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 16]; + B_TYPE_VEC2 by248 = data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 24]; + + const FLOAT_TYPE sx = + fma(FLOAT_TYPE(by10.x), q4_0, + fma(FLOAT_TYPE(by10.y), q4_1, + fma(FLOAT_TYPE(by116.x), q4_2, + FLOAT_TYPE(by116.y) * q4_3))); + const FLOAT_TYPE sy = + fma(FLOAT_TYPE(by132.x), q4_4, + fma(FLOAT_TYPE(by132.y), q4_5, + fma(FLOAT_TYPE(by148.x), q4_6, + FLOAT_TYPE(by148.y) * q4_7))); + const FLOAT_TYPE sz = + fma(FLOAT_TYPE(by20.x), q4_8, + fma(FLOAT_TYPE(by20.y), q4_9, + fma(FLOAT_TYPE(by216.x), q4_10, + FLOAT_TYPE(by216.y) * q4_11))); + const FLOAT_TYPE sw = + fma(FLOAT_TYPE(by232.x), q4_12, + fma(FLOAT_TYPE(by232.y), q4_13, + fma(FLOAT_TYPE(by248.x), q4_14, + FLOAT_TYPE(by248.y) * q4_15))); + const FLOAT_TYPE smin = + fma(FLOAT_TYPE(by10.x) + FLOAT_TYPE(by10.y) + FLOAT_TYPE(by116.x) + FLOAT_TYPE(by116.y), sc2, + fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3, + fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6, + (FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7))); + temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n])); + } + } + } + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp new file mode 100644 index 000000000..70e13a56b --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp @@ -0,0 +1,112 @@ +#version 450 + +#extension GL_EXT_shader_explicit_arithmetic_types : require + +#include "mul_mat_vec_base.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 16 threads are used to process each block + const uint it_size = gl_WorkGroupSize.x/16; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid%16; // 0...16 + const uint ix = tid/16; + + const uint step = 8; + + const uint v_im = itid/step; // 0 or 1. 0 computes 0..., 1 computes 128... + const uint v_in = itid - step*v_im; // 0...15 or 0...7 + + const uint l0 = 4 * v_in; // 0, 4, 8, ..., 28 + const uint is = v_in / 4; + + const uint ql_offset = 64*v_im + l0; + const uint qh_offset = 32*v_im + l0; + const uint s_offset = 8*v_im + is; + const uint y_offset = 128*v_im + l0; + + FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { + const uint y_idx = i * QUANT_K + y_offset; + + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); + + FLOAT_TYPE scales[4]; + scales[0] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]); + scales[1] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]); + scales[2] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]); + scales[3] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]); + + uint32_t ql0_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 1]) << 16); + uint32_t ql32_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 16]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 17]) << 16); + + uint32_t ql0_u32_lo4 = ql0_u32 & 0x0F0F0F0F; + uint32_t ql0_u32_hi4 = (ql0_u32 >> 4) & 0x0F0F0F0F; + uint32_t ql32_u32_lo4 = ql32_u32 & 0x0F0F0F0F; + uint32_t ql32_u32_hi4 = (ql32_u32 >> 4) & 0x0F0F0F0F; + + uint32_t qh_u32 = uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2 + 1]) << 16); + uint32_t qh0_u32 = (qh_u32 & 0x03030303) << 4; + uint32_t qh2_u32 = (qh_u32 & 0x0C0C0C0C) << 2; + uint32_t qh4_u32 = (qh_u32 & 0x30303030) << 0; + uint32_t qh6_u32 = (qh_u32 & 0xC0C0C0C0) >> 2; + + uint32_t q0_u32 = ql0_u32_lo4 | qh0_u32; + uint32_t q1_u32 = ql32_u32_lo4 | qh2_u32; + uint32_t q2_u32 = ql0_u32_hi4 | qh4_u32; + uint32_t q3_u32 = ql32_u32_hi4 | qh6_u32; + + uvec4 q0 = uvec4(unpack8(q0_u32)); + uvec4 q1 = uvec4(unpack8(q1_u32)); + uvec4 q2 = uvec4(unpack8(q2_u32)); + uvec4 q3 = uvec4(unpack8(q3_u32)); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + B_TYPE_VEC4 by0 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4]; + B_TYPE_VEC4 by32 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 8]; + B_TYPE_VEC4 by64 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 16]; + B_TYPE_VEC4 by96 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 24]; + + FLOAT_TYPE sum = FLOAT_TYPE(0.0); + [[unroll]] for (int l = 0; l < 4; ++l) { + sum = fma(FLOAT_TYPE(by0[l]) * scales[0], FLOAT_TYPE(int8_t(q0[l]) - 32), + fma(FLOAT_TYPE(by32[l]) * scales[1], FLOAT_TYPE(int8_t(q1[l]) - 32), + fma(FLOAT_TYPE(by64[l]) * scales[2], FLOAT_TYPE(int8_t(q2[l]) - 32), + fma(FLOAT_TYPE(by96[l]) * scales[3], FLOAT_TYPE(int8_t(q3[l]) - 32), sum)))); + } + temp[j][n] += sum * d; + } + } + } + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp new file mode 100644 index 000000000..48122cbef --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -0,0 +1,631 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : require + +#ifdef FLOAT16 +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#endif + +#ifdef COOPMAT +#extension GL_KHR_cooperative_matrix : enable +#extension GL_KHR_memory_scope_semantics : enable +#extension GL_KHR_shader_subgroup_basic : enable +#endif + +#ifdef MUL_MAT_ID +#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require +#endif + +#include "types.comp" + +#ifndef LOAD_VEC_A +#define LOAD_VEC_A 1 +#endif +#ifndef LOAD_VEC_B +#define LOAD_VEC_B 1 +#endif + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; +layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; + +#ifdef MUL_MAT_ID +layout (binding = 3) readonly buffer IDS {int data_ids[];}; +#endif + +layout (push_constant) uniform parameter +{ + uint M; + uint N; + uint K; + uint stride_a; + uint stride_b; + uint stride_d; + + uint batch_stride_a; + uint batch_stride_b; + uint batch_stride_d; + +#ifdef MUL_MAT_ID + uint nei0; + uint nei1; + uint nbi1; + uint ne11; +#else + uint k_split; + uint ne02; + uint ne12; + uint broadcast2; + uint broadcast3; +#endif +} p; + +layout (constant_id = 0) const uint BLOCK_SIZE = 64; +layout (constant_id = 1) const uint BM = 64; +layout (constant_id = 2) const uint BN = 64; +layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant +layout (constant_id = 4) const uint WM = 32; +layout (constant_id = 5) const uint WN = 32; +layout (constant_id = 6) const uint WMITER = 2; +layout (constant_id = 7) const uint TM = 4; +layout (constant_id = 8) const uint TN = 2; +layout (constant_id = 9) const uint TK = 1; // Only needed for coopmat +layout (constant_id = 10) const uint WARP = 32; + +#ifdef COOPMAT +#define SHMEM_STRIDE (BK + 8) +#else +#define SHMEM_STRIDE (BK + 1) +#endif + +shared FLOAT_TYPE buf_a[BM * SHMEM_STRIDE]; +shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE]; + +#ifdef MUL_MAT_ID +shared u16vec2 row_ids[3072]; +#endif // MUL_MAT_ID + +#define NUM_WARPS (BLOCK_SIZE / WARP) + +#ifdef COOPMAT +shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS]; +#endif + +void main() { +#if defined(DATA_A_IQ4_NL) + init_iq4nl_shmem(); +#endif + +#ifdef MUL_MAT_ID + const uint expert_idx = gl_GlobalInvocationID.z; +#else + const uint batch_idx = gl_GlobalInvocationID.z; + + const uint i13 = batch_idx / p.ne12; + const uint i12 = batch_idx % p.ne12; + + const uint i03 = i13 / p.broadcast3; + const uint i02 = i12 / p.broadcast2; + + const uint batch_idx_a = i03 * p.ne02 + i02; +#endif + + const uint blocks_m = (p.M + BM - 1) / BM; + const uint ir = gl_WorkGroupID.x % blocks_m; + const uint ik = gl_WorkGroupID.x / blocks_m; + const uint ic = gl_WorkGroupID.y; + + const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER); + const uint WSUBM = WM / WMITER; + const uint WSUBN = WN / WNITER; + +#ifdef COOPMAT + const uint warp_i = gl_SubgroupID; + + const uint tiw = gl_SubgroupInvocationID; + + const uint cms_per_row = WM / TM; + const uint cms_per_col = WN / TN; + + const uint storestride = WARP / TM; + const uint store_r = tiw % TM; + const uint store_c = tiw / TM; +#else + const uint warp_i = gl_LocalInvocationID.x / WARP; + + const uint tiw = gl_LocalInvocationID.x % WARP; + + const uint tiwr = tiw % (WSUBM / TM); + const uint tiwc = tiw / (WSUBM / TM); +#endif + + const uint warp_r = warp_i % (BM / WM); + const uint warp_c = warp_i / (BM / WM); + + const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A); + const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A); + const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B); + const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B); + + const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A / BK; + const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B / BK; + +#ifdef MUL_MAT_ID + uint _ne1 = 0; + for (uint ii1 = 0; ii1 < p.nei1; ii1++) { + for (uint ii0 = 0; ii0 < p.nei0; ii0++) { + if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) { + row_ids[_ne1] = u16vec2(ii0, ii1); + _ne1++; + } + } + } + + barrier(); + + // Workgroup has no work + if (ic * BN >= _ne1) return; +#endif + +#ifdef MUL_MAT_ID + const uint start_k = 0; + const uint end_k = p.K; +#else + const uint start_k = ik * p.k_split; + const uint end_k = min(p.K, (ik + 1) * p.k_split); +#endif + + uint pos_a = ( +#ifdef MUL_MAT_ID + expert_idx * p.batch_stride_a + +#else + batch_idx_a * p.batch_stride_a + +#endif + ir * BM * p.stride_a + start_k) / LOAD_VEC_A; +#ifdef MUL_MAT_ID + uint pos_b = 0; +#else + uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B; +#endif + +#ifdef COOPMAT + coopmat cache_a; + coopmat cache_b; + coopmat sums[cms_per_row * cms_per_col]; + + [[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) { + sums[i] = coopmat(0.0f); + } +#else + ACC_TYPE sums[WMITER * TM * WNITER * TN]; + FLOAT_TYPE cache_a[WMITER * TM]; + FLOAT_TYPE cache_b[WNITER * TN]; + + [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) { + sums[i] = ACC_TYPE(0.0f); + } +#endif + + for (uint block = start_k; block < end_k; block += BK) { + [[unroll]] for (uint l = 0; l < BM; l += loadstride_a) { + +#if defined(DATA_A_F32) || defined(DATA_A_F16) +#if LOAD_VEC_A == 8 + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; + buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx][0].x); + buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx][0].y); + buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx][0].z); + buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx][0].w); + buf_a[buf_idx + 4] = FLOAT_TYPE(data_a[idx][1].x); + buf_a[buf_idx + 5] = FLOAT_TYPE(data_a[idx][1].y); + buf_a[buf_idx + 6] = FLOAT_TYPE(data_a[idx][1].z); + buf_a[buf_idx + 7] = FLOAT_TYPE(data_a[idx][1].w); +#elif LOAD_VEC_A == 4 + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; + buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx].x); + buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx].y); + buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx].z); + buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx].w); +#else + if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) { + buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]); + } else { + buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(0.0f); + } +#endif +#elif defined(DATA_A_Q4_0) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a; + + const uint ib = idx / 16; + const uint iqs = idx & 0xF; + + const float d = float(data_a[ib].d); + const uint vui = uint(data_a[ib].qs[iqs]); + const vec2 v = (vec2(vui & 0xF, vui >> 4) - 8.0f) * d; + + buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); +#elif defined(DATA_A_Q4_1) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a; + + const uint ib = idx / 16; + const uint iqs = idx & 0xF; + + const float d = float(data_a[ib].d); + const float m = float(data_a[ib].m); + const uint vui = uint(data_a[ib].qs[iqs]); + const vec2 v = vec2(vui & 0xF, vui >> 4) * d + m; + + buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); +#elif defined(DATA_A_Q5_0) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a; + + const uint ib = idx / 16; + const uint iqs = idx & 0xF; + + const float d = float(data_a[ib].d); + const uint uint_qh = uint(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]; + const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); + const uint vui = uint(data_a[ib].qs[iqs]); + const vec2 v = (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f) * d; + + buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); +#elif defined(DATA_A_Q5_1) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a; + + const uint ib = idx / 16; + const uint iqs = idx & 0xF; + + const float d = float(data_a[ib].d); + const float m = float(data_a[ib].m); + const uint uint_qh = data_a[ib].qh; + const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); + const uint vui = uint(data_a[ib].qs[iqs]); + const vec2 v = vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) * d + m; + + buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); +#elif defined(DATA_A_Q8_0) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; + + const uint ib = idx / 16; + const uint iqs = (idx & 0xF) * 2; + + const float d = float(data_a[ib].d); + const vec2 v = vec2(int(data_a[ib].qs[iqs]), int(data_a[ib].qs[iqs + 1])) * d; + + buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); +#elif defined(DATA_A_Q2_K) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; + + const uint ib = idx / 128; // 2 values per idx + const uint iqs = idx % 128; // 0..127 + + const uint qsi = (iqs / 64) * 32 + (iqs % 16) * 2; // 0,2,4..30 + const uint scalesi = iqs / 8; // 0..15 + const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6 + + const uvec2 qs = uvec2(data_a[ib].qs[qsi], data_a[ib].qs[qsi + 1]); + const uint scales = data_a[ib].scales[scalesi]; + const vec2 d = vec2(data_a[ib].d); + + const vec2 v = d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4); + + buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); +#elif defined(DATA_A_Q3_K) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; + + const uint ib = idx / 128; // 2 values per idx + const uint iqs = idx % 128; // 0..127 + + const uint n = iqs / 64; // 0,1 + const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..62 + const uint hmi = (iqs % 16) * 2; // 0,2,4..30 + const uint j = (iqs % 64) / 4; // 0..3 + const uint is = iqs / 8; // 0..15 + const uint halfsplit = ((iqs % 64) / 16); // 0,1,2,3 + const uint qsshift = halfsplit * 2; // 0,2,4,6 + const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128 + + const int8_t us = int8_t(is < 4 ? (data_a[ib].scales[is-0] & 0xF) | (((data_a[ib].scales[is+8] >> 0) & 3) << 4) : + is < 8 ? (data_a[ib].scales[is-0] & 0xF) | (((data_a[ib].scales[is+4] >> 2) & 3) << 4) : + is < 12 ? (data_a[ib].scales[is-8] >> 4) | (((data_a[ib].scales[is+0] >> 4) & 3) << 4) : + (data_a[ib].scales[is-8] >> 4) | (((data_a[ib].scales[is-4] >> 6) & 3) << 4)); + const float dl = float(data_a[ib].d) * float(us - 32); + + buf_a[buf_idx ] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi ] >> qsshift) & 3) - (((data_a[ib].hmask[hmi ] & m) != 0) ? 0 : 4))); + buf_a[buf_idx + 1] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4))); +#elif defined(DATA_A_Q4_K) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; + + const uint ib = idx / 128; // 2 values per idx + const uint iqs = idx % 128; // 0..127 + + const uint n = iqs / 32; // 0,1,2,3 + const uint b = (iqs % 32) / 16; // 0,1 + const uint is = 2 * n + b; // 0..7 + const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126 + + const vec2 loadd = vec2(data_a[ib].d); + + const uint scidx0 = (is < 4) ? is : (is + 4); + const uint scidx1 = (is < 4) ? is : (is - 4); + const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0; + const uint scidxshift1 = (is < 4) ? 0 : 2; + const uint mbidx0 = is + 4; + const uint mbidx1 = (is < 4) ? is + 4 : is; + const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0; + const uint mbidxshift0 = (is < 4) ? 0 : 4; + const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0; + const uint mbidxshift1 = (is < 4) ? 0 : 2; + + const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); + const uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + + const float d = loadd.x * sc; + const float m = -loadd.y * mbyte; + + buf_a[buf_idx ] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF), m)); + buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF), m)); +#elif defined(DATA_A_Q5_K) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; + + const uint ib = idx / 128; // 2 values per idx + const uint iqs = idx % 128; // 0..127 + + const uint n = iqs / 32; // 0,1,2,3 + const uint b = (iqs % 32) / 16; // 0,1 + const uint is = 2 * n + b; // 0..7 + const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126 + const uint qhi = (iqs % 16) * 2; // 0,2,4..30 + + const uint8_t hm = uint8_t(1 << (iqs / 16)); + + const vec2 loadd = vec2(data_a[ib].d); + + const uint scidx0 = (is < 4) ? is : (is + 4); + const uint scidx1 = (is < 4) ? is : (is - 4); + const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0; + const uint scidxshift1 = (is < 4) ? 0 : 2; + const uint mbidx0 = is + 4; + const uint mbidx1 = (is < 4) ? is + 4 : is; + const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0; + const uint mbidxshift0 = (is < 4) ? 0 : 4; + const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0; + const uint mbidxshift1 = (is < 4) ? 0 : 2; + + const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); + const uint8_t mbyte = uint8_t(((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + + const float d = loadd.x * sc; + const float m = -loadd.y * mbyte; + + buf_a[buf_idx ] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi ] & hm) != 0 ? 16 : 0), m)); + buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m)); +#elif defined(DATA_A_Q6_K) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; + + const uint ib = idx / 128; // 2 values per idx + const uint iqs = idx % 128; // 0..127 + + const uint n = iqs / 64; // 0,1 + const uint b = (iqs % 64) / 32; // 0,1 + const uint is_b = (iqs % 16) / 8; // 0,1 + const uint qhshift = ((iqs % 64) / 16) * 2; // 0,2,4,6 + const uint is = 8 * n + qhshift + is_b; // 0..15 + const uint qsi = n * 64 + (iqs % 32) * 2; // 0,2,4..126 + const uint qhi = n * 32 + (iqs % 16) * 2; // 0,2,4..62 + + const float dscale = float(data_a[ib].d) * float(data_a[ib].scales[is]); + + buf_a[buf_idx ] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi ] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi ] >> qhshift) & 3) << 4)) - 32)); + buf_a[buf_idx + 1] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32)); +#elif defined(DATA_A_IQ4_NL) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a; + + const uint ib = idx / 16; + const uint iqs = idx & 0xF; + + const float d = float(data_a[ib].d); + const uint vui = uint(data_a[ib].qs[iqs]); + const vec2 v = vec2(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[vui >> 4]) * d; + + buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); +#endif + } + [[unroll]] for (uint l = 0; l < BN; l += loadstride_b) { +#if LOAD_VEC_B == 8 +#ifdef MUL_MAT_ID + const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l]; + const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b; +#else + const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b; +#endif + const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B; + buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx][0].x); + buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx][0].y); + buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx][0].z); + buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx][0].w); + buf_b[buf_idx + 4] = FLOAT_TYPE(data_b[idx][1].x); + buf_b[buf_idx + 5] = FLOAT_TYPE(data_b[idx][1].y); + buf_b[buf_idx + 6] = FLOAT_TYPE(data_b[idx][1].z); + buf_b[buf_idx + 7] = FLOAT_TYPE(data_b[idx][1].w); +#elif LOAD_VEC_B == 4 +#ifdef MUL_MAT_ID + const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l]; + const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b; +#else + const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b; +#endif + const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B; + buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx].x); + buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx].y); + buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx].z); + buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx].w); +#elif !MUL_MAT_ID + if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) { + buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]); + } else { + buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f); + } +#else + const uint row_i = ic * BN + loadc_b + l; + if (row_i < _ne1) { + const u16vec2 row_idx = row_ids[row_i]; + buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]); + } else { + buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f); + } +#endif + } + + barrier(); + + pos_a += BK / LOAD_VEC_A; + pos_b += BK / LOAD_VEC_B; + +#ifdef COOPMAT + [[unroll]] for (uint i = 0; i < BK; i += TK) { + [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { + // Load from shared into cache + coopMatLoad(cache_a, buf_a, (warp_r * WM + cm_row * TM) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor); + + [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { + coopMatLoad(cache_b, buf_b, (warp_c * WN + cm_col * TN) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor); + + sums[cm_col * cms_per_row + cm_row] = coopMatMulAdd(cache_a, cache_b, sums[cm_col * cms_per_row + cm_row]); + } + } + } +#else + [[unroll]] for (uint i = 0; i < BK; i++) { + // Load from shared into cache + [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { + [[unroll]] for (uint j = 0; j < TM; j++) { + cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + i]; + } + } + [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { + [[unroll]] for (uint j = 0; j < TN; j++) { + cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * SHMEM_STRIDE + i]; + } + } + + [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { + [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { + [[unroll]] for (uint cc = 0; cc < TN; cc++) { + [[unroll]] for (uint cr = 0; cr < TM; cr++) { + const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr; + sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr]), ACC_TYPE(cache_b[wsic * TN + cc]), sums[sums_idx]); + } + } + } + } + } +#endif + + barrier(); + } + + const uint dr = ir * BM + warp_r * WM; + const uint dc = ic * BN + warp_c * WN; + +#ifndef MUL_MAT_ID + const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; +#endif + +#ifdef COOPMAT +#ifdef MUL_MAT_ID + [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { + [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { + coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); + + [[unroll]] for (uint col = 0; col < BN; col += storestride) { + const uint row_i = dc + cm_col * TN + col + store_c; + if (row_i >= _ne1) break; + + const u16vec2 row_idx = row_ids[row_i]; + + data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); + } + } + } +#else + const bool is_aligned = p.stride_d % 4 == 0; // Assumption: D_TYPE == float + + [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { + [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { + const bool is_in_bounds = dr + (cm_row + 1) * TM <= p.M && dc + (cm_col + 1) * TN <= p.N; + + if (is_aligned && is_in_bounds) { + // Full coopMat is within bounds and stride_d is aligned with 16B + coopmat cm_dtype = coopmat(sums[cm_col * cms_per_row + cm_row]); + coopMatStore(cm_dtype, data_d, offsets + (dc + cm_col * TN) * p.stride_d + dr + cm_row * TM, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor); + } else if (is_in_bounds) { + // Full coopMat is within bounds, but stride_d is not aligned + coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); + + [[unroll]] for (uint col = 0; col < TN; col += storestride) { + data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); + } + } else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) { + // Partial coopMat is within bounds + coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); + + [[unroll]] for (uint col = 0; col < TN; col += storestride) { + if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) { + data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); + } + } + } + } + } +#endif // MUL_MAT_ID +#else + [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { + [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { + + const uint dr_warp = dr + wsir * WSUBM + tiwr * TM; + const uint dc_warp = dc + wsic * WSUBN + tiwc * TN; + [[unroll]] for (uint cc = 0; cc < TN; cc++) { +#ifdef MUL_MAT_ID + const uint row_i = dc_warp + cc; + if (row_i >= _ne1) break; + + const u16vec2 row_idx = row_ids[row_i]; +#endif // MUL_MAT_ID + [[unroll]] for (uint cr = 0; cr < TM; cr++) { +#ifdef MUL_MAT_ID + data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); +#else + if (dr_warp + cr < p.M && dc_warp + cc < p.N) { + data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); + } +#endif // MUL_MAT_ID + } + } + } + } +#endif // COOPMAT +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp new file mode 100644 index 000000000..cbfa5dce1 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp @@ -0,0 +1,328 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : require + +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require + +#extension GL_KHR_memory_scope_semantics : enable +#extension GL_KHR_cooperative_matrix : enable +#extension GL_NV_cooperative_matrix2 : enable +#extension GL_EXT_buffer_reference : enable +#extension GL_KHR_shader_subgroup_ballot : enable +#extension GL_KHR_shader_subgroup_vote : enable + +#include "types.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (constant_id = 1) const uint BM = 64; +layout (constant_id = 2) const uint BN = 64; +layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant + +layout (push_constant) uniform parameter +{ + uint M; + uint N; + uint K; + uint stride_a; + uint stride_b; + uint stride_d; + + uint batch_stride_a; + uint batch_stride_b; + uint batch_stride_d; + +#ifdef MUL_MAT_ID + uint nei0; + uint nei1; + uint nbi1; + uint ne11; +#else + uint k_split; + uint ne02; + uint ne12; + uint broadcast2; + uint broadcast3; +#endif +} p; + + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; +layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; + +#if QUANT_K > 1 +#define DECODEFUNCA , dequantFuncA +#define MAT_A_TYPE float16_t + +#include "dequant_funcs_cm2.comp" + +#else +#define DECODEFUNCA +#define MAT_A_TYPE A_TYPE +#endif + +#define MAT_B_TYPE B_TYPE + +#ifdef MUL_MAT_ID +layout (binding = 3) readonly buffer IDS {int data_ids[];}; + +shared u16vec4 row_ids[3072]; + +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufB { + B_TYPE b[]; +}; + +uint _ne1; +shared uint _ne1_sh; + +B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const uint row_i = blockCoords[0]; + + if (row_i >= _ne1) { + return B_TYPE(0.0); + } + + const u16vec4 row_idx = row_ids[row_i]; + B_TYPE ret = data_b[row_idx.y * p.batch_stride_b + row_idx.x * p.stride_b + blockCoords[1]]; + + return ret; +} + +D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t ir, const in uint32_t ic) +{ + uint dr = ir * BM + r; + uint dc = ic * BN + c; + + if (dr < p.M && dc < _ne1) { + uint row_i = dc; + const u16vec4 row_idx = row_ids[row_i]; + data_d[row_idx.y * p.batch_stride_d + row_idx.z * p.stride_d + dr] = elem; + } + return elem; +} + +#endif + +void main() { +#if defined(DATA_A_IQ4_NL) + init_iq4nl_shmem(); +#endif + +#ifdef MUL_MAT_ID + const uint expert_idx = gl_GlobalInvocationID.z; +#else + const uint batch_idx = gl_GlobalInvocationID.z; + + const uint i13 = batch_idx / p.ne12; + const uint i12 = batch_idx % p.ne12; + + const uint i03 = i13 / p.broadcast3; + const uint i02 = i12 / p.broadcast2; + + const uint batch_idx_a = i03 * p.ne02 + i02; +#endif + + const uint blocks_m = (p.M + BM - 1) / BM; + const uint ir = gl_WorkGroupID.x % blocks_m; + const uint ik = gl_WorkGroupID.x / blocks_m; + const uint ic = gl_WorkGroupID.y; + +#ifdef MUL_MAT_ID + // Spread the search across all elements in the first subgroup + if (gl_SubgroupID == 0) { + _ne1 = 0; + uint num_elements = p.nei1 * p.nei0; + + for (uint i = gl_SubgroupInvocationID; subgroupAny(i < num_elements); i += gl_SubgroupSize) { + bool in_range = i < num_elements; + uint ii0 = i % p.nei0; + uint ii1 = i / p.nei0; + uint id = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0; + uvec4 ballot = subgroupBallot(in_range && id == expert_idx); + uint idx = subgroupBallotExclusiveBitCount(ballot); + if (in_range && id == expert_idx) { + row_ids[_ne1 + idx] = u16vec4(ii0 % p.ne11, ii1, ii0, 0); + } + _ne1 += subgroupBallotBitCount(ballot); + } + _ne1_sh = _ne1; + } + + barrier(); + + _ne1 = _ne1_sh; + + // Workgroup has no work + if (ic * BN >= _ne1) return; +#endif + +#ifdef MUL_MAT_ID + uint start_k = 0; + const uint end_k = p.K; +#else + uint start_k = ik * p.k_split; + const uint end_k = min(p.K, (ik + 1) * p.k_split); +#endif + + coopmat sum; + sum = coopmat(0.0); + +#ifdef MUL_MAT_ID + uint pos_a = (expert_idx * p.batch_stride_a) / QUANT_K; + uint pos_b = 0; +#else + uint pos_a = (batch_idx_a * p.batch_stride_a) / QUANT_K; + uint pos_b = batch_idx * p.batch_stride_b; +#endif + + uint stride_a = p.stride_a / QUANT_K; + uint stride_b = p.stride_b; + + // Hint to the compiler that values are aligned (want 16B alignment). + // Quants are always block-aligned, no alignment needed. +#if ALIGNED +#if QUANT_K == 1 + stride_a &= ~7; +#endif + stride_b &= ~7; +#endif + + // Create layouts for both clamped and unclamped accesses + tensorLayoutNV<2> tensorLayoutA = createTensorLayoutNV(2); + tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutAClamp = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); + tensorLayoutNV<2> tensorLayoutB = createTensorLayoutNV(2); + tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutBClamp = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); + tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); + +#if QUANT_K > 1 + tensorLayoutA = setTensorLayoutBlockSizeNV(tensorLayoutA, 1, QUANT_K); + tensorLayoutAClamp = setTensorLayoutBlockSizeNV(tensorLayoutAClamp, 1, QUANT_K); +#endif + + // Use end_k rather than p.K as the dimension because that's what + // we need to bound check against when using split_k + tensorLayoutA = setTensorLayoutDimensionNV(tensorLayoutA, p.M, end_k); + tensorLayoutB = setTensorLayoutDimensionNV(tensorLayoutB, p.N, end_k); + tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.N, p.M); + tensorLayoutAClamp = setTensorLayoutDimensionNV(tensorLayoutAClamp, p.M, end_k); + tensorLayoutBClamp = setTensorLayoutDimensionNV(tensorLayoutBClamp, p.N, end_k); + + tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0); + +#if !defined(MUL_MAT_ID) + // Detect a fast path where all loads are entirely in bounds and no clamping is required + if ((ir + 1) * BM <= p.M && (ic + 1) * BN <= p.N && (start_k % BK) == 0 && (end_k % BK) == 0 && +#if QUANT_K == 1 + (stride_a % 8) == 0 && +#endif + (stride_b % 8) == 0 && (start_k % 8) == 0) { + // Hint to the compiler that values are aligned (want 16B alignment) + start_k &= ~7; + stride_b &= ~7; +#if QUANT_K == 1 + stride_a &= ~7; +#endif + + tensorLayoutA = setTensorLayoutStrideNV(tensorLayoutA, stride_a, 1); + tensorLayoutB = setTensorLayoutStrideNV(tensorLayoutB, stride_b, 1); + + uint k_iters = (end_k - start_k + BK - 1) / BK; + + for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) { + + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopmat mat_a_ft = coopmat(mat_a); + + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose); + coopmat mat_b_ft = coopmat(mat_b); + + sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum); + } + } else +#endif // !defined(MUL_MAT_ID) + { + tensorLayoutA = setTensorLayoutStrideNV(tensorLayoutA, stride_a, 1); + + tensorLayoutAClamp = setTensorLayoutStrideNV(tensorLayoutAClamp, stride_a, 1); + + tensorLayoutB = setTensorLayoutStrideNV(tensorLayoutB, stride_b, 1); + + tensorLayoutBClamp = setTensorLayoutStrideNV(tensorLayoutBClamp, stride_b, 1); + + [[dont_unroll]] + for (uint block_k = start_k; block_k < end_k; block_k += BK) { + + coopmat mat_a; + coopmat mat_b; + coopmat mat_a_ft; + coopmat mat_b_ft; + + // Clamping is expensive, so detect different code paths for each combination + // of A and B needing clamping. + bool unclampedA = (ir + 1) * BM <= p.M && block_k + BK <= end_k && (block_k % 8) == 0; +#ifdef MUL_MAT_ID + bool unclampedB = true; +#else + bool unclampedB = (ic + 1) * BN <= p.N && block_k + BK <= end_k && (block_k % 8) == 0; +#endif + if (unclampedA && unclampedB) { + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, (block_k & ~7), BK) DECODEFUNCA); +#ifdef MUL_MAT_ID + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB); +#else + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, (block_k & ~7), BK), tensorViewTranspose); +#endif + mat_a_ft = coopmat(mat_a); + mat_b_ft = coopmat(mat_b); + sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum); + } else if (unclampedA && !unclampedB) { + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, (block_k & ~7), BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose); + + mat_a_ft = coopmat(mat_a); + mat_b_ft = coopmat(mat_b); + sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum); + } else if (!unclampedA && unclampedB) { + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); +#ifdef MUL_MAT_ID + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB); +#else + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, (block_k & ~7), BK), tensorViewTranspose); +#endif + mat_a_ft = coopmat(mat_a); + mat_b_ft = coopmat(mat_b); + sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum); + } else if (!unclampedA && !unclampedB) { + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose); + + mat_a_ft = coopmat(mat_a); + mat_b_ft = coopmat(mat_b); + sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum); + } + } + } + + // Convert from ACC_TYPE to D_TYPE + coopmat mat_d; + mat_d = coopmat(sum); + +#ifdef MUL_MAT_ID + // Call callback to store each element, remapping row through shared memory + coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic); +#else + tensorLayoutD = setTensorLayoutStrideNV(tensorLayoutD, p.stride_d, 1); + + uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; + coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose); +#endif +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp new file mode 100644 index 000000000..6627a50bd --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp @@ -0,0 +1,44 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable +#define BLOCK_SIZE 512 + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +shared vec2 sum[BLOCK_SIZE]; + +void main() { + const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; + const uint tid = gl_LocalInvocationID.x; + + sum[tid] = vec2(0.0f, 0.0f); + + [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { + const float xi = float(data_a[row*p.KX + col]); + sum[tid].x += xi; + sum[tid].y += xi * xi; + } + + // sum up partial sums and write back result + barrier(); + [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + sum[tid] += sum[tid + s]; + } + barrier(); + } + + const float mean = sum[0].x / p.KX; + const float var = sum[0].y / p.KX - mean * mean; + const float inv_std = inversesqrt(var + p.param1); + + [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { + data_d[row*p.KX + col] = D_TYPE((float(data_a[row*p.KX + col]) - mean) * inv_std); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp new file mode 100644 index 000000000..450b67fc5 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp @@ -0,0 +1,28 @@ +#version 450 + +#include "types.comp" +#include "generic_unary_head.comp" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (idx >= p.ne) { + return; + } + + const uint i3 = idx / (p.ne12*p.ne11*p.ne10); + const uint i3_offset = i3 * p.ne12*p.ne11*p.ne10; + const uint i2 = (idx - i3_offset) / (p.ne11*p.ne10); + const uint i2_offset = i2*p.ne11*p.ne10; + const uint i1 = (idx - i3_offset - i2_offset) / p.ne10; + const uint i0 = idx - i3_offset - i2_offset - i1*p.ne10; + + const uint src0_idx = i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0*p.nb00; + const uint dst_idx = i3*p.nb13 + i2*p.nb12 + i1*p.nb11 + i0*p.nb10; + + const bool is_src0 = i0 < p.ne00 && i1 < p.ne01 && i2 < p.ne02 && i3 < p.ne03; + + data_d[get_doffset() + dst_idx] = D_TYPE(is_src0 ? data_a[get_aoffset() + src0_idx] : 0.0f); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp new file mode 100644 index 000000000..b6124411a --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp @@ -0,0 +1,74 @@ +#version 450 + +#include "types.comp" + +#extension GL_EXT_shader_16bit_storage : require + +layout(push_constant) uniform parameter { + uint IW; uint IH; + uint OW; uint OH; + uint OC; + uint pelements; + uint op; + int k0; int k1; + int s0; int s1; + int p0; int p1; +} p; + +#define BLOCK_SIZE 512 +#define FLT_MAX 3.402823466e+38F +#define OP_POOL_MAX 0u +#define OP_POOL_AVG 1u + +layout (local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout(binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout(binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint idx = gl_GlobalInvocationID.x; + if (idx >= p.pelements) { + return; + } + + const uint O_HW = p.OW * p.OH; + + const uint nc = idx / O_HW; + const uint cur_oh = (idx % O_HW) / p.OW; + const uint cur_ow = (idx % O_HW) % p.OW; + + const int start_h = int(cur_oh) * p.s0 - p.p0; + const uint bh = max(start_h, 0); + const uint eh = min(start_h + p.k0, p.IH); + + const int start_w = int(cur_ow) * p.s1 - p.p1; + const uint bw = max(start_w, 0); + const uint ew = min(start_w + p.k1, p.IW); + + const float scale = 1.0 / float(p.k0 * p.k1); + float res; + + if (p.op == OP_POOL_AVG) { + res = 0.0; + } else if (p.op == OP_POOL_MAX) { + res = -FLT_MAX; + } else { + return; + } + + #pragma unroll + for (uint i = bh; i < eh; i++) { + #pragma unroll + for (uint j = bw; j < ew; j++) { + const float cur = D_TYPE(data_a[nc * p.IH * p.IW + i * p.IW + j]); + + if (p.op == OP_POOL_AVG) { + res += cur * scale; + } else if (p.op == OP_POOL_MAX) { + res = max(res, cur); + } + } + } + + data_d[nc * O_HW + cur_oh * p.OW + cur_ow] = res; +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp new file mode 100644 index 000000000..52a19b62a --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp @@ -0,0 +1,21 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + data_d[i] = max(float(data_a[i]), 0); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp new file mode 100644 index 000000000..1568b141d --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp @@ -0,0 +1,26 @@ +#version 450 + +#include "types.comp" +#include "generic_unary_head.comp" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +uint src0_idx_mod(uint idx) { + const uint i13 = idx / (p.ne12*p.ne11*p.ne10); + const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10; + const uint i12 = (idx - i13_offset) / (p.ne11*p.ne10); + const uint i12_offset = i12*p.ne11*p.ne10; + const uint i11 = (idx - i13_offset - i12_offset) / p.ne10; + const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10; + return (i13 % p.ne03)*p.nb03 + (i12 % p.ne02)*p.nb02 + (i11 % p.ne01)*p.nb01 + (i10 % p.ne00)*p.nb00; +} + +void main() { + const uint idx = get_idx(); + + if (idx >= p.ne) { + return; + } + + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(data_a[get_aoffset() + src0_idx_mod(idx)]); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp new file mode 100644 index 000000000..b554400ba --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp @@ -0,0 +1,42 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable +#define BLOCK_SIZE 512 + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +shared FLOAT_TYPE sum[BLOCK_SIZE]; + +void main() { + const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; + const uint tid = gl_LocalInvocationID.x; + + sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp + + [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { + const FLOAT_TYPE xi = FLOAT_TYPE(data_a[row*p.KX + col]); + sum[tid] += xi * xi; + } + + // sum up partial sums and write back result + barrier(); + [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + sum[tid] += sum[tid + s]; + } + barrier(); + } + + const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(p.KX); + const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1)); + + [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { + data_d[row*p.KX + col] = D_TYPE(scale * FLOAT_TYPE(data_a[row*p.KX + col])); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp new file mode 100644 index 000000000..574b51ca5 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp @@ -0,0 +1,49 @@ +#include "types.comp" + +#extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_spirv_intrinsics: enable + +#if RTE16 +spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits +#endif + +layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer Y {int data_pos[];}; +layout (binding = 2) readonly buffer Z {float data_ff[];}; +layout (binding = 3) writeonly buffer D {D_TYPE data_d[];}; + +layout (push_constant) uniform parameter { + uint ncols; + uint n_dims; + float freq_scale; + uint p_delta_rows; + float freq_base; + float ext_factor; + float attn_factor; + float corr_dims[2]; + float theta_scale; + uint has_ff; +} p; + +float rope_yarn_ramp(const float low, const float high, const uint i0) { + const float y = (i0 / 2 - low) / max(0.001f, high - low); + return 1.0f - min(1.0f, max(0.0f, y)); +} + +void rope_yarn(const float theta_extrap, const uint i0, out float cos_theta, out float sin_theta) { + float mscale = p.attn_factor; + // Get n-d rotational scaling corrected for extrapolation + float theta_interp = p.freq_scale * theta_extrap; + float theta = theta_interp; + if (p.ext_factor != 0.0f) { + float ramp_mix = rope_yarn_ramp(p.corr_dims[0], p.corr_dims[1], i0) * p.ext_factor; + theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; + + // Get n-d magnitude scaling corrected for interpolation + mscale *= 1.0f + 0.1f * log(1.0f / p.freq_scale); + } + cos_theta = cos(theta) * mscale; + sin_theta = sin(theta) * mscale; +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp new file mode 100644 index 000000000..83b46b69b --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp @@ -0,0 +1,37 @@ +#version 450 + +#include "rope_head.comp" + +void main() { + const uint col = gl_GlobalInvocationID.y * 2; + const uint row = gl_GlobalInvocationID.x; + + if (col >= p.ncols) { + return; + } + + if (col >= p.n_dims) { + const uint i = row*p.ncols + col; + + data_d[i + 0] = data_a[i + 0]; + data_d[i + 1] = data_a[i + 1]; + + return; + } + + const uint i = row*p.ncols + col/2; + const uint i2 = row/p.p_delta_rows; + + const float theta_base = data_pos[i2] * pow(p.theta_scale, col/2.0f); + + const float freq_factor = p.has_ff != 0 ? data_ff[col/2] : 1.0f; + + float cos_theta, sin_theta; + rope_yarn(theta_base / freq_factor, col, cos_theta, sin_theta); + + const float x0 = float(data_a[i + 0]); + const float x1 = float(data_a[i + p.n_dims/2]); + + data_d[i + 0] = D_TYPE(x0*cos_theta - x1*sin_theta); + data_d[i + p.n_dims/2] = D_TYPE(x0*sin_theta + x1*cos_theta); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp new file mode 100644 index 000000000..e416ad938 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp @@ -0,0 +1,37 @@ +#version 450 + +#include "rope_head.comp" + +void main() { + const uint col = gl_GlobalInvocationID.y * 2; + const uint row = gl_GlobalInvocationID.x; + + if (col >= p.ncols) { + return; + } + + if (col >= p.n_dims) { + const uint i = row*p.ncols + col; + + data_d[i + 0] = data_a[i + 0]; + data_d[i + 1] = data_a[i + 1]; + + return; + } + + const uint i = row*p.ncols + col; + const uint i2 = row/p.p_delta_rows; + + const float theta_base = data_pos[i2] * pow(p.theta_scale, col/2.0f); + + const float freq_factor = p.has_ff != 0 ? data_ff[col/2] : 1.0f; + + float cos_theta, sin_theta; + rope_yarn(theta_base / freq_factor, col, cos_theta, sin_theta); + + const float x0 = float(data_a[i + 0]); + const float x1 = float(data_a[i + 1]); + + data_d[i + 0] = D_TYPE(x0*cos_theta - x1*sin_theta); + data_d[i + 1] = D_TYPE(x0*sin_theta + x1*cos_theta); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp new file mode 100644 index 000000000..4663428de --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp @@ -0,0 +1,24 @@ +#version 450 + +#include "types.comp" +#include "generic_unary_head.comp" + +const uint num_threads = 128; + +layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; + +void main() { + uint idx = get_idx(); + + // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation + const uint num_iter = 4; + + [[unroll]] for (uint i = 0; i < num_iter; ++i) { + if (idx >= p.ne) { + continue; + } + + data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) * FLOAT_TYPE(p.param1)); + idx += num_threads; + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp new file mode 100644 index 000000000..4d36f88e0 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp @@ -0,0 +1,22 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float xi = float(data_a[i]); + data_d[i] = D_TYPE(xi / (1.0f + exp(-xi))); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp new file mode 100644 index 000000000..d7c15a169 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp @@ -0,0 +1,17 @@ +#version 450 + +#include "types.comp" +#include "generic_unary_head.comp" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint idx = get_idx(); + + if (idx >= p.ne) { + return; + } + + const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]); + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(sin(val)); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp new file mode 100644 index 000000000..a25808e16 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp @@ -0,0 +1,174 @@ +#version 450 + +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_EXT_control_flow_attributes : enable + +layout (push_constant) uniform parameter +{ + uint KX; + uint KY; + float scale; + float max_bias; + float m0; + float m1; + uint n_head_log2; + uint nrows_x; +} p; + +#include "types.comp" + +layout(constant_id = 0) const uint BLOCK_SIZE = 32; +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer Y {B_TYPE data_b[];}; +layout (binding = 2) buffer D {D_TYPE data_d[];}; + +shared FLOAT_TYPE vals[BLOCK_SIZE]; + +// num_iters is the number of BLOCK_SIZE loop iterations we need to iterate +// over all the columns. The main function tries to pass a constant here, +// as if it were a template function, to allow unrolling. +void soft_max(uint num_iters) { + const uint tid = gl_LocalInvocationID.x; + const uint rowx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; + const uint rowy = (p.KY > 0) ? (rowx % p.KY) : 0; + + if (rowx >= p.nrows_x) { + return; + } + + float slope = 1.0f; + + // ALiBi + if (p.max_bias > 0.0f) { + const uint h = rowx/p.KY; // head index + + const float base = h < p.n_head_log2 ? p.m0 : p.m1; + const uint exp = h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1; + + slope = pow(base, exp); + } + + // Find max + FLOAT_TYPE max_val = uintBitsToFloat(0xFF800000); + + // Cache values while we compute the max, so we don't need to read them + // again when we're ready to compute exp(x-max). + const uint DATA_CACHE_SIZE = 16; + FLOAT_TYPE data_cache[DATA_CACHE_SIZE]; + + [[unroll]] for (uint col0 = 0, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) { + const uint col = col0 + tid; + + FLOAT_TYPE a = FLOAT_TYPE(0); + if (col < p.KX) { + a = data_a[rowx * p.KX + col]; + } + + FLOAT_TYPE b = FLOAT_TYPE(0); + if (p.KY > 0 && col < p.KX) { + b = data_b[rowy * p.KX + col]; + } + + FLOAT_TYPE v = a * p.scale + slope * b; + + if (col < p.KX) { + max_val = max(max_val, v); + } + + if (idx < DATA_CACHE_SIZE) { + data_cache[idx] = v; + } + } + + // reduce across the workgroup + vals[tid] = max_val; + barrier(); + [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + vals[tid] = max(vals[tid], vals[tid + s]); + } + barrier(); + } + + max_val = vals[0]; + barrier(); + + FLOAT_TYPE sum = FLOAT_TYPE(0.0f); + + // Compute sum{exp(x - max)} + [[unroll]] for (uint col0 = 0, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) { + const uint col = col0 + tid; + + if (col >= p.KX) { + break; + } + + // compute exp(a*scale+b*slope), add it to sum, and cache the new value + // in data_cache if possible. + const uint i = rowx * p.KX + col; + FLOAT_TYPE val; + if (idx < DATA_CACHE_SIZE) { + val = exp(data_cache[idx] - max_val); + } else { + val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)) - max_val); + } + sum += val; + if (idx < DATA_CACHE_SIZE) { + data_cache[idx] = val; + } else { + data_d[i] = D_TYPE(val); + } + } + + // reduce across the workgroup + vals[tid] = sum; + barrier(); + [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + vals[tid] += vals[tid + s]; + } + barrier(); + } + sum = vals[0]; + + FLOAT_TYPE rcpdivisor = 1.0/sum; + + [[unroll]] for (uint col0 = 0, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) { + const uint col = col0 + tid; + + if (col >= p.KX) { + continue; + } + + if (idx < DATA_CACHE_SIZE) { + data_d[rowx*p.KX + col] = D_TYPE(data_cache[idx] * rcpdivisor); + } else { + data_d[rowx*p.KX + col] *= D_TYPE(rcpdivisor); + } + } +} + +void main() { + // instantiate the soft_max function for several different + // dimensions, to allow loop unrolling + uint num_blocks = (p.KX + BLOCK_SIZE - 1) / BLOCK_SIZE; + if (num_blocks > 32) { + soft_max(num_blocks); + } else if (num_blocks > 16) { + soft_max(32); + } else if (num_blocks > 8) { + soft_max(16); + } else if (num_blocks > 4) { + soft_max(8); + } else if (num_blocks == 4) { + soft_max(4); + } else if (num_blocks == 3) { + soft_max(3); + } else if (num_blocks == 2) { + soft_max(2); + } else if (num_blocks == 1) { + soft_max(1); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/square.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/square.comp new file mode 100644 index 000000000..ef43598ba --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/square.comp @@ -0,0 +1,17 @@ +#version 450 + +#include "types.comp" +#include "generic_unary_head.comp" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint idx = get_idx(); + + if (idx >= p.ne) { + return; + } + + const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]); + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val * val); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp new file mode 100644 index 000000000..961e5ffa1 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp @@ -0,0 +1,37 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +layout (constant_id = 0) const uint BLOCK_SIZE = 32; + +shared FLOAT_TYPE tmp[BLOCK_SIZE]; + +void main() { + const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; + const uint col = gl_LocalInvocationID.x; + + tmp[col] = FLOAT_TYPE(0.0f); + + for (uint i = col; i < p.KX; i += BLOCK_SIZE) { + tmp[col] += FLOAT_TYPE(data_a[row*p.KX + i]); + } + + barrier(); + [[unroll]] for (int s = int(BLOCK_SIZE) / 2; s > 0; s >>= 1) { + if (col < s) { + tmp[col] += tmp[col + s]; + } + barrier(); + } + + if (col == 0) { + data_d[row] = D_TYPE(tmp[0]); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp new file mode 100644 index 000000000..495f966bd --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp @@ -0,0 +1,20 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + data_d[i] = D_TYPE(1. - 2. / (exp(2.*data_a[i]) + 1.)); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat2_support.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat2_support.comp new file mode 100644 index 000000000..28eb24e11 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat2_support.comp @@ -0,0 +1,7 @@ +#version 460 + +#extension GL_NV_cooperative_matrix2 : require + +void main() +{ +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp new file mode 100644 index 000000000..79e065a93 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp @@ -0,0 +1,41 @@ +#version 450 + +#extension GL_EXT_shader_16bit_storage : require + +layout (push_constant) uniform parameter +{ + uint nb1; + uint dim; + uint max_period; +} p; + +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable +#define BLOCK_SIZE 256 + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_WorkGroupID.y; + const uint j = gl_GlobalInvocationID.x; + const uint d_offset = i * p.nb1; + + if (p.dim % 2 != 0 && j == ((p.dim + 1) / 2)) { + data_d[d_offset + p.dim] = 0.f; + } + + const uint half_dim = p.dim / 2; + if (j >= half_dim) { + return; + } + + const float timestep = float(data_a[i]); + const float freq = float(exp(-log(p.max_period) * j / half_dim)); + const float arg = timestep * freq; + data_d[d_offset + j] = D_TYPE(cos(arg)); + data_d[d_offset + j + half_dim] = D_TYPE(sin(arg)); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/types.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/types.comp new file mode 100644 index 000000000..eecc47f3a --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/types.comp @@ -0,0 +1,323 @@ + +#if !defined(GGML_TYPES_COMP) +#define GGML_TYPES_COMP + +#extension GL_EXT_shader_explicit_arithmetic_types : require + +#if defined(DATA_A_F32) +#define QUANT_K 1 +#define QUANT_R 1 + +#if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1 +#define A_TYPE float +#elif LOAD_VEC_A == 4 +#define A_TYPE vec4 +#elif LOAD_VEC_A == 8 +#define A_TYPE mat2x4 +#endif +#endif + +#if defined(DATA_A_F16) +#define QUANT_K 1 +#define QUANT_R 1 + +#if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1 +#define A_TYPE float16_t +#elif LOAD_VEC_A == 4 +#define A_TYPE f16vec4 +#elif LOAD_VEC_A == 8 +#define A_TYPE f16mat2x4 +#endif +#endif + +#define QUANT_K_Q4_0 32 +#define QUANT_R_Q4_0 2 + +struct block_q4_0 +{ + float16_t d; + uint8_t qs[16]; +}; +struct block_q4_0_packed16 +{ + float16_t d; + uint16_t qs[16/2]; +}; + +#if defined(DATA_A_Q4_0) +#define QUANT_K QUANT_K_Q4_0 +#define QUANT_R QUANT_R_Q4_0 +#define A_TYPE block_q4_0 +#define A_TYPE_PACKED16 block_q4_0_packed16 +#endif + +#define QUANT_K_Q4_1 32 +#define QUANT_R_Q4_1 2 + +struct block_q4_1 +{ + float16_t d; + float16_t m; + uint8_t qs[16]; +}; + +struct block_q4_1_packed16 +{ + float16_t d; + float16_t m; + uint16_t qs[16/2]; +}; + +#if defined(DATA_A_Q4_1) +#define QUANT_K QUANT_K_Q4_1 +#define QUANT_R QUANT_R_Q4_1 +#define A_TYPE block_q4_1 +#define A_TYPE_PACKED16 block_q4_1_packed16 +#endif + +#define QUANT_K_Q5_0 32 +#define QUANT_R_Q5_0 2 + +struct block_q5_0 +{ + float16_t d; + uint16_t qh[2]; + uint8_t qs[16]; +}; + +struct block_q5_0_packed16 +{ + float16_t d; + uint16_t qh[2]; + uint16_t qs[16/2]; +}; + +#if defined(DATA_A_Q5_0) +#define QUANT_K QUANT_K_Q5_0 +#define QUANT_R QUANT_R_Q5_0 +#define A_TYPE block_q5_0 +#define A_TYPE_PACKED16 block_q5_0_packed16 +#endif + +#define QUANT_K_Q5_1 32 +#define QUANT_R_Q5_1 2 + +struct block_q5_1 +{ + float16_t d; + float16_t m; + uint qh; + uint8_t qs[16]; +}; + +struct block_q5_1_packed16 +{ + float16_t d; + float16_t m; + uint qh; + uint16_t qs[16/2]; +}; + +#if defined(DATA_A_Q5_1) +#define QUANT_K QUANT_K_Q5_1 +#define QUANT_R QUANT_R_Q5_1 +#define A_TYPE block_q5_1 +#define A_TYPE_PACKED16 block_q5_1_packed16 +#endif + +#define QUANT_K_Q8_0 32 +#define QUANT_R_Q8_0 1 + +struct block_q8_0 +{ + float16_t d; + int8_t qs[32]; +}; +struct block_q8_0_packed16 +{ + float16_t d; + uint16_t qs[32/2]; +}; + +#if defined(DATA_A_Q8_0) +#define QUANT_K QUANT_K_Q8_0 +#define QUANT_R QUANT_R_Q8_0 +#define A_TYPE block_q8_0 +#define A_TYPE_PACKED16 block_q8_0_packed16 +#endif + +// K-quants +#define QUANT_K_Q2_K 256 + +struct block_q2_K +{ + uint8_t scales[QUANT_K_Q2_K/16]; + uint8_t qs[QUANT_K_Q2_K/4]; + f16vec2 d; +}; + +struct block_q2_K_packed16 +{ + uint16_t scales[QUANT_K_Q2_K/16/2]; + uint16_t qs[QUANT_K_Q2_K/4/2]; + f16vec2 d; +}; + +struct block_q2_K_packed32 +{ + uint32_t scales[QUANT_K_Q2_K/16/4]; + uint32_t qs[QUANT_K_Q2_K/4/4]; + f16vec2 d; +}; + +#if defined(DATA_A_Q2_K) +#define QUANT_K QUANT_K_Q2_K +#define A_TYPE block_q2_K +#define A_TYPE_PACKED16 block_q2_K_packed16 +#define A_TYPE_PACKED32 block_q2_K_packed32 +#endif + +#define QUANT_K_Q3_K 256 + +struct block_q3_K +{ + uint8_t hmask[QUANT_K_Q3_K/8]; + uint8_t qs[QUANT_K_Q3_K/4]; + uint8_t scales[12]; + float16_t d; +}; + +struct block_q3_K_packed16 +{ + uint16_t hmask[QUANT_K_Q3_K/8/2]; + uint16_t qs[QUANT_K_Q3_K/4/2]; + uint16_t scales[12/2]; + float16_t d; +}; + +#if defined(DATA_A_Q3_K) +#define QUANT_K QUANT_K_Q3_K +#define A_TYPE block_q3_K +#define A_TYPE_PACKED16 block_q3_K_packed16 +#endif + +#define QUANT_K_Q4_K 256 + +struct block_q4_K +{ + f16vec2 d; + uint8_t scales[3*QUANT_K_Q4_K/64]; + uint8_t qs[QUANT_K_Q4_K/2]; +}; + +struct block_q4_K_packed16 +{ + f16vec2 d; + uint16_t scales[3*QUANT_K_Q4_K/64/2]; + uint16_t qs[QUANT_K_Q4_K/2/2]; +}; + +struct block_q4_K_packed32 +{ + f16vec2 d; + uint32_t scales[3*QUANT_K_Q4_K/64/4]; + uint32_t qs[QUANT_K_Q4_K/2/4]; +}; + +#if defined(DATA_A_Q4_K) +#define QUANT_K QUANT_K_Q4_K +#define A_TYPE block_q4_K +#define A_TYPE_PACKED16 block_q4_K_packed16 +#define A_TYPE_PACKED32 block_q4_K_packed32 +#endif + +#define QUANT_K_Q5_K 256 + +struct block_q5_K +{ + f16vec2 d; + uint8_t scales[12]; + uint8_t qh[QUANT_K_Q5_K/8]; + uint8_t qs[QUANT_K_Q5_K/2]; +}; + +struct block_q5_K_packed16 +{ + f16vec2 d; + uint16_t scales[12/2]; + uint16_t qh[QUANT_K_Q5_K/8/2]; + uint16_t qs[QUANT_K_Q5_K/2/2]; +}; + +#if defined(DATA_A_Q5_K) +#define QUANT_K QUANT_K_Q5_K +#define A_TYPE block_q5_K +#define A_TYPE_PACKED16 block_q5_K_packed16 +#endif + +#define QUANT_K_Q6_K 256 + +struct block_q6_K +{ + uint8_t ql[QUANT_K_Q6_K/2]; + uint8_t qh[QUANT_K_Q6_K/4]; + int8_t scales[QUANT_K_Q6_K/16]; + float16_t d; +}; + +struct block_q6_K_packed16 +{ + uint16_t ql[QUANT_K_Q6_K/2/2]; + uint16_t qh[QUANT_K_Q6_K/4/2]; + int8_t scales[QUANT_K_Q6_K/16]; + float16_t d; +}; + +#if defined(DATA_A_Q6_K) +#define QUANT_K QUANT_K_Q6_K +#define A_TYPE block_q6_K +#define A_TYPE_PACKED16 block_q6_K_packed16 +#endif + +// IQuants + +#define QUANT_K_IQ4_NL 32 +#define QUANT_R_IQ4_NL 2 + +struct block_iq4_nl +{ + float16_t d; + uint8_t qs[QUANT_K_IQ4_NL/2]; +}; + +struct block_iq4_nl_packed16 +{ + float16_t d; + uint16_t qs[QUANT_K_IQ4_NL/2/2]; +}; + +#if defined(DATA_A_IQ4_NL) + +const int8_t kvalues_iq4nl_const[16] = { + int8_t(-127), int8_t(-104), int8_t(-83), int8_t(-65), int8_t(-49), int8_t(-35), int8_t(-22), int8_t(-10), + int8_t(1), int8_t(13), int8_t(25), int8_t(38), int8_t(53), int8_t(69), int8_t(89), int8_t(113) +}; + +shared FLOAT_TYPE kvalues_iq4nl[16]; + +void init_iq4nl_shmem() +{ + // copy the table into shared memory and sync + if (gl_LocalInvocationIndex.x < 16) { + kvalues_iq4nl[gl_LocalInvocationIndex.x] = FLOAT_TYPE(kvalues_iq4nl_const[gl_LocalInvocationIndex.x]); + } + barrier(); +} + +#define QUANT_K QUANT_K_IQ4_NL +#define QUANT_R QUANT_R_IQ4_NL +#define A_TYPE block_iq4_nl +#define A_TYPE_PACKED16 block_iq4_nl_packed16 +#endif + +#endif // !defined(GGML_TYPES_COMP) diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp new file mode 100644 index 000000000..6f607380d --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp @@ -0,0 +1,36 @@ +#version 450 + +layout (push_constant) uniform parameter +{ + uint ne; uint a_offset; uint d_offset; + uint nb00; uint nb01; uint nb02; uint nb03; + uint ne10; uint ne11; uint ne12; uint ne13; + float sf0; float sf1; float sf2; float sf3; +} p; + +#include "types.comp" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (idx >= p.ne) { + return; + } + + const uint i10 = idx % p.ne10; + const uint i11 = (idx / p.ne10) % p.ne11; + const uint i12 = (idx / (p.ne10 * p.ne11)) % p.ne12; + const uint i13 = (idx / (p.ne10 * p.ne11 * p.ne12)) % p.ne13; + + const uint i00 = uint(i10 / p.sf0); + const uint i01 = uint(i11 / p.sf1); + const uint i02 = uint(i12 / p.sf2); + const uint i03 = uint(i13 / p.sf3); + + data_d[p.d_offset + idx] = D_TYPE(data_a[p.a_offset + i03 * p.nb03 + i02 * p.nb02 + i01 * p.nb01 + i00 * p.nb00]); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp new file mode 100644 index 000000000..8111c0638 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -0,0 +1,594 @@ + + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef _WIN32 + #include + #include // For _mkdir on Windows + #include // For std::replace on w64devkit +#else + #include + #include + #include +#endif + +#include + +#define ASYNCIO_CONCURRENCY 64 + +std::mutex lock; +std::vector> shader_fnames; + +std::string GLSLC = "glslc"; +std::string input_dir = "vulkan-shaders"; +std::string output_dir = "/tmp"; +std::string target_hpp = "ggml-vulkan-shaders.hpp"; +std::string target_cpp = "ggml-vulkan-shaders.cpp"; +bool no_clean = false; + +const std::vector type_names = { + "f32", + "f16", + "q4_0", + "q4_1", + "q5_0", + "q5_1", + "q8_0", + "q2_k", + "q3_k", + "q4_k", + "q5_k", + "q6_k", + "iq4_nl" +}; + +namespace { +void execute_command(const std::string& command, std::string& stdout_str, std::string& stderr_str) { +#ifdef _WIN32 + HANDLE stdout_read, stdout_write; + HANDLE stderr_read, stderr_write; + SECURITY_ATTRIBUTES sa = { sizeof(SECURITY_ATTRIBUTES), NULL, TRUE }; + + if (!CreatePipe(&stdout_read, &stdout_write, &sa, 0) || + !SetHandleInformation(stdout_read, HANDLE_FLAG_INHERIT, 0)) { + throw std::runtime_error("Failed to create stdout pipe"); + } + + if (!CreatePipe(&stderr_read, &stderr_write, &sa, 0) || + !SetHandleInformation(stderr_read, HANDLE_FLAG_INHERIT, 0)) { + throw std::runtime_error("Failed to create stderr pipe"); + } + + PROCESS_INFORMATION pi; + STARTUPINFOA si = {}; + si.cb = sizeof(STARTUPINFOA); + si.dwFlags = STARTF_USESTDHANDLES; + si.hStdOutput = stdout_write; + si.hStdError = stderr_write; + + std::vector cmd(command.begin(), command.end()); + cmd.push_back('\0'); + + if (!CreateProcessA(NULL, cmd.data(), NULL, NULL, TRUE, 0, NULL, NULL, &si, &pi)) { + throw std::runtime_error("Failed to create process"); + } + + CloseHandle(stdout_write); + CloseHandle(stderr_write); + + std::array buffer; + DWORD bytes_read; + + while (ReadFile(stdout_read, buffer.data(), (DWORD)buffer.size(), &bytes_read, NULL) && bytes_read > 0) { + stdout_str.append(buffer.data(), bytes_read); + } + + while (ReadFile(stderr_read, buffer.data(), (DWORD)buffer.size(), &bytes_read, NULL) && bytes_read > 0) { + stderr_str.append(buffer.data(), bytes_read); + } + + CloseHandle(stdout_read); + CloseHandle(stderr_read); + WaitForSingleObject(pi.hProcess, INFINITE); + CloseHandle(pi.hProcess); + CloseHandle(pi.hThread); +#else +int stdout_pipe[2]; + int stderr_pipe[2]; + + if (pipe(stdout_pipe) != 0 || pipe(stderr_pipe) != 0) { + throw std::runtime_error("Failed to create pipes"); + } + + pid_t pid = fork(); + if (pid < 0) { + throw std::runtime_error("Failed to fork process"); + } + + if (pid == 0) { + close(stdout_pipe[0]); + close(stderr_pipe[0]); + dup2(stdout_pipe[1], STDOUT_FILENO); + dup2(stderr_pipe[1], STDERR_FILENO); + close(stdout_pipe[1]); + close(stderr_pipe[1]); + execl("/bin/sh", "sh", "-c", command.c_str(), (char*) nullptr); + _exit(EXIT_FAILURE); + } else { + close(stdout_pipe[1]); + close(stderr_pipe[1]); + + std::array buffer; + ssize_t bytes_read; + + while ((bytes_read = read(stdout_pipe[0], buffer.data(), buffer.size())) > 0) { + stdout_str.append(buffer.data(), bytes_read); + } + + while ((bytes_read = read(stderr_pipe[0], buffer.data(), buffer.size())) > 0) { + stderr_str.append(buffer.data(), bytes_read); + } + + close(stdout_pipe[0]); + close(stderr_pipe[0]); + waitpid(pid, nullptr, 0); + } +#endif +} + +bool directory_exists(const std::string& path) { + struct stat info; + if (stat(path.c_str(), &info) != 0) { + return false; // Path doesn't exist or can't be accessed + } + return (info.st_mode & S_IFDIR) != 0; // Check if it is a directory +} + +bool create_directory(const std::string& path) { +#ifdef _WIN32 + return _mkdir(path.c_str()) == 0 || errno == EEXIST; // EEXIST means the directory already exists +#else + return mkdir(path.c_str(), 0755) == 0 || errno == EEXIST; // 0755 is the directory permissions +#endif +} + +std::string to_uppercase(const std::string& input) { + std::string result = input; + for (char& c : result) { + c = std::toupper(c); + } + return result; +} + +bool string_ends_with(const std::string& str, const std::string& suffix) { + if (suffix.size() > str.size()) { + return false; + } + return std::equal(suffix.rbegin(), suffix.rend(), str.rbegin()); +} + +static const char path_separator = '/'; + +std::string join_paths(const std::string& path1, const std::string& path2) { + return path1 + path_separator + path2; +} + +std::string basename(const std::string &path) { + return path.substr(path.find_last_of("/\\") + 1); +} + +// variables to track number of compiles in progress +static uint32_t compile_count = 0; +static std::mutex compile_count_mutex; +static std::condition_variable compile_count_cond; + +void string_to_spv_func(const std::string& _name, const std::string& in_fname, const std::map& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) { + std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat ? "_coopmat" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32")); + std::string out_fname = join_paths(output_dir, name + ".spv"); + std::string in_path = join_paths(input_dir, in_fname); + + std::string target_env = (name.find("_cm2") != std::string::npos) ? "--target-env=vulkan1.3" : "--target-env=vulkan1.2"; + + // disable spirv-opt for coopmat shaders for https://github.com/ggerganov/llama.cpp/issues/10734 + std::string opt_level = coopmat ? "" : "-O"; + + #ifdef _WIN32 + std::vector cmd = {GLSLC, "-fshader-stage=compute", target_env, opt_level, "\"" + in_path + "\"", "-o", "\"" + out_fname + "\""}; + #else + std::vector cmd = {GLSLC, "-fshader-stage=compute", target_env, opt_level, in_path, "-o", out_fname}; + #endif + + #ifdef GGML_VULKAN_SHADER_DEBUG_INFO + cmd.push_back("-g"); + #endif + + for (const auto& define : defines) { + cmd.push_back("-D" + define.first + "=" + define.second); + } + + std::string command; + for (const auto& part : cmd) { + command += part + " "; + } + + std::string stdout_str, stderr_str; + try { + // std::cout << "Executing command: "; + // for (const auto& part : cmd) { + // std::cout << part << " "; + // } + // std::cout << std::endl; + + execute_command(command, stdout_str, stderr_str); + if (!stderr_str.empty()) { + std::cerr << "cannot compile " << name << "\n\n" << command << "\n\n" << stderr_str << std::endl; + return; + } + + std::lock_guard guard(lock); + shader_fnames.push_back(std::make_pair(name, out_fname)); + } catch (const std::exception& e) { + std::cerr << "Error executing command for " << name << ": " << e.what() << std::endl; + } + { + std::lock_guard guard(compile_count_mutex); + assert(compile_count > 0); + compile_count--; + } + compile_count_cond.notify_all(); +} + +std::map merge_maps(const std::map& a, const std::map& b) { + std::map result = a; + result.insert(b.begin(), b.end()); + return result; +} + +static std::vector> compiles; +void string_to_spv(const std::string& _name, const std::string& in_fname, const std::map& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) { + { + // wait until fewer than N compiles are in progress. + // 16 is an arbitrary limit, the goal is to avoid "failed to create pipe" errors. + uint32_t N = 16; + std::unique_lock guard(compile_count_mutex); + while (compile_count >= N) { + compile_count_cond.wait(guard); + } + compile_count++; + } + compiles.push_back(std::async(string_to_spv_func, _name, in_fname, defines, fp16, coopmat, coopmat2, f16acc)); +} + +void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool f16acc) { + std::string load_vec = coopmat2 ? "1" : fp16 ? "8" : "4"; + std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4"; + std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4"; + + std::map base_dict = {{"FLOAT_TYPE", (coopmat2 || fp16) ? "float16_t" : "float"}}; + std::string shader_name = "matmul"; + + if (matmul_id) { + base_dict["MUL_MAT_ID"] = "1"; + shader_name = "matmul_id"; + } + + if (fp16) { + base_dict["FLOAT16"] = "1"; + } + + base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float"; + + if (coopmat) { + base_dict["COOPMAT"] = "1"; + } + + base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float"; + + std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp"; + + // Shaders with f16 B_TYPE + string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + + string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); + + for (const auto& tname : type_names) { + std::string data_a_key = "DATA_A_" + to_uppercase(tname); + // For unaligned, load one at a time for f32/f16, or two at a time for quants + std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16") ? "1" : "2"; + // For aligned matmul loads + std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : "2"; + + string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + + if (tname != "f16" && tname != "f32") { + string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + } + } +} + +void process_shaders() { + std::cout << "ggml_vulkan: Generating and compiling shaders to SPIR-V" << std::endl; + std::map base_dict = {{"FLOAT_TYPE", "float"}}; + + // matmul + for (const auto& matmul_id : {false, true}) { + // No coopmats + // fp32 + matmul_shaders(false, matmul_id, false, false, false); + + // fp16, fp32acc and fp16acc + matmul_shaders(true, matmul_id, false, false, false); + matmul_shaders(true, matmul_id, false, false, true); + + // Coopmat, fp32acc and fp16acc + matmul_shaders(true, matmul_id, true, false, false); + matmul_shaders(true, matmul_id, true, false, true); + +#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + // Coopmat2, fp32acc and fp16acc + matmul_shaders(true, matmul_id, false, true, false); + matmul_shaders(true, matmul_id, false, true, true); +#endif + } + +#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + // flash attention + for (const auto& f16acc : {false, true}) { + std::string acctype = f16acc ? "float16_t" : "float"; + + for (const auto& tname : type_names) { + if (tname == "f32") { + continue; + } + + if (tname == "f16") { + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", + merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, true, f16acc); + } else { + std::string data_a_key = "DATA_A_" + to_uppercase(tname); + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", + merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc); + } + } + } +#endif + + for (const auto& tname : type_names) { + // mul mat vec + std::string data_a_key = "DATA_A_" + to_uppercase(tname); + std::string shader = (string_ends_with(tname, "_k")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp"; + + string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}})); + string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}})); + + string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}})); + + // Dequant shaders + if (tname != "f16") { + string_to_spv("dequant_" + tname, "dequant_" + tname + ".comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float16_t"}})); + } + + if (!string_ends_with(tname, "_k")) { + shader = (tname == "f32" || tname == "f16") ? "get_rows.comp" : "get_rows_quant.comp"; + + if (tname == "f16") { + string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}})); + } else { + string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}})); + } + string_to_spv("get_rows_" + tname + "_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}})); + } + } + + string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); + + // Norms + string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + + string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}); + string_to_spv("cpy_f16_f16", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); + string_to_spv("contig_cpy_f32_f32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("contig_cpy_f32_f16", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}); + string_to_spv("contig_cpy_f16_f16", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); + + string_to_spv("add_f32", "add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + string_to_spv("add_f16_f32_f16", "add.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}}); + + string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + + string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {}); + + string_to_spv("mul_f32", "mul.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + + string_to_spv("div_f32", "div.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + + string_to_spv("repeat_f32", "repeat.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + + string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + + string_to_spv("sqr_f32", "square.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + + string_to_spv("sin_f32", "sin.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + + string_to_spv("cos_f32", "cos.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + + string_to_spv("clamp_f32", "clamp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + + string_to_spv("pad_f32", "pad.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + + string_to_spv("concat_f32", "concat.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("concat_f16", "concat.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); + string_to_spv("concat_i32", "concat.comp", {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}}); + + string_to_spv("upscale_f32", "upscale.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); + + string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("gelu_quick_f32", "gelu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + + string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + + string_to_spv("soft_max_f32", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("soft_max_f32_f16", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}})); + + string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("rope_norm_f16_rte", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}); + + string_to_spv("rope_neox_f32", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("rope_neox_f16_rte", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}); + + string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}}); + + string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + + string_to_spv("im2col_f32", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("im2col_f32_f16", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}})); + string_to_spv("im2col_f32_f16_rte", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}})); + + string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + + string_to_spv("pool2d_f32", "pool2d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + + string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); + + for (auto &c : compiles) { + c.wait(); + } +} + +void write_output_files() { + FILE* hdr = fopen(target_hpp.c_str(), "w"); + FILE* src = fopen(target_cpp.c_str(), "w"); + + fprintf(hdr, "#include \n\n"); + fprintf(src, "#include \"%s\"\n\n", basename(target_hpp).c_str()); + + for (const auto& pair : shader_fnames) { + const std::string& name = pair.first; + #ifdef _WIN32 + std::string path = pair.second; + std::replace(path.begin(), path.end(), '/', '\\' ); + #else + const std::string& path = pair.second; + #endif + + FILE* spv = fopen(path.c_str(), "rb"); + if (!spv) { + std::cerr << "Error opening SPIR-V file: " << path << " (" << strerror(errno) << ")\n"; + continue; + } + + fseek(spv, 0, SEEK_END); + size_t size = ftell(spv); + fseek(spv, 0, SEEK_SET); + + std::vector data(size); + size_t read_size = fread(data.data(), 1, size, spv); + fclose(spv); + if (read_size != size) { + std::cerr << "Error reading SPIR-V file: " << path << " (" << strerror(errno) << ")\n"; + continue; + } + + fprintf(hdr, "extern unsigned char %s_data[%zu];\n", name.c_str(), size); + fprintf(hdr, "const uint64_t %s_len = %zu;\n\n", name.c_str(), size); + + fprintf(src, "unsigned char %s_data[%zu] = {\n", name.c_str(), size); + for (size_t i = 0; i < size; ++i) { + fprintf(src, "0x%02x,", data[i]); + if ((i + 1) % 12 == 0) fprintf(src, "\n"); + } + fprintf(src, "\n};\n\n"); + + if (!no_clean) { + std::remove(path.c_str()); + } + } + + fclose(hdr); + fclose(src); +} +} + +int main(int argc, char** argv) { + std::map args; + for (int i = 1; i < argc; ++i) { + std::string arg = argv[i]; + if (arg.rfind("--", 0) == 0) { + if (i + 1 < argc && argv[i + 1][0] != '-') { + args[arg] = argv[i + 1]; + ++i; + } else { + args[arg] = ""; + } + } + } + + if (args.find("--glslc") != args.end()) { + GLSLC = args["--glslc"]; // Path to glslc + } + if (args.find("--input-dir") != args.end()) { + input_dir = args["--input-dir"]; // Directory containing shader sources + } + if (args.find("--output-dir") != args.end()) { + output_dir = args["--output-dir"]; // Directory for containing SPIR-V output + } + if (args.find("--target-hpp") != args.end()) { + target_hpp = args["--target-hpp"]; // Path to generated header file + } + if (args.find("--target-cpp") != args.end()) { + target_cpp = args["--target-cpp"]; // Path to generated cpp file + } + if (args.find("--no-clean") != args.end()) { + no_clean = true; // Keep temporary SPIR-V files in output-dir after build + } + + if (!directory_exists(input_dir)) { + std::cerr << "\"" << input_dir << "\" must be a valid directory containing shader sources" << std::endl; + return EXIT_FAILURE; + } + + if (!directory_exists(output_dir)) { + if (!create_directory(output_dir)) { + std::cerr << "Error creating output directory: " << output_dir << "\n"; + return EXIT_FAILURE; + } + } + + process_shaders(); + + write_output_files(); + + return EXIT_SUCCESS; +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp new file mode 100644 index 000000000..35cc6c45f --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp @@ -0,0 +1,87 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : require + +#define BLOCK_SIZE 64 +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout(push_constant) uniform Parameters { + uint B; + uint T; + uint C; + uint H; +}; + +layout(binding = 0) readonly buffer KBuf { A_TYPE k[]; }; +layout(binding = 1) readonly buffer VBuf { A_TYPE v[]; }; +layout(binding = 2) readonly buffer RBuf { A_TYPE r[]; }; +layout(binding = 3) readonly buffer TimeFBuf { A_TYPE tf[]; }; +layout(binding = 4) readonly buffer TimeDBuf { A_TYPE td[]; }; +layout(binding = 5) readonly buffer StateBuf { A_TYPE state_in[]; }; +layout(binding = 6) buffer DstBuf { A_TYPE dst[]; }; + +shared A_TYPE _k[BLOCK_SIZE], _r[BLOCK_SIZE], _tf[BLOCK_SIZE], _td[BLOCK_SIZE]; + +void main() { + const uint head_size = BLOCK_SIZE; + const uint batch_id = gl_WorkGroupID.x / H; + const uint head_id = gl_WorkGroupID.x % H; + const uint tid = gl_LocalInvocationID.x; + + const uint state_size = C * head_size; + const uint n_seq_tokens = T / B; + + if (batch_id >= B || head_id >= H) { + return; + } + + A_TYPE state[BLOCK_SIZE]; + [[unroll]] for (uint i = 0; i < head_size; i++) { + state[i] = state_in[batch_id * state_size + head_id * head_size * head_size + + i * head_size + tid]; + } + + barrier(); + _tf[tid] = tf[head_id * head_size + tid]; + barrier(); + + const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid; + const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid; + + for (uint t = start_t; t < end_t; t += C) { + barrier(); + _k[tid] = k[t]; + _r[tid] = r[t]; + _td[tid] = td[t]; + barrier(); + + const A_TYPE v_val = v[t]; + A_TYPE y = 0.0; + + [[unroll]] for (uint j = 0; j < head_size; j += 4) { + vec4 k_vec = vec4(_k[j], _k[j+1], _k[j+2], _k[j+3]); + vec4 r_vec = vec4(_r[j], _r[j+1], _r[j+2], _r[j+3]); + vec4 tf_vec = vec4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]); + vec4 td_vec = vec4(_td[j], _td[j+1], _td[j+2], _td[j+3]); + vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]); + + vec4 kv = k_vec * v_val; + + vec4 temp = tf_vec * kv + s_vec; + y += dot(r_vec, temp); + + s_vec = s_vec * td_vec + kv; + state[j] = s_vec.x; + state[j+1] = s_vec.y; + state[j+2] = s_vec.z; + state[j+3] = s_vec.w; + } + + dst[t] = y; + } + + [[unroll]] for (uint i = 0; i < head_size; i++) { + dst[T * C + batch_id * state_size + head_id * head_size * head_size + + i * head_size + tid] = state[i]; + } +} From 189cbb40a6f7f3ea55d1986827f5912a10f485ea Mon Sep 17 00:00:00 2001 From: Vadim Grinco Date: Sat, 8 Mar 2025 19:40:53 +0100 Subject: [PATCH 022/172] Updated dockerfile https://github.com/whyvl/ollama-vulkan/issues/7#issuecomment-2660836871 Signed-off-by: Vadim Grinco --- Dockerfile | 181 ++++++++++++++++++++--------------------------------- 1 file changed, 69 insertions(+), 112 deletions(-) diff --git a/Dockerfile b/Dockerfile index 4136fca71..4e87ba43c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,131 +1,88 @@ -# vim: filetype=dockerfile +FROM --platform=linux/amd64 library/ubuntu:noble as builder -ARG FLAVOR=${TARGETARCH} +ENV DEBIAN_FRONTEND="noninteractive" -ARG ROCMVERSION=6.3.3 -ARG JETPACK5VERSION=r35.4.1 -ARG JETPACK6VERSION=r36.4.0 -ARG CMAKEVERSION=3.31.2 +ENV VULKAN_VER_BASE="1.3.296" +ENV VULKAN_VER="${VULKAN_VER_BASE}.0" +ENV UBUNTU_VERSION="noble" -# CUDA v11 requires gcc v10. v10.3 has regressions, so the rockylinux 8.5 AppStream has the latest compatible version -FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base-amd64 -RUN yum install -y yum-utils \ - && yum-config-manager --add-repo https://dl.rockylinux.org/vault/rocky/8.5/AppStream/\$basearch/os/ \ - && rpm --import https://dl.rockylinux.org/pub/rocky/RPM-GPG-KEY-Rocky-8 \ - && dnf install -y yum-utils ccache gcc-toolset-10-gcc-10.2.1-8.2.el8 gcc-toolset-10-gcc-c++-10.2.1-8.2.el8 gcc-toolset-10-binutils-2.35-11.el8 \ - && yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo -ENV PATH=/opt/rh/gcc-toolset-10/root/usr/bin:$PATH - -FROM --platform=linux/arm64 almalinux:8 AS base-arm64 -# install epel-release for ccache -RUN yum install -y yum-utils epel-release \ - && dnf install -y clang ccache \ - && yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/sbsa/cuda-rhel8.repo -ENV CC=clang CXX=clang++ - -FROM base-${TARGETARCH} AS base -ARG CMAKEVERSION -RUN curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1 -COPY CMakeLists.txt CMakePresets.json . -COPY ml/backend/ggml/ggml ml/backend/ggml/ggml +ENV GOLANG_VERSION="1.22.8" +ENV GOARCH="amd64" +ENV CGO_ENABLED=1 ENV LDFLAGS=-s -FROM base AS cpu -RUN dnf install -y gcc-toolset-11-gcc gcc-toolset-11-gcc-c++ -ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH -RUN --mount=type=cache,target=/root/.ccache \ - cmake --preset 'CPU' \ - && cmake --build --parallel --preset 'CPU' \ - && cmake --install build --component CPU --strip --parallel 8 +# Default mirror was very slow +RUN \ + sed -i 's/archive.ubuntu.com/gb.archive.ubuntu.com/g' /etc/apt/sources.list.d/ubuntu.sources -FROM base AS cuda-11 -ARG CUDA11VERSION=11.3 -RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-} -ENV PATH=/usr/local/cuda-11/bin:$PATH -RUN --mount=type=cache,target=/root/.ccache \ - cmake --preset 'CUDA 11' \ - && cmake --build --parallel --preset 'CUDA 11' \ - && cmake --install build --component CUDA --strip --parallel 8 +RUN \ + apt-get update && \ + apt-get install -y ca-certificates build-essential ccache cmake wget git curl rsync xz-utils libcap-dev -FROM base AS cuda-12 -ARG CUDA12VERSION=12.8 -RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-} -ENV PATH=/usr/local/cuda-12/bin:$PATH -RUN --mount=type=cache,target=/root/.ccache \ - cmake --preset 'CUDA 12' \ - && cmake --build --parallel --preset 'CUDA 12' \ - && cmake --install build --component CUDA --strip --parallel 8 +RUN \ + mkdir -p /usr/local 2>/dev/null || true && \ + curl -s -L https://dl.google.com/go/go${GOLANG_VERSION}.linux-${GOARCH}.tar.gz | tar -xz -C /usr/local && \ + ln -s /usr/local/go/bin/go /usr/local/bin/go && \ + ln -s /usr/local/go/bin/gofmt /usr/local/bin/gofmt -FROM base AS rocm-6 -ENV PATH=/opt/rocm/hcc/bin:/opt/rocm/hip/bin:/opt/rocm/bin:/opt/rocm/hcc/bin:$PATH -RUN --mount=type=cache,target=/root/.ccache \ - cmake --preset 'ROCm 6' \ - && cmake --build --parallel --preset 'ROCm 6' \ - && cmake --install build --component HIP --strip --parallel 8 -FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK5VERSION} AS jetpack-5 -ARG CMAKEVERSION -RUN apt-get update && apt-get install -y curl ccache \ - && curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1 -COPY CMakeLists.txt CMakePresets.json . -COPY ml/backend/ggml/ggml ml/backend/ggml/ggml -RUN --mount=type=cache,target=/root/.ccache \ - cmake --preset 'JetPack 5' \ - && cmake --build --parallel --preset 'JetPack 5' \ - && cmake --install build --component CUDA --strip --parallel 8 +RUN \ + wget -qO - https://packages.lunarg.com/lunarg-signing-key-pub.asc | gpg --dearmor -o /etc/apt/trusted.gpg.d/lunarg-signing-key-pub.gpg && \ + wget -qO /etc/apt/sources.list.d/lunarg-vulkan-${UBUNTU_VERSION}.list https://packages.lunarg.com/vulkan/${VULKAN_VER_BASE}/lunarg-vulkan-${VULKAN_VER_BASE}-${UBUNTU_VERSION}.list && \ + apt update && apt install -y vulkan-sdk -FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK6VERSION} AS jetpack-6 -ARG CMAKEVERSION -RUN apt-get update && apt-get install -y curl ccache \ - && curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1 -COPY CMakeLists.txt CMakePresets.json . -COPY ml/backend/ggml/ggml ml/backend/ggml/ggml -RUN --mount=type=cache,target=/root/.ccache \ - cmake --preset 'JetPack 6' \ - && cmake --build --parallel --preset 'JetPack 6' \ - && cmake --install build --component CUDA --strip --parallel 8 +# Last testet ollama-vulkan commit: +# 2d443b3dd660a1fd2760d64538512df93648b4bb +COPY patches/ /tmp/patches/ +RUN \ + git clone https://github.com/pufferffish/ollama-vulkan.git "/tmp/ollama-vulkan-git" && \ + cd "/tmp/ollama-vulkan-git" && \ + git checkout 2d443b3dd660a1fd2760d64538512df93648b4bb && git checkout -b ollama_vulkan_stable && \ + git config user.name "Builder" && git config user.email "builder@local" && \ + git remote add ollama_vanilla https://github.com/ollama/ollama.git && \ + git fetch ollama_vanilla --tags && git checkout v0.5.11 && git checkout -b ollama_vanilla_stable && \ + git checkout ollama_vulkan_stable && git merge ollama_vanilla_stable --allow-unrelated-histories --no-edit && \ + for p in /tmp/patches/00-fix-vulkan-building.patch; do patch -p1 < $p; done -FROM base AS build -WORKDIR /go/src/github.com/ollama/ollama -COPY go.mod go.sum . -RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local -ENV PATH=/usr/local/go/bin:$PATH -RUN go mod download -COPY . . -ARG GOFLAGS="'-ldflags=-w -s'" -ENV CGO_ENABLED=1 -RUN --mount=type=cache,target=/root/.cache/go-build \ - go build -trimpath -buildmode=pie -o /bin/ollama . +RUN \ + cd "/tmp/ollama-vulkan-git" && \ + make -f Makefile.sync clean sync -FROM --platform=linux/amd64 scratch AS amd64 -COPY --from=cuda-11 dist/lib/ollama/cuda_v11 /lib/ollama/cuda_v11 -COPY --from=cuda-12 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_v12 -FROM --platform=linux/arm64 scratch AS arm64 -COPY --from=cuda-11 dist/lib/ollama/cuda_v11 /lib/ollama/cuda_v11 -COPY --from=cuda-12 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_v12 -COPY --from=jetpack-5 dist/lib/ollama/cuda_v11 lib/ollama/cuda_jetpack5 -COPY --from=jetpack-6 dist/lib/ollama/cuda_v12 lib/ollama/cuda_jetpack6 +FROM builder AS cpu-build +RUN \ + cd "/tmp/ollama-vulkan-git" && \ + cmake --preset CPU && cmake --build --parallel --preset CPU && \ + cmake --install build --component CPU --strip -FROM scratch AS rocm -COPY --from=rocm-6 dist/lib/ollama/rocm /lib/ollama/rocm +FROM builder AS vulkan-build +RUN \ + cd "/tmp/ollama-vulkan-git" && \ + cmake --preset Vulkan && \ + cmake --build --parallel --preset Vulkan && \ + cmake --install build --component Vulkan --strip -FROM ${FLAVOR} AS archive -COPY --from=cpu dist/lib/ollama /lib/ollama -COPY --from=build /bin/ollama /bin/ollama +FROM builder AS binary-build +RUN \ + cd "/tmp/ollama-vulkan-git" && \ + . scripts/env.sh && \ + mkdir -p dist/bin && \ + go build -trimpath -buildmode=pie -o dist/bin/ollama . + + +FROM --platform=linux/amd64 library/ubuntu:noble +RUN \ + apt-get update && \ + apt-get install -y ca-certificates libcap2 libvulkan1 && \ + apt-get clean && rm -rf /var/lib/apt/lists/* +COPY --from=cpu-build /tmp/ollama-vulkan-git/dist/lib/ollama/ /lib/ollama/ +COPY --from=vulkan-build /tmp/ollama-vulkan-git/dist/lib/ollama/vulkan/ /lib/ollama/vulkan/ +COPY --from=binary-build /tmp/ollama-vulkan-git/dist/bin/ /bin/ + +RUN find /lib/ollama && find /bin/ollama -FROM ubuntu:20.04 -RUN apt-get update \ - && apt-get install -y ca-certificates \ - && apt-get clean \ - && rm -rf /var/lib/apt/lists/* -COPY --from=archive /bin /usr/bin -ENV PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin -COPY --from=archive /lib/ollama /usr/lib/ollama -ENV LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64 -ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility -ENV NVIDIA_VISIBLE_DEVICES=all -ENV OLLAMA_HOST=0.0.0.0:11434 EXPOSE 11434 +ENV OLLAMA_HOST 0.0.0.0 + ENTRYPOINT ["/bin/ollama"] CMD ["serve"] From 81465ca37494217f0d6f074eb47306835d26638b Mon Sep 17 00:00:00 2001 From: Vadim Grinco Date: Sun, 9 Mar 2025 20:42:32 +0100 Subject: [PATCH 023/172] Installing rocm library Signed-off-by: Vadim Grinco --- Dockerfile | 105 ++++++++++++++++++++++++++++------------------------- 1 file changed, 55 insertions(+), 50 deletions(-) diff --git a/Dockerfile b/Dockerfile index 4e87ba43c..25d8ddd3e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,88 +1,93 @@ -FROM --platform=linux/amd64 library/ubuntu:noble as builder +# Base Image +FROM --platform=linux/amd64 library/ubuntu:noble AS builder +# Set Environment Variables ENV DEBIAN_FRONTEND="noninteractive" - ENV VULKAN_VER_BASE="1.3.296" ENV VULKAN_VER="${VULKAN_VER_BASE}.0" ENV UBUNTU_VERSION="noble" - ENV GOLANG_VERSION="1.22.8" ENV GOARCH="amd64" ENV CGO_ENABLED=1 ENV LDFLAGS=-s -# Default mirror was very slow -RUN \ - sed -i 's/archive.ubuntu.com/gb.archive.ubuntu.com/g' /etc/apt/sources.list.d/ubuntu.sources +# Set up faster package mirrors +RUN sed -i 's/archive.ubuntu.com/gb.archive.ubuntu.com/g' /etc/apt/sources.list.d/ubuntu.sources -RUN \ - apt-get update && \ - apt-get install -y ca-certificates build-essential ccache cmake wget git curl rsync xz-utils libcap-dev +# Install Required Dependencies +RUN apt-get update && apt-get install -y \ + ca-certificates build-essential ccache cmake wget git curl rsync xz-utils libcap-dev \ + && apt-get clean && rm -rf /var/lib/apt/lists/* -RUN \ - mkdir -p /usr/local 2>/dev/null || true && \ +# Install Go +RUN mkdir -p /usr/local && \ curl -s -L https://dl.google.com/go/go${GOLANG_VERSION}.linux-${GOARCH}.tar.gz | tar -xz -C /usr/local && \ ln -s /usr/local/go/bin/go /usr/local/bin/go && \ ln -s /usr/local/go/bin/gofmt /usr/local/bin/gofmt - -RUN \ - wget -qO - https://packages.lunarg.com/lunarg-signing-key-pub.asc | gpg --dearmor -o /etc/apt/trusted.gpg.d/lunarg-signing-key-pub.gpg && \ +# Install Vulkan SDK +RUN wget -qO - https://packages.lunarg.com/lunarg-signing-key-pub.asc | gpg --dearmor -o /etc/apt/trusted.gpg.d/lunarg-signing-key-pub.gpg && \ wget -qO /etc/apt/sources.list.d/lunarg-vulkan-${UBUNTU_VERSION}.list https://packages.lunarg.com/vulkan/${VULKAN_VER_BASE}/lunarg-vulkan-${VULKAN_VER_BASE}-${UBUNTU_VERSION}.list && \ - apt update && apt install -y vulkan-sdk + apt update && apt install -y vulkan-sdk && \ + apt-get clean && rm -rf /var/lib/apt/lists/* -# Last testet ollama-vulkan commit: -# 2d443b3dd660a1fd2760d64538512df93648b4bb -COPY patches/ /tmp/patches/ -RUN \ - git clone https://github.com/pufferffish/ollama-vulkan.git "/tmp/ollama-vulkan-git" && \ - cd "/tmp/ollama-vulkan-git" && \ +# Install AMDVLK (Optional: If you want to use AMDVLK instead of RADV) +RUN wget -qO - http://repo.radeon.com/amdvlk/apt/debian/amdvlk.gpg.key | apt-key add && \ + echo "deb [arch=amd64,i386] http://repo.radeon.com/amdvlk/apt/debian/ bionic main" > /etc/apt/sources.list.d/amdvlk.list && \ + apt update && apt install -y amdvlk && \ + apt-get clean && rm -rf /var/lib/apt/lists/* + +# Set AMDVLK as the default Vulkan driver +ENV VK_ICD_FILENAMES=/usr/share/vulkan/icd.d/amd_icd64.json + +# Clone Ollama Vulkan Fork +WORKDIR /opt +RUN git clone https://github.com/pufferffish/ollama-vulkan.git ollama-vulkan + +# Download and Apply Patches Automatically +WORKDIR /opt/ollama-vulkan +RUN mkdir -p patches && \ + wget -O patches/00-fix-vulkan-building.patch https://github.com/user-attachments/files/18783263/0002-fix-fix-vulkan-building.patch && \ git checkout 2d443b3dd660a1fd2760d64538512df93648b4bb && git checkout -b ollama_vulkan_stable && \ git config user.name "Builder" && git config user.email "builder@local" && \ git remote add ollama_vanilla https://github.com/ollama/ollama.git && \ - git fetch ollama_vanilla --tags && git checkout v0.5.11 && git checkout -b ollama_vanilla_stable && \ + git fetch ollama_vanilla --tags && git checkout v0.5.13 && git checkout -b ollama_vanilla_stable && \ git checkout ollama_vulkan_stable && git merge ollama_vanilla_stable --allow-unrelated-histories --no-edit && \ - for p in /tmp/patches/00-fix-vulkan-building.patch; do patch -p1 < $p; done + for p in patches/*.patch; do patch -p1 < $p; done -RUN \ - cd "/tmp/ollama-vulkan-git" && \ - make -f Makefile.sync clean sync +# Build Shared Libraries (CPU & Vulkan) +WORKDIR /opt/ollama-vulkan +RUN cmake -S . -B build -DCMAKE_BUILD_TYPE=Release +RUN cmake --build build --parallel +RUN cmake --install build --component CPU --strip +RUN cmake --install build --component Vulkan --strip +# Install rocm +RUN apt update +RUN apt install -y wget "linux-headers-$(uname -r)" "linux-modules-extra-$(uname -r)" +RUN apt install -y python3-setuptools python3-wheel +RUN wget https://repo.radeon.com/amdgpu-install/6.3.3/ubuntu/noble/amdgpu-install_6.3.60303-1_all.deb -O /tmp/amdgpu-install_6.3.60303-1_all.deb +RUN apt install -y /tmp/amdgpu-install_6.3.60303-1_all.deb +RUN apt update && apt install -y rocm -FROM builder AS cpu-build -RUN \ - cd "/tmp/ollama-vulkan-git" && \ - cmake --preset CPU && cmake --build --parallel --preset CPU && \ - cmake --install build --component CPU --strip - -FROM builder AS vulkan-build -RUN \ - cd "/tmp/ollama-vulkan-git" && \ - cmake --preset Vulkan && \ - cmake --build --parallel --preset Vulkan && \ - cmake --install build --component Vulkan --strip - -FROM builder AS binary-build -RUN \ - cd "/tmp/ollama-vulkan-git" && \ +# Build Final Binary +RUN cd /opt/ollama-vulkan && \ . scripts/env.sh && \ mkdir -p dist/bin && \ go build -trimpath -buildmode=pie -o dist/bin/ollama . - +# Final Image FROM --platform=linux/amd64 library/ubuntu:noble -RUN \ - apt-get update && \ - apt-get install -y ca-certificates libcap2 libvulkan1 && \ +RUN apt-get update && apt-get install -y ca-certificates libcap2 libvulkan1 && \ apt-get clean && rm -rf /var/lib/apt/lists/* -COPY --from=cpu-build /tmp/ollama-vulkan-git/dist/lib/ollama/ /lib/ollama/ -COPY --from=vulkan-build /tmp/ollama-vulkan-git/dist/lib/ollama/vulkan/ /lib/ollama/vulkan/ -COPY --from=binary-build /tmp/ollama-vulkan-git/dist/bin/ /bin/ -RUN find /lib/ollama && find /bin/ollama +# Copy Built Components +COPY --from=builder /opt/ollama-vulkan/dist/bin/ollama /bin/ollama +# Expose Ollama Server Port EXPOSE 11434 ENV OLLAMA_HOST 0.0.0.0 +# Run Ollama Server ENTRYPOINT ["/bin/ollama"] CMD ["serve"] From 42bac5caddfb95c40c5d34d0fe123c4e3c0f1f2a Mon Sep 17 00:00:00 2001 From: Vadim Grinco Date: Sun, 9 Mar 2025 23:21:57 +0100 Subject: [PATCH 024/172] This version works well built based on this: https://github.com/whyvl/ollama-vulkan/issues/7#issuecomment-2660836871 Signed-off-by: Vadim Grinco --- Dockerfile | 116 +- patches/00-fix-vulkan-building.patch | 15297 +++++++++++++++++++++++++ 2 files changed, 15358 insertions(+), 55 deletions(-) create mode 100644 patches/00-fix-vulkan-building.patch diff --git a/Dockerfile b/Dockerfile index 25d8ddd3e..9e2928108 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,93 +1,99 @@ -# Base Image -FROM --platform=linux/amd64 library/ubuntu:noble AS builder +FROM --platform=linux/amd64 library/ubuntu:noble as builder -# Set Environment Variables ENV DEBIAN_FRONTEND="noninteractive" + ENV VULKAN_VER_BASE="1.3.296" ENV VULKAN_VER="${VULKAN_VER_BASE}.0" ENV UBUNTU_VERSION="noble" + ENV GOLANG_VERSION="1.22.8" ENV GOARCH="amd64" ENV CGO_ENABLED=1 ENV LDFLAGS=-s -# Set up faster package mirrors -RUN sed -i 's/archive.ubuntu.com/gb.archive.ubuntu.com/g' /etc/apt/sources.list.d/ubuntu.sources +# Default mirror was very slow +RUN \ + sed -i 's/archive.ubuntu.com/gb.archive.ubuntu.com/g' /etc/apt/sources.list.d/ubuntu.sources -# Install Required Dependencies -RUN apt-get update && apt-get install -y \ - ca-certificates build-essential ccache cmake wget git curl rsync xz-utils libcap-dev \ - && apt-get clean && rm -rf /var/lib/apt/lists/* +RUN \ + apt-get update && \ + apt-get install -y ca-certificates build-essential ccache cmake wget git curl rsync xz-utils libcap-dev -# Install Go -RUN mkdir -p /usr/local && \ +RUN \ + mkdir -p /usr/local 2>/dev/null || true && \ curl -s -L https://dl.google.com/go/go${GOLANG_VERSION}.linux-${GOARCH}.tar.gz | tar -xz -C /usr/local && \ ln -s /usr/local/go/bin/go /usr/local/bin/go && \ ln -s /usr/local/go/bin/gofmt /usr/local/bin/gofmt -# Install Vulkan SDK -RUN wget -qO - https://packages.lunarg.com/lunarg-signing-key-pub.asc | gpg --dearmor -o /etc/apt/trusted.gpg.d/lunarg-signing-key-pub.gpg && \ + +RUN \ + wget -qO - https://packages.lunarg.com/lunarg-signing-key-pub.asc | gpg --dearmor -o /etc/apt/trusted.gpg.d/lunarg-signing-key-pub.gpg && \ wget -qO /etc/apt/sources.list.d/lunarg-vulkan-${UBUNTU_VERSION}.list https://packages.lunarg.com/vulkan/${VULKAN_VER_BASE}/lunarg-vulkan-${VULKAN_VER_BASE}-${UBUNTU_VERSION}.list && \ - apt update && apt install -y vulkan-sdk && \ - apt-get clean && rm -rf /var/lib/apt/lists/* + apt update && apt install -y vulkan-sdk -# Install AMDVLK (Optional: If you want to use AMDVLK instead of RADV) -RUN wget -qO - http://repo.radeon.com/amdvlk/apt/debian/amdvlk.gpg.key | apt-key add && \ - echo "deb [arch=amd64,i386] http://repo.radeon.com/amdvlk/apt/debian/ bionic main" > /etc/apt/sources.list.d/amdvlk.list && \ - apt update && apt install -y amdvlk && \ - apt-get clean && rm -rf /var/lib/apt/lists/* - -# Set AMDVLK as the default Vulkan driver -ENV VK_ICD_FILENAMES=/usr/share/vulkan/icd.d/amd_icd64.json - -# Clone Ollama Vulkan Fork -WORKDIR /opt -RUN git clone https://github.com/pufferffish/ollama-vulkan.git ollama-vulkan - -# Download and Apply Patches Automatically -WORKDIR /opt/ollama-vulkan -RUN mkdir -p patches && \ - wget -O patches/00-fix-vulkan-building.patch https://github.com/user-attachments/files/18783263/0002-fix-fix-vulkan-building.patch && \ +# Last testet ollama-vulkan commit: +# 2d443b3dd660a1fd2760d64538512df93648b4bb +COPY patches/ /tmp/patches/ +RUN \ + git clone https://github.com/pufferffish/ollama-vulkan.git "/tmp/ollama-vulkan-git" && \ + cd "/tmp/ollama-vulkan-git" && \ git checkout 2d443b3dd660a1fd2760d64538512df93648b4bb && git checkout -b ollama_vulkan_stable && \ git config user.name "Builder" && git config user.email "builder@local" && \ git remote add ollama_vanilla https://github.com/ollama/ollama.git && \ - git fetch ollama_vanilla --tags && git checkout v0.5.13 && git checkout -b ollama_vanilla_stable && \ + git fetch ollama_vanilla --tags && git checkout v0.5.14-rc0 && git checkout -b ollama_vanilla_stable && \ git checkout ollama_vulkan_stable && git merge ollama_vanilla_stable --allow-unrelated-histories --no-edit && \ - for p in patches/*.patch; do patch -p1 < $p; done + for p in /tmp/patches/00-fix-vulkan-building.patch; do patch -p1 < $p; done -# Build Shared Libraries (CPU & Vulkan) -WORKDIR /opt/ollama-vulkan -RUN cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -RUN cmake --build build --parallel -RUN cmake --install build --component CPU --strip -RUN cmake --install build --component Vulkan --strip +RUN \ + cd "/tmp/ollama-vulkan-git" && \ + make -f Makefile.sync clean sync -# Install rocm -RUN apt update -RUN apt install -y wget "linux-headers-$(uname -r)" "linux-modules-extra-$(uname -r)" -RUN apt install -y python3-setuptools python3-wheel -RUN wget https://repo.radeon.com/amdgpu-install/6.3.3/ubuntu/noble/amdgpu-install_6.3.60303-1_all.deb -O /tmp/amdgpu-install_6.3.60303-1_all.deb -RUN apt install -y /tmp/amdgpu-install_6.3.60303-1_all.deb -RUN apt update && apt install -y rocm -# Build Final Binary -RUN cd /opt/ollama-vulkan && \ +FROM builder AS cpu-build +RUN \ + cd "/tmp/ollama-vulkan-git" && \ + cmake --preset CPU && cmake --build --parallel --preset CPU && \ + cmake --install build --component CPU --strip + +FROM builder AS vulkan-build +RUN \ + cd "/tmp/ollama-vulkan-git" && \ + cmake --preset Vulkan && \ + cmake --build --parallel --preset Vulkan && \ + cmake --install build --component Vulkan --strip + +FROM builder AS binary-build +RUN \ + cd "/tmp/ollama-vulkan-git" && \ . scripts/env.sh && \ mkdir -p dist/bin && \ go build -trimpath -buildmode=pie -o dist/bin/ollama . -# Final Image + FROM --platform=linux/amd64 library/ubuntu:noble -RUN apt-get update && apt-get install -y ca-certificates libcap2 libvulkan1 && \ +RUN \ + apt-get update && apt -y dist-upgrade && \ + apt-get install -y ca-certificates libcap2 libvulkan1 && \ apt-get clean && rm -rf /var/lib/apt/lists/* -# Copy Built Components -COPY --from=builder /opt/ollama-vulkan/dist/bin/ollama /bin/ollama +# Install ROCm +RUN \ + apt update && \ + apt install -y wget python3-setuptools python3-wheel && \ + wget https://repo.radeon.com/amdgpu-install/6.3.3/ubuntu/noble/amdgpu-install_6.3.60303-1_all.deb -O /tmp/amdgpu-install_6.3.60303-1_all.deb && \ + apt install -y /tmp/amdgpu-install_6.3.60303-1_all.deb && \ + apt update && apt install -y rocm && \ + apt-get clean && rm -rf /var/lib/apt/lists/* + + +COPY --from=cpu-build /tmp/ollama-vulkan-git/dist/lib/ollama/ /lib/ollama/ +COPY --from=vulkan-build /tmp/ollama-vulkan-git/dist/lib/ollama/vulkan/ /lib/ollama/vulkan/ +COPY --from=binary-build /tmp/ollama-vulkan-git/dist/bin/ /bin/ + +RUN find /lib/ollama && find /bin/ollama -# Expose Ollama Server Port EXPOSE 11434 ENV OLLAMA_HOST 0.0.0.0 -# Run Ollama Server ENTRYPOINT ["/bin/ollama"] CMD ["serve"] diff --git a/patches/00-fix-vulkan-building.patch b/patches/00-fix-vulkan-building.patch new file mode 100644 index 000000000..52e498ee2 --- /dev/null +++ b/patches/00-fix-vulkan-building.patch @@ -0,0 +1,15297 @@ +From 7c5f98c4cbfaf472a0d05baa3cc61afdcaeee7de Mon Sep 17 00:00:00 2001 +From: dream +Date: Thu, 13 Feb 2025 18:58:59 +0800 +Subject: [PATCH 2/2] fix: fix vulkan building + +1. Add preset for vulkan. +2. Add backend ggml-vulkan. +3. Add some log info. +--- + CMakePresets.json | 13 +- + discover/gpu.go | 7 +- + .../ggml/ggml/src/ggml-vulkan/CMakeLists.txt | 92 + + .../ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp | 8745 +++++++++++++++++ + .../ggml-vulkan/vulkan-shaders/CMakeLists.txt | 9 + + .../src/ggml-vulkan/vulkan-shaders/acc.comp | 29 + + .../src/ggml-vulkan/vulkan-shaders/add.comp | 29 + + .../ggml-vulkan/vulkan-shaders/argsort.comp | 69 + + .../src/ggml-vulkan/vulkan-shaders/clamp.comp | 17 + + .../ggml-vulkan/vulkan-shaders/concat.comp | 41 + + .../vulkan-shaders/contig_copy.comp | 42 + + .../src/ggml-vulkan/vulkan-shaders/copy.comp | 20 + + .../src/ggml-vulkan/vulkan-shaders/cos.comp | 17 + + .../vulkan-shaders/dequant_f32.comp | 20 + + .../vulkan-shaders/dequant_funcs.comp | 118 + + .../vulkan-shaders/dequant_funcs_cm2.comp | 325 + + .../vulkan-shaders/dequant_head.comp | 13 + + .../vulkan-shaders/dequant_iq4_nl.comp | 32 + + .../vulkan-shaders/dequant_q2_k.comp | 34 + + .../vulkan-shaders/dequant_q3_k.comp | 42 + + .../vulkan-shaders/dequant_q4_0.comp | 30 + + .../vulkan-shaders/dequant_q4_1.comp | 32 + + .../vulkan-shaders/dequant_q4_k.comp | 68 + + .../vulkan-shaders/dequant_q5_0.comp | 34 + + .../vulkan-shaders/dequant_q5_1.comp | 35 + + .../vulkan-shaders/dequant_q5_k.comp | 70 + + .../vulkan-shaders/dequant_q6_k.comp | 33 + + .../vulkan-shaders/dequant_q8_0.comp | 31 + + .../vulkan-shaders/diag_mask_inf.comp | 34 + + .../src/ggml-vulkan/vulkan-shaders/div.comp | 27 + + .../vulkan-shaders/flash_attn_cm2.comp | 289 + + .../src/ggml-vulkan/vulkan-shaders/gelu.comp | 25 + + .../vulkan-shaders/gelu_quick.comp | 23 + + .../vulkan-shaders/generic_binary_head.comp | 64 + + .../vulkan-shaders/generic_head.comp | 9 + + .../vulkan-shaders/generic_unary_head.comp | 56 + + .../ggml-vulkan/vulkan-shaders/get_rows.comp | 28 + + .../vulkan-shaders/get_rows_quant.comp | 39 + + .../vulkan-shaders/group_norm.comp | 66 + + .../ggml-vulkan/vulkan-shaders/im2col.comp | 87 + + .../vulkan-shaders/leaky_relu.comp | 22 + + .../src/ggml-vulkan/vulkan-shaders/mul.comp | 27 + + .../mul_mat_split_k_reduce.comp | 48 + + .../vulkan-shaders/mul_mat_vec.comp | 152 + + .../vulkan-shaders/mul_mat_vec_base.comp | 118 + + .../vulkan-shaders/mul_mat_vec_nc.comp | 71 + + .../vulkan-shaders/mul_mat_vec_p021.comp | 73 + + .../vulkan-shaders/mul_mat_vec_q2_k.comp | 115 + + .../vulkan-shaders/mul_mat_vec_q3_k.comp | 103 + + .../vulkan-shaders/mul_mat_vec_q4_k.comp | 133 + + .../vulkan-shaders/mul_mat_vec_q5_k.comp | 162 + + .../vulkan-shaders/mul_mat_vec_q6_k.comp | 112 + + .../ggml-vulkan/vulkan-shaders/mul_mm.comp | 631 ++ + .../vulkan-shaders/mul_mm_cm2.comp | 328 + + .../src/ggml-vulkan/vulkan-shaders/norm.comp | 44 + + .../src/ggml-vulkan/vulkan-shaders/pad.comp | 28 + + .../ggml-vulkan/vulkan-shaders/pool2d.comp | 74 + + .../src/ggml-vulkan/vulkan-shaders/relu.comp | 21 + + .../ggml-vulkan/vulkan-shaders/repeat.comp | 26 + + .../ggml-vulkan/vulkan-shaders/rms_norm.comp | 42 + + .../ggml-vulkan/vulkan-shaders/rope_head.comp | 49 + + .../ggml-vulkan/vulkan-shaders/rope_neox.comp | 37 + + .../ggml-vulkan/vulkan-shaders/rope_norm.comp | 37 + + .../src/ggml-vulkan/vulkan-shaders/scale.comp | 24 + + .../src/ggml-vulkan/vulkan-shaders/silu.comp | 22 + + .../src/ggml-vulkan/vulkan-shaders/sin.comp | 17 + + .../ggml-vulkan/vulkan-shaders/soft_max.comp | 174 + + .../ggml-vulkan/vulkan-shaders/square.comp | 17 + + .../ggml-vulkan/vulkan-shaders/sum_rows.comp | 37 + + .../src/ggml-vulkan/vulkan-shaders/tanh.comp | 20 + + .../vulkan-shaders/test_coopmat2_support.comp | 7 + + .../vulkan-shaders/timestep_embedding.comp | 41 + + .../src/ggml-vulkan/vulkan-shaders/types.comp | 323 + + .../ggml-vulkan/vulkan-shaders/upscale.comp | 36 + + .../vulkan-shaders/vulkan-shaders-gen.cpp | 594 ++ + .../src/ggml-vulkan/vulkan-shaders/wkv6.comp | 87 + + 76 files changed, 14642 insertions(+), 4 deletions(-) + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/CMakeLists.txt + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/add.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/div.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_split_k_reduce.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/square.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat2_support.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/types.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp + create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp + +diff --git a/CMakePresets.json b/CMakePresets.json +index 3ecb0a8f..a77f15ba 100644 +--- a/CMakePresets.json ++++ b/CMakePresets.json +@@ -58,7 +58,11 @@ + "cacheVariables": { + "AMDGPU_TARGETS": "gfx803;gfx900;gfx940;gfx941;gfx942;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx906:xnack-;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-" + } +- } ++ }, ++ { ++ "name": "Vulkan", ++ "inherits": [ "Default" ] ++ } + ], + "buildPresets": [ + { +@@ -105,6 +109,11 @@ + "name": "ROCm 6", + "inherits": [ "ROCm" ], + "configurePreset": "ROCm 6" +- } ++ }, ++ { ++ "name": "Vulkan", ++ "targets": [ "ggml-vulkan" ], ++ "configurePreset": "Vulkan" ++ } + ] + } +diff --git a/discover/gpu.go b/discover/gpu.go +index ec96f5d4..8079be99 100644 +--- a/discover/gpu.go ++++ b/discover/gpu.go +@@ -197,7 +197,10 @@ func initVulkanHandles() *vulkanHandles { + libcapPaths := FindLibCapLibs() + + if len(vulkanPaths) > 0 && len(libcapPaths) > 0 { ++ slog.Info("vulkan: load libvulkan and libcap ok") + vHandles.deviceCount, vHandles.vulkan, vulkanLibPath, libcapLibPath = LoadVulkanMgmt(vulkanPaths, libcapPaths) ++ } else { ++ slog.Info("vulkan: failed to load libvulkan or libcap") + } + + return vHandles +@@ -426,7 +429,7 @@ func GetGPUInfo() GpuInfoList { + gpuInfo.ID = C.GoString(&memInfo.gpu_id[0]) + gpuInfo.Compute = fmt.Sprintf("%d.%d", memInfo.major, memInfo.minor) + gpuInfo.MinimumMemory = 0 +- gpuInfo.DependencyPath = depPaths ++ gpuInfo.DependencyPath = []string{LibOllamaPath} + gpuInfo.Name = C.GoString(&memInfo.gpu_name[0]) + gpuInfo.DriverMajor = int(memInfo.major) + gpuInfo.DriverMinor = int(memInfo.minor) +@@ -768,7 +771,7 @@ func LoadVulkanMgmt(vulkanLibPaths []string, capLibPaths []string) (int, *C.vk_h + + C.vk_init(vkLib, capLib, &resp) + if resp.err != nil { +- slog.Debug("Unable to load vulkan", "library", vkLibPath, capLibPath, "error", C.GoString(resp.err)) ++ slog.Error("Unable to load vulkan", "library", vkLibPath, capLibPath, "error", C.GoString(resp.err)) + C.free(unsafe.Pointer(resp.err)) + } else { + return int(resp.num_devices), &resp.ch, vkLibPath, capLibPath +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/CMakeLists.txt b/ml/backend/ggml/ggml/src/ggml-vulkan/CMakeLists.txt +new file mode 100644 +index 00000000..9501de73 +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/CMakeLists.txt +@@ -0,0 +1,92 @@ ++find_package(Vulkan COMPONENTS glslc REQUIRED) ++ ++if (Vulkan_FOUND) ++ message(STATUS "Vulkan found") ++ ++ ggml_add_backend_library(ggml-vulkan ++ ggml-vulkan.cpp ++ ../../include/ggml-vulkan.h ++ ) ++ ++ # Compile a test shader to determine whether GL_NV_cooperative_matrix2 is supported. ++ # If it's not, there will be an error to stderr. ++ # If it's supported, set a define to indicate that we should compile those shaders ++ execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat2_support.comp" ++ OUTPUT_VARIABLE glslc_output ++ ERROR_VARIABLE glslc_error) ++ ++ if (${glslc_error} MATCHES ".*extension not supported: GL_NV_cooperative_matrix2.*") ++ message(STATUS "GL_NV_cooperative_matrix2 not supported by glslc") ++ else() ++ message(STATUS "GL_NV_cooperative_matrix2 supported by glslc") ++ add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) ++ endif() ++ ++ target_link_libraries(ggml-vulkan PRIVATE Vulkan::Vulkan) ++ target_include_directories(ggml-vulkan PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) ++ ++ # Workaround to the "can't dereference invalidated vector iterator" bug in clang-cl debug build ++ # Posssibly relevant: https://stackoverflow.com/questions/74748276/visual-studio-no-displays-the-correct-length-of-stdvector ++ if (MSVC AND CMAKE_CXX_COMPILER_ID STREQUAL "Clang") ++ add_compile_definitions(_ITERATOR_DEBUG_LEVEL=0) ++ endif() ++ ++ if (GGML_VULKAN_CHECK_RESULTS) ++ add_compile_definitions(GGML_VULKAN_CHECK_RESULTS) ++ endif() ++ ++ if (GGML_VULKAN_DEBUG) ++ add_compile_definitions(GGML_VULKAN_DEBUG) ++ endif() ++ ++ if (GGML_VULKAN_MEMORY_DEBUG) ++ add_compile_definitions(GGML_VULKAN_MEMORY_DEBUG) ++ endif() ++ ++ if (GGML_VULKAN_SHADER_DEBUG_INFO) ++ add_compile_definitions(GGML_VULKAN_SHADER_DEBUG_INFO) ++ endif() ++ ++ if (GGML_VULKAN_PERF) ++ add_compile_definitions(GGML_VULKAN_PERF) ++ endif() ++ ++ if (GGML_VULKAN_VALIDATE) ++ add_compile_definitions(GGML_VULKAN_VALIDATE) ++ endif() ++ ++ if (GGML_VULKAN_RUN_TESTS) ++ add_compile_definitions(GGML_VULKAN_RUN_TESTS) ++ endif() ++ ++ add_subdirectory(vulkan-shaders) ++ ++ set (_ggml_vk_genshaders_cmd vulkan-shaders-gen) ++ set (_ggml_vk_header ${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.hpp) ++ set (_ggml_vk_source ${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.cpp) ++ set (_ggml_vk_input_dir ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders) ++ set (_ggml_vk_output_dir ${CMAKE_CURRENT_BINARY_DIR}/vulkan-shaders.spv) ++ ++ file(GLOB _ggml_vk_shader_deps "${_ggml_vk_input_dir}/*.comp") ++ ++ add_custom_command( ++ OUTPUT ${_ggml_vk_header} ++ ${_ggml_vk_source} ++ ++ COMMAND "$/${_ggml_vk_genshaders_cmd}" ++ --glslc ${Vulkan_GLSLC_EXECUTABLE} ++ --input-dir ${_ggml_vk_input_dir} ++ --output-dir ${_ggml_vk_output_dir} ++ --target-hpp ${_ggml_vk_header} ++ --target-cpp ${_ggml_vk_source} ++ --no-clean ++ ++ DEPENDS ${_ggml_vk_shader_deps} ${_ggml_vk_genshaders_cmd} ++ COMMENT "Generate vulkan shaders" ++ ) ++ ++ target_sources(ggml-vulkan PRIVATE ${_ggml_vk_source} ${_ggml_vk_header}) ++ ++else() ++ message(WARNING "Vulkan not found") ++endif() +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp +new file mode 100644 +index 00000000..d75cd6d6 +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp +@@ -0,0 +1,8745 @@ ++#include "ggml-vulkan.h" ++#include ++#if defined(GGML_VULKAN_RUN_TESTS) || defined(GGML_VULKAN_PERF) || defined(GGML_VULKAN_CHECK_RESULTS) ++#include ++#include "ggml-cpu.h" ++#endif ++ ++#include ++ ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++ ++#include "ggml-impl.h" ++#include "ggml-backend-impl.h" ++ ++#include "ggml-vulkan-shaders.hpp" ++ ++#define VK_API_VERSION VK_API_VERSION_1_2 ++ ++#define CEIL_DIV(M, N) (((M) + (N)-1) / (N)) ++ ++#define VK_VENDOR_ID_AMD 0x1002 ++#define VK_VENDOR_ID_APPLE 0x106b ++#define VK_VENDOR_ID_INTEL 0x8086 ++#define VK_VENDOR_ID_NVIDIA 0x10de ++ ++#define VK_DEVICE_DESCRIPTOR_POOL_SIZE 32 ++ ++#define GGML_VK_MAX_NODES 8192 ++ ++#define MAX_VK_BUFFERS 256 ++ ++#define VK_CHECK(err, msg) \ ++ do { \ ++ vk::Result err_ = (err); \ ++ if (err_ != vk::Result::eSuccess) { \ ++ fprintf(stderr, "ggml_vulkan: %s error %s at %s:%d\n", \ ++ #err, to_string(err_).c_str(), __FILE__, __LINE__); \ ++ exit(1); \ ++ } \ ++ } while (0) ++ ++#ifdef GGML_VULKAN_DEBUG ++#define VK_LOG_DEBUG(msg) std::cerr << msg << std::endl ++#else ++#define VK_LOG_DEBUG(msg) ((void) 0) ++#endif // GGML_VULKAN_DEBUG ++ ++struct ggml_backend_vk_context; ++ ++struct vk_queue { ++ uint32_t queue_family_index; ++ vk::Queue queue; ++ vk::CommandPool pool; ++ uint32_t cmd_buffer_idx; ++ std::vector cmd_buffers; ++ ++ vk::PipelineStageFlags stage_flags; ++ ++ bool transfer_only; ++}; ++ ++struct vk_pipeline_struct { ++ std::string name; ++ vk::ShaderModule shader_module; ++ vk::DescriptorSetLayout dsl; ++ std::vector descriptor_pools; ++ std::vector descriptor_sets; ++ uint32_t descriptor_set_idx; ++ vk::PipelineLayout layout; ++ vk::Pipeline pipeline; ++ uint32_t push_constant_size; ++ uint32_t parameter_count; ++ std::array wg_denoms; ++ uint32_t align; ++}; ++ ++typedef std::shared_ptr vk_pipeline; ++typedef std::weak_ptr vk_pipeline_ref; ++ ++static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline); ++ ++struct vk_matmul_pipeline_struct { ++ vk_pipeline l, m, s; ++ vk_pipeline a_l, a_m, a_s; ++}; ++ ++typedef std::shared_ptr vk_matmul_pipeline; ++ ++struct vk_matmul_pipeline2 { ++ vk_matmul_pipeline2() { ++ f16acc = std::make_shared(); ++ f32acc = std::make_shared(); ++ } ++ vk_matmul_pipeline f32acc; ++ vk_matmul_pipeline f16acc; ++}; ++ ++struct vk_device_struct; ++typedef std::shared_ptr vk_device; ++typedef std::weak_ptr vk_device_ref; ++ ++struct vk_buffer_struct; ++typedef std::shared_ptr vk_buffer; ++typedef std::weak_ptr vk_buffer_ref; ++ ++struct ggml_backend_vk_buffer_type_context { ++ std::string name; ++ vk_device device; ++}; ++ ++static const char * ggml_backend_vk_buffer_type_name(ggml_backend_buffer_type_t buft); ++static ggml_backend_buffer_t ggml_backend_vk_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size); ++static size_t ggml_backend_vk_buffer_type_get_alignment(ggml_backend_buffer_type_t buft); ++static size_t ggml_backend_vk_buffer_type_get_max_size(ggml_backend_buffer_type_t buft); ++static size_t ggml_backend_vk_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor); ++static ggml_backend_buffer_type_i ggml_backend_vk_buffer_type_interface = { ++ /* .get_name = */ ggml_backend_vk_buffer_type_name, ++ /* .alloc_buffer = */ ggml_backend_vk_buffer_type_alloc_buffer, ++ /* .get_alignment = */ ggml_backend_vk_buffer_type_get_alignment, ++ /* .get_max_size = */ ggml_backend_vk_buffer_type_get_max_size, ++ /* .get_alloc_size = */ ggml_backend_vk_buffer_type_get_alloc_size, ++ /* .is_host = */ NULL, ++}; ++ ++#ifdef GGML_VULKAN_MEMORY_DEBUG ++class vk_memory_logger; ++#endif ++#ifdef GGML_VULKAN_PERF ++class vk_perf_logger; ++#endif ++static void ggml_vk_destroy_buffer(vk_buffer& buf); ++ ++static constexpr uint32_t mul_mat_vec_max_cols = 8; ++ ++struct vk_device_struct { ++ std::mutex mutex; ++ ++ vk::PhysicalDevice physical_device; ++ vk::PhysicalDeviceProperties properties; ++ std::string name; ++ uint64_t max_memory_allocation_size; ++ bool fp16; ++ bool pipeline_robustness; ++ vk::Device device; ++ uint32_t vendor_id; ++ vk_queue compute_queue; ++ vk_queue transfer_queue; ++ bool single_queue; ++ uint32_t subgroup_size; ++ uint32_t shader_core_count; ++ bool uma; ++ bool float_controls_rte_fp16; ++ ++ bool subgroup_size_control; ++ uint32_t subgroup_min_size; ++ uint32_t subgroup_max_size; ++ bool subgroup_require_full_support; ++ ++ bool coopmat_support; ++ bool coopmat_acc_f32_support; ++ bool coopmat_acc_f16_support; ++ uint32_t coopmat_m; ++ uint32_t coopmat_n; ++ uint32_t coopmat_k; ++ bool coopmat2; ++ ++ size_t idx; ++ ++ bool mul_mat_l; ++ bool mul_mat_m; ++ bool mul_mat_s; ++ bool mul_mat_id_l; ++ bool mul_mat_id_m; ++ bool mul_mat_id_s; ++ ++ vk_matmul_pipeline pipeline_matmul_f32; ++ vk_matmul_pipeline pipeline_matmul_f32_f16; ++ vk_matmul_pipeline2 pipeline_matmul_f16; ++ vk_matmul_pipeline2 pipeline_matmul_f16_f32; ++ vk_pipeline pipeline_matmul_split_k_reduce; ++ ++ vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_COUNT]; ++ vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat[GGML_TYPE_COUNT]; ++ ++ vk_matmul_pipeline pipeline_matmul_id_f32; ++ vk_matmul_pipeline2 pipeline_matmul_id_f16; ++ vk_matmul_pipeline2 pipeline_matmul_id_f16_f32; ++ ++ vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT]; ++ ++ vk_pipeline pipeline_dequant[GGML_TYPE_COUNT]; ++ vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols]; ++ vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols]; ++ vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_COUNT]; ++ ++ vk_pipeline pipeline_mul_mat_vec_p021_f16_f32; ++ vk_pipeline pipeline_mul_mat_vec_nc_f16_f32; ++ vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT]; ++ vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT]; ++ vk_pipeline pipeline_acc_f32; ++ vk_pipeline pipeline_add_f32, pipeline_add_f32_norepeat; ++ vk_pipeline pipeline_add_f16_f32_f16, pipeline_add_f16_f32_f16_norepeat; ++ vk_pipeline pipeline_mul_f32, pipeline_mul_f32_norepeat; ++ vk_pipeline pipeline_div_f32, pipeline_div_f32_norepeat; ++ vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32; ++ vk_pipeline pipeline_upscale_f32; ++ vk_pipeline pipeline_scale_f32; ++ vk_pipeline pipeline_sqr_f32; ++ vk_pipeline pipeline_sin_f32; ++ vk_pipeline pipeline_cos_f32; ++ vk_pipeline pipeline_clamp_f32; ++ vk_pipeline pipeline_pad_f32; ++ vk_pipeline pipeline_repeat_f32; ++ vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16; ++ vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16; ++ vk_pipeline pipeline_norm_f32; ++ vk_pipeline pipeline_group_norm_f32; ++ vk_pipeline pipeline_rms_norm_f32; ++ vk_pipeline pipeline_gelu_f32; ++ vk_pipeline pipeline_gelu_quick_f32; ++ vk_pipeline pipeline_silu_f32; ++ vk_pipeline pipeline_relu_f32; ++ vk_pipeline pipeline_leaky_relu_f32; ++ vk_pipeline pipeline_tanh_f32; ++ vk_pipeline pipeline_diag_mask_inf_f32; ++ vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16; ++ vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512; ++ vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16; ++ vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16; ++ vk_pipeline pipeline_argsort_f32; ++ vk_pipeline pipeline_sum_rows_f32; ++ vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16; ++ vk_pipeline pipeline_timestep_embedding_f32; ++ vk_pipeline pipeline_pool2d_f32; ++ vk_pipeline pipeline_rwkv_wkv6_f32; ++ ++ // [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned} ++ vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2]; ++ vk_pipeline pipeline_flash_attn_f32_f16_D80[GGML_TYPE_COUNT][2][2][2]; ++ vk_pipeline pipeline_flash_attn_f32_f16_D96[GGML_TYPE_COUNT][2][2][2]; ++ vk_pipeline pipeline_flash_attn_f32_f16_D112[GGML_TYPE_COUNT][2][2][2]; ++ vk_pipeline pipeline_flash_attn_f32_f16_D128[GGML_TYPE_COUNT][2][2][2]; ++ vk_pipeline pipeline_flash_attn_f32_f16_D256[GGML_TYPE_COUNT][2][2][2]; ++ ++ std::unordered_map pipelines; ++ std::unordered_map pipeline_descriptor_set_requirements; ++ ++ std::vector> pinned_memory; ++ ++ vk::Fence fence; ++ vk_buffer sync_staging; ++ ++ ggml_backend_buffer_type buffer_type; ++ ++#ifdef GGML_VULKAN_MEMORY_DEBUG ++ std::unique_ptr memory_logger; ++#endif ++#ifdef GGML_VULKAN_PERF ++ std::unique_ptr perf_logger; ++#endif ++ ++ ~vk_device_struct() { ++ VK_LOG_DEBUG("destroy device " << name); ++ ++ device.destroyFence(fence); ++ ++ ggml_vk_destroy_buffer(sync_staging); ++ ++ device.destroyCommandPool(compute_queue.pool); ++ if (!single_queue) { ++ device.destroyCommandPool(transfer_queue.pool); ++ } ++ ++ for (auto& pipeline : pipelines) { ++ if (pipeline.second.expired()) { ++ continue; ++ } ++ ++ vk_pipeline pl = pipeline.second.lock(); ++ ggml_vk_destroy_pipeline(device, pl); ++ } ++ pipelines.clear(); ++ ++ device.destroy(); ++ } ++}; ++ ++struct vk_buffer_struct { ++ vk::Buffer buffer = VK_NULL_HANDLE; ++ vk::DeviceMemory device_memory = VK_NULL_HANDLE; ++ vk::MemoryPropertyFlags memory_property_flags; ++ void * ptr; ++ size_t size = 0; ++ ++ vk_device device; ++ ++ ~vk_buffer_struct() { ++ if (size == 0) { ++ return; ++ } ++ VK_LOG_DEBUG("~vk_buffer_struct(" << buffer << ", " << size << ")"); ++ ++ device->device.freeMemory(device_memory); ++ device->device.destroyBuffer(buffer); ++ } ++}; ++ ++struct vk_subbuffer { ++ vk_buffer buffer; ++ uint64_t offset; ++ uint64_t size; ++ ++ operator vk::DescriptorBufferInfo() const { ++ return { buffer->buffer, offset, size }; ++ } ++}; ++ ++struct vk_semaphore { ++ vk::Semaphore s; ++ uint64_t value; ++}; ++ ++struct vk_submission { ++ vk::CommandBuffer buffer; ++ std::vector wait_semaphores; ++ std::vector signal_semaphores; ++}; ++ ++typedef std::vector vk_sequence; ++ ++struct vk_mat_mat_push_constants { ++ uint32_t M; uint32_t N; uint32_t K; ++ uint32_t stride_a; uint32_t stride_b; uint32_t stride_d; ++ uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d; ++ uint32_t k_split; ++ uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3; ++}; ++struct vk_mat_vec_push_constants { ++ uint32_t ncols; uint32_t stride_a; uint32_t stride_b; uint32_t stride_d; ++ uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d; ++ uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3; ++}; ++ ++struct vk_mat_mat_id_push_constants { ++ uint32_t M; uint32_t N; uint32_t K; ++ uint32_t stride_a; uint32_t stride_b; uint32_t stride_d; ++ uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d; ++ uint32_t nei0; uint32_t nei1; uint32_t nbi1; uint32_t ne11; ++}; ++struct vk_mat_vec_id_push_constants { ++ uint32_t ncols; uint32_t stride_a; uint32_t stride_b; uint32_t stride_d; ++ uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d; ++ uint32_t nei0; uint32_t ne11; ++}; ++ ++struct vk_flash_attn_push_constants { ++ uint32_t N; ++ uint32_t KV; ++ ++ uint32_t ne1; ++ uint32_t ne2; ++ uint32_t ne3; ++ ++ uint32_t neq2; ++ uint32_t neq3; ++ uint32_t nek2; ++ uint32_t nek3; ++ uint32_t nev2; ++ uint32_t nev3; ++ uint32_t nem1; ++ ++ uint32_t nb02; ++ uint32_t nb03; ++ uint32_t nb12; ++ uint32_t nb13; ++ uint32_t nb22; ++ uint32_t nb23; ++ uint32_t nb31; ++ ++ float scale; ++ float max_bias; ++ float logit_softcap; ++ ++ uint32_t mask; ++ uint32_t n_head_log2; ++ float m0; ++ float m1; ++}; ++ ++struct vk_op_push_constants { ++ uint32_t KX; ++ uint32_t KY; ++ float param1; ++ float param2; ++}; ++ ++struct vk_op_unary_push_constants { ++ uint32_t ne; ++ uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03; ++ uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13; ++ uint32_t misalign_offsets; ++ float param1; float param2; ++ uint32_t ne0_012mp; uint32_t ne0_012L; ++ uint32_t ne0_01mp; uint32_t ne0_01L; ++ uint32_t ne0_0mp; uint32_t ne0_0L; ++ uint32_t ne1_012mp; uint32_t ne1_012L; ++ uint32_t ne1_01mp; uint32_t ne1_01L; ++ uint32_t ne1_0mp; uint32_t ne1_0L; ++}; ++static_assert(sizeof(vk_op_unary_push_constants) <= 128, "sizeof(vk_op_unary_push_constants) must be <= 128"); ++ ++// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1. ++// Precompute mp (m' in the paper) and L such that division ++// can be computed using a multiply (high 32b of 64b result) ++// and a shift: ++// ++// n/d = (mulhi(n, mp) + n) >> L; ++static void init_fastdiv_values(uint32_t d, uint32_t &mp, uint32_t &L) ++{ ++ // compute L = ceil(log2(d)); ++ L = 0; ++ while (L < 32 && (uint32_t{1} << L) < d) { ++ L++; ++ } ++ ++ mp = (uint32_t)((uint64_t{1} << 32) * ((uint64_t{1} << L) - d) / d + 1); ++} ++ ++template void init_pushconst_fastdiv(T &p) { ++ GGML_UNUSED(p); ++ static_assert(!std::is_const::value, "unexpected type"); ++} ++ ++template <> void init_pushconst_fastdiv(vk_op_unary_push_constants &p) { ++ // Compute magic values to divide by these six numbers. ++ init_fastdiv_values(p.ne02*p.ne01*p.ne00, p.ne0_012mp, p.ne0_012L); ++ init_fastdiv_values(p.ne01*p.ne00, p.ne0_01mp, p.ne0_01L); ++ init_fastdiv_values(p.ne00, p.ne0_0mp, p.ne0_0L); ++ init_fastdiv_values(p.ne12*p.ne11*p.ne10, p.ne1_012mp, p.ne1_012L); ++ init_fastdiv_values(p.ne11*p.ne10, p.ne1_01mp, p.ne1_01L); ++ init_fastdiv_values(p.ne10, p.ne1_0mp, p.ne1_0L); ++} ++ ++struct vk_op_binary_push_constants { ++ uint32_t ne; ++ uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03; ++ uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13; ++ uint32_t ne20; uint32_t ne21; uint32_t ne22; uint32_t ne23; uint32_t nb20; uint32_t nb21; uint32_t nb22; uint32_t nb23; ++ uint32_t misalign_offsets; ++ float param1; float param2; int32_t param3; ++}; ++ ++struct vk_op_diag_mask_push_constants { ++ uint32_t ncols; ++ uint32_t rows_per_channel; ++ int32_t n_past; ++}; ++ ++struct vk_op_rope_push_constants { ++ uint32_t ncols; ++ uint32_t n_dims; ++ float freq_scale; ++ uint32_t p_delta_rows; ++ float freq_base; ++ float ext_factor; ++ float attn_factor; ++ float corr_dims[2]; ++ float theta_scale; ++ uint32_t has_ff; ++}; ++ ++struct vk_op_soft_max_push_constants { ++ uint32_t KX; ++ uint32_t KY; ++ float scale; ++ float max_bias; ++ float m0; ++ float m1; ++ uint32_t n_head_log2; ++ uint32_t nrows_x; ++}; ++ ++struct vk_op_argsort_push_constants { ++ uint32_t ncols; ++ uint32_t ncols_pad; ++ int32_t order; ++}; ++ ++struct vk_op_im2col_push_constants { ++ uint32_t batch_offset; uint32_t offset_delta; ++ uint32_t IC; ++ uint32_t IW; uint32_t IH; ++ uint32_t OW; uint32_t OH; ++ uint32_t KW; uint32_t KH; ++ uint32_t pelements; ++ uint32_t CHW; ++ int32_t s0; int32_t s1; ++ int32_t p0; int32_t p1; ++ int32_t d0; int32_t d1; ++}; ++ ++struct vk_op_timestep_embedding_push_constants { ++ uint32_t nb1; ++ uint32_t dim; ++ uint32_t max_period; ++}; ++ ++struct vk_op_pool2d_push_constants { ++ uint32_t IW; uint32_t IH; ++ uint32_t OW; uint32_t OH; ++ uint32_t OC; ++ uint32_t pelements; ++ uint32_t op; ++ int32_t k0; int32_t k1; ++ int32_t s0; int32_t s1; ++ int32_t p0; int32_t p1; ++}; ++ ++struct vk_op_rwkv_wkv6_push_constants { ++ uint32_t B; ++ uint32_t T; ++ uint32_t C; ++ uint32_t H; ++}; ++ ++// Allow pre-recording command buffers ++struct vk_staging_memcpy { ++ vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {} ++ ++ void * dst; ++ const void * src; ++ size_t n; ++}; ++ ++struct vk_op_upscale_push_constants { ++ uint32_t ne; uint32_t a_offset; uint32_t d_offset; ++ uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03; ++ uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; ++ float sf0; float sf1; float sf2; float sf3; ++}; ++ ++struct vk_context_struct { ++ vk_submission * s; ++ std::vector seqs; ++ ++ int exit_tensor_idx; ++ ++ std::vector in_memcpys; ++ std::vector out_memcpys; ++ ++ vk_queue * q; ++}; ++typedef std::shared_ptr vk_context; ++typedef std::weak_ptr vk_context_ref; ++ ++struct ggml_vk_garbage_collector { ++ std::vector tl_semaphores; ++ std::vector semaphores; ++ std::vector events; ++ std::vector temp_buffers; ++ std::vector contexts; ++}; ++ ++#if defined(GGML_VULKAN_MEMORY_DEBUG) || defined(GGML_VULKAN_DEBUG) ++#define VK_LOG_MEMORY(msg) std::cerr << "ggml_vulkan memory: " << msg << std::endl ++ ++static std::string format_size(size_t size) { ++ const size_t kib = 1024; ++ const size_t mib = kib * 1024; ++ const size_t gib = mib * 1024; ++ ++ std::ostringstream oss; ++ oss << std::fixed << std::setprecision(2); ++ ++ if (size >= gib) { ++ oss << static_cast(size) / gib << " GiB"; ++ } else if (size >= mib) { ++ oss << static_cast(size) / mib << " MiB"; ++ } else if (size >= kib) { ++ oss << static_cast(size) / kib << " KiB"; ++ } else { ++ oss << size << " B"; ++ } ++ ++ return oss.str(); ++} ++ ++static std::mutex log_mutex; ++ ++class vk_memory_logger { ++public: ++ vk_memory_logger(): total_device(0), total_host(0) {} ++ void log_allocation(vk_buffer_ref buf_ref, size_t size); ++ void log_deallocation(vk_buffer_ref buf_ref); ++ ++private: ++ std::map allocations; // Track allocations ++ size_t total_device; ++ size_t total_host; ++}; ++#else ++#define VK_LOG_MEMORY(msg) ((void) 0) ++#endif // GGML_VULKAN_MEMORY_DEBUG ++ ++#if defined(GGML_VULKAN_PERF) ++ ++class vk_perf_logger { ++public: ++ void print_timings() { ++ std::cerr << "----------------\nVulkan Timings:" << std::endl; ++ for (const auto& t : timings) { ++ uint64_t total = 0; ++ for (const auto& time : t.second) { ++ total += time; ++ } ++ std::cerr << t.first << ": " << t.second.size() << " x " << (total / t.second.size() / 1000.0) << " ms" << std::endl; ++ } ++ ++ timings.clear(); ++ } ++ ++ void log_timing(const ggml_tensor * node, uint64_t time) { ++ if (node->op == GGML_OP_UNARY) { ++ timings[ggml_unary_op_name(ggml_get_unary_op(node))].push_back(time); ++ return; ++ } ++ if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) { ++ const uint64_t m = node->src[0]->ne[1]; ++ const uint64_t n = node->src[1]->ne[1]; ++ const uint64_t k = node->src[1]->ne[0]; ++ std::string name = ggml_op_name(node->op); ++ if (n == 1) { ++ name += "_VEC m=" + std::to_string(m) + " k=" + std::to_string(k); ++ } else { ++ name += " m=" + std::to_string(m) + " n=" + std::to_string(n) + " k=" + std::to_string(k); ++ } ++ timings[name].push_back(time); ++ return; ++ } ++ timings[ggml_op_name(node->op)].push_back(time); ++ } ++private: ++ std::map> timings; ++}; ++#endif // GGML_VULKAN_PERF ++ ++struct ggml_backend_vk_context { ++ std::string name; ++ ++ vk_device device; ++ ++ size_t semaphore_idx, event_idx; ++ ggml_vk_garbage_collector gc; ++ size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k; ++ vk_buffer prealloc_x, prealloc_y, prealloc_split_k; ++ vk::Fence fence; ++ ++ vk_buffer buffer_pool[MAX_VK_BUFFERS]; ++ ++ vk_context_ref compute_ctx; ++ vk_context_ref transfer_ctx; ++ ++ std::vector tensor_ctxs; ++}; ++ ++static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT ++ ++static uint64_t vk_tensor_offset(const ggml_tensor * tensor) { ++ if (tensor->view_src) { ++ return (uint8_t *) tensor->view_src->data - (uint8_t *) vk_ptr_base; ++ } ++ return (uint8_t *) tensor->data - (uint8_t *) vk_ptr_base; ++} ++ ++struct ggml_backend_vk_buffer_context { ++ vk_device_ref device; ++ vk_buffer dev_buffer; ++ std::string name; ++ ++ ggml_backend_vk_buffer_context(vk_device_ref device, vk_buffer&& dev_buffer, std::string& name) : ++ device(device), ++ dev_buffer(dev_buffer), ++ name(name) { ++ } ++ ++ ~ggml_backend_vk_buffer_context() { ++ ggml_vk_destroy_buffer(dev_buffer); ++ } ++}; ++ ++#ifdef GGML_VULKAN_MEMORY_DEBUG ++void vk_memory_logger::log_allocation(vk_buffer_ref buf_ref, size_t size) { ++ std::lock_guard guard(log_mutex); ++ vk_buffer buf = buf_ref.lock(); ++ const bool device = bool(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eDeviceLocal); ++ const std::string type = device ? "device" : "host"; ++ allocations[buf->buffer] = size; ++ total_device += device ? size : 0; ++ total_host += device ? 0 : size; ++ VK_LOG_MEMORY(buf->device->name << ": +" << format_size(size) << " " << type << " at " << buf->buffer << ". Total device: " << format_size(total_device) << ", total host: " << format_size(total_host)); ++} ++ ++void vk_memory_logger::log_deallocation(vk_buffer_ref buf_ref) { ++ if (buf_ref.expired() || buf_ref.lock()->size == 0) { ++ return; ++ } ++ ++ std::lock_guard guard(log_mutex); ++ vk_buffer buf = buf_ref.lock(); ++ const bool device = bool(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eDeviceLocal); ++ std::string type = device ? "device" : "host"; ++ auto it = allocations.find(buf->buffer); ++ total_device -= device ? it->second : 0; ++ total_host -= device ? 0 : it->second; ++ if (it != allocations.end()) { ++ VK_LOG_MEMORY(buf->device->name << ": -" << format_size(it->second) << " " << type << " at " << buf->buffer << ". Total device: " << format_size(total_device) << ", total host: " << format_size(total_host)); ++ allocations.erase(it); ++ } else { ++ VK_LOG_MEMORY("ERROR " << buf->device->name << ": Attempted to deallocate unknown " << type << " memory at " << buf->buffer); ++ } ++} ++#endif // GGML_VULKAN_MEMORY_DEBUG ++ ++struct vk_instance_t { ++ vk::Instance instance; ++ ++ std::vector device_indices; ++ vk_device devices[GGML_VK_MAX_DEVICES]; ++}; ++ ++static bool vk_instance_initialized = false; ++static vk_instance_t vk_instance; ++ ++#ifdef GGML_VULKAN_CHECK_RESULTS ++static size_t vk_skip_checks; ++static size_t vk_output_tensor; ++ ++static void ggml_vk_print_tensor(const ggml_tensor * tensor, const char * name); ++static void ggml_vk_check_results_0(ggml_tensor * tensor); ++static void ggml_vk_check_results_1(ggml_tensor * tensor); ++#endif ++ ++typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); ++ ++static void ggml_backend_vk_free(ggml_backend_t backend); ++ ++// variables to track number of compiles in progress ++static uint32_t compile_count = 0; ++static std::mutex compile_count_mutex; ++static std::condition_variable compile_count_cond; ++ ++static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, const std::string name, size_t spv_size, const void* spv_data, const std::string entrypoint, ++ uint32_t parameter_count, uint32_t push_constant_size, std::array wg_denoms, std::vector specialization_constants, ++ uint32_t align, bool disable_robustness, bool require_full_subgroups, uint32_t required_subgroup_size) { ++ VK_LOG_DEBUG("ggml_vk_create_pipeline(" << device->name << ", " << name << ", " << entrypoint << ", " << parameter_count << ", " << push_constant_size << ++ ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << align << ++ ", " << disable_robustness << ", " << require_full_subgroups << ", " << required_subgroup_size << ")"); ++ GGML_ASSERT(parameter_count > 0); ++ GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT ++ ++ pipeline = std::make_shared(); ++ pipeline->name = name; ++ pipeline->parameter_count = parameter_count; ++ pipeline->push_constant_size = push_constant_size; ++ pipeline->wg_denoms = wg_denoms; ++ pipeline->align = align; ++ ++ vk::ShaderModuleCreateInfo shader_module_create_info({}, spv_size, reinterpret_cast(spv_data)); ++ pipeline->shader_module = device->device.createShaderModule(shader_module_create_info); ++ ++ std::vector dsl_binding; ++ std::vector dsl_binding_flags; ++ for (uint32_t i = 0; i < parameter_count; i++) { ++ dsl_binding.push_back({i, vk::DescriptorType::eStorageBuffer, 1, vk::ShaderStageFlagBits::eCompute}); ++ dsl_binding_flags.push_back({}); ++ } ++ ++ vk::DescriptorSetLayoutBindingFlagsCreateInfo dslbfci = { dsl_binding_flags }; ++ ++ vk::PushConstantRange pcr( ++ vk::ShaderStageFlagBits::eCompute, ++ 0, ++ pipeline->push_constant_size ++ ); ++ ++ vk::DescriptorSetLayoutCreateInfo descriptor_set_layout_create_info( ++ {}, ++ dsl_binding); ++ descriptor_set_layout_create_info.setPNext(&dslbfci); ++ pipeline->dsl = device->device.createDescriptorSetLayout(descriptor_set_layout_create_info); ++ ++ vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline->parameter_count * VK_DEVICE_DESCRIPTOR_POOL_SIZE); ++ vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, VK_DEVICE_DESCRIPTOR_POOL_SIZE, descriptor_pool_size); ++ pipeline->descriptor_pools.push_back(device->device.createDescriptorPool(descriptor_pool_create_info)); ++ ++ pipeline->descriptor_set_idx = 0; ++ ++ vk::PipelineLayoutCreateInfo pipeline_layout_create_info(vk::PipelineLayoutCreateFlags(), pipeline->dsl, pcr); ++ pipeline->layout = device->device.createPipelineLayout(pipeline_layout_create_info); ++ ++ std::vector specialization_entries(specialization_constants.size()); ++ ++ for (size_t i = 0; i < specialization_constants.size(); i++) { ++ specialization_entries[i].constantID = i; ++ specialization_entries[i].offset = i * sizeof(uint32_t); ++ specialization_entries[i].size = sizeof(uint32_t); ++ } ++ ++ vk::SpecializationInfo specialization_info( ++ specialization_entries.size(), ++ specialization_entries.data(), ++ specialization_constants.size() * sizeof(uint32_t), ++ specialization_constants.data() ++ ); ++ ++ vk::PipelineShaderStageCreateFlags pipeline_shader_stage_create_flags{}; ++ ++ if (device->subgroup_require_full_support && require_full_subgroups) { ++ pipeline_shader_stage_create_flags |= vk::PipelineShaderStageCreateFlagBits::eRequireFullSubgroupsEXT; ++ } ++ ++ vk::PipelineShaderStageCreateInfo pipeline_shader_create_info( ++ pipeline_shader_stage_create_flags, ++ vk::ShaderStageFlagBits::eCompute, ++ pipeline->shader_module, ++ entrypoint.c_str(), ++ &specialization_info); ++ ++ vk::PipelineShaderStageRequiredSubgroupSizeCreateInfoEXT pipeline_shader_stage_required_subgroup_size_create_info; ++ pipeline_shader_stage_required_subgroup_size_create_info.requiredSubgroupSize = required_subgroup_size; ++ if (device->subgroup_size_control && required_subgroup_size > 0) { ++ GGML_ASSERT(device->subgroup_min_size <= required_subgroup_size && required_subgroup_size <= device->subgroup_max_size); ++ pipeline_shader_create_info.setPNext(&pipeline_shader_stage_required_subgroup_size_create_info); ++ } ++ ++ vk::ComputePipelineCreateInfo compute_pipeline_create_info( ++ vk::PipelineCreateFlags{}, ++ pipeline_shader_create_info, ++ pipeline->layout); ++ ++ vk::PipelineRobustnessCreateInfoEXT rci; ++ ++ if (device->pipeline_robustness && disable_robustness) { ++ rci.storageBuffers = vk::PipelineRobustnessBufferBehaviorEXT::eDisabled; ++ rci.uniformBuffers = vk::PipelineRobustnessBufferBehaviorEXT::eDisabled; ++ compute_pipeline_create_info.setPNext(&rci); ++ } ++ ++ pipeline->pipeline = device->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value; ++ ++ { ++ std::lock_guard guard(device->mutex); ++ device->pipelines.insert({ pipeline->name, pipeline }); ++ } ++ ++ { ++ std::lock_guard guard(compile_count_mutex); ++ assert(compile_count > 0); ++ compile_count--; ++ ++ // "Progress bar" for shader compiles ++ static uint32_t total_compile_count = 0; ++ if ((total_compile_count++ % 10) == 0) { ++ std::cerr << "."; ++ } ++ } ++ compile_count_cond.notify_all(); ++} ++ ++static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline) { ++ VK_LOG_DEBUG("ggml_pipeline_destroy_pipeline(" << pipeline->name << ")"); ++ for (auto& pool : pipeline->descriptor_pools) { ++ device.destroyDescriptorPool(pool); ++ } ++ pipeline->descriptor_pools.clear(); ++ pipeline->descriptor_sets.clear(); ++ pipeline->descriptor_set_idx = 0; ++ ++ device.destroyDescriptorSetLayout(pipeline->dsl); ++ ++ device.destroyPipelineLayout(pipeline->layout); ++ ++ device.destroyShaderModule(pipeline->shader_module); ++ ++ device.destroyPipeline(pipeline->pipeline); ++} ++ ++static void ggml_pipeline_request_descriptor_sets(vk_device& device, vk_pipeline& pipeline, uint32_t n) { ++ VK_LOG_DEBUG("ggml_pipeline_request_descriptor_sets(" << pipeline->name << ", " << n << ")"); ++ device->pipeline_descriptor_set_requirements[pipeline->name] += n; ++} ++ ++static void ggml_pipeline_allocate_descriptor_sets(vk_device& device) { ++ std::lock_guard guard(device->mutex); ++ ++ for (auto& pair : device->pipeline_descriptor_set_requirements) { ++ vk_pipeline pipeline = device->pipelines.at(pair.first).lock(); ++ const uint64_t n = pair.second; ++ ++ VK_LOG_DEBUG("ggml_pipeline_allocate_descriptor_sets(" << pipeline->name << ", " << n << ")"); ++ ++ if (pipeline->descriptor_sets.size() >= pipeline->descriptor_set_idx + n) { ++ // Enough descriptors are available ++ continue; ++ } ++ ++ uint32_t to_alloc = pipeline->descriptor_set_idx + n - pipeline->descriptor_sets.size(); ++ uint32_t pool_remaining = VK_DEVICE_DESCRIPTOR_POOL_SIZE - pipeline->descriptor_sets.size() % VK_DEVICE_DESCRIPTOR_POOL_SIZE; ++ uint32_t pool_idx = pipeline->descriptor_sets.size() / VK_DEVICE_DESCRIPTOR_POOL_SIZE; ++ ++ while (to_alloc > 0) { ++ const uint32_t alloc_count = std::min(pool_remaining, to_alloc); ++ to_alloc -= alloc_count; ++ pool_remaining = VK_DEVICE_DESCRIPTOR_POOL_SIZE; ++ ++ if (pool_idx >= pipeline->descriptor_pools.size()) { ++ vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline->parameter_count * VK_DEVICE_DESCRIPTOR_POOL_SIZE); ++ vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, VK_DEVICE_DESCRIPTOR_POOL_SIZE, descriptor_pool_size); ++ pipeline->descriptor_pools.push_back(device->device.createDescriptorPool(descriptor_pool_create_info)); ++ } ++ ++ std::vector layouts(alloc_count); ++ for (uint32_t i = 0; i < alloc_count; i++) { ++ layouts[i] = pipeline->dsl; ++ } ++ vk::DescriptorSetAllocateInfo descriptor_set_alloc_info(pipeline->descriptor_pools[pool_idx], alloc_count, layouts.data()); ++ std::vector sets = device->device.allocateDescriptorSets(descriptor_set_alloc_info); ++ pipeline->descriptor_sets.insert(pipeline->descriptor_sets.end(), sets.begin(), sets.end()); ++ ++ pool_idx++; ++ } ++ } ++} ++ ++static void ggml_pipeline_cleanup(vk_pipeline& pipeline) { ++ VK_LOG_DEBUG("ggml_pipeline_cleanup(" << pipeline->name << ")"); ++ pipeline->descriptor_set_idx = 0; ++} ++ ++static vk::CommandBuffer ggml_vk_create_cmd_buffer(vk_device& device, vk_queue& q) { ++ VK_LOG_DEBUG("ggml_vk_create_cmd_buffer()"); ++ std::lock_guard guard(device->mutex); ++ ++ if (q.cmd_buffers.size() > q.cmd_buffer_idx) { ++ // Reuse command buffer ++ return q.cmd_buffers[q.cmd_buffer_idx++]; ++ } ++ ++ vk::CommandBufferAllocateInfo command_buffer_alloc_info( ++ q.pool, ++ vk::CommandBufferLevel::ePrimary, ++ 1); ++ const std::vector cmd_buffers = device->device.allocateCommandBuffers(command_buffer_alloc_info); ++ auto buf = cmd_buffers.front(); ++ ++ q.cmd_buffers.push_back(buf); ++ q.cmd_buffer_idx++; ++ ++ return buf; ++} ++ ++static vk_submission ggml_vk_create_submission(vk_device& device, vk_queue& q, std::vector wait_semaphores, std::vector signal_semaphores) { ++ VK_LOG_DEBUG("ggml_vk_create_submission()"); ++ vk_submission s; ++ s.buffer = ggml_vk_create_cmd_buffer(device, q); ++ s.wait_semaphores = std::move(wait_semaphores); ++ s.signal_semaphores = std::move(signal_semaphores); ++ return s; ++} ++ ++static void ggml_vk_submit(vk_context& ctx, vk::Fence fence) { ++ if (ctx->seqs.empty()) { ++ if (fence) { ++ ctx->q->queue.submit({}, fence); ++ } ++ return; ++ } ++ VK_LOG_DEBUG("ggml_vk_submit(" << ctx << ", " << fence << ")"); ++ ++ std::vector> tl_wait_vals; ++ std::vector> tl_signal_vals; ++ std::vector> tl_wait_semaphores; ++ std::vector> tl_signal_semaphores; ++ std::vector tl_submit_infos; ++ std::vector submit_infos; ++ int idx = -1; ++ std::vector> stage_flags; ++ ++ size_t reserve = 0; ++ ++ for (const auto& sequence : ctx->seqs) { ++ reserve += sequence.size(); ++ } ++ ++ // Pre-reserve vectors to prevent reallocation, which invalidates pointers ++ tl_wait_semaphores.reserve(reserve); ++ tl_wait_vals.reserve(reserve); ++ tl_signal_semaphores.reserve(reserve); ++ tl_signal_vals.reserve(reserve); ++ tl_submit_infos.reserve(reserve); ++ submit_infos.reserve(reserve); ++ stage_flags.reserve(reserve); ++ ++ for (const auto& sequence : ctx->seqs) { ++ for (const auto& submission : sequence) { ++ stage_flags.push_back({}); ++ idx++; ++ tl_wait_vals.push_back({}); ++ tl_wait_semaphores.push_back({}); ++ tl_signal_vals.push_back({}); ++ tl_signal_semaphores.push_back({}); ++ for (size_t i = 0; i < submission.wait_semaphores.size(); i++) { ++ stage_flags[idx].push_back(ctx->q->stage_flags); ++ tl_wait_vals[idx].push_back(submission.wait_semaphores[i].value); ++ tl_wait_semaphores[idx].push_back(submission.wait_semaphores[i].s); ++ } ++ for (size_t i = 0; i < submission.signal_semaphores.size(); i++) { ++ tl_signal_vals[idx].push_back(submission.signal_semaphores[i].value); ++ tl_signal_semaphores[idx].push_back(submission.signal_semaphores[i].s); ++ } ++ tl_submit_infos.push_back({ ++ (uint32_t) submission.wait_semaphores.size(), ++ tl_wait_vals[idx].data(), ++ (uint32_t) submission.signal_semaphores.size(), ++ tl_signal_vals[idx].data(), ++ }); ++ tl_submit_infos[idx].sType = vk::StructureType::eTimelineSemaphoreSubmitInfo; ++ tl_submit_infos[idx].pNext = nullptr; ++ vk::SubmitInfo si{ ++ (uint32_t) submission.wait_semaphores.size(), ++ tl_wait_semaphores[idx].data(), ++ stage_flags[idx].data(), ++ 1, ++ &submission.buffer, ++ (uint32_t) submission.signal_semaphores.size(), ++ tl_signal_semaphores[idx].data(), ++ }; ++ si.setPNext(&tl_submit_infos[idx]); ++ submit_infos.push_back(si); ++ } ++ } ++ ++ ctx->q->queue.submit(submit_infos, fence); ++ ++ ctx->seqs.clear(); ++} ++ ++static uint32_t ggml_vk_find_queue_family_index(std::vector& queue_family_props, const vk::QueueFlags& required, const vk::QueueFlags& avoid, int32_t compute_index, uint32_t min_num_queues) { ++ VK_LOG_DEBUG("ggml_vk_find_queue_family_index()"); ++ const uint32_t qfsize = queue_family_props.size(); ++ ++ // Try with avoid preferences first ++ for (uint32_t i = 0; i < qfsize; i++) { ++ if (queue_family_props[i].queueCount >= min_num_queues && (compute_index < 0 || i != (uint32_t) compute_index) && queue_family_props[i].queueFlags & required && !(queue_family_props[i].queueFlags & avoid)) { ++ return i; ++ } ++ } ++ ++ // Fall back to only required ++ for (size_t i = 0; i < qfsize; i++) { ++ if (queue_family_props[i].queueCount >= min_num_queues && (compute_index < 0 || i != (uint32_t) compute_index) && queue_family_props[i].queueFlags & required) { ++ return i; ++ } ++ } ++ ++ // Fall back to reusing compute queue ++ for (size_t i = 0; i < qfsize; i++) { ++ if (queue_family_props[i].queueCount >= min_num_queues && queue_family_props[i].queueFlags & required) { ++ return i; ++ } ++ } ++ ++ // Fall back to ignoring min_num_queries ++ for (size_t i = 0; i < qfsize; i++) { ++ if (queue_family_props[i].queueFlags & required) { ++ return i; ++ } ++ } ++ ++ // All commands that are allowed on a queue that supports transfer operations are also allowed on a queue that supports either graphics or compute operations. ++ // Thus, if the capabilities of a queue family include VK_QUEUE_GRAPHICS_BIT or VK_QUEUE_COMPUTE_BIT, then reporting the VK_QUEUE_TRANSFER_BIT capability separately for that queue family is optional. ++ if (compute_index >= 0) { ++ return compute_index; ++ } ++ ++ std::cerr << "ggml_vulkan: No suitable queue family index found." << std::endl; ++ ++ for(auto &q_family : queue_family_props) { ++ std::cerr << "Queue number: " + std::to_string(q_family.queueCount) << " flags: " + to_string(q_family.queueFlags) << std::endl; ++ } ++ abort(); ++} ++ ++static void ggml_vk_create_queue(vk_device& device, vk_queue& q, uint32_t queue_family_index, uint32_t queue_index, vk::PipelineStageFlags&& stage_flags, bool transfer_only) { ++ VK_LOG_DEBUG("ggml_vk_create_queue()"); ++ std::lock_guard guard(device->mutex); ++ ++ q.queue_family_index = queue_family_index; ++ q.transfer_only = transfer_only; ++ ++ vk::CommandPoolCreateInfo command_pool_create_info_compute(vk::CommandPoolCreateFlags(VK_COMMAND_POOL_CREATE_TRANSIENT_BIT), queue_family_index); ++ q.pool = device->device.createCommandPool(command_pool_create_info_compute); ++ ++ q.cmd_buffer_idx = 0; ++ ++ q.queue = device->device.getQueue(queue_family_index, queue_index); ++ ++ q.stage_flags = stage_flags; ++} ++ ++static vk_context ggml_vk_create_context(ggml_backend_vk_context * ctx, vk_queue& q) { ++ vk_context result = std::make_shared(); ++ VK_LOG_DEBUG("ggml_vk_create_context(" << result << ")"); ++ ctx->gc.contexts.emplace_back(result); ++ result->q = &q; ++ return result; ++} ++ ++static vk_context ggml_vk_create_temporary_context(vk_queue& q) { ++ vk_context result = std::make_shared(); ++ VK_LOG_DEBUG("ggml_vk_create_temporary_context(" << result << ")"); ++ result->q = &q; ++ return result; ++} ++ ++static vk_semaphore * ggml_vk_create_binary_semaphore(ggml_backend_vk_context * ctx) { ++ VK_LOG_DEBUG("ggml_vk_create_timeline_semaphore()"); ++ vk::SemaphoreTypeCreateInfo tci{ vk::SemaphoreType::eBinary, 0 }; ++ vk::SemaphoreCreateInfo ci{}; ++ ci.setPNext(&tci); ++ vk::Semaphore semaphore = ctx->device->device.createSemaphore(ci); ++ ctx->gc.semaphores.push_back({ semaphore, 0 }); ++ return &ctx->gc.semaphores[ctx->gc.semaphores.size() - 1]; ++} ++ ++static vk_semaphore * ggml_vk_create_timeline_semaphore(ggml_backend_vk_context * ctx) { ++ VK_LOG_DEBUG("ggml_vk_create_timeline_semaphore()"); ++ if (ctx->semaphore_idx >= ctx->gc.tl_semaphores.size()) { ++ vk::SemaphoreTypeCreateInfo tci{ vk::SemaphoreType::eTimeline, 0 }; ++ vk::SemaphoreCreateInfo ci{}; ++ ci.setPNext(&tci); ++ vk::Semaphore semaphore = ctx->device->device.createSemaphore(ci); ++ ctx->gc.tl_semaphores.push_back({ semaphore, 0 }); ++ } ++ return &ctx->gc.tl_semaphores[ctx->semaphore_idx++]; ++} ++ ++static vk::Event ggml_vk_create_event(ggml_backend_vk_context * ctx) { ++ if (ctx->event_idx >= ctx->gc.events.size()) { ++ ctx->gc.events.push_back(ctx->device->device.createEvent({})); ++ } ++ return ctx->gc.events[ctx->event_idx++]; ++} ++ ++static void ggml_vk_queue_cleanup(vk_device& device, vk_queue& q) { ++ VK_LOG_DEBUG("ggml_vk_queue_cleanup()"); ++ std::lock_guard guard(device->mutex); ++ ++ // Requires command buffers to be done ++ device->device.resetCommandPool(q.pool); ++ q.cmd_buffer_idx = 0; ++} ++ ++static uint32_t find_properties(const vk::PhysicalDeviceMemoryProperties* mem_props, vk::MemoryRequirements* mem_req, vk::MemoryPropertyFlags flags) { ++ for (uint32_t i = 0; i < mem_props->memoryTypeCount; ++i) { ++ vk::MemoryType memory_type = mem_props->memoryTypes[i]; ++ if ((mem_req->memoryTypeBits & ((uint64_t)1 << i)) && ++ (flags & memory_type.propertyFlags) == flags && ++ mem_props->memoryHeaps[memory_type.heapIndex].size >= mem_req->size) { ++ return static_cast(i); ++ } ++ } ++ return UINT32_MAX; ++} ++ ++static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, vk::MemoryPropertyFlags req_flags, vk::MemoryPropertyFlags fallback_flags = vk::MemoryPropertyFlags(0)) { ++ VK_LOG_DEBUG("ggml_vk_create_buffer(" << device->name << ", " << size << ", " << to_string(req_flags) << ", " << to_string(fallback_flags) << ")"); ++ if (size > device->max_memory_allocation_size) { ++ throw vk::OutOfDeviceMemoryError("Requested buffer size exceeds device memory allocation limit"); ++ } ++ ++ std::lock_guard guard(device->mutex); ++ ++ vk_buffer buf = std::make_shared(); ++ ++ if (size == 0) { ++ buf->size = 0; ++ return buf; ++ } ++ ++ vk::BufferCreateInfo buffer_create_info{ ++ vk::BufferCreateFlags(), ++ size, ++ vk::BufferUsageFlagBits::eStorageBuffer | vk::BufferUsageFlagBits::eTransferSrc | vk::BufferUsageFlagBits::eTransferDst, ++ vk::SharingMode::eExclusive, ++ 0, ++ nullptr, ++ }; ++ ++ buf->buffer = device->device.createBuffer(buffer_create_info); ++ ++ vk::MemoryRequirements mem_req = device->device.getBufferMemoryRequirements(buf->buffer); ++ ++ vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties(); ++ ++ uint32_t memory_type_index = UINT32_MAX; ++ ++ memory_type_index = find_properties(&mem_props, &mem_req, req_flags); ++ buf->memory_property_flags = req_flags; ++ ++ if (memory_type_index == UINT32_MAX && fallback_flags) { ++ memory_type_index = find_properties(&mem_props, &mem_req, fallback_flags); ++ buf->memory_property_flags = fallback_flags; ++ } ++ ++ if (memory_type_index == UINT32_MAX) { ++ device->device.destroyBuffer(buf->buffer); ++ throw vk::OutOfDeviceMemoryError("No suitable memory type found"); ++ } ++ ++ try { ++ buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index }); ++ } catch (const vk::SystemError& e) { ++ if (buf->memory_property_flags != fallback_flags) { ++ // Try again with fallback flags ++ memory_type_index = find_properties(&mem_props, &mem_req, fallback_flags); ++ buf->memory_property_flags = fallback_flags; ++ ++ try { ++ buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index }); ++ } ++ catch (const vk::SystemError& e) { ++ device->device.destroyBuffer(buf->buffer); ++ throw e; ++ } ++ } else { ++ // Out of Host/Device memory, clean up buffer ++ device->device.destroyBuffer(buf->buffer); ++ throw e; ++ } ++ } ++ buf->ptr = nullptr; ++ ++ if (buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) { ++ buf->ptr = device->device.mapMemory(buf->device_memory, 0, VK_WHOLE_SIZE); ++ } ++ ++ device->device.bindBufferMemory(buf->buffer, buf->device_memory, 0); ++ ++ buf->device = device; ++ buf->size = size; ++ ++#ifdef GGML_VULKAN_MEMORY_DEBUG ++ device->memory_logger->log_allocation(buf, size); ++#endif ++ ++ return buf; ++} ++ ++static vk_buffer ggml_vk_create_buffer_check(vk_device& device, size_t size, vk::MemoryPropertyFlags req_flags, vk::MemoryPropertyFlags fallback_flags = vk::MemoryPropertyFlags(0)) { ++ try { ++ return ggml_vk_create_buffer(device, size, req_flags, fallback_flags); ++ } catch (const vk::SystemError& e) { ++ std::cerr << "ggml_vulkan: Memory allocation of size " << size << " failed." << std::endl; ++ std::cerr << "ggml_vulkan: " << e.what() << std::endl; ++ throw e; ++ } ++} ++ ++static vk_buffer ggml_vk_create_buffer_device(vk_device& device, size_t size) { ++ vk_buffer buf; ++ try { ++ if (device->uma) { ++ // Fall back to host memory type ++ buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent); ++ } else { ++ // use rebar if available, otherwise fallback to device only visible memory ++ buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal | vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent, vk::MemoryPropertyFlagBits::eDeviceLocal); ++ } ++ } catch (const vk::SystemError& e) { ++ std::cerr << "ggml_vulkan: Device memory allocation of size " << size << " failed." << std::endl; ++ std::cerr << "ggml_vulkan: " << e.what() << std::endl; ++ throw e; ++ } ++ ++ return buf; ++} ++ ++static void ggml_vk_destroy_buffer(vk_buffer& buf) { ++ if (buf == nullptr) { ++ return; ++ } ++ ++#ifdef GGML_VULKAN_MEMORY_DEBUG ++ if (buf->device != nullptr) { ++ buf->device->memory_logger->log_deallocation(buf); ++ } ++#endif ++ ++ buf.reset(); ++} ++ ++static vk_subbuffer ggml_vk_subbuffer(vk_buffer& buf) { ++ return { buf, 0, VK_WHOLE_SIZE }; ++} ++ ++static void ggml_vk_sync_buffers(vk_context& ctx) { ++ VK_LOG_DEBUG("ggml_vk_sync_buffers()"); ++ ++ const bool transfer_queue = ctx->q->transfer_only; ++ ++ ctx->s->buffer.pipelineBarrier( ++ ctx->q->stage_flags, ++ ctx->q->stage_flags, ++ {}, ++ { { ++ { !transfer_queue ? (vk::AccessFlagBits::eShaderRead | vk::AccessFlagBits::eShaderWrite | vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) : (vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) }, ++ { !transfer_queue ? (vk::AccessFlagBits::eShaderRead | vk::AccessFlagBits::eShaderWrite | vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) : (vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) } ++ } }, ++ {}, ++ {} ++ ); ++} ++ ++static void ggml_vk_wait_events(vk_context& ctx, std::vector&& events) { ++ VK_LOG_DEBUG("ggml_vk_wait_events()"); ++ if (events.empty()) { ++ return; ++ } ++ ++ ctx->s->buffer.waitEvents( ++ events, ++ ctx->q->stage_flags, ++ ctx->q->stage_flags, ++ {}, ++ {}, ++ {} ++ ); ++} ++ ++// number of rows/cols for flash attention shader ++static constexpr uint32_t flash_attention_num_small_rows = 32; ++static std::array fa_rows_cols(uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) { ++ GGML_UNUSED(clamp); ++ ++ // small rows, large cols ++ if (small_rows) { ++ return {flash_attention_num_small_rows, 128}; ++ } ++ // small cols to reduce register count ++ if (ggml_is_quantized(type) || D == 256) { ++ return {64, 32}; ++ } ++ return {64, 64}; ++}; ++ ++static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector& warptile, bool mul_mat_id) { ++ // Needs to be kept up to date on shader changes ++ const uint32_t bank_conflict_offset = device->coopmat_support ? 8 : 1; ++ const uint32_t type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float); ++ const uint32_t warps = warptile[0] / warptile[10]; ++ ++ const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size; ++ const uint32_t mmid_row_ids = mul_mat_id ? 3072 * sizeof(uint32_t) : 0; ++ const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0; ++ ++ return (load_bufs + mmid_row_ids + coopmat_stage) <= device->properties.limits.maxComputeSharedMemorySize; ++} ++ ++static void ggml_vk_load_shaders(vk_device& device) { ++ VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")"); ++ ++ std::cerr << "ggml_vulkan: Compiling shaders"; ++ ++ // some shaders have a minimum subgroup size ++ const uint32_t subgroup_size_16 = std::max(device->subgroup_size, 16u); ++ const uint32_t subgroup_size_32 = std::max(device->subgroup_size, 32u); ++ ++ // mulmat ++ std::vector l_warptile, m_warptile, s_warptile, ++ l_warptile_mmq, m_warptile_mmq, s_warptile_mmq, ++ l_warptile_mmq_k, m_warptile_mmq_k, s_warptile_mmq_k, ++ l_warptile_mmqid, m_warptile_mmqid, s_warptile_mmqid; ++ std::array l_wg_denoms, m_wg_denoms, s_wg_denoms, ++ l_mmq_wg_denoms, m_mmq_wg_denoms, s_mmq_wg_denoms, ++ l_mmq_wg_denoms_k, m_mmq_wg_denoms_k, s_mmq_wg_denoms_k, ++ l_mmqid_wg_denoms, m_mmqid_wg_denoms, s_mmqid_wg_denoms; ++ ++ uint32_t l_align, m_align, s_align; ++ if (device->coopmat2) { ++ // spec constants and tile sizes for non-quant matmul/matmul_id ++ l_warptile = { 256, 128, 256, 64 }; ++ m_warptile = { 256, 128, 128, 64 }; ++ s_warptile = { 128, 64, 64, 64 }; ++ l_wg_denoms = {128, 256, 1 }; ++ m_wg_denoms = {128, 128, 1 }; ++ s_wg_denoms = { 64, 64, 1 }; ++ ++ // spec constants and tile sizes for quant matmul (non-Qi_K) ++ l_warptile_mmq = { 256, 128, 256, 64 }; ++ m_warptile_mmq = { 256, 128, 128, 64 }; ++ s_warptile_mmq = { 256, 128, 128, 64 }; ++ l_mmq_wg_denoms = { 128, 256, 1 }; ++ m_mmq_wg_denoms = { 128, 128, 1 }; ++ s_mmq_wg_denoms = { 128, 128, 1 }; ++ ++ // spec constants and tile sizes for quant matmul (Qi_K) ++ l_warptile_mmq_k = { 256, 128, 512, 16 }; ++ m_warptile_mmq_k = { 256, 128, 256, 16 }; ++ s_warptile_mmq_k = { 256, 32, 128, 64 }; ++ l_mmq_wg_denoms_k = { 128, 512, 1 }; ++ m_mmq_wg_denoms_k = { 128, 256, 1 }; ++ s_mmq_wg_denoms_k = { 32, 128, 1 }; ++ ++ // spec constants and tile sizes for quant matmul_id ++ l_warptile_mmqid = { 256, 128, 128, 16 }; ++ m_warptile_mmqid = { 256, 128, 64, 16 }; ++ s_warptile_mmqid = { 256, 64, 64, 16 }; ++ l_mmqid_wg_denoms = { 128, 128, 1 }; ++ m_mmqid_wg_denoms = { 128, 64, 1 }; ++ s_mmqid_wg_denoms = { 64, 64, 1 }; ++ ++ l_align = 128; ++ m_align = 64; ++ s_align = 32; ++ } else { ++ // Matrix cores require different warp group sizes ++ const uint32_t tm_l = device->coopmat_support ? device->coopmat_m : 4; ++ const uint32_t tm_m = device->coopmat_support ? device->coopmat_m : 4; ++ const uint32_t tm_s = device->coopmat_support ? device->coopmat_m : 2; ++ const uint32_t tn_l = device->coopmat_support ? device->coopmat_n : 4; ++ const uint32_t tn_m = device->coopmat_support ? device->coopmat_n : 2; ++ const uint32_t tn_s = device->coopmat_support ? device->coopmat_n : 2; ++ const uint32_t tk_l = device->coopmat_support ? device->coopmat_k : 1; ++ const uint32_t tk_m = device->coopmat_support ? device->coopmat_k : 1; ++ const uint32_t tk_s = device->coopmat_support ? device->coopmat_k : 1; ++ ++ l_warptile = { 128, 128, 128, 16, device->subgroup_size * 2, 64, 2, tm_l, tn_l, tk_l, device->subgroup_size }; ++ m_warptile = { 128, 64, 64, 16, device->subgroup_size, 32, 2, tm_m, tn_m, tk_m, device->subgroup_size }; ++ s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, device->subgroup_size }; ++ ++ l_warptile_mmq = { 128, 128, 128, 32, device->subgroup_size * 2, 64, 2, tm_l, tn_l, tk_l, device->subgroup_size }; ++ m_warptile_mmq = { 128, 64, 64, 32, device->subgroup_size, 32, 2, tm_m, tn_m, tk_m, device->subgroup_size }; ++ s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, device->subgroup_size }; ++ ++ l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 }; ++ m_mmq_wg_denoms = m_wg_denoms = { 64, 64, 1 }; ++ s_mmq_wg_denoms = s_wg_denoms = { 32, 32, 1 }; ++ l_align = 128; ++ m_align = 64; ++ s_align = 32; ++ ++ // Fallback to smaller sizes if there's not enough shared memory. Given the current shaders ++ // and tile sizes, this should handle 16KB, 32KB, and 48KB+. ++ // This logic doesn't explicitly account for the 12KB row_ids in the mul_mat_mat_id shaders. ++ // But the numbers happen to work out for 32KB shared memory size that when using the medium ++ // size there's enough room for everything, and we assert for this. ++ uint32_t shmem_needed = (l_warptile[1] + l_warptile[2]) * (l_warptile[3] + 1) * sizeof(float); ++ if (shmem_needed > device->properties.limits.maxComputeSharedMemorySize) { ++ l_warptile = m_warptile; ++ l_wg_denoms = m_wg_denoms; ++ shmem_needed = (l_warptile[1] + l_warptile[2]) * (l_warptile[3] + 1) * sizeof(float); ++ GGML_ASSERT(shmem_needed <= device->properties.limits.maxComputeSharedMemorySize); ++ } ++ if (device->properties.limits.maxComputeSharedMemorySize >= 32768) { ++ // assert mul_mat_mat_id shaders will fit. ++ GGML_ASSERT(shmem_needed + 3072*4 <= device->properties.limits.maxComputeSharedMemorySize); ++ } ++ ++ shmem_needed = (l_warptile_mmq[1] + l_warptile_mmq[2]) * (l_warptile_mmq[3] + 1) * sizeof(float); ++ if (shmem_needed > device->properties.limits.maxComputeSharedMemorySize) { ++ if (device->properties.limits.maxComputeSharedMemorySize == 32768) { ++ l_warptile_mmq = m_warptile_mmq; ++ l_mmq_wg_denoms = m_mmq_wg_denoms; ++ } else { ++ l_warptile_mmq = s_warptile_mmq; ++ l_mmq_wg_denoms = s_mmq_wg_denoms; ++ } ++ shmem_needed = (l_warptile_mmq[1] + l_warptile_mmq[2]) * (l_warptile_mmq[3] + 1) * sizeof(float); ++ GGML_ASSERT(shmem_needed <= device->properties.limits.maxComputeSharedMemorySize); ++ } ++ if (device->properties.limits.maxComputeSharedMemorySize >= 32768) { ++ // assert mul_mat_mat_id shaders will fit. ++ GGML_ASSERT(shmem_needed + 3072*4 <= device->properties.limits.maxComputeSharedMemorySize); ++ } ++ // Disable medium and large matrix multiplication if not enough shared memory is available ++ // Check mmq warptiles as the largest configuration ++ // Throw an error if not enough for any matrix multiplication is available ++ if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, false)) { ++ std::cerr << "ggml_vulkan: Error: Shared memory size too small for matrix multiplication." << std::endl; ++ throw std::runtime_error("Shared memory size too small for matrix multiplication."); ++ } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, false)) { ++ device->mul_mat_m = false; ++ device->mul_mat_l = false; ++ } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, false)) { ++ device->mul_mat_l = false; ++ } ++ ++ // Disable mul_mat_id if not enough shared memory is available ++ if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, true)) { ++ device->mul_mat_id_s = false; ++ device->mul_mat_id_m = false; ++ device->mul_mat_id_l = false; ++ } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, true)) { ++ device->mul_mat_id_m = false; ++ device->mul_mat_id_l = false; ++ } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, true)) { ++ device->mul_mat_id_l = false; ++ } ++ } ++ ++ device->pipeline_matmul_f32 = std::make_shared(); ++ device->pipeline_matmul_f32_f16 = std::make_shared(); ++ ++ device->pipeline_matmul_id_f32 = std::make_shared(); ++ ++ std::vector> compiles; ++ auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint, ++ uint32_t parameter_count, uint32_t push_constant_size, std::array wg_denoms, const std::vector& specialization_constants, ++ uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) { ++ { ++ // wait until fewer than N compiles are in progress ++ uint32_t N = std::max(1u, std::thread::hardware_concurrency()); ++ std::unique_lock guard(compile_count_mutex); ++ while (compile_count >= N) { ++ compile_count_cond.wait(guard); ++ } ++ compile_count++; ++ } ++ compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), name, spv_size, spv_data, entrypoint, ++ parameter_count, push_constant_size, wg_denoms, specialization_constants, align, disable_robustness, require_full_subgroups, required_subgroup_size)); ++ }; ++ ++#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) ++ if (device->coopmat2) { ++ ++ auto const &fa_wg_denoms = [&](uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array { ++ return {fa_rows_cols(D, clamp, type, small_rows)[0], 1, 1}; ++ }; ++ ++ auto const &fa_spec_constants = [&](uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector { ++ // For large number of rows, 128 invocations seems to work best. ++ // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we ++ // can't use 256 for D==80. ++ uint32_t wg_size = (small_rows && (D % 32) == 0) ? 256 : 128; ++ auto rows_cols = fa_rows_cols(D, clamp, type, small_rows); ++ return {wg_size, rows_cols[0], rows_cols[1], (D), clamp}; ++ }; ++ ++#define CREATE_FA2(TYPE, NAMELC, D) \ ++ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,false), fa_spec_constants(D,1,TYPE,false), 1); \ ++ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,false), fa_spec_constants(D,0,TYPE,false), fa_rows_cols(D,0,TYPE,false)[1]); \ ++ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,false), fa_spec_constants(D,1,TYPE,false), 1); \ ++ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,false), fa_spec_constants(D,0,TYPE,false), fa_rows_cols(D,0,TYPE,false)[1]); \ ++ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,true), fa_spec_constants(D,1,TYPE,true), 1); \ ++ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,true), fa_spec_constants(D,0,TYPE,true), fa_rows_cols(D,0,TYPE,true)[1]); \ ++ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,true), fa_spec_constants(D,1,TYPE,true), 1); \ ++ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,true), fa_spec_constants(D,0,TYPE,true), fa_rows_cols(D,0,TYPE,true)[1]); \ ++ ++#define CREATE_FA(TYPE, NAMELC) \ ++ CREATE_FA2(TYPE, NAMELC, 64) \ ++ CREATE_FA2(TYPE, NAMELC, 80) \ ++ CREATE_FA2(TYPE, NAMELC, 96) \ ++ CREATE_FA2(TYPE, NAMELC, 112) \ ++ CREATE_FA2(TYPE, NAMELC, 128) \ ++ CREATE_FA2(TYPE, NAMELC, 256) ++ ++ CREATE_FA(GGML_TYPE_F16, f16) ++ CREATE_FA(GGML_TYPE_Q4_0, q4_0) ++ CREATE_FA(GGML_TYPE_Q4_1, q4_1) ++ CREATE_FA(GGML_TYPE_Q5_0, q5_0) ++ CREATE_FA(GGML_TYPE_Q5_1, q5_1) ++ CREATE_FA(GGML_TYPE_Q8_0, q8_0) ++ // K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently ++ //CREATE_FA(GGML_TYPE_Q2_K, q2_k) ++ //CREATE_FA(GGML_TYPE_Q3_K, q3_k) ++ //CREATE_FA(GGML_TYPE_Q4_K, q4_k) ++ //CREATE_FA(GGML_TYPE_Q5_K, q5_k) ++ //CREATE_FA(GGML_TYPE_Q6_K, q6_k) ++ CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl) ++#undef CREATE_FA ++ ++ // Create 6 variants, {s,m,l}x{unaligned,aligned} ++#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ ++ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ ++ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ ++ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ ++ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \ ++ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \ ++ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \ ++ ++ // Create 2 variants, {f16,f32} accumulator ++#define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ ++ CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ ++ CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ ++ ++ CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3) ++ CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3) ++ ++ CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3) ++ CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3) ++ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) ++ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) ++ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) ++ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) ++ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) ++ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) ++ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) ++ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) ++ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) ++ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) ++ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) ++ ++ CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4) ++ CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4) ++ CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4) ++ ++ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) ++ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) ++ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) ++ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) ++ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) ++ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) ++ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) ++ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) ++ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) ++ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) ++ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) ++#undef CREATE_MM ++#undef CREATE_MM2 ++ } else ++#endif // defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) ++ if (device->coopmat_support) { ++ // Create 6 variants, {s,m,l}x{unaligned,aligned} ++#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ ++ if (device->mul_mat ## ID ## _l) \ ++ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \ ++ if (device->mul_mat ## ID ## _m) \ ++ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \ ++ if (device->mul_mat ## ID ## _s) \ ++ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \ ++ if (device->mul_mat ## ID ## _l) \ ++ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \ ++ if (device->mul_mat ## ID ## _m) \ ++ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \ ++ if (device->mul_mat ## ID ## _s) \ ++ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \ ++ ++ // Create 2 variants, {f16,f32} accumulator ++#define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ ++ if (device->coopmat_acc_f16_support) { \ ++ CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ ++ } \ ++ if (device->coopmat_acc_f32_support) { \ ++ CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ ++ } \ ++ ++ CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); ++ CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); ++ CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); ++ CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); ++ ++ if (device->coopmat_acc_f16_support) { ++ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); ++ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); ++ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); ++ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); ++ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); ++ ++ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); ++ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); ++ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); ++ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); ++ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); ++ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); ++ } else { ++ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); ++ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); ++ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); ++ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); ++ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); ++ ++ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); ++ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); ++ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); ++ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); ++ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); ++ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); ++ } ++ ++ // If there's not enough shared memory for row_ids and the result tile, don't create these pipelines. ++ if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) { ++ CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); ++ CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); ++ CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); ++ ++ if (device->coopmat_acc_f16_support) { ++ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); ++ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); ++ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); ++ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); ++ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); ++ ++ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); ++ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); ++ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); ++ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); ++ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); ++ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); ++ } else { ++ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); ++ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); ++ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); ++ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); ++ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); ++ ++ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); ++ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); ++ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); ++ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); ++ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); ++ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); ++ } ++ } ++#undef CREATE_MM2 ++#undef CREATE_MM ++ } else if (device->fp16) { ++ // Create 6 variants, {s,m,l}x{unaligned,aligned} ++#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ ++ if (device->mul_mat ## ID ## _l) \ ++ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ ++ if (device->mul_mat ## ID ## _m) \ ++ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ ++ if (device->mul_mat ## ID ## _s) \ ++ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ ++ if (device->mul_mat ## ID ## _l) \ ++ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \ ++ if (device->mul_mat ## ID ## _m) \ ++ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \ ++ if (device->mul_mat ## ID ## _s) \ ++ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \ ++ ++ // Create 2 variants, {f16,f32} accumulator ++#define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ ++ CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ ++ CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ ++ ++ CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); ++ CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); ++ CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); ++ CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); ++ ++ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); ++ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); ++ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); ++ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); ++ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); ++ ++ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); ++ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); ++ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); ++ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); ++ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); ++ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); ++ ++ // If there's not enough shared memory for row_ids and the result tile, don't create these pipelines. ++ if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) { ++ CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); ++ CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); ++ CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); ++ ++ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); ++ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); ++ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); ++ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); ++ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); ++ ++ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); ++ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); ++ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); ++ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); ++ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); ++ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); ++ } ++#undef CREATE_MM2 ++#undef CREATE_MM ++ } else { ++ // Create 6 variants, {s,m,l}x{unaligned,aligned} ++#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ ++ if (device->mul_mat ## ID ## _l) \ ++ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ ++ if (device->mul_mat ## ID ## _m) \ ++ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ ++ if (device->mul_mat ## ID ## _s) \ ++ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ ++ if (device->mul_mat ## ID ## _l) \ ++ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \ ++ if (device->mul_mat ## ID ## _m) \ ++ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \ ++ if (device->mul_mat ## ID ## _s) \ ++ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \ ++ ++ CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); ++ CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); ++ CREATE_MM(pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); ++ CREATE_MM(pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); ++ ++ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); ++ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); ++ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); ++ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); ++ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); ++ ++ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); ++ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); ++ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); ++ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); ++ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); ++ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); ++ ++ // If there's not enough shared memory for row_ids and the result tile, don't create these pipelines. ++ if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) { ++ CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); ++ CREATE_MM(pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); ++ CREATE_MM(pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); ++ ++ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); ++ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); ++ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); ++ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); ++ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); ++ ++ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); ++ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); ++ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); ++ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); ++ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); ++ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); ++ } ++#undef CREATE_MM ++ } ++ ++ // mul mat vec ++ ++ // the number of rows computed per shader depends on GPU model and quant ++ uint32_t rm_stdq = 1; ++ uint32_t rm_kq = 2; ++ if (device->vendor_id == VK_VENDOR_ID_AMD) { ++ if (device->subgroup_min_size == 64 && device->subgroup_max_size == 64) { // GCN ++ rm_stdq = 2; ++ rm_kq = 4; ++ } ++ } else if (device->vendor_id == VK_VENDOR_ID_INTEL) ++ rm_stdq = 2; ++ ++ for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) { ++ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32_"+std::to_string(i+1), mul_mat_vec_f32_f32_f32_len, mul_mat_vec_f32_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f32_f32_"+std::to_string(i+1), mul_mat_vec_f16_f32_f32_len, mul_mat_vec_f16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); ++ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); ++ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); ++ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_1_f32_f32_len, mul_mat_vec_q5_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); ++ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q8_0_f32_f32_len, mul_mat_vec_q8_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq, i+1}, 1, true); ++ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q2_k_f32_f32_len, mul_mat_vec_q2_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); ++ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q3_k_f32_f32_len, mul_mat_vec_q3_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); ++ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); ++ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); ++ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); ++ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq, i+1}, 1, true); ++ ++ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32_"+std::to_string(i+1), mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32_"+std::to_string(i+1), mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); ++ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); ++ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); ++ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_1_f16_f32_len, mul_mat_vec_q5_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); ++ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q8_0_f16_f32_len, mul_mat_vec_q8_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq, i+1}, 1, true); ++ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q2_k_f16_f32_len, mul_mat_vec_q2_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); ++ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q3_k_f16_f32_len, mul_mat_vec_q3_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); ++ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); ++ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); ++ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); ++ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq, i+1}, 1, true); ++ } ++ ++ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); ++ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); ++ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); ++ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32", mul_mat_vec_id_q5_1_f32_len, mul_mat_vec_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); ++ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32", mul_mat_vec_id_q8_0_f32_len, mul_mat_vec_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq}, 1, true); ++ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", mul_mat_vec_id_q2_k_f32_len, mul_mat_vec_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); ++ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", mul_mat_vec_id_q3_k_f32_len, mul_mat_vec_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); ++ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); ++ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); ++ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); ++ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq}, 1, true); ++ ++ // dequant shaders ++ ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_0], "dequant_q4_0", dequant_q4_0_len, dequant_q4_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_1], "dequant_q4_1", dequant_q4_1_len, dequant_q4_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_0], "dequant_q5_0", dequant_q5_0_len, dequant_q5_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_1], "dequant_q5_1", dequant_q5_1_len, dequant_q5_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q8_0], "dequant_q8_0", dequant_q8_0_len, dequant_q8_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q2_K], "dequant_q2_k", dequant_q2_k_len, dequant_q2_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q3_K], "dequant_q3_k", dequant_q3_k_len, dequant_q3_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_K], "dequant_q4_k", dequant_q4_k_len, dequant_q4_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_K], "dequant_q5_k", dequant_q5_k_len, dequant_q5_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q6_K], "dequant_q6_k", dequant_q6_k_len, dequant_q6_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_NL], "dequant_iq4_nl", dequant_iq4_nl_len, dequant_iq4_nl_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); ++ ++ // get_rows ++ ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F16 ], "get_rows_f16", get_rows_f16_len, get_rows_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_0], "get_rows_q4_0", get_rows_q4_0_len, get_rows_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_1], "get_rows_q4_1", get_rows_q4_1_len, get_rows_q4_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_0], "get_rows_q5_0", get_rows_q5_0_len, get_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_1], "get_rows_q5_1", get_rows_q5_1_len, get_rows_q5_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q8_0], "get_rows_q8_0", get_rows_q8_0_len, get_rows_q8_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl", get_rows_iq4_nl_len, get_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ++ ++ ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_0], "get_rows_q4_0_f32", get_rows_q4_0_f32_len, get_rows_q4_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_1], "get_rows_q4_1_f32", get_rows_q4_1_f32_len, get_rows_q4_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_0], "get_rows_q5_0_f32", get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_1], "get_rows_q5_1_f32", get_rows_q5_1_f32_len, get_rows_q5_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q8_0], "get_rows_q8_0_f32", get_rows_q8_0_f32_len, get_rows_q8_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ++ ++ ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1); ++ ++ ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32, "mul_mat_vec_p021_f16_f32", mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 7 * sizeof(uint32_t), {1, 1, 1}, {}, 1); ++ ++ ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ++ ++ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ++ ++ ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f32, "contig_cpy_f32_f32", contig_cpy_f32_f32_len, contig_cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ++ ++ ggml_vk_create_pipeline(device, device->pipeline_add_f32, "add_f32", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_add_f32_norepeat, "add_f32_norepeat", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16, "add_f16_f32_f16", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16_norepeat, "add_f16_f32_f16_norepeat", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1); ++ ++ ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); ++ ++ ggml_vk_create_pipeline(device, device->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_mul_f32_norepeat, "mul_f32_norepeat", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_div_f32, "div_f32", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_div_f32_norepeat, "div_f32_norepeat", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1); ++ ++ ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_concat_i32, "concat_i32", concat_i32_len, concat_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); ++ ++ ggml_vk_create_pipeline(device, device->pipeline_upscale_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {}, 1); ++ ++ ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ++ ++ ggml_vk_create_pipeline(device, device->pipeline_sqr_f32, "sqr_f32", sqr_f32_len, sqr_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_sin_f32, "sin_f32", sin_f32_len, sin_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_cos_f32, "cos_f32", cos_f32_len, cos_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ++ ++ ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ++ ++ ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ++ ++ ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32", repeat_f32_len, repeat_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ++ ++ ggml_vk_create_pipeline(device, device->pipeline_gelu_f32, "gelu_f32", gelu_f32_len, gelu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_gelu_quick_f32, "gelu_quick_f32", gelu_quick_f32_len, gelu_quick_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_silu_f32, "silu_f32", silu_f32_len, silu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_relu_f32, "relu_f32", relu_f32_len, relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_tanh_f32, "tanh_f32", tanh_f32_len, tanh_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); ++ ++ ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {512, 1, 1}, {}, 1); ++ ++ ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1); ++ ++ ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ++ ++ if (device->float_controls_rte_fp16) { ++ ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_rte_len, rope_norm_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_rte_len, rope_neox_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ++ } else { ++ ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ++ ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ++ } ++ ++ ggml_vk_create_pipeline(device, device->pipeline_argsort_f32, "argsort_f32", argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1024, 1, 1}, {}, 1); ++ ++ ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); ++ ++ ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); ++ if (device->float_controls_rte_fp16) { ++ ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); ++ } else { ++ ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); ++ } ++ ++ ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1); ++ ++ ggml_vk_create_pipeline(device, device->pipeline_pool2d_f32, "pool2d_f32", pool2d_f32_len, pool2d_f32_data, "main", 2, sizeof(vk_op_pool2d_push_constants), {512, 1, 1}, {}, 1); ++ ++ ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); ++ ++ for (auto &c : compiles) { ++ c.wait(); ++ } ++ std::cerr << "Done!" << std::endl; ++} ++ ++static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props); ++ ++static vk_device ggml_vk_get_device(size_t idx) { ++ VK_LOG_DEBUG("ggml_vk_get_device(" << idx << ")"); ++ ++ if (vk_instance.devices[idx] == nullptr) { ++ VK_LOG_DEBUG("Initializing new vk_device"); ++ vk_device device = std::make_shared(); ++ vk_instance.devices[idx] = device; ++ ++#ifdef GGML_VULKAN_MEMORY_DEBUG ++ device->memory_logger = std::unique_ptr(new vk_memory_logger()); ++#endif ++#ifdef GGML_VULKAN_PERF ++ device->perf_logger = std::unique_ptr(new vk_perf_logger()); ++#endif ++ ++ size_t dev_num = vk_instance.device_indices[idx]; ++ ++ std::vector physical_devices = vk_instance.instance.enumeratePhysicalDevices(); ++ ++ if (dev_num >= physical_devices.size()) { ++ std::cerr << "ggml_vulkan: Device with index " << dev_num << " does not exist." << std::endl; ++ throw std::runtime_error("Device not found"); ++ } ++ ++ device->physical_device = physical_devices[dev_num]; ++ const std::vector ext_props = device->physical_device.enumerateDeviceExtensionProperties(); ++ ++ bool fp16_storage = false; ++ bool fp16_compute = false; ++ bool maintenance4_support = false; ++ bool sm_builtins = false; ++ bool amd_shader_core_properties2 = false; ++ bool pipeline_robustness = false; ++ bool coopmat2_support = false; ++ device->coopmat_support = false; ++ ++ // Check if maintenance4 is supported ++ for (const auto& properties : ext_props) { ++ if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) { ++ maintenance4_support = true; ++ } else if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) { ++ fp16_storage = true; ++ } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) { ++ fp16_compute = true; ++ } else if (strcmp("VK_NV_shader_sm_builtins", properties.extensionName) == 0) { ++ sm_builtins = true; ++ } else if (strcmp("VK_AMD_shader_core_properties2", properties.extensionName) == 0) { ++ amd_shader_core_properties2 = true; ++ } else if (strcmp("VK_EXT_pipeline_robustness", properties.extensionName) == 0) { ++ pipeline_robustness = true; ++ } else if (strcmp("VK_EXT_subgroup_size_control", properties.extensionName) == 0) { ++ device->subgroup_size_control = true; ++ } else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0 && ++ !getenv("GGML_VK_DISABLE_COOPMAT")) { ++ device->coopmat_support = true; ++ device->coopmat_m = 0; ++ device->coopmat_n = 0; ++ device->coopmat_k = 0; ++ } else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 && ++ !getenv("GGML_VK_DISABLE_COOPMAT2")) { ++ coopmat2_support = true; ++ } ++ } ++ ++ vk::PhysicalDeviceProperties2 props2; ++ vk::PhysicalDeviceMaintenance3Properties props3; ++ vk::PhysicalDeviceMaintenance4Properties props4; ++ vk::PhysicalDeviceSubgroupProperties subgroup_props; ++ vk::PhysicalDeviceDriverProperties driver_props; ++ vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props; ++ vk::PhysicalDeviceShaderCoreProperties2AMD amd_shader_core_properties2_props; ++ vk::PhysicalDeviceVulkan12Properties vk12_props; ++ vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props; ++ ++ props2.pNext = &props3; ++ props3.pNext = &subgroup_props; ++ subgroup_props.pNext = &driver_props; ++ driver_props.pNext = &vk12_props; ++ ++ VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_props; ++ ++ if (maintenance4_support) { ++ last_struct->pNext = (VkBaseOutStructure *)&props4; ++ last_struct = (VkBaseOutStructure *)&props4; ++ } ++ if (sm_builtins) { ++ last_struct->pNext = (VkBaseOutStructure *)&sm_props; ++ last_struct = (VkBaseOutStructure *)&sm_props; ++ } ++ if (amd_shader_core_properties2) { ++ last_struct->pNext = (VkBaseOutStructure *)&amd_shader_core_properties2_props; ++ last_struct = (VkBaseOutStructure *)&amd_shader_core_properties2_props; ++ } ++ if (device->subgroup_size_control) { ++ last_struct->pNext = (VkBaseOutStructure *)&subgroup_size_control_props; ++ last_struct = (VkBaseOutStructure *)&subgroup_size_control_props; ++ } ++ ++#if defined(VK_NV_cooperative_matrix2) ++ vk::PhysicalDeviceCooperativeMatrix2PropertiesNV coopmat2_props; ++ if (coopmat2_support) { ++ last_struct->pNext = (VkBaseOutStructure *)&coopmat2_props; ++ last_struct = (VkBaseOutStructure *)&coopmat2_props; ++ } ++#endif ++ ++ device->physical_device.getProperties2(&props2); ++ device->properties = props2.properties; ++ ++ const char* GGML_VK_FORCE_MAX_ALLOCATION_SIZE = getenv("GGML_VK_FORCE_MAX_ALLOCATION_SIZE"); ++ ++ if (GGML_VK_FORCE_MAX_ALLOCATION_SIZE != nullptr) { ++ device->max_memory_allocation_size = std::stoul(GGML_VK_FORCE_MAX_ALLOCATION_SIZE); ++ } else if (maintenance4_support) { ++ device->max_memory_allocation_size = std::min(props3.maxMemoryAllocationSize, props4.maxBufferSize); ++ } else { ++ device->max_memory_allocation_size = props3.maxMemoryAllocationSize; ++ } ++ ++ device->vendor_id = device->properties.vendorID; ++ device->subgroup_size = subgroup_props.subgroupSize; ++ device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu; ++ if (sm_builtins) { ++ device->shader_core_count = sm_props.shaderSMCount; ++ } else if (amd_shader_core_properties2) { ++ device->shader_core_count = amd_shader_core_properties2_props.activeComputeUnitCount; ++ } else { ++ device->shader_core_count = 0; ++ } ++ device->float_controls_rte_fp16 = vk12_props.shaderRoundingModeRTEFloat16; ++ ++ const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr; ++ ++ device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute; ++ ++ if (!ggml_vk_khr_cooperative_matrix_support(device->properties, driver_props)) { ++ device->coopmat_support = false; ++ } ++ ++ std::vector queue_family_props = device->physical_device.getQueueFamilyProperties(); ++ ++ // Try to find a non-graphics compute queue and transfer-focused queues ++ const uint32_t compute_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eCompute, vk::QueueFlagBits::eGraphics, -1, 1); ++ const uint32_t transfer_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eTransfer, vk::QueueFlagBits::eCompute | vk::QueueFlagBits::eGraphics, compute_queue_family_index, 1); ++ ++ const float priorities[] = { 1.0f, 1.0f }; ++ device->single_queue = compute_queue_family_index == transfer_queue_family_index && queue_family_props[compute_queue_family_index].queueCount == 1; ++ ++ std::vector device_queue_create_infos; ++ if (compute_queue_family_index != transfer_queue_family_index) { ++ device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities}); ++ device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), transfer_queue_family_index, 1, priorities + 1}); ++ } else if(!device->single_queue) { ++ device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 2, priorities}); ++ } else { ++ device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities}); ++ } ++ vk::DeviceCreateInfo device_create_info; ++ std::vector device_extensions; ++ vk::PhysicalDeviceFeatures device_features = device->physical_device.getFeatures(); ++ ++ VkPhysicalDeviceFeatures2 device_features2; ++ device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2; ++ device_features2.pNext = nullptr; ++ device_features2.features = (VkPhysicalDeviceFeatures)device_features; ++ ++ VkPhysicalDeviceVulkan11Features vk11_features; ++ vk11_features.pNext = nullptr; ++ vk11_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_1_FEATURES; ++ device_features2.pNext = &vk11_features; ++ ++ VkPhysicalDeviceVulkan12Features vk12_features; ++ vk12_features.pNext = nullptr; ++ vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES; ++ vk11_features.pNext = &vk12_features; ++ ++ last_struct = (VkBaseOutStructure *)&vk12_features; ++ ++ VkPhysicalDevicePipelineRobustnessFeaturesEXT pl_robustness_features; ++ pl_robustness_features.pNext = nullptr; ++ pl_robustness_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PIPELINE_ROBUSTNESS_FEATURES_EXT; ++ pl_robustness_features.pipelineRobustness = VK_FALSE; ++ ++ if (pipeline_robustness) { ++ last_struct->pNext = (VkBaseOutStructure *)&pl_robustness_features; ++ last_struct = (VkBaseOutStructure *)&pl_robustness_features; ++ device_extensions.push_back("VK_EXT_pipeline_robustness"); ++ } ++ ++ VkPhysicalDeviceSubgroupSizeControlFeaturesEXT subgroup_size_control_features; ++ subgroup_size_control_features.pNext = nullptr; ++ subgroup_size_control_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_SIZE_CONTROL_FEATURES_EXT; ++ subgroup_size_control_features.computeFullSubgroups = false; ++ subgroup_size_control_features.subgroupSizeControl = false; ++ ++ if (device->subgroup_size_control) { ++ last_struct->pNext = (VkBaseOutStructure *)&subgroup_size_control_features; ++ last_struct = (VkBaseOutStructure *)&subgroup_size_control_features; ++ } ++ ++ VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features; ++ coopmat_features.pNext = nullptr; ++ coopmat_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR; ++ coopmat_features.cooperativeMatrix = VK_FALSE; ++ ++ if (device->coopmat_support) { ++ last_struct->pNext = (VkBaseOutStructure *)&coopmat_features; ++ last_struct = (VkBaseOutStructure *)&coopmat_features; ++ } ++ ++#if defined(VK_NV_cooperative_matrix2) ++ VkPhysicalDeviceCooperativeMatrix2FeaturesNV coopmat2_features {}; ++ coopmat2_features.pNext = nullptr; ++ coopmat2_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_2_FEATURES_NV; ++ if (coopmat2_support) { ++ last_struct->pNext = (VkBaseOutStructure *)&coopmat2_features; ++ last_struct = (VkBaseOutStructure *)&coopmat2_features; ++ device_extensions.push_back("VK_NV_cooperative_matrix2"); ++ } ++#endif ++ ++ vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2); ++ ++ device->fp16 = device->fp16 && vk12_features.shaderFloat16; ++ ++ device->pipeline_robustness = pl_robustness_features.pipelineRobustness; ++ ++ if (device->subgroup_size_control) { ++ device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize; ++ device->subgroup_max_size = subgroup_size_control_props.maxSubgroupSize; ++ } ++ ++ device->subgroup_size_control = device->subgroup_size_control && ++ (subgroup_size_control_props.requiredSubgroupSizeStages & vk::ShaderStageFlagBits::eCompute) && ++ subgroup_size_control_features.subgroupSizeControl; ++ ++ if (device->subgroup_size_control) { ++ device->subgroup_require_full_support = subgroup_size_control_features.computeFullSubgroups; ++ device_extensions.push_back("VK_EXT_subgroup_size_control"); ++ } ++ ++ device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix; ++ ++ if (coopmat2_support) { ++#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) ++ if (coopmat2_features.cooperativeMatrixWorkgroupScope && ++ coopmat2_features.cooperativeMatrixFlexibleDimensions && ++ coopmat2_features.cooperativeMatrixReductions && ++ coopmat2_features.cooperativeMatrixConversions && ++ coopmat2_features.cooperativeMatrixPerElementOperations && ++ coopmat2_features.cooperativeMatrixTensorAddressing && ++ coopmat2_features.cooperativeMatrixBlockLoads && ++ vk12_features.bufferDeviceAddress) { ++ ++ std::vector flexible_dimensions; ++ uint32_t count = 0; ++ ++ PFN_vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV ++ _vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV = ++ (PFN_vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV) ++ vk_instance.instance.getProcAddr("vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV"); ++ ++ _vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV(device->physical_device, &count, nullptr); ++ ++ VkCooperativeMatrixFlexibleDimensionsPropertiesNV empty_prop {}; ++ empty_prop.sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_FLEXIBLE_DIMENSIONS_PROPERTIES_NV; ++ flexible_dimensions.resize(count, empty_prop); ++ ++ _vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV(device->physical_device, &count, flexible_dimensions.data()); ++ ++ bool found_fp16_128 = false, ++ found_fp16_256 = false, ++ found_fp32_128 = false, ++ found_fp32_256 = false; ++ // need to support fp16*fp16 with fp16/fp32 accumulator, for workgroupsize 128 ++ // with 32x16x16 and 256 with 32x32x16. ++ for (auto &prop : flexible_dimensions) { ++ if (prop.saturatingAccumulation == VK_FALSE && ++ prop.scope == VK_SCOPE_WORKGROUP_KHR && ++ prop.AType == VK_COMPONENT_TYPE_FLOAT16_KHR && ++ prop.BType == VK_COMPONENT_TYPE_FLOAT16_KHR) { ++ ++ if (prop.workgroupInvocations == 128 && ++ prop.MGranularity <= 32 && ++ prop.NGranularity <= 16 && ++ prop.KGranularity <= 16) { ++ if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR && ++ prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) { ++ found_fp16_128 = true; ++ } ++ if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR && ++ prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) { ++ found_fp32_128 = true; ++ } ++ } ++ if (prop.workgroupInvocations == 256 && ++ prop.MGranularity <= 32 && ++ prop.NGranularity <= 32 && ++ prop.KGranularity <= 16) { ++ if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR && ++ prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) { ++ found_fp16_256 = true; ++ } ++ if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR && ++ prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) { ++ found_fp32_256 = true; ++ } ++ } ++ } ++ } ++ if (found_fp16_128 && found_fp16_256 && ++ found_fp32_128 && found_fp32_256 && ++ coopmat2_props.cooperativeMatrixFlexibleDimensionsMaxDimension >= 512) { ++ device->coopmat2 = true; ++ } ++ } ++#endif ++ } ++ ++ if (!vk11_features.storageBuffer16BitAccess) { ++ std::cerr << "ggml_vulkan: device " << GGML_VK_NAME << idx << " does not support 16-bit storage." << std::endl; ++ throw std::runtime_error("Unsupported device"); ++ } ++ ++ device_extensions.push_back("VK_KHR_16bit_storage"); ++ ++#ifdef GGML_VULKAN_VALIDATE ++ device_extensions.push_back("VK_KHR_shader_non_semantic_info"); ++#endif ++ ++ if (device->fp16) { ++ device_extensions.push_back("VK_KHR_shader_float16_int8"); ++ } ++ ++ if (device->coopmat_support) { ++ // Query supported shapes ++ std::vector cm_props; ++ ++ PFN_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR pfn_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR = ++ (PFN_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR)vkGetInstanceProcAddr(vk_instance.instance, "vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR"); ++ ++ uint32_t cm_props_num; ++ ++ pfn_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR(device->physical_device, &cm_props_num, nullptr); ++ ++ cm_props.resize(cm_props_num); ++ ++ for (auto& prop : cm_props) { ++ prop.sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_PROPERTIES_KHR; ++ } ++ ++ pfn_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR(device->physical_device, &cm_props_num, cm_props.data()); ++ ++ VK_LOG_DEBUG("ggml_vulkan: Cooperative Matrix Shapes: " << cm_props.size()); ++ ++ for (auto& prop : cm_props) { ++ VK_LOG_DEBUG("ggml_vulkan: M: " << prop.MSize << " N: " << prop.NSize << " K: " << prop.KSize << " A: " << vk::to_string((vk::ComponentTypeKHR)prop.AType) << " B: " << vk::to_string((vk::ComponentTypeKHR)prop.BType) << " C: " << vk::to_string((vk::ComponentTypeKHR)prop.CType) << " Result: " << vk::to_string((vk::ComponentTypeKHR)prop.ResultType) << " saturatingAccumulation: " << prop.saturatingAccumulation << " scope: " << vk::to_string((vk::ScopeKHR)prop.scope)); ++ ++ if ((vk::ComponentTypeKHR)prop.AType == vk::ComponentTypeKHR::eFloat16 && ++ (vk::ComponentTypeKHR)prop.BType == vk::ComponentTypeKHR::eFloat16 && ++ (vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup ++ ) { ++ if ((vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat32 && ++ (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat32) { ++ // coopmat sizes not set yet ++ if (device->coopmat_m == 0) { ++ device->coopmat_acc_f32_support = true; ++ device->coopmat_m = prop.MSize; ++ device->coopmat_n = prop.NSize; ++ device->coopmat_k = prop.KSize; ++ } else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) { ++ // Only enable if shape is identical ++ device->coopmat_acc_f32_support = true; ++ } ++ } else if ((vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat16 && ++ (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat16) { ++ // coopmat sizes not set yet ++ if (device->coopmat_m == 0) { ++ device->coopmat_acc_f16_support = true; ++ device->coopmat_m = prop.MSize; ++ device->coopmat_n = prop.NSize; ++ device->coopmat_k = prop.KSize; ++ } else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) { ++ // Only enable if shape is identical ++ device->coopmat_acc_f16_support = true; ++ } ++ } ++ } ++ } ++ ++ if (device->coopmat_m == 0 || !device->coopmat_acc_f32_support) { ++ // No suitable matmul mode found ++ GGML_LOG_DEBUG("ggml_vulkan: WARNING: No suitable matrix core mode found. Disabling matrix cores.\n"); ++ device->coopmat_support = false; ++ } ++ } ++ ++ if (device->coopmat_support) { ++ device_extensions.push_back("VK_KHR_cooperative_matrix"); ++ } ++ ++ device->name = GGML_VK_NAME + std::to_string(idx); ++ ++ device_create_info = { ++ vk::DeviceCreateFlags(), ++ device_queue_create_infos, ++ {}, ++ device_extensions ++ }; ++ device_create_info.setPNext(&device_features2); ++ device->device = device->physical_device.createDevice(device_create_info); ++ ++ // Queues ++ ggml_vk_create_queue(device, device->compute_queue, compute_queue_family_index, 0, { vk::PipelineStageFlagBits::eComputeShader | vk::PipelineStageFlagBits::eTransfer }, false); ++ ++ // Shaders ++ // Disable matmul tile sizes early if performance low or not supported ++ switch (device->vendor_id) { ++#ifndef GGML_VULKAN_RUN_TESTS ++ case VK_VENDOR_ID_AMD: ++ case VK_VENDOR_ID_INTEL: ++ device->mul_mat_l = false; ++ device->mul_mat_m = true; ++ device->mul_mat_s = true; ++ device->mul_mat_id_l = false; ++ device->mul_mat_id_m = true; ++ device->mul_mat_id_s = true; ++ break; ++ case VK_VENDOR_ID_APPLE: ++ device->mul_mat_l = false; ++ device->mul_mat_m = true; ++ device->mul_mat_s = false; ++ device->mul_mat_id_l = false; ++ device->mul_mat_id_m = true; ++ device->mul_mat_id_s = false; ++ break; ++#endif ++ default: ++ device->mul_mat_l = true; ++ device->mul_mat_m = true; ++ device->mul_mat_s = true; ++ device->mul_mat_id_l = true; ++ device->mul_mat_id_m = true; ++ device->mul_mat_id_s = true; ++ break; ++ } ++ ++ ggml_vk_load_shaders(device); ++ ++ if (!device->single_queue) { ++ const uint32_t transfer_queue_index = compute_queue_family_index == transfer_queue_family_index ? 1 : 0; ++ ggml_vk_create_queue(device, device->transfer_queue, transfer_queue_family_index, transfer_queue_index, { vk::PipelineStageFlagBits::eTransfer }, true); ++ } else { ++ // TODO: Use pointer or reference to avoid copy ++ device->transfer_queue = device->compute_queue; ++ } ++ ++ device->buffer_type = { ++ /* .iface = */ ggml_backend_vk_buffer_type_interface, ++ /* .device = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), idx), ++ /* .context = */ new ggml_backend_vk_buffer_type_context{ device->name, device }, ++ }; ++ ++ device->fence = device->device.createFence({}); ++ ++ device->idx = idx; ++ ++ return device; ++ } ++ ++ return vk_instance.devices[idx]; ++} ++ ++static void ggml_vk_print_gpu_info(size_t idx) { ++ GGML_ASSERT(idx < vk_instance.device_indices.size()); ++ size_t dev_num = vk_instance.device_indices[idx]; ++ VK_LOG_DEBUG("ggml_vk_print_gpu_info(" << dev_num << ")"); ++ GGML_ASSERT(vk_instance_initialized); ++ ++ std::vector devices = vk_instance.instance.enumeratePhysicalDevices(); ++ ++ if (dev_num >= devices.size()) { ++ std::cerr << "ggml_vulkan: Device with index " << dev_num << " does not exist." << std::endl; ++ throw std::runtime_error("Device not found"); ++ } ++ ++ vk::PhysicalDevice physical_device = devices[dev_num]; ++ std::vector ext_props = physical_device.enumerateDeviceExtensionProperties(); ++ ++ vk::PhysicalDeviceProperties2 props2; ++ vk::PhysicalDeviceMaintenance3Properties props3; ++ vk::PhysicalDeviceSubgroupProperties subgroup_props; ++ vk::PhysicalDeviceDriverProperties driver_props; ++ props2.pNext = &props3; ++ props3.pNext = &subgroup_props; ++ subgroup_props.pNext = &driver_props; ++ physical_device.getProperties2(&props2); ++ ++ const size_t subgroup_size = subgroup_props.subgroupSize; ++ const bool uma = props2.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu; ++ ++ bool fp16_storage = false; ++ bool fp16_compute = false; ++ bool coopmat_support = false; ++ bool coopmat2_support = false; ++ ++ for (auto properties : ext_props) { ++ if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) { ++ fp16_storage = true; ++ } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) { ++ fp16_compute = true; ++ } else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0 && ++ !getenv("GGML_VK_DISABLE_COOPMAT")) { ++ coopmat_support = true; ++#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) ++ } else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 && ++ !getenv("GGML_VK_DISABLE_COOPMAT2")) { ++ coopmat2_support = true; ++#endif ++ } ++ } ++ ++ if (!ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props)) { ++ coopmat_support = false; ++ } ++ ++ const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16"); ++ bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr; ++ ++ bool fp16 = !force_disable_f16 && fp16_storage && fp16_compute; ++ ++ vk::PhysicalDeviceFeatures device_features = physical_device.getFeatures(); ++ ++ VkPhysicalDeviceFeatures2 device_features2; ++ device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2; ++ device_features2.pNext = nullptr; ++ device_features2.features = (VkPhysicalDeviceFeatures)device_features; ++ ++ VkPhysicalDeviceVulkan11Features vk11_features; ++ vk11_features.pNext = nullptr; ++ vk11_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_1_FEATURES; ++ device_features2.pNext = &vk11_features; ++ ++ VkPhysicalDeviceVulkan12Features vk12_features; ++ vk12_features.pNext = nullptr; ++ vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES; ++ vk11_features.pNext = &vk12_features; ++ ++ // Pointer to the last chain element ++ VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_features; ++ ++ VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features; ++ coopmat_features.pNext = nullptr; ++ coopmat_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR; ++ coopmat_features.cooperativeMatrix = VK_FALSE; ++ ++ if (coopmat_support) { ++ last_struct->pNext = (VkBaseOutStructure *)&coopmat_features; ++ last_struct = (VkBaseOutStructure *)&coopmat_features; ++ } ++ ++ vkGetPhysicalDeviceFeatures2(physical_device, &device_features2); ++ ++ fp16 = fp16 && vk12_features.shaderFloat16; ++ ++ coopmat_support = coopmat_support && coopmat_features.cooperativeMatrix; ++ ++ std::string matrix_cores = coopmat2_support ? "NV_coopmat2" : coopmat_support ? "KHR_coopmat" : "none"; ++ ++ std::string device_name = props2.properties.deviceName.data(); ++ GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | matrix cores: %s\n", ++ idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, subgroup_size, matrix_cores.c_str()); ++ ++ if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) { ++ GGML_LOG_DEBUG("ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want.\n"); ++ } ++} ++ ++static bool ggml_vk_instance_validation_ext_available(const std::vector& instance_extensions); ++static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector& instance_extensions); ++ ++void ggml_vk_instance_init() { ++ if (vk_instance_initialized) { ++ return; ++ } ++ VK_LOG_DEBUG("ggml_vk_instance_init()"); ++ ++ vk_instance_initialized = true; ++ ++ vk::ApplicationInfo app_info{ "ggml-vulkan", 1, nullptr, 0, VK_API_VERSION }; ++ ++ const std::vector instance_extensions = vk::enumerateInstanceExtensionProperties(); ++ const bool validation_ext = ggml_vk_instance_validation_ext_available(instance_extensions); ++#ifdef __APPLE__ ++ const bool portability_enumeration_ext = ggml_vk_instance_portability_enumeration_ext_available(instance_extensions); ++#endif ++ ++ std::vector layers; ++ ++ if (validation_ext) { ++ layers.push_back("VK_LAYER_KHRONOS_validation"); ++ } ++ std::vector extensions; ++ if (validation_ext) { ++ extensions.push_back("VK_EXT_validation_features"); ++ } ++#ifdef __APPLE__ ++ if (portability_enumeration_ext) { ++ extensions.push_back("VK_KHR_portability_enumeration"); ++ } ++#endif ++ vk::InstanceCreateInfo instance_create_info(vk::InstanceCreateFlags{}, &app_info, layers, extensions); ++#ifdef __APPLE__ ++ if (portability_enumeration_ext) { ++ instance_create_info.flags |= vk::InstanceCreateFlagBits::eEnumeratePortabilityKHR; ++ } ++#endif ++ ++ std::vector features_enable; ++ vk::ValidationFeaturesEXT validation_features; ++ ++ if (validation_ext) { ++ features_enable = { vk::ValidationFeatureEnableEXT::eBestPractices }; ++ validation_features = { ++ features_enable, ++ {}, ++ }; ++ validation_features.setPNext(nullptr); ++ instance_create_info.setPNext(&validation_features); ++ GGML_LOG_DEBUG("ggml_vulkan: Validation layers enabled\n"); ++ } ++ vk_instance.instance = vk::createInstance(instance_create_info); ++ ++ size_t num_available_devices = vk_instance.instance.enumeratePhysicalDevices().size(); ++ ++ // Emulate behavior of CUDA_VISIBLE_DEVICES for Vulkan ++ char * devices_env = getenv("GGML_VK_VISIBLE_DEVICES"); ++ if (devices_env != nullptr) { ++ std::string devices(devices_env); ++ std::replace(devices.begin(), devices.end(), ',', ' '); ++ ++ std::stringstream ss(devices); ++ size_t tmp; ++ while (ss >> tmp) { ++ if(tmp >= num_available_devices) { ++ std::cerr << "ggml_vulkan: Invalid device index " << tmp << " in GGML_VK_VISIBLE_DEVICES." << std::endl; ++ throw std::runtime_error("Invalid Vulkan device index"); ++ } ++ vk_instance.device_indices.push_back(tmp); ++ } ++ } else { ++ std::vector devices = vk_instance.instance.enumeratePhysicalDevices(); ++ ++ // Make sure at least one device exists ++ if (devices.empty()) { ++ std::cerr << "ggml_vulkan: Error: No devices found." << std::endl; ++ GGML_ABORT("fatal error"); ++ } ++ ++ // Default to using all dedicated GPUs ++ for (size_t i = 0; i < devices.size(); i++) { ++ vk::PhysicalDeviceProperties2 new_props; ++ vk::PhysicalDeviceDriverProperties new_driver; ++ vk::PhysicalDeviceIDProperties new_id; ++ new_props.pNext = &new_driver; ++ new_driver.pNext = &new_id; ++ devices[i].getProperties2(&new_props); ++ ++ if (new_props.properties.deviceType == vk::PhysicalDeviceType::eDiscreteGpu) { ++ // Check if there are two physical devices corresponding to the same GPU ++ auto old_device = std::find_if( ++ vk_instance.device_indices.begin(), ++ vk_instance.device_indices.end(), ++ [&devices, &new_id](const size_t k){ ++ vk::PhysicalDeviceProperties2 old_props; ++ vk::PhysicalDeviceIDProperties old_id; ++ old_props.pNext = &old_id; ++ devices[k].getProperties2(&old_props); ++ return std::equal(std::begin(old_id.deviceUUID), std::end(old_id.deviceUUID), std::begin(new_id.deviceUUID)); ++ } ++ ); ++ if (old_device == vk_instance.device_indices.end()) { ++ vk_instance.device_indices.push_back(i); ++ } else { ++ // There can be two physical devices corresponding to the same GPU if there are 2 different drivers ++ // This can cause error when splitting layers aross the devices, need to keep only 1 ++ VK_LOG_DEBUG("Device " << i << " and device " << *old_device << " have the same deviceUUID"); ++ ++ vk::PhysicalDeviceProperties2 old_props; ++ vk::PhysicalDeviceDriverProperties old_driver; ++ old_props.pNext = &old_driver; ++ devices[*old_device].getProperties2(&old_props); ++ ++ std::map driver_priorities {}; ++ int old_priority = std::numeric_limits::max(); ++ int new_priority = std::numeric_limits::max(); ++ ++ // Check https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VkDriverId.html for the list of driver id ++ // Smaller number -> higher priority ++ switch (old_props.properties.vendorID) { ++ case VK_VENDOR_ID_AMD: ++ driver_priorities[vk::DriverId::eMesaRadv] = 1; ++ driver_priorities[vk::DriverId::eAmdOpenSource] = 2; ++ driver_priorities[vk::DriverId::eAmdProprietary] = 3; ++ break; ++ case VK_VENDOR_ID_INTEL: ++ driver_priorities[vk::DriverId::eIntelOpenSourceMESA] = 1; ++ driver_priorities[vk::DriverId::eIntelProprietaryWindows] = 2; ++ break; ++ case VK_VENDOR_ID_NVIDIA: ++ driver_priorities[vk::DriverId::eNvidiaProprietary] = 1; ++#if defined(VK_API_VERSION_1_3) && VK_HEADER_VERSION >= 235 ++ driver_priorities[vk::DriverId::eMesaNvk] = 2; ++#endif ++ break; ++ } ++ ++ if (driver_priorities.count(old_driver.driverID)) { ++ old_priority = driver_priorities[old_driver.driverID]; ++ } ++ if (driver_priorities.count(new_driver.driverID)) { ++ new_priority = driver_priorities[new_driver.driverID]; ++ } ++ ++ if (new_priority < old_priority) { ++ auto r = std::remove(vk_instance.device_indices.begin(), vk_instance.device_indices.end(), *old_device); ++ vk_instance.device_indices.erase(r, vk_instance.device_indices.end()); ++ vk_instance.device_indices.push_back(i); ++ ++ VK_LOG_DEBUG("Prioritize device " << i << " driver " << new_driver.driverName << " over device " << *old_device << " driver " << old_driver.driverName); ++ } ++ else { ++ VK_LOG_DEBUG("Prioritize device " << *old_device << " driver " << old_driver.driverName << " over device " << i << " driver " << new_driver.driverName << std::endl); ++ } ++ } ++ } ++ } ++ ++ // If no dedicated GPUs found, fall back to GPU 0 ++ if (vk_instance.device_indices.empty()) { ++ vk_instance.device_indices.push_back(0); ++ } ++ } ++ GGML_LOG_DEBUG("ggml_vulkan: Found %zu Vulkan devices:\n", vk_instance.device_indices.size()); ++ ++ for (size_t i = 0; i < vk_instance.device_indices.size(); i++) { ++ ggml_vk_print_gpu_info(i); ++ } ++} ++ ++static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) { ++ VK_LOG_DEBUG("ggml_vk_init(" << ctx->name << ", " << idx << ")"); ++ ggml_vk_instance_init(); ++ GGML_ASSERT(idx < vk_instance.device_indices.size()); ++ ++ ctx->name = GGML_VK_NAME + std::to_string(idx); ++ ++ ctx->device = ggml_vk_get_device(idx); ++ ++ ctx->semaphore_idx = 0; ++ ctx->event_idx = 0; ++ ++ ctx->prealloc_size_x = 0; ++ ctx->prealloc_size_y = 0; ++ ctx->prealloc_size_split_k = 0; ++ ++ ctx->fence = ctx->device->device.createFence({}); ++ ++#ifdef GGML_VULKAN_CHECK_RESULTS ++ const char* skip_checks = getenv("GGML_VULKAN_SKIP_CHECKS"); ++ vk_skip_checks = (skip_checks == NULL ? 0 : atoi(skip_checks)); ++ const char* output_tensor = getenv("GGML_VULKAN_OUTPUT_TENSOR"); ++ vk_output_tensor = (output_tensor == NULL ? 0 : atoi(output_tensor)); ++#endif ++} ++ ++static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type type) { ++ VK_LOG_DEBUG("ggml_vk_get_to_fp16()"); ++ switch (type) { ++ case GGML_TYPE_F32: ++ case GGML_TYPE_Q4_0: ++ case GGML_TYPE_Q4_1: ++ case GGML_TYPE_Q5_0: ++ case GGML_TYPE_Q5_1: ++ case GGML_TYPE_Q8_0: ++ case GGML_TYPE_Q2_K: ++ case GGML_TYPE_Q3_K: ++ case GGML_TYPE_Q4_K: ++ case GGML_TYPE_Q5_K: ++ case GGML_TYPE_Q6_K: ++ case GGML_TYPE_IQ4_NL: ++ break; ++ default: ++ return nullptr; ++ } ++ ++ return ctx->device->pipeline_dequant[type]; ++} ++ ++static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type, ggml_prec prec) { ++ VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_pipeline(" << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")"); ++ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) { ++ return ctx->device->pipeline_matmul_f32; ++ } ++ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) { ++ return ctx->device->pipeline_matmul_f32_f16; ++ } ++ if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) { ++ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) { ++ return ctx->device->pipeline_matmul_f16_f32.f16acc; ++ } ++ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { ++ return ctx->device->pipeline_matmul_f16.f16acc; ++ } ++ } else { ++ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) { ++ return ctx->device->pipeline_matmul_f16_f32.f32acc; ++ } ++ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { ++ return ctx->device->pipeline_matmul_f16.f32acc; ++ } ++ } ++ ++ if (src1_type != GGML_TYPE_F32 && !ctx->device->coopmat2) { ++ return nullptr; ++ } ++ ++ switch (src0_type) { ++ case GGML_TYPE_Q4_0: ++ case GGML_TYPE_Q4_1: ++ case GGML_TYPE_Q5_0: ++ case GGML_TYPE_Q5_1: ++ case GGML_TYPE_Q8_0: ++ case GGML_TYPE_Q2_K: ++ case GGML_TYPE_Q3_K: ++ case GGML_TYPE_Q4_K: ++ case GGML_TYPE_Q5_K: ++ case GGML_TYPE_Q6_K: ++ case GGML_TYPE_IQ4_NL: ++ break; ++ default: ++ return nullptr; ++ } ++ ++ if (ctx->device->coopmat2) { ++ assert(src1_type == GGML_TYPE_F16); ++ return ctx->device->pipeline_dequant_mul_mat_mat_f16[src0_type].f16acc; ++ } ++ return ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc; ++} ++ ++static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type, uint32_t num_cols) { ++ VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec()"); ++ GGML_ASSERT(b_type == GGML_TYPE_F32 || b_type == GGML_TYPE_F16); ++ GGML_ASSERT(num_cols >= 1 && num_cols <= mul_mat_vec_max_cols); ++ ++ switch (a_type) { ++ case GGML_TYPE_F32: ++ case GGML_TYPE_F16: ++ case GGML_TYPE_Q4_0: ++ case GGML_TYPE_Q4_1: ++ case GGML_TYPE_Q5_0: ++ case GGML_TYPE_Q5_1: ++ case GGML_TYPE_Q8_0: ++ case GGML_TYPE_Q2_K: ++ case GGML_TYPE_Q3_K: ++ case GGML_TYPE_Q4_K: ++ case GGML_TYPE_Q5_K: ++ case GGML_TYPE_Q6_K: ++ case GGML_TYPE_IQ4_NL: ++ break; ++ default: ++ return nullptr; ++ } ++ ++ return b_type == GGML_TYPE_F32 ? ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[a_type][num_cols-1] : ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[a_type][num_cols-1]; ++} ++ ++static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type, ggml_prec prec) { ++ VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_id_pipeline()"); ++ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) { ++ return ctx->device->pipeline_matmul_id_f32; ++ } ++ if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) { ++ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) { ++ return ctx->device->pipeline_matmul_id_f16_f32.f16acc; ++ } ++ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { ++ return ctx->device->pipeline_matmul_id_f16.f16acc; ++ } ++ } else { ++ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) { ++ return ctx->device->pipeline_matmul_id_f16_f32.f32acc; ++ } ++ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { ++ return ctx->device->pipeline_matmul_id_f16.f32acc; ++ } ++ } ++ ++ GGML_ASSERT(src1_type == GGML_TYPE_F32); ++ ++ switch (src0_type) { ++ case GGML_TYPE_Q4_0: ++ case GGML_TYPE_Q4_1: ++ case GGML_TYPE_Q5_0: ++ case GGML_TYPE_Q5_1: ++ case GGML_TYPE_Q8_0: ++ case GGML_TYPE_Q2_K: ++ case GGML_TYPE_Q3_K: ++ case GGML_TYPE_Q4_K: ++ case GGML_TYPE_Q5_K: ++ case GGML_TYPE_Q6_K: ++ case GGML_TYPE_IQ4_NL: ++ break; ++ default: ++ return nullptr; ++ } ++ ++ return ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f32acc; ++} ++ ++static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type) { ++ VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec()"); ++ GGML_ASSERT(b_type == GGML_TYPE_F32); ++ ++ switch (a_type) { ++ case GGML_TYPE_F32: ++ case GGML_TYPE_F16: ++ case GGML_TYPE_Q4_0: ++ case GGML_TYPE_Q4_1: ++ case GGML_TYPE_Q5_0: ++ case GGML_TYPE_Q5_1: ++ case GGML_TYPE_Q8_0: ++ case GGML_TYPE_Q2_K: ++ case GGML_TYPE_Q3_K: ++ case GGML_TYPE_Q4_K: ++ case GGML_TYPE_Q5_K: ++ case GGML_TYPE_Q6_K: ++ case GGML_TYPE_IQ4_NL: ++ break; ++ default: ++ return nullptr; ++ } ++ ++ return ctx->device->pipeline_dequant_mul_mat_vec_id_f32[a_type]; ++} ++ ++static vk_buffer ggml_vk_pool_malloc(ggml_backend_vk_context * ctx, size_t size) { ++ VK_LOG_DEBUG("ggml_vk_pool_malloc(" << size << ")"); ++ VK_LOG_MEMORY("ggml_vk_pool_malloc"); ++ ++ int best_i = -1; ++ size_t best_size = std::numeric_limits::max(); //smallest unused buffer that fits our needs ++ int worst_i = -1; ++ size_t worst_size = 0; //largest unused buffer seen so far ++ for (int i = 0; i < MAX_VK_BUFFERS; ++i) { ++ vk_buffer &b = ctx->buffer_pool[i]; ++ if (b != nullptr && b->size >= size && b->size < best_size) { ++ best_i = i; ++ best_size = b->size; ++ } ++ if (b != nullptr && b->size > worst_size) { ++ worst_i = i; ++ worst_size = b->size; ++ } ++ } ++ if(best_i != -1) { ++ //found the smallest buffer that fits our needs ++ vk_buffer b = ctx->buffer_pool[best_i]; ++ ctx->buffer_pool[best_i].reset(); ++ return b; ++ } ++ if(worst_i != -1) { ++ //no buffer that fits our needs, resize largest one to save memory ++ vk_buffer& b = ctx->buffer_pool[worst_i]; ++ ggml_vk_destroy_buffer(b); ++ } ++ ++ return ggml_vk_create_buffer_device(ctx->device, size); ++} ++ ++static void ggml_vk_pool_free(ggml_backend_vk_context * ctx, vk_buffer& buffer) { ++ VK_LOG_DEBUG("ggml_vk_pool_free(" << buffer->size << ")"); ++ for (int i = 0; i < MAX_VK_BUFFERS; ++i) { ++ vk_buffer& b = ctx->buffer_pool[i]; ++ if (b == nullptr) { ++ b = buffer; ++ return; ++ } ++ } ++ std::cerr << "ggml_vulkan: WARNING: vk buffer pool full, increase MAX_VK_BUFFERS" << std::endl; ++ ggml_vk_destroy_buffer(buffer); ++} ++ ++// Returns an available temporary buffer that may only be used temporarily, it will be reused ++static vk_buffer ggml_vk_create_buffer_temp(ggml_backend_vk_context * ctx, size_t size) { ++ // Try to find existing temp buffer with enough capacity ++ for (auto& buffer : ctx->gc.temp_buffers) { ++ if (buffer->size >= size) { ++ return buffer; ++ } ++ } ++ ++ VK_LOG_MEMORY("ggml_vk_create_buffer_temp(" << size << ")"); ++ ++ // Otherwise create new buffer ++ vk_buffer buf = ggml_vk_pool_malloc(ctx, size); ++ ctx->gc.temp_buffers.push_back(buf); ++ ++ return buf; ++} ++ ++static void * ggml_vk_host_malloc(vk_device& device, size_t size) { ++ VK_LOG_MEMORY("ggml_vk_host_malloc(" << size << ")"); ++ vk_buffer buf = ggml_vk_create_buffer(device, size, ++ vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached, ++ vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent); ++ ++ if(!(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible)) { ++ fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory\n", ++ size/1024.0/1024.0); ++ device->device.freeMemory(buf->device_memory); ++ device->device.destroyBuffer(buf->buffer); ++ return nullptr; ++ } ++ ++ device->pinned_memory.push_back(std::make_tuple(buf->ptr, size, buf)); ++ ++ return buf->ptr; ++} ++ ++static void ggml_vk_host_free(vk_device& device, void* ptr) { ++ if (ptr == nullptr) { ++ return; ++ } ++ VK_LOG_MEMORY("ggml_vk_host_free(" << ptr << ")"); ++ vk_buffer buf; ++ size_t index; ++ for (size_t i = 0; i < device->pinned_memory.size(); i++) { ++ const uint8_t* addr = (const uint8_t*) std::get<0>(device->pinned_memory[i]); ++ const uint8_t* endr = addr + std::get<1>(device->pinned_memory[i]); ++ if (ptr >= addr && ptr < endr) { ++ buf = std::get<2>(device->pinned_memory[i]); ++ index = i; ++ break; ++ } ++ } ++ if (buf == nullptr) { ++ fprintf(stderr, "WARNING: failed to free pinned memory: memory not in map\n"); ++ return; ++ } ++ ++ ggml_vk_destroy_buffer(buf); ++ ++ device->pinned_memory.erase(device->pinned_memory.begin() + index); ++} ++ ++static void ggml_vk_host_get(vk_device& device, const void * ptr, vk_buffer& buf, size_t& buf_offset) { ++ buf = nullptr; ++ buf_offset = 0; ++ for (size_t i = 0; i < device->pinned_memory.size(); i++) { ++ const uint8_t* addr = (const uint8_t*) std::get<0>(device->pinned_memory[i]); ++ const uint8_t* endr = addr + std::get<1>(device->pinned_memory[i]); ++ if (ptr >= addr && ptr < endr) { ++ buf = std::get<2>(device->pinned_memory[i]); ++ buf_offset = ((const uint8_t *)ptr) - addr; ++ break; ++ } ++ } ++} ++ ++static vk_submission ggml_vk_begin_submission(vk_device& device, vk_queue& q, bool one_time = true) { ++ vk_submission s; ++ s.buffer = ggml_vk_create_cmd_buffer(device, q); ++ if (one_time) { ++ s.buffer.begin({ vk::CommandBufferUsageFlagBits::eOneTimeSubmit }); ++ } else { ++ s.buffer.begin({ vk::CommandBufferUsageFlags{} }); ++ } ++ ++ return s; ++} ++ ++ ++ ++static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& subctx, vk_pipeline& pipeline, std::initializer_list const& descriptor_buffer_infos, size_t push_constant_size, const void* push_constants, std::array elements) { ++ const uint32_t wg0 = CEIL_DIV(elements[0], pipeline->wg_denoms[0]); ++ const uint32_t wg1 = CEIL_DIV(elements[1], pipeline->wg_denoms[1]); ++ const uint32_t wg2 = CEIL_DIV(elements[2], pipeline->wg_denoms[2]); ++ VK_LOG_DEBUG("ggml_vk_dispatch_pipeline(" << pipeline->name << ", {"; ++ for (auto& buffer : descriptor_buffer_infos) { ++ std::cerr << "(" << buffer.buffer << ", " << buffer.offset << ", " << buffer.range << "), "; ++ } ++ std::cerr << "}, (" << wg0 << "," << wg1 << "," << wg2 << "))"); ++ GGML_ASSERT(pipeline->descriptor_set_idx < pipeline->descriptor_sets.size()); ++ GGML_ASSERT(descriptor_buffer_infos.size() == pipeline->parameter_count); ++ ++ vk::DescriptorSet& descriptor_set = pipeline->descriptor_sets[pipeline->descriptor_set_idx++]; ++ vk::WriteDescriptorSet write_descriptor_set{ descriptor_set, 0, 0, pipeline->parameter_count, vk::DescriptorType::eStorageBuffer, nullptr, descriptor_buffer_infos.begin() }; ++ ctx->device->device.updateDescriptorSets({ write_descriptor_set }, {}); ++ ++ subctx->s->buffer.pushConstants(pipeline->layout, vk::ShaderStageFlagBits::eCompute, 0, push_constant_size, push_constants); ++ subctx->s->buffer.bindPipeline(vk::PipelineBindPoint::eCompute, pipeline->pipeline); ++ subctx->s->buffer.bindDescriptorSets(vk::PipelineBindPoint::eCompute, ++ pipeline->layout, ++ 0, ++ { descriptor_set }, ++ {}); ++ subctx->s->buffer.dispatch(wg0, wg1, wg2); ++} ++ ++static void ggml_vk_end_submission(vk_submission& s, std::vector wait_semaphores, std::vector signal_semaphores) { ++ s.buffer.end(); ++ ++ s.wait_semaphores = std::move(wait_semaphores); ++ s.signal_semaphores = std::move(signal_semaphores); ++} ++ ++static void ggml_vk_ctx_end(vk_context& ctx) { ++ VK_LOG_DEBUG("ggml_vk_ctx_end(" << ctx << ", " << ctx->seqs.size() << ")"); ++ if (ctx->s == nullptr) { ++ return; ++ } ++ ++ ctx->s->buffer.end(); ++ ctx->s = nullptr; ++} ++ ++static void ggml_vk_ctx_begin(vk_device& device, vk_context& subctx) { ++ VK_LOG_DEBUG("ggml_vk_ctx_begin(" << device->name << ")"); ++ if (subctx->s != nullptr) { ++ ggml_vk_ctx_end(subctx); ++ } ++ ++ subctx->seqs.push_back({ ggml_vk_begin_submission(device, *subctx->q) }); ++ subctx->s = subctx->seqs[subctx->seqs.size() - 1].data(); ++} ++ ++static size_t ggml_vk_align_size(size_t width, size_t align) { ++ VK_LOG_DEBUG("ggml_vk_align_size(" << width << ", " << align << ")"); ++ return CEIL_DIV(width, align) * align; ++} ++ ++static void deferred_memcpy(void * dst, const void * src, size_t size, std::vector* memcpys = nullptr) { ++ if (memcpys == nullptr) { ++ memcpy(dst, src, size); ++ } else { ++ memcpys->emplace_back(dst, src, size); ++ } ++} ++ ++static void ggml_vk_ensure_sync_staging_buffer(vk_device& device, size_t size) { ++ if (device->sync_staging == nullptr || device->sync_staging->size < size) { ++ VK_LOG_MEMORY("ggml_vk_ensure_sync_staging_buffer(" << size << ")"); ++ ggml_vk_destroy_buffer(device->sync_staging); ++ device->sync_staging = ggml_vk_create_buffer_check(device, size, ++ vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached, ++ vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent); ++ } ++} ++ ++static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_context& subctx, vk_buffer& dst, size_t offset, const ggml_tensor * tensor, bool sync_staging = false) { ++ VK_LOG_DEBUG("ggml_vk_buffer_write_nc_async(" << tensor << ")"); ++ GGML_ASSERT(!ggml_is_contiguous(tensor)); ++ // Buffer is already mapped ++ if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) { ++ std::cerr << "ggml_vulkan: buffer_write_nc_async dst buffer is host_visible. Use synchronous write." << std::endl; ++ GGML_ABORT("fatal error"); ++ } ++ // Check if src is pinned memory ++ vk_buffer buf = nullptr; ++ size_t buf_offset = 0; ++ ggml_vk_host_get(ctx->device, tensor->data, buf, buf_offset); ++ ++ const uint64_t ne0 = tensor->ne[0]; ++ const uint64_t ne1 = tensor->ne[1]; ++ const uint64_t ne2 = tensor->ne[2]; ++ const uint64_t ne3 = tensor->ne[3]; ++ const uint64_t nb0 = tensor->nb[0]; ++ const uint64_t nb1 = tensor->nb[1]; ++ const uint64_t nb2 = tensor->nb[2]; ++ const uint64_t nb3 = tensor->nb[3]; ++ const ggml_type type = tensor->type; ++ const uint64_t ts = ggml_type_size(type); ++ const uint64_t bs = ggml_blck_size(type); ++ ++ const uint64_t dstnb0 = ts; ++ const uint64_t dstnb1 = dstnb0*(ne0/bs); ++ const uint64_t dstnb2 = dstnb1*ne1; ++ const uint64_t dstnb3 = dstnb2*ne2; ++ ++ const uint64_t ne = ggml_nelements(tensor); ++ ++ if (buf != nullptr) { ++ // Memory is pinned, use as staging buffer ++ std::vector slices; ++ ++ for (uint64_t i3 = 0; i3 < ne3; i3++) { ++ for (uint64_t i2 = 0; i2 < ne2; i2++) { ++ // Find longest contiguous slice ++ if (ne1*nb1 == dstnb2) { ++ slices.push_back({ buf_offset + i3*nb3 + i2*nb2, offset + i3*dstnb3 + i2*dstnb2, dstnb2 }); ++ } else { ++ for (uint64_t i1 = 0; i1 < ne1; i1++) { ++ if (ne0*nb0/bs == dstnb1) { ++ slices.push_back({ buf_offset + i3*nb3 + i2*nb2 + i1*nb1, offset + i3*dstnb3 + i2*dstnb2 + i1*dstnb1, dstnb1 }); ++ } else { ++ const uint64_t s_off = buf_offset + i3*nb3 + i2*nb2 + i1*nb1; ++ const uint64_t d_off = offset + i3*dstnb3 + i2*dstnb2 + i1*dstnb1; ++ for (uint64_t i0 = 0; i0 < ne0; i0++) { ++ slices.push_back({ s_off + i1*nb0, d_off + i0*dstnb0, dstnb0 }); ++ } ++ } ++ } ++ } ++ } ++ } ++ ++ ggml_vk_sync_buffers(subctx); ++ subctx->s->buffer.copyBuffer(buf->buffer, dst->buffer, slices); ++ return; ++ } ++ ++ if (!sync_staging) { ++ GGML_ABORT("Asynchronous write to non-pinned memory not supported"); ++ } ++ ++ // Staging buffer required ++ vk_buffer& staging = ctx->device->sync_staging; ++ const uint64_t copy_size = ts*ne/bs; ++ ggml_vk_ensure_sync_staging_buffer(ctx->device, copy_size); ++ VkBufferCopy buf_copy{ 0, offset, copy_size }; ++ ++ ggml_vk_sync_buffers(subctx); ++ vkCmdCopyBuffer(subctx->s->buffer, (VkBuffer)staging->buffer, (VkBuffer)dst->buffer, 1, &buf_copy); ++ ++ for (uint64_t i3 = 0; i3 < ne3; i3++) { ++ for (uint64_t i2 = 0; i2 < ne2; i2++) { ++ // Find longest contiguous slice ++ if (ne1*nb1 == dstnb2) { ++ deferred_memcpy((uint8_t *)staging->ptr + i3*dstnb3 + i2*dstnb2, (const uint8_t *) tensor->data + buf_offset + i3*nb3 + i2*nb2, dstnb2, &subctx->in_memcpys); ++ } else { ++ for (uint64_t i1 = 0; i1 < ne1; i1++) { ++ if (ne0*nb0/bs == dstnb1) { ++ deferred_memcpy((uint8_t *)staging->ptr + i3*dstnb3 + i2*dstnb2 + i1*dstnb1, (const uint8_t *) tensor->data + buf_offset + i3*nb3 + i2*nb2 + i1*nb1, dstnb1, &subctx->in_memcpys); ++ } else { ++ const uint64_t s_off = buf_offset + i3*nb3 + i2*nb2 + i1*nb1; ++ const uint64_t d_off = i3*dstnb3 + i2*dstnb2 + i1*dstnb1; ++ for (uint64_t i0 = 0; i0 < ne0; i0++) { ++ deferred_memcpy((uint8_t *)staging->ptr + d_off + i0*dstnb0, (const uint8_t *) tensor->data + s_off + i0*nb0, dstnb0, &subctx->in_memcpys); ++ } ++ } ++ } ++ } ++ } ++ } ++} ++ ++static void ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height, bool sync_staging = false) { ++ VK_LOG_DEBUG("ggml_vk_buffer_write_2d_async(" << width << ", " << height << ")"); ++ // Buffer is already mapped ++ if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) { ++ std::cerr << "ggml_vulkan: buffer_write_async dst buffer is host_visible. Use synchronous write." << std::endl; ++ GGML_ABORT("fatal error"); ++ } ++ // Check if src is pinned memory ++ vk_buffer buf = nullptr; ++ size_t buf_offset = 0; ++ ggml_vk_host_get(dst->device, src, buf, buf_offset); ++ ++ if (buf != nullptr) { ++ // Memory is pinned, use as staging buffer ++ std::vector slices(1); ++ if (width == spitch) { ++ // Only do single write if stride is equal ++ slices[0].srcOffset = buf_offset; ++ slices[0].dstOffset = offset; ++ slices[0].size = width * height; ++ } else { ++ slices.resize(height); ++ for (size_t i = 0; i < height; i++) { ++ slices[i].srcOffset = buf_offset + i * spitch; ++ slices[i].dstOffset = offset + i * width; ++ slices[i].size = width; ++ } ++ } ++ ++ ggml_vk_sync_buffers(subctx); ++ subctx->s->buffer.copyBuffer(buf->buffer, dst->buffer, slices); ++ return; ++ } ++ VK_LOG_DEBUG("STAGING"); ++ ++ if (!sync_staging) { ++ GGML_ABORT("Asynchronous write to non-pinned memory not supported"); ++ } ++ ++ // Staging buffer required ++ const size_t copy_size = width*height; ++ ggml_vk_ensure_sync_staging_buffer(dst->device, copy_size); ++ ++ vk_buffer& staging_buffer = dst->device->sync_staging; ++ ++ VkBufferCopy buf_copy = { ++ 0, ++ offset, ++ copy_size}; ++ ++ ggml_vk_sync_buffers(subctx); ++ vkCmdCopyBuffer(subctx->s->buffer, (VkBuffer)staging_buffer->buffer, (VkBuffer)dst->buffer, 1, &buf_copy); ++ ++ if (width == spitch) { ++ deferred_memcpy((uint8_t *)staging_buffer->ptr, src, width * height, &subctx->in_memcpys); ++ } else { ++ for (size_t i = 0; i < height; i++) { ++ deferred_memcpy((uint8_t *)staging_buffer->ptr + i * width, (const uint8_t *) src + i * spitch, width, &subctx->in_memcpys); ++ } ++ } ++} ++ ++static void ggml_vk_buffer_write_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t size, bool sync_staging = false) { ++ VK_LOG_DEBUG("ggml_vk_buffer_write_async(" << size << ")"); ++ return ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, size, size, 1, sync_staging); ++} ++ ++static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height) { ++ VK_LOG_DEBUG("ggml_vk_buffer_write_2d(" << width << ", " << height << ")"); ++ // Buffer is already mapped ++ if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) { ++ GGML_ASSERT(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostCoherent); ++ ++ for (size_t i = 0; i < height; i++) { ++ memcpy((uint8_t *)dst->ptr + offset + i * width, (const uint8_t *) src + i * spitch, width); ++ } ++ } else { ++ vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue); ++ ggml_vk_ctx_begin(dst->device, subctx); ++ ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, spitch, width, height, true); ++ ggml_vk_ctx_end(subctx); ++ ++ for (auto& cpy : subctx->in_memcpys) { ++ memcpy(cpy.dst, cpy.src, cpy.n); ++ } ++ ++ ggml_vk_submit(subctx, dst->device->fence); ++ VK_CHECK(dst->device->device.waitForFences({ dst->device->fence }, true, UINT64_MAX), "vk_buffer_write_2d waitForFences"); ++ dst->device->device.resetFences({ dst->device->fence }); ++ } ++} ++ ++static void ggml_vk_buffer_write(vk_buffer& dst, size_t offset, const void * src, size_t size) { ++ VK_LOG_DEBUG("ggml_vk_buffer_write(" << size << ")"); ++ ggml_vk_buffer_write_2d(dst, offset, src, 0, size, 1); ++} ++ ++static void ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size_t offset, void * dst, size_t spitch, size_t dpitch, size_t width, size_t height, bool sync_staging = false) { ++ VK_LOG_DEBUG("ggml_vk_buffer_read_2d_async(offset=" << offset << ", width=" << width << ", height=" << height << ")"); ++ GGML_ASSERT(width > 0); ++ GGML_ASSERT(height > 0); ++ GGML_ASSERT(src != nullptr); ++ ++ // TODO: staging_offset is not used ++ ++ // Check if dst is pinned memory ++ vk_buffer buf = nullptr; ++ size_t buf_offset = 0; ++ ggml_vk_host_get(src->device, dst, buf, buf_offset); ++ ++ std::vector slices(1); ++ if (width == spitch && width == dpitch) { ++ // Only do single write if stride is equal ++ slices[0].srcOffset = offset; ++ slices[0].dstOffset = buf_offset; ++ slices[0].size = width * height; ++ } else { ++ slices.resize(height); ++ for (size_t i = 0; i < height; i++) { ++ slices[i].srcOffset = offset + i * spitch; ++ slices[i].dstOffset = buf_offset + i * dpitch; ++ slices[i].size = width; ++ } ++ } ++ ++ if (buf != nullptr) { ++ // Memory is pinned, use as staging buffer ++ ggml_vk_sync_buffers(subctx); ++ subctx->s->buffer.copyBuffer(src->buffer, buf->buffer, slices); ++ ++ return; ++ } ++ VK_LOG_DEBUG("STAGING"); ++ ++ if (!sync_staging) { ++ GGML_ABORT("Asynchronous read from non-pinned memory not supported"); ++ } ++ ++ // Fall back to staging buffer ++ const size_t copy_size = dpitch * height; ++ ggml_vk_ensure_sync_staging_buffer(src->device, copy_size); ++ ++ vk_buffer& staging_buffer = src->device->sync_staging; ++ ++ ggml_vk_sync_buffers(subctx); ++ subctx->s->buffer.copyBuffer(src->buffer, staging_buffer->buffer, slices); ++ ++ deferred_memcpy(dst, staging_buffer->ptr, copy_size, &subctx->out_memcpys); ++} ++ ++static void ggml_vk_buffer_read_async(vk_context subctx, vk_buffer& src, size_t offset, void * dst, size_t size, bool sync_staging = false) { ++ return ggml_vk_buffer_read_2d_async(subctx, src, offset, dst, size, size, size, 1, sync_staging); ++} ++ ++static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_t size) { ++ VK_LOG_DEBUG("ggml_vk_buffer_read(" << src->buffer << ", " << offset << ", " << size << ")"); ++ ++ // If the device is not an UMA device the memory is host-accessible through rebar. While writing ++ // through PCIe is sufficient fast reading back data from PCIe is slower than going through ++ // the HW device to host copy path. ++ if(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible && src->device->uma) { ++ GGML_ASSERT(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostCoherent); ++ ++ memcpy(dst, (uint8_t *) src->ptr + offset, size); ++ } else { ++ vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue); ++ ggml_vk_ctx_begin(src->device, subctx); ++ ggml_vk_buffer_read_async(subctx, src, offset, dst, size, true); ++ ggml_vk_ctx_end(subctx); ++ ++ ggml_vk_submit(subctx, src->device->fence); ++ VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX), "vk_buffer_read waitForFences"); ++ src->device->device.resetFences({ src->device->fence }); ++ ++ for (auto& cpy : subctx->out_memcpys) { ++ memcpy(cpy.dst, cpy.src, cpy.n); ++ } ++ } ++} ++ ++static void ggml_vk_buffer_copy_async(vk_context& ctx, vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) { ++ VK_LOG_DEBUG("ggml_vk_buffer_copy_async(" << size << ")"); ++ // Make sure both buffers are on same device ++ GGML_ASSERT(src->device == dst->device); ++ ++ VkBufferCopy bc{ src_offset, dst_offset, size }; ++ ++ vkCmdCopyBuffer(ctx->s->buffer, (VkBuffer)src->buffer, (VkBuffer)dst->buffer, 1, &bc); ++} ++ ++static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) { ++ if (src->device == dst->device) { ++ VK_LOG_DEBUG("ggml_vk_buffer_copy(SINGLE_DEVICE, " << size << ")"); ++ // Copy within the device ++ vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue); ++ ggml_vk_ctx_begin(src->device, subctx); ++ ggml_vk_buffer_copy_async(subctx, dst, dst_offset, src, src_offset, size); ++ ggml_vk_ctx_end(subctx); ++ ggml_vk_submit(subctx, src->device->fence); ++ VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX), "vk_buffer_copy waitForFences"); ++ src->device->device.resetFences({ src->device->fence }); ++ } else { ++ VK_LOG_DEBUG("ggml_vk_buffer_copy(MULTI_DEVICE, " << size << ")"); ++ // Copy device to device ++ ggml_vk_ensure_sync_staging_buffer(src->device, size); ++ ggml_vk_ensure_sync_staging_buffer(dst->device, size); ++ ++ // Copy to src staging buffer ++ ggml_vk_buffer_copy(src->device->sync_staging, 0, src, src_offset, size); ++ // memcpy to dst staging buffer ++ memcpy(dst->device->sync_staging->ptr, src->device->sync_staging->ptr, size); ++ // Copy to dst buffer ++ ggml_vk_buffer_copy(dst, dst_offset, dst->device->sync_staging, 0, size); ++ } ++} ++ ++static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, size_t size) { ++ VK_LOG_DEBUG("ggml_vk_buffer_memset(" << offset << ", " << c << ", " << size << ")"); ++ ++ vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue); ++ ggml_vk_ctx_begin(dst->device, subctx); ++ subctx->s->buffer.fillBuffer(dst->buffer, offset, size, c); ++ ggml_vk_ctx_end(subctx); ++ ++ ggml_vk_submit(subctx, dst->device->fence); ++ VK_CHECK(dst->device->device.waitForFences({ dst->device->fence }, true, UINT64_MAX), "vk_memset waitForFences"); ++ dst->device->device.resetFences({ dst->device->fence }); ++} ++ ++static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, int m, int n, int k, const vk_pipeline& pipeline) { ++ VK_LOG_DEBUG("ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ")"); ++ ++ uint32_t split_k = 1; ++ if (ctx->device->shader_core_count != 0 && m >= (int)pipeline->wg_denoms[0] && n >= (int)pipeline->wg_denoms[1]) { ++ // If k is 'large' and the SMs will fill less than halfway, use split_k. ++ uint32_t m_tiles = CEIL_DIV(m, pipeline->wg_denoms[0]); ++ uint32_t n_tiles = CEIL_DIV(n, pipeline->wg_denoms[1]); ++ if (k >= 2048 && m_tiles * n_tiles < ctx->device->shader_core_count / 2) { ++ split_k = ctx->device->shader_core_count / (m_tiles * n_tiles); ++ // Clamp to 2 or 4 ++ split_k = std::min(split_k, 4u); ++ if (split_k == 3) { ++ split_k = 2; ++ } ++ } ++ } ++ ++ return split_k; ++} ++ ++static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) { ++ VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ")"); ++ ++ if (ctx->device->coopmat2) { ++ if ((ctx->device->mul_mat_l && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_m && !ctx->device->mul_mat_s)) { ++ return aligned ? mmp->a_l : mmp->l; ++ } ++ if ((ctx->device->mul_mat_m && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_s) { ++ return aligned ? mmp->a_m : mmp->m; ++ } ++ return aligned ? mmp->a_s : mmp->s; ++ } ++ ++ if ((ctx->device->mul_mat_s && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_m && !ctx->device->mul_mat_l)) { ++ return aligned ? mmp->a_s : mmp->s; ++ } ++ if ((ctx->device->mul_mat_m && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_l) { ++ return aligned ? mmp->a_m : mmp->m; ++ } ++ return aligned ? mmp->a_l : mmp->l; ++} ++ ++static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n) { ++ VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ")"); ++ return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true)->align; ++} ++ ++static void ggml_vk_matmul( ++ ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline& pipeline, ++ vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& split_k_buffer, ++ uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d, ++ uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d, ++ uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3) { ++ VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ")"); ++ ggml_vk_sync_buffers(subctx); ++ if (split_k == 1) { ++ const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3 }; ++ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, sizeof(vk_mat_mat_push_constants), &pc, { m, n, batch }); ++ return; ++ } ++ ++ GGML_ASSERT(batch_stride_d == m * n); ++ ++ const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, CEIL_DIV(k, split_k), ne02, ne12, broadcast2, broadcast3 }; ++ // Make sure enough workgroups get assigned for split k to work ++ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, sizeof(vk_mat_mat_push_constants), &pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch }); ++ ggml_vk_sync_buffers(subctx); ++ const std::array pc2 = { (uint32_t)(m * n * batch), split_k }; ++ ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 }); ++} ++ ++static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) { ++ VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ")"); ++ ++ if (ctx->device->coopmat2) { ++ if ((ctx->device->mul_mat_id_l && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_id_m && !ctx->device->mul_mat_id_s)) { ++ return aligned ? mmp->a_l : mmp->l; ++ } ++ if ((ctx->device->mul_mat_id_m && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_id_s) { ++ return aligned ? mmp->a_m : mmp->m; ++ } ++ return aligned ? mmp->a_s : mmp->s; ++ } ++ ++ if ((ctx->device->mul_mat_id_s && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_id_m && !ctx->device->mul_mat_id_l)) { ++ return aligned ? mmp->a_s : mmp->s; ++ } ++ if ((ctx->device->mul_mat_id_m && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_id_l) { ++ return aligned ? mmp->a_m : mmp->m; ++ } ++ return aligned ? mmp->a_l : mmp->l; ++} ++ ++static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n) { ++ VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ")"); ++ return ggml_vk_guess_matmul_id_pipeline(ctx, mmp, m, n, true)->align; ++} ++ ++static void ggml_vk_matmul_id( ++ ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline& pipeline, ++ vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& ids, ++ uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d, ++ uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d, ++ uint32_t n_as, uint32_t nei0, uint32_t nei1, uint32_t nbi1, uint32_t ne11) { ++ VK_LOG_DEBUG("ggml_vk_matmul_id(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), ids: (" << ids.buffer->buffer << ", " << ids.offset << ", " << ids.size << "), " << ++ "m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", " << ++ "batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", " << ++ "n_as: " << n_as << ", nei0: " << nei0 << ", nei1: " << nei1 << ", nbi1: " << nbi1 << ", ne11: " << ne11 << ")"); ++ ggml_vk_sync_buffers(subctx); ++ const vk_mat_mat_id_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, ++ nei0, nei1, nbi1, ne11 }; ++ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids }, sizeof(vk_mat_mat_id_push_constants), &pc, { m, nei1, n_as }); ++} ++ ++static bool ggml_vk_dim01_contiguous(const ggml_tensor * tensor) { ++ return ++ tensor->nb[0] == ggml_type_size(tensor->type) && ++ tensor->nb[1] == (tensor->nb[0]*tensor->ne[0])/ggml_blck_size(tensor->type) && ++ tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; ++} ++ ++static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src, const ggml_tensor * dst, ggml_type to) { ++ ++ // Choose "contiguous copy" shader if src/dst are contiguous ++ bool contig = ggml_is_contiguous(src) && (!dst || ggml_is_contiguous(dst)); ++ ++ if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_F32) { ++ if (contig) { ++ return ctx->device->pipeline_contig_cpy_f32_f32; ++ } else { ++ return ctx->device->pipeline_cpy_f32_f32; ++ } ++ } ++ if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_F16) { ++ if (contig) { ++ return ctx->device->pipeline_contig_cpy_f32_f16; ++ } else { ++ return ctx->device->pipeline_cpy_f32_f16; ++ } ++ } ++ if (src->type == GGML_TYPE_F16 && to == GGML_TYPE_F16) { ++ if (contig) { ++ return ctx->device->pipeline_contig_cpy_f16_f16; ++ } else { ++ return ctx->device->pipeline_cpy_f16_f16; ++ } ++ } ++ ++ std::cerr << "Missing CPY op for types: " << ggml_type_name(src->type) << " " << ggml_type_name(to) << std::endl; ++ GGML_ABORT("fatal error"); ++} ++ ++static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline pipeline, const ggml_tensor * tensor, vk_subbuffer&& in, vk_subbuffer&& out) { ++ VK_LOG_DEBUG("ggml_vk_cpy_to_contiguous((" << tensor << ", type=" << tensor->type << ", ne0=" << tensor->ne[0] << ", ne1=" << tensor->ne[1] << ", ne2=" << tensor->ne[2] << ", ne3=" << tensor->ne[3] << ", nb0=" << tensor->nb[0] << ", nb1=" << tensor->nb[1] << ", nb2=" << tensor->nb[2] << ", nb3=" << tensor->nb[3] << "), "; ++ std::cerr << "buffer in size=" << in.buffer->size << ", buffer out size=" << out.buffer->size << ")"); ++ const int tensor_type_size = ggml_type_size(tensor->type); ++ ++ const uint32_t ne = ggml_nelements(tensor); ++ std::array elements; ++ ++ if (ne > 262144) { ++ elements = { 512, 512, CEIL_DIV(ne, 262144) }; ++ } else if (ne > 512) { ++ elements = { 512, CEIL_DIV(ne, 512), 1 }; ++ } else { ++ elements = { ne, 1, 1 }; ++ } ++ ++ vk_op_unary_push_constants pc = { ++ (uint32_t)ne, ++ (uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3], (uint32_t)tensor->nb[0] / tensor_type_size, (uint32_t)tensor->nb[1] / tensor_type_size, (uint32_t)tensor->nb[2] / tensor_type_size, (uint32_t)tensor->nb[3] / tensor_type_size, ++ (uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3], 1 , (uint32_t)tensor->ne[0] , (uint32_t)(tensor->ne[0] * tensor->ne[1]) , (uint32_t)(tensor->ne[0] * tensor->ne[1] * tensor->ne[2]), ++ 0, ++ 0.0f, 0.0f, ++ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ++ }; ++ init_pushconst_fastdiv(pc); ++ ggml_vk_sync_buffers(subctx); ++ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(vk_op_unary_push_constants), &pc, elements); ++} ++ ++static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { ++ VK_LOG_DEBUG("ggml_vk_mul_mat_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; ++ std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; ++ std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; ++ std::cerr << "), " << (dryrun ? "dryrun" : "") << ")"); ++ GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); // NOLINT ++ GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT ++ ++ const uint64_t ne00 = src0->ne[0]; ++ const uint64_t ne01 = src0->ne[1]; ++ const uint64_t ne02 = src0->ne[2]; ++ const uint64_t ne03 = src0->ne[3]; ++ ++ const uint64_t ne10 = src1->ne[0]; ++ const uint64_t ne11 = src1->ne[1]; ++ const uint64_t ne12 = src1->ne[2]; ++ const uint64_t ne13 = src1->ne[3]; ++ ++ const uint64_t ne20 = dst->ne[0]; ++ const uint64_t ne21 = dst->ne[1]; ++ ++ const uint64_t r2 = ne12 / ne02; ++ const uint64_t r3 = ne13 / ne03; ++ ++ ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; ++ ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; ++ ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; ++ ++ vk_buffer d_Qx = nullptr; ++ size_t qx_buf_offset = 0; ++ vk_buffer d_Qy = nullptr; ++ size_t qy_buf_offset = 0; ++ ++ bool src0_uma = false; ++ bool src1_uma = false; ++ ++ if (ctx->device->uma) { ++ ggml_vk_host_get(ctx->device, src0->data, d_Qx, qx_buf_offset); ++ ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset); ++ src0_uma = d_Qx != nullptr; ++ src1_uma = d_Qy != nullptr; ++ } ++ ++ const bool x_non_contig = !ggml_vk_dim01_contiguous(src0); ++ // Reformat and convert to fp16 if src1 is non-contiguous, or for coopmat2 for better perf ++ const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) || ++ !ggml_vk_dim01_contiguous(src1); ++ ++ const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig; ++ ++ vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type, (ggml_prec)dst->op_params[0]); ++ ++ const bool qx_needs_dequant = mmp == nullptr || x_non_contig; ++ const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig; ++ ++ if (qx_needs_dequant) { ++ // Fall back to dequant + f16 mulmat ++ mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16, (ggml_prec)dst->op_params[0]); ++ } ++ ++ // Not implemented ++ GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT ++ ++ const int x_ne = ne01 * ne00; ++ const int y_ne = ne11 * ne10; ++ const int d_ne = ne11 * ne01; ++ ++ const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11)); ++ const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8; ++ ++ vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned); ++ ++ const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, pipeline); ++ ++ const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type); ++ const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); ++ const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne; ++ const uint64_t y_sz = y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne; ++ const uint64_t d_sz = sizeof(float) * d_ne; ++ ++ vk_pipeline to_fp16_vk_0 = nullptr; ++ vk_pipeline to_fp16_vk_1 = nullptr; ++ ++ if (x_non_contig) { ++ to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, GGML_TYPE_F16); ++ } else { ++ to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type); ++ } ++ if (y_non_contig) { ++ to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, GGML_TYPE_F16); ++ } else { ++ to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type); ++ } ++ GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT ++ GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT ++ ++ if (dryrun) { ++ const uint64_t x_sz_upd = x_sz * ne02 * ne03; ++ const uint64_t y_sz_upd = y_sz * ne12 * ne13; ++ const uint64_t split_k_size = split_k > 1 ? d_sz * ne12 * ne13 * split_k : 0; ++ if ( ++ (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) || ++ (qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size) || ++ (split_k > 1 && split_k_size > ctx->device->max_memory_allocation_size)) { ++ GGML_ABORT("Requested preallocation size is too large"); ++ } ++ if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) { ++ ctx->prealloc_size_x = x_sz_upd; ++ } ++ if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) { ++ ctx->prealloc_size_y = y_sz_upd; ++ } ++ if (split_k > 1 && ctx->prealloc_size_split_k < split_k_size) { ++ ctx->prealloc_size_split_k = split_k_size; ++ } ++ ++ // Request descriptor sets ++ ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); ++ if (qx_needs_dequant) { ++ ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_0, 1); ++ } ++ if (qy_needs_dequant) { ++ ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1); ++ } ++ if (split_k > 1) { ++ ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, 1); ++ } ++ return; ++ } ++ ++ vk_buffer d_D = dst_buf_ctx->dev_buffer; ++ const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; ++ GGML_ASSERT(d_D != nullptr); ++ GGML_ASSERT(d_D->size >= d_buf_offset + d_sz * ne02 * ne03); ++ vk_buffer d_X; ++ uint64_t x_buf_offset = 0; ++ vk_buffer d_Y; ++ uint64_t y_buf_offset = 0; ++ if (!src0_uma) { ++ d_Qx = src0_buf_ctx->dev_buffer; ++ qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs; ++ GGML_ASSERT(d_Qx != nullptr); ++ } ++ if (!src1_uma) { ++ d_Qy = src1_buf_ctx->dev_buffer; ++ qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs; ++ GGML_ASSERT(d_Qy != nullptr); ++ } ++ if (qx_needs_dequant) { ++ d_X = ctx->prealloc_x; ++ GGML_ASSERT(d_X->size >= x_sz * ne02 * ne03); ++ } else { ++ d_X = d_Qx; ++ x_buf_offset = qx_buf_offset; ++ GGML_ASSERT(qx_sz == x_sz); ++ } ++ if (qy_needs_dequant) { ++ d_Y = ctx->prealloc_y; ++ GGML_ASSERT(d_Y->size >= y_sz * ne02 * ne03); ++ } else { ++ d_Y = d_Qy; ++ y_buf_offset = qy_buf_offset; ++ GGML_ASSERT(qy_sz == y_sz); ++ } ++ ++ if (x_non_contig) { ++ ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE }); ++ } else if (qx_needs_dequant) { ++ const std::vector pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) }; ++ ggml_vk_sync_buffers(subctx); ++ ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, vk_subbuffer{ d_X, 0, x_sz * ne02 * ne03 } }, pc.size() * sizeof(uint32_t), pc.data(), { (uint32_t)(x_ne * ne02 * ne03), 1, 1}); ++ } ++ if (y_non_contig) { ++ ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }); ++ } ++ ++ uint32_t stride_batch_x = ne00*ne01; ++ uint32_t stride_batch_y = ne10*ne11; ++ ++ if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) { ++ stride_batch_x = src0->nb[0] / ggml_type_size(src0->type); ++ } ++ ++ if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) { ++ stride_batch_y = src1->nb[0] / ggml_type_size(src1->type); ++ } ++ ++ // compute ++ ggml_vk_matmul( ++ ctx, subctx, pipeline, ++ { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 }, ++ { d_D, d_buf_offset, d_sz * ne12 * ne13 }, { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k }, ++ ne01, ne11, ne10, ++ ne10, ne10, ne01, stride_batch_x, stride_batch_y, ne20*ne21, ++ split_k, ne12*ne13, ne02, ne12, r2, r3 ++ ); // NOLINT ++} ++ ++static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { ++ VK_LOG_DEBUG("ggml_vk_mul_mat_vec_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; ++ std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; ++ std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; ++ std::cerr << "), " << (dryrun ? "dryrun" : "") << "),)"); ++ GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); // NOLINT ++ GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT ++ ++ const uint64_t ne00 = src0->ne[0]; ++ const uint64_t ne01 = src0->ne[1]; ++ const uint64_t ne02 = src0->ne[2]; ++ const uint64_t ne03 = src0->ne[3]; ++ ++ const uint64_t ne10 = src1->ne[0]; ++ const uint64_t ne11 = src1->ne[1]; ++ const uint64_t ne12 = src1->ne[2]; ++ const uint64_t ne13 = src1->ne[3]; ++ ++ const uint64_t ne20 = dst->ne[0]; ++ const uint64_t ne21 = dst->ne[1]; ++ const uint64_t ne22 = dst->ne[2]; ++ const uint64_t ne23 = dst->ne[3]; ++ ++ const uint64_t r2 = ne12 / ne02; ++ const uint64_t r3 = ne13 / ne03; ++ ++ // batch_n indicates that we need to compute a few vector results, and this assumes ++ // ne12 and ne13 are 1. It overloads the batch_strides to hold the row strides. ++ GGML_ASSERT(ne11 == 1 || ne12 * ne13 == 1); ++ bool batch_n = ne11 > 1; ++ ++ ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; ++ ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; ++ ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; ++ ++ vk_buffer d_Qx = nullptr; ++ size_t qx_buf_offset = 0; ++ vk_buffer d_Qy = nullptr; ++ size_t qy_buf_offset = 0; ++ ++ bool src0_uma = false; ++ bool src1_uma = false; ++ ++ if (ctx->device->uma) { ++ ggml_vk_host_get(ctx->device, src0->data, d_Qx, qx_buf_offset); ++ ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset); ++ src0_uma = d_Qx != nullptr; ++ src1_uma = d_Qy != nullptr; ++ } ++ ++ const bool x_non_contig = !ggml_vk_dim01_contiguous(src0); ++ const bool y_non_contig = !ggml_vk_dim01_contiguous(src1); ++ ++ const bool f16_f32_kernel = src1->type == GGML_TYPE_F32; ++ ++ const bool qx_needs_dequant = x_non_contig; ++ const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig; ++ ++ // Not implemented ++ GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT ++ ++ const uint64_t x_ne = ne01 * ne00; ++ const uint64_t y_ne = ne11 * ne10; ++ const uint64_t d_ne = ne11 * ne01; ++ ++ const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment); ++ const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); ++ const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz; ++ const uint64_t y_sz = f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne; ++ const uint64_t d_sz = sizeof(float) * d_ne; ++ ++ vk_pipeline to_fp16_vk_0 = nullptr; ++ vk_pipeline to_fp16_vk_1 = nullptr; ++ if (x_non_contig) { ++ to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, src0->type); ++ } ++ if (y_non_contig) { ++ to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, src1->type); ++ } else { ++ to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type); ++ } ++ vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, src1->type, ne11); ++ GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT ++ GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT ++ GGML_ASSERT(dmmv != nullptr); ++ ++ if (dryrun) { ++ const uint64_t x_sz_upd = x_sz * ne02 * ne03; ++ const uint64_t y_sz_upd = y_sz * ne12 * ne13; ++ if ( ++ (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) || ++ (qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size)) { ++ GGML_ABORT("Requested preallocation size is too large"); ++ } ++ if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) { ++ ctx->prealloc_size_x = x_sz_upd; ++ } ++ if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) { ++ ctx->prealloc_size_y = y_sz_upd; ++ } ++ ++ // Request descriptor sets ++ if (qx_needs_dequant) { ++ ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_0, 1); ++ } ++ if (qy_needs_dequant) { ++ ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1); ++ } ++ ggml_pipeline_request_descriptor_sets(ctx->device, dmmv, 1); ++ return; ++ } ++ ++ vk_buffer d_D = dst_buf_ctx->dev_buffer; ++ const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; ++ GGML_ASSERT(d_D != nullptr); ++ vk_buffer d_X; ++ uint64_t x_buf_offset = 0; ++ vk_buffer d_Y; ++ uint64_t y_buf_offset = 0; ++ if(!src0_uma) { ++ d_Qx = src0_buf_ctx->dev_buffer; ++ qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs; ++ GGML_ASSERT(d_Qx != nullptr); ++ } ++ if(!src1_uma) { ++ d_Qy = src1_buf_ctx->dev_buffer; ++ qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs; ++ GGML_ASSERT(d_Qy != nullptr); ++ } ++ if (qx_needs_dequant) { ++ d_X = ctx->prealloc_x; ++ } else { ++ d_X = d_Qx; ++ x_buf_offset = qx_buf_offset; ++ GGML_ASSERT(qx_sz == x_sz); ++ } ++ if (qy_needs_dequant) { ++ d_Y = ctx->prealloc_y; ++ } else { ++ d_Y = d_Qy; ++ y_buf_offset = qy_buf_offset; ++ GGML_ASSERT(qy_sz == y_sz); ++ } ++ ++ if (x_non_contig) { ++ GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment)); ++ ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE }); ++ } ++ if (y_non_contig) { ++ GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne); ++ ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }); ++ } ++ ++ // For batch_n, the A matrix is the same for each batch, and B/D use the row stride as the batch stride ++ uint32_t stride_batch_x = batch_n ? 0 : ne00*ne01; ++ uint32_t stride_batch_y = batch_n ? ne10 : (ne10*ne11); ++ uint32_t stride_batch_d = batch_n ? ne20 : (ne20*ne21); ++ ++ if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) { ++ stride_batch_x = src0->nb[0] / ggml_type_size(src0->type); ++ } ++ ++ if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) { ++ stride_batch_y = src1->nb[0] / ggml_type_size(src1->type); ++ } ++ ++ const uint32_t max_groups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0]; ++ ++ uint32_t groups_x = ne01; ++ uint32_t groups_z = 1; ++ ++ if (ne01 > max_groups_x) { ++ groups_z = 64; ++ groups_x = CEIL_DIV(groups_x, groups_z); ++ } ++ ++ // compute ++ const vk_mat_vec_push_constants pc = { ++ (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01, ++ stride_batch_x, stride_batch_y, stride_batch_d, ++ (uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3, ++ }; ++ ggml_vk_sync_buffers(subctx); ++ ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, ++ { vk_subbuffer{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, vk_subbuffer{ d_Y, y_buf_offset, y_sz * ne12 * ne13 }, vk_subbuffer{ d_D, d_buf_offset, d_sz * ne22 * ne23} }, ++ sizeof(vk_mat_vec_push_constants), &pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z }); ++} ++ ++static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { ++ VK_LOG_DEBUG("ggml_vk_mul_mat_p021_f16_f32(" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; ++ std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; ++ std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; ++ std::cerr << "), " << (dryrun ? "dryrun" : "") << ")"); ++ GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1)); ++ GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // NOLINT ++ GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); // NOLINT ++ GGML_ASSERT(src0->type == GGML_TYPE_F16); ++ GGML_ASSERT(src1->type == GGML_TYPE_F32); ++ ++ const uint64_t ne00 = src0->ne[0]; ++ const uint64_t ne01 = src0->ne[1]; ++ const uint64_t ne02 = src0->ne[2]; ++ // const uint64_t ne03 = src0->ne[3]; ++ ++ const uint64_t ne10 = src1->ne[0]; ++ const uint64_t ne11 = src1->ne[1]; ++ const uint64_t ne12 = src1->ne[2]; ++ // const uint64_t ne13 = src1->ne[3]; ++ ++ GGML_ASSERT(ne11 == 1); ++ ++ ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; ++ ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; ++ ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; ++ ++ vk_buffer d_Qy = nullptr; ++ size_t qy_buf_offset = 0; ++ ++ bool src1_uma = false; ++ ++ if (ctx->device->uma) { ++ ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset); ++ src1_uma = d_Qy != nullptr; ++ } ++ ++ const uint64_t x_ne = ne00 * ne01 * ne02; ++ const uint64_t y_ne = ne10 * ne11 * ne12; ++ const uint64_t d_ne = ne01 * ne11 * ne12; ++ ++ const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment); ++ const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); ++ const uint64_t d_sz = sizeof(float) * d_ne; ++ ++ if (dryrun) { ++ // Request descriptor sets ++ ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_mul_mat_vec_p021_f16_f32, 1); ++ return; ++ } ++ ++ vk_buffer d_D = dst_buf_ctx->dev_buffer; ++ const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; ++ GGML_ASSERT(d_D != nullptr); ++ vk_buffer d_Qx = src0_buf_ctx->dev_buffer; ++ const uint64_t qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs; ++ GGML_ASSERT(d_Qx != nullptr); ++ if (!src1_uma) { ++ d_Qy = src1_buf_ctx->dev_buffer; ++ qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs; ++ GGML_ASSERT(d_Qx != nullptr); ++ } ++ ++ const uint64_t qy_buffer_offset = (qy_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment; ++ const uint64_t qy_shader_offset = qy_buf_offset - qy_buffer_offset; ++ ++ const uint64_t d_buffer_offset = (d_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment; ++ const uint64_t d_shader_offset = d_buf_offset - d_buffer_offset; ++ ++ // compute ++ const std::array pc = { (uint32_t)ne00, (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) }; ++ ggml_vk_sync_buffers(subctx); ++ ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 6 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 }); ++} ++ ++static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { ++ VK_LOG_DEBUG("ggml_vk_mul_mat_nc_f16_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; ++ std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; ++ std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; ++ std::cerr << "), " << (dryrun ? "dryrun" : "") << ")"); ++ GGML_ASSERT(!ggml_is_transposed(src0)); ++ GGML_ASSERT(!ggml_is_transposed(src1)); ++ GGML_ASSERT(!ggml_is_permuted(src0)); ++ GGML_ASSERT(src0->type == GGML_TYPE_F16); ++ GGML_ASSERT(src1->type == GGML_TYPE_F32); ++ ++ const uint64_t ne00 = src0->ne[0]; ++ const uint64_t ne01 = src0->ne[1]; ++ const uint64_t ne02 = src0->ne[2]; ++ // const uint64_t ne03 = src0->ne[3]; ++ ++ const uint64_t nb01 = src0->nb[1]; ++ const uint64_t nb02 = src0->nb[2]; ++ ++ // const uint64_t ne10 = src1->ne[0]; ++ const uint64_t ne11 = src1->ne[1]; ++ const uint64_t ne12 = src1->ne[2]; ++ // const uint64_t ne13 = src1->ne[3]; ++ ++ GGML_ASSERT(ne11 == 1); ++ ++ ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; ++ ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; ++ ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; ++ ++ vk_buffer d_Qy = nullptr; ++ size_t qy_buf_offset = 0; ++ ++ bool src1_uma = false; ++ ++ if (ctx->device->uma) { ++ ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset); ++ src1_uma = d_Qy != nullptr; ++ } ++ ++ const uint64_t d_ne = ne01 * ne11 * ne12; ++ ++ const uint32_t row_stride_x = nb01 / sizeof(ggml_fp16_t); ++ const uint32_t channel_stride_x = nb02 / sizeof(ggml_fp16_t); ++ ++ const uint64_t qx_sz = ggml_nbytes(src0); ++ const uint64_t qy_sz = ggml_nbytes(src1); ++ const uint64_t d_sz = sizeof(float) * d_ne; ++ ++ if (dryrun) { ++ // Request descriptor sets ++ ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, 1); ++ return; ++ } ++ ++ vk_buffer d_D = dst_buf_ctx->dev_buffer; ++ const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; ++ GGML_ASSERT(d_D != nullptr); ++ vk_buffer d_Qx = src0_buf_ctx->dev_buffer; ++ const uint64_t qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs; ++ GGML_ASSERT(d_Qx != nullptr); ++ if (!src1_uma) { ++ d_Qy = src1_buf_ctx->dev_buffer; ++ qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs; ++ GGML_ASSERT(d_Qx != nullptr); ++ } ++ ++ const uint64_t qy_buffer_offset = (qy_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment; ++ const uint64_t qy_shader_offset = qy_buf_offset - qy_buffer_offset; ++ ++ const uint64_t d_buffer_offset = (d_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment; ++ const uint64_t d_shader_offset = d_buf_offset - d_buffer_offset; ++ ++ // compute ++ const std::array pc = { (uint32_t)ne00, (uint32_t)ne01, row_stride_x, channel_stride_x, (uint32_t)(ne12 / ne02), (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) }; ++ ggml_vk_sync_buffers(subctx); ++ ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, ++ { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 7 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 }); ++} ++ ++static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { ++ VK_LOG_DEBUG("ggml_vk_mul_mat(" << src0 << ", " << src1 << ", " << dst << ")"); ++ if (src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && dst->ne[1] == 1 && ++ // detect 0213 permutation, and batch size of 1 ++ src0->nb[0] <= src0->nb[2] && ++ src0->nb[2] <= src0->nb[1] && ++ src0->nb[1] <= src0->nb[3] && ++ src1->nb[0] <= src1->nb[2] && ++ src1->nb[2] <= src1->nb[1] && ++ src1->nb[1] <= src1->nb[3] && ++ src0->ne[3] == 1 && ++ src1->ne[3] == 1) { ++ ggml_vk_mul_mat_vec_p021_f16_f32(ctx, subctx, src0, src1, dst, dryrun); ++ } else if (src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && dst->ne[1] == 1 && ++ !ggml_is_permuted(src0) && !ggml_is_permuted(src1)) { ++ ggml_vk_mul_mat_vec_nc_f16_f32(ctx, subctx, src0, src1, dst, dryrun); ++ // mul_mat_vec supports batching ne12*ne13 when ne11==1, or treating ne11 as the batch size (up to four) ++ // when ne12 and ne13 are one. ++ } else if ((dst->ne[1] == 1 || (dst->ne[1] <= mul_mat_vec_max_cols && src1->ne[2] * src1->ne[3] == 1)) && ++ (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) { ++ ggml_vk_mul_mat_vec_q_f16(ctx, subctx, src0, src1, dst, dryrun); ++ } else { ++ ggml_vk_mul_mat_q_f16(ctx, subctx, src0, src1, dst, dryrun); ++ } ++} ++ ++static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, bool dryrun = false) { ++ VK_LOG_DEBUG("ggml_vk_mul_mat_id_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; ++ std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; ++ std::cerr << "), (" << ids << ", name=" << ids->name << ", type=" << ids->type << ", ne0=" << ids->ne[0] << ", ne1=" << ids->ne[1] << ", ne2=" << ids->ne[2] << ", ne3=" << ids->ne[3] << ", nb0=" << ids->nb[0] << ", nb1=" << ids->nb[1] << ", nb2=" << ids->nb[2] << ", nb3=" << ids->nb[3]; ++ std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)"); ++ GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT ++ GGML_ASSERT(ids->type == GGML_TYPE_I32); ++ ++ const uint64_t ne00 = src0->ne[0]; ++ const uint64_t ne01 = src0->ne[1]; ++ const uint64_t ne02 = src0->ne[2]; ++ const uint64_t ne03 = src0->ne[3]; ++ ++ const uint64_t ne10 = src1->ne[0]; ++ const uint64_t ne11 = src1->ne[1]; ++ const uint64_t ne12 = src1->ne[2]; ++ const uint64_t ne13 = src1->ne[3]; ++ ++ const uint64_t nei0 = ids->ne[0]; ++ const uint64_t nei1 = ids->ne[1]; ++ GGML_ASSERT(nei0 * nei1 <= 3072); ++ ++ const uint32_t nbi1 = ids->nb[1]; ++ const uint32_t nbi2 = ids->nb[2]; ++ ++ const uint64_t ne20 = dst->ne[0]; ++ const uint64_t ne21 = dst->ne[1]; ++ const uint64_t ne22 = dst->ne[2]; ++ const uint64_t ne23 = dst->ne[3]; ++ ++ const uint64_t n_as = ne02; ++ ++ ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; ++ ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; ++ ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; ++ ggml_backend_vk_buffer_context * ids_buf_ctx = (ggml_backend_vk_buffer_context *)ids->buffer->context; ++ ++ vk_buffer d_Qx = nullptr; ++ size_t qx_buf_offset = 0; ++ vk_buffer d_Qy = nullptr; ++ size_t qy_buf_offset = 0; ++ vk_buffer d_ids = nullptr; ++ size_t ids_buf_offset = 0; ++ ++ bool src0_uma = false; ++ bool src1_uma = false; ++ bool ids_uma = false; ++ ++ if (ctx->device->uma) { ++ ggml_vk_host_get(ctx->device, src0->data, d_Qx, qx_buf_offset); ++ ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset); ++ ggml_vk_host_get(ctx->device, ids->data, d_ids, ids_buf_offset); ++ src0_uma = d_Qx != nullptr; ++ src1_uma = d_Qy != nullptr; ++ ids_uma = d_ids != nullptr; ++ } ++ ++ const bool x_non_contig = !ggml_vk_dim01_contiguous(src0); ++ const bool y_non_contig = !ggml_vk_dim01_contiguous(src1); ++ ++ const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig; ++ ++ vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type, (ggml_prec)dst->op_params[0]); ++ ++ const bool qx_needs_dequant = mmp == nullptr || x_non_contig; ++ const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig; ++ ++ if (qx_needs_dequant) { ++ GGML_ABORT("fatal error"); ++ } ++ ++ // Not implemented ++ GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT ++ ++ const uint64_t x_ne = ne01 * ne00; ++ const uint64_t y_ne = ne11 * ne10; ++ const uint64_t d_ne = ne21 * ne20; ++ ++ const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1)); ++ const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8; ++ ++ vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned); ++ ++ const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type); ++ const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); ++ const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne; ++ const uint64_t y_sz = y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne; ++ const uint64_t ids_sz = nbi2; ++ const uint64_t d_sz = sizeof(float) * d_ne; ++ ++ vk_pipeline to_fp16_vk_0 = nullptr; ++ vk_pipeline to_fp16_vk_1 = nullptr; ++ ++ if (x_non_contig) { ++ to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, GGML_TYPE_F16); ++ } else { ++ to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type); ++ } ++ if (y_non_contig) { ++ to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, GGML_TYPE_F16); ++ } else { ++ to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type); ++ } ++ GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT ++ GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT ++ ++ if (dryrun) { ++ const uint64_t x_sz_upd = x_sz * ne02 * ne03; ++ const uint64_t y_sz_upd = y_sz * ne12 * ne13; ++ if ( ++ (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) || ++ (qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size)) { ++ GGML_ABORT("Requested preallocation size is too large"); ++ } ++ if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) { ++ ctx->prealloc_size_x = x_sz_upd; ++ } ++ if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) { ++ ctx->prealloc_size_y = y_sz_upd; ++ } ++ ++ // Request descriptor sets ++ ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); ++ if (qx_needs_dequant) { ++ ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_0, 1); ++ } ++ if (qy_needs_dequant) { ++ ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1); ++ } ++ return; ++ } ++ ++ vk_buffer d_D = dst_buf_ctx->dev_buffer; ++ const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; ++ GGML_ASSERT(d_D != nullptr); ++ vk_buffer d_X; ++ uint64_t x_buf_offset = 0; ++ vk_buffer d_Y; ++ uint64_t y_buf_offset = 0; ++ if (!src0_uma) { ++ d_Qx = src0_buf_ctx->dev_buffer; ++ qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs; ++ GGML_ASSERT(d_Qx != nullptr); ++ } ++ if (!src1_uma) { ++ d_Qy = src1_buf_ctx->dev_buffer; ++ qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs; ++ GGML_ASSERT(d_Qy != nullptr); ++ } ++ if (!ids_uma) { ++ d_ids = ids_buf_ctx->dev_buffer; ++ ids_buf_offset = vk_tensor_offset(ids) + ids->view_offs; ++ GGML_ASSERT(d_ids != nullptr); ++ } ++ if (qx_needs_dequant) { ++ d_X = ctx->prealloc_x; ++ GGML_ASSERT(d_X->size >= x_sz * ne02 * ne03); ++ } else { ++ d_X = d_Qx; ++ x_buf_offset = qx_buf_offset; ++ GGML_ASSERT(qx_sz == x_sz); ++ } ++ if (qy_needs_dequant) { ++ d_Y = ctx->prealloc_y; ++ GGML_ASSERT(d_Y->size >= y_sz * ne02 * ne03); ++ } else { ++ d_Y = d_Qy; ++ y_buf_offset = qy_buf_offset; ++ GGML_ASSERT(qy_sz == y_sz); ++ } ++ ++ if (x_non_contig) { ++ ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE }); ++ } else if (qx_needs_dequant) { ++ const std::vector pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) }; ++ ggml_vk_sync_buffers(subctx); ++ ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0, ++ { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, vk_subbuffer{ d_X, 0, x_sz * ne02 * ne03 } }, pc.size() * sizeof(uint32_t), pc.data(), { (uint32_t)(x_ne * ne02 * ne03), 1, 1}); ++ } ++ if (y_non_contig) { ++ ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }); ++ } ++ ++ uint32_t stride_batch_x = ne00*ne01; ++ uint32_t stride_batch_y = ne10*ne11; ++ ++ if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) { ++ stride_batch_x = src0->nb[0] / ggml_type_size(src0->type); ++ } ++ ++ if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) { ++ stride_batch_y = src1->nb[0] / ggml_type_size(src1->type); ++ } ++ ++ // compute ++ ggml_vk_matmul_id( ++ ctx, subctx, pipeline, ++ { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 }, ++ { d_D, d_buf_offset, d_sz * ne22 * ne23 }, { d_ids, ids_buf_offset, ids_sz }, ++ ne01, ne21, ne10, ne10, ne10, ne01, ++ stride_batch_x, stride_batch_y, ne20*ne21, ++ n_as, nei0, nei1, nbi1 / ggml_type_size(ids->type), ne11 ++ ); // NOLINT ++} ++ ++static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, bool dryrun = false) { ++ VK_LOG_DEBUG("ggml_vk_mul_mat_vec_id_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; ++ std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; ++ std::cerr << "), (" << ids << ", name=" << ids->name << ", type=" << ids->type << ", ne0=" << ids->ne[0] << ", ne1=" << ids->ne[1] << ", ne2=" << ids->ne[2] << ", ne3=" << ids->ne[3] << ", nb0=" << ids->nb[0] << ", nb1=" << ids->nb[1] << ", nb2=" << ids->nb[2] << ", nb3=" << ids->nb[3]; ++ std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; ++ std::cerr << "), " << (dryrun ? "dryrun" : "") << ")"); ++ GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); // NOLINT ++ GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT ++ GGML_ASSERT(ids->type == GGML_TYPE_I32); ++ ++ const uint64_t ne00 = src0->ne[0]; ++ const uint64_t ne01 = src0->ne[1]; ++ const uint64_t ne02 = src0->ne[2]; ++ const uint64_t ne03 = src0->ne[3]; ++ ++ const uint64_t ne10 = src1->ne[0]; ++ const uint64_t ne11 = src1->ne[1]; ++ const uint64_t ne12 = src1->ne[2]; ++ const uint64_t ne13 = src1->ne[3]; ++ ++ const uint64_t nei0 = ids->ne[0]; ++ const uint64_t nei1 = ids->ne[1]; ++ ++ const uint64_t nbi2 = ids->nb[2]; ++ ++ GGML_ASSERT(nei1 == 1); ++ ++ const uint64_t ne20 = dst->ne[0]; ++ const uint64_t ne21 = dst->ne[1]; ++ const uint64_t ne22 = dst->ne[2]; ++ const uint64_t ne23 = dst->ne[3]; ++ ++ ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; ++ ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; ++ ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; ++ ggml_backend_vk_buffer_context * ids_buf_ctx = (ggml_backend_vk_buffer_context *)ids->buffer->context; ++ ++ vk_buffer d_Qx = nullptr; ++ size_t qx_buf_offset = 0; ++ vk_buffer d_Qy = nullptr; ++ size_t qy_buf_offset = 0; ++ vk_buffer d_ids = nullptr; ++ size_t ids_buf_offset = 0; ++ ++ bool src0_uma = false; ++ bool src1_uma = false; ++ bool ids_uma = false; ++ ++ if (ctx->device->uma) { ++ ggml_vk_host_get(ctx->device, src0->data, d_Qx, qx_buf_offset); ++ ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset); ++ ggml_vk_host_get(ctx->device, ids->data, d_ids, ids_buf_offset); ++ src0_uma = d_Qx != nullptr; ++ src1_uma = d_Qy != nullptr; ++ ids_uma = d_ids != nullptr; ++ } ++ ++ const bool x_non_contig = !ggml_vk_dim01_contiguous(src0); ++ const bool y_non_contig = !ggml_vk_dim01_contiguous(src1); ++ ++ const bool f16_f32_kernel = src1->type == GGML_TYPE_F32; ++ ++ const bool qx_needs_dequant = x_non_contig; ++ const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig; ++ ++ // Not implemented ++ GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT ++ ++ const uint64_t x_ne = ne01 * ne00; ++ const uint64_t y_ne = ne11 * ne10; ++ const uint64_t d_ne = ne21 * ne20; ++ ++ const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment); ++ const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); ++ const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz; ++ const uint64_t y_sz = f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne; ++ const uint64_t ids_sz = nbi2; ++ const uint64_t d_sz = sizeof(float) * d_ne; ++ ++ vk_pipeline to_fp16_vk_0 = nullptr; ++ vk_pipeline to_fp16_vk_1 = nullptr; ++ if (x_non_contig) { ++ to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, src0->type); ++ } ++ if (y_non_contig) { ++ to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, src1->type); ++ } else { ++ to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type); ++ } ++ vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec_id(ctx, src0->type, src1->type); ++ GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT ++ GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT ++ GGML_ASSERT(dmmv != nullptr); ++ ++ if (dryrun) { ++ const uint64_t x_sz_upd = x_sz * ne02 * ne03; ++ const uint64_t y_sz_upd = y_sz * ne12 * ne13; ++ if ( ++ (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) || ++ (qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size)) { ++ GGML_ABORT("Requested preallocation size is too large"); ++ } ++ if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) { ++ ctx->prealloc_size_x = x_sz_upd; ++ } ++ if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) { ++ ctx->prealloc_size_y = y_sz_upd; ++ } ++ ++ // Request descriptor sets ++ if (qx_needs_dequant) { ++ ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_0, 1); ++ } ++ if (qy_needs_dequant) { ++ ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1); ++ } ++ ggml_pipeline_request_descriptor_sets(ctx->device, dmmv, 1); ++ return; ++ } ++ ++ vk_buffer d_D = dst_buf_ctx->dev_buffer; ++ const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; ++ GGML_ASSERT(d_D != nullptr); ++ vk_buffer d_X; ++ uint64_t x_buf_offset = 0; ++ vk_buffer d_Y; ++ uint64_t y_buf_offset = 0; ++ if(!src0_uma) { ++ d_Qx = src0_buf_ctx->dev_buffer; ++ qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs; ++ GGML_ASSERT(d_Qx != nullptr); ++ } ++ if(!src1_uma) { ++ d_Qy = src1_buf_ctx->dev_buffer; ++ qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs; ++ GGML_ASSERT(d_Qy != nullptr); ++ } ++ if(!ids_uma) { ++ d_ids = ids_buf_ctx->dev_buffer; ++ ids_buf_offset = vk_tensor_offset(ids) + ids->view_offs; ++ GGML_ASSERT(d_ids != nullptr); ++ } ++ if (qx_needs_dequant) { ++ d_X = ctx->prealloc_x; ++ } else { ++ d_X = d_Qx; ++ x_buf_offset = qx_buf_offset; ++ GGML_ASSERT(qx_sz == x_sz); ++ } ++ if (qy_needs_dequant) { ++ d_Y = ctx->prealloc_y; ++ } else { ++ d_Y = d_Qy; ++ y_buf_offset = qy_buf_offset; ++ GGML_ASSERT(qy_sz == y_sz); ++ } ++ ++ if (x_non_contig) { ++ GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment)); ++ ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE }); ++ } ++ if (y_non_contig) { ++ GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne); ++ ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }); ++ } ++ ++ uint32_t stride_batch_y = ne10*ne11; ++ ++ if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) { ++ stride_batch_y = src1->nb[0] / ggml_type_size(src1->type); ++ } ++ ++ const uint32_t max_groups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0]; ++ ++ uint32_t groups_x = ne01; ++ uint32_t groups_z = 1; ++ ++ if (ne01 > max_groups_x) { ++ groups_z = 64; ++ groups_x = CEIL_DIV(groups_x, groups_z); ++ } ++ ++ // compute ++ const vk_mat_vec_id_push_constants pc = { ++ (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01, ++ (uint32_t)x_ne, stride_batch_y, (uint32_t)(ne20*ne21), ++ (uint32_t)nei0, (uint32_t)ne11, ++ }; ++ ggml_vk_sync_buffers(subctx); ++ ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, ++ { vk_subbuffer{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, ++ vk_subbuffer{ d_Y, y_buf_offset, y_sz * ne12 * ne13 }, vk_subbuffer{ d_D, d_buf_offset, d_sz * ne22 * ne23}, vk_subbuffer{ d_ids, ids_buf_offset, ids_sz } }, ++ sizeof(vk_mat_vec_id_push_constants), &pc, { groups_x, (uint32_t)nei0, groups_z }); ++} ++ ++static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) { ++ VK_LOG_DEBUG("ggml_vk_mul_mat_id(" << src0 << ", " << src1 << ", " << src2 << ", " << dst << ")"); ++ if (src2->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) { ++ ggml_vk_mul_mat_vec_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun); ++ } else { ++ ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun); ++ } ++} ++ ++static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * q, const ggml_tensor * k, const ggml_tensor * v, const ggml_tensor * mask, ggml_tensor * dst, bool dryrun = false) { ++ VK_LOG_DEBUG("ggml_vk_flash_attn((" << q << ", name=" << q->name << ", type=" << q->type << ", ne0=" << q->ne[0] << ", ne1=" << q->ne[1] << ", ne2=" << q->ne[2] << ", ne3=" << q->ne[3] << ", nb0=" << q->nb[0] << ", nb1=" << q->nb[1] << ", nb2=" << q->nb[2] << ", nb3=" << q->nb[3]; ++ std::cerr << "), (" << k << ", name=" << k->name << ", type=" << k->type << ", ne0=" << k->ne[0] << ", ne1=" << k->ne[1] << ", ne2=" << k->ne[2] << ", ne3=" << k->ne[3] << ", nb0=" << k->nb[0] << ", nb1=" << k->nb[1] << ", nb2=" << k->nb[2] << ", nb3=" << k->nb[3]; ++ std::cerr << "), (" << v << ", name=" << v->name << ", type=" << v->type << ", ne0=" << v->ne[0] << ", ne1=" << v->ne[1] << ", ne2=" << v->ne[2] << ", ne3=" << v->ne[3] << ", nb0=" << v->nb[0] << ", nb1=" << v->nb[1] << ", nb2=" << v->nb[2] << ", nb3=" << v->nb[3]; ++ std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; ++ std::cerr << "), " << (dryrun ? "dryrun" : "") << ")"); ++ ++ GGML_TENSOR_LOCALS(int64_t, neq, q, ne) ++ GGML_TENSOR_LOCALS(size_t, nbq, q, nb) ++ GGML_TENSOR_LOCALS(int64_t, nek, k, ne) ++ GGML_TENSOR_LOCALS(size_t, nbk, k, nb) ++ GGML_TENSOR_LOCALS(int64_t, nev, v, ne) ++ GGML_TENSOR_LOCALS(size_t, nbv, v, nb) ++ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) ++ GGML_TENSOR_LOCALS(size_t, nb, dst, nb) ++ ++ const uint32_t nem1 = mask ? mask->ne[1] : 0; ++ const uint32_t nbm1 = mask ? mask->nb[1] : 0; ++ ++ const uint32_t D = neq0; ++ const uint32_t N = neq1; ++ const uint32_t KV = nek1; ++ ++ GGML_ASSERT(ne0 == D); ++ GGML_ASSERT(ne2 == N); ++ ++ // input tensor rows must be contiguous ++ GGML_ASSERT(nbq0 == ggml_type_size(q->type)); ++ GGML_ASSERT(nbk0 == ggml_type_size(k->type)); ++ GGML_ASSERT(nbv0 == ggml_type_size(v->type)); ++ ++ GGML_ASSERT(neq0 == D); ++ GGML_ASSERT(nek0 == D); ++ GGML_ASSERT(nev0 == D); ++ ++ GGML_ASSERT(neq1 == N); ++ GGML_ASSERT(nev0 == D); ++ ++ GGML_ASSERT(nev1 == nek1); ++ ++ // dst cannot be transposed or permuted ++ GGML_ASSERT(nb0 == sizeof(float)); ++ GGML_ASSERT(nb0 <= nb1); ++ GGML_ASSERT(nb1 <= nb2); ++ GGML_ASSERT(nb2 <= nb3); ++ ++ assert(dst->type == GGML_TYPE_F32); ++ assert(q->type == GGML_TYPE_F32); ++ assert(k->type == v->type); ++ ++ vk_pipeline *pipelines; ++ // XXX TODO other backends may be changing accumulator precision to default to f32 soon ++ bool f32acc = dst->op_params[3] == GGML_PREC_F32; ++ bool small_rows = N <= flash_attention_num_small_rows; ++ switch (D) { ++ case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break; ++ case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break; ++ case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96[k->type][f32acc][small_rows][0]; break; ++ case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112[k->type][f32acc][small_rows][0]; break; ++ case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128[k->type][f32acc][small_rows][0]; break; ++ case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256[k->type][f32acc][small_rows][0]; break; ++ default: ++ assert(!"unsupported D value"); ++ return; ++ } ++ assert(pipelines); ++ ++ bool aligned = (KV % pipelines[1]->align) == 0; ++ vk_pipeline pipeline = pipelines[aligned]; ++ assert(pipeline); ++ ++ if (dryrun) { ++ // Request descriptor sets ++ ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); ++ return; ++ } ++ ++ float scale = 1.0f; ++ float max_bias = 0.0f; ++ float logit_softcap = 0.0f; ++ ++ memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float)); ++ memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float)); ++ memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float)); ++ ++ if (logit_softcap != 0) { ++ scale /= logit_softcap; ++ } ++ ++ const uint32_t n_head_kv = neq2; ++ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv)); ++ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); ++ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); ++ ++ ggml_vk_sync_buffers(subctx); ++ ++ vk_buffer d_Q = nullptr, d_K = nullptr, d_V = nullptr, d_D = nullptr, d_M = nullptr; ++ size_t q_buf_offset = 0, k_buf_offset = 0, v_buf_offset = 0, d_buf_offset = 0, m_buf_offset = 0; ++ ++ bool Q_uma = false, K_uma = false, V_uma = false, D_uma = false, M_uma = false; ++ ++ if (ctx->device->uma) { ++ ggml_vk_host_get(ctx->device, q->data, d_Q, q_buf_offset); ++ ggml_vk_host_get(ctx->device, k->data, d_K, q_buf_offset); ++ ggml_vk_host_get(ctx->device, v->data, d_V, q_buf_offset); ++ ggml_vk_host_get(ctx->device, dst->data, d_D, q_buf_offset); ++ Q_uma = d_Q != nullptr; ++ K_uma = d_K != nullptr; ++ V_uma = d_V != nullptr; ++ D_uma = d_D != nullptr; ++ if (mask) { ++ ggml_vk_host_get(ctx->device, mask->data, d_M, q_buf_offset); ++ M_uma = d_M != nullptr; ++ } ++ } ++ ++ ++ ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; ++ ggml_backend_vk_buffer_context * q_buf_ctx = (ggml_backend_vk_buffer_context *)q->buffer->context; ++ ggml_backend_vk_buffer_context * k_buf_ctx = (ggml_backend_vk_buffer_context *)k->buffer->context; ++ ggml_backend_vk_buffer_context * v_buf_ctx = (ggml_backend_vk_buffer_context *)v->buffer->context; ++ ++ if (!Q_uma) { ++ d_Q = q_buf_ctx->dev_buffer; ++ q_buf_offset = vk_tensor_offset(q) + q->view_offs; ++ } ++ if (!K_uma) { ++ d_K = k_buf_ctx->dev_buffer; ++ k_buf_offset = vk_tensor_offset(k) + k->view_offs; ++ } ++ if (!V_uma) { ++ d_V = v_buf_ctx->dev_buffer; ++ v_buf_offset = vk_tensor_offset(v) + v->view_offs; ++ } ++ if (!D_uma) { ++ d_D = d_buf_ctx->dev_buffer; ++ d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; ++ } ++ ++ if (!M_uma) { ++ d_M = d_Q; ++ m_buf_offset = q_buf_offset; ++ if (mask) { ++ ggml_backend_vk_buffer_context * m_buf_ctx = (ggml_backend_vk_buffer_context*)mask->buffer->context; ++ d_M = m_buf_ctx->dev_buffer; ++ m_buf_offset = vk_tensor_offset(mask) + mask->view_offs; ++ } ++ } ++ ++ const vk_flash_attn_push_constants pc = { N, KV, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, (uint32_t)neq2, (uint32_t)neq3, (uint32_t)nek2, (uint32_t)nek3, (uint32_t)nev2, (uint32_t)nev3, nem1, (uint32_t)nbq2, (uint32_t)nbq3, (uint32_t)nbk2, (uint32_t)nbk3, (uint32_t)nbv2, (uint32_t)nbv3, nbm1, scale, max_bias, logit_softcap, mask != nullptr, n_head_log2, m0, m1 }; ++ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, ++ { ++ vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE}, ++ vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE}, ++ vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE}, ++ vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE}, ++ vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE}, ++ }, ++ sizeof(vk_flash_attn_push_constants), &pc, { (uint32_t)neq1, (uint32_t)neq2, (uint32_t)neq3 }); ++} ++ ++static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) { ++ switch (op) { ++ case GGML_OP_GET_ROWS: ++ GGML_ASSERT(src1->type == GGML_TYPE_I32); ++ if (dst->type == GGML_TYPE_F16) { ++ return ctx->device->pipeline_get_rows[src0->type]; ++ } ++ if (dst->type == GGML_TYPE_F32) { ++ return ctx->device->pipeline_get_rows_f32[src0->type]; ++ } ++ return nullptr; ++ case GGML_OP_ACC: ++ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { ++ return ctx->device->pipeline_acc_f32; ++ } ++ return nullptr; ++ case GGML_OP_ADD: ++ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { ++ return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_f32_norepeat : ctx->device->pipeline_add_f32; ++ } ++ if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) { ++ return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_f16_f32_f16_norepeat : ctx->device->pipeline_add_f16_f32_f16; ++ } ++ return nullptr; ++ case GGML_OP_MUL: ++ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { ++ return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_mul_f32_norepeat : ctx->device->pipeline_mul_f32; ++ } ++ return nullptr; ++ case GGML_OP_DIV: ++ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { ++ return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_div_f32_norepeat : ctx->device->pipeline_div_f32; ++ } ++ return nullptr; ++ case GGML_OP_CONCAT: ++ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { ++ return ctx->device->pipeline_concat_f32; ++ } ++ if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { ++ return ctx->device->pipeline_concat_f16; ++ } ++ if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) { ++ return ctx->device->pipeline_concat_i32; ++ } ++ return nullptr; ++ case GGML_OP_UPSCALE: ++ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { ++ return ctx->device->pipeline_upscale_f32; ++ } ++ return nullptr; ++ case GGML_OP_SCALE: ++ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { ++ return ctx->device->pipeline_scale_f32; ++ } ++ return nullptr; ++ case GGML_OP_SQR: ++ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { ++ return ctx->device->pipeline_sqr_f32; ++ } ++ return nullptr; ++ case GGML_OP_SIN: ++ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { ++ return ctx->device->pipeline_sin_f32; ++ } ++ return nullptr; ++ case GGML_OP_COS: ++ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { ++ return ctx->device->pipeline_cos_f32; ++ } ++ return nullptr; ++ case GGML_OP_CLAMP: ++ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { ++ return ctx->device->pipeline_clamp_f32; ++ } ++ return nullptr; ++ case GGML_OP_PAD: ++ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { ++ return ctx->device->pipeline_pad_f32; ++ } ++ return nullptr; ++ case GGML_OP_REPEAT: ++ if (ggml_type_size(src0->type) == sizeof(float) && ggml_type_size(dst->type) == sizeof(float)) { ++ return ctx->device->pipeline_repeat_f32; ++ } ++ return nullptr; ++ case GGML_OP_CPY: ++ case GGML_OP_CONT: ++ case GGML_OP_DUP: ++ return ggml_vk_get_cpy_pipeline(ctx, src0, dst, dst->type); ++ case GGML_OP_NORM: ++ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { ++ return ctx->device->pipeline_norm_f32; ++ } ++ return nullptr; ++ case GGML_OP_GROUP_NORM: ++ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { ++ return ctx->device->pipeline_group_norm_f32; ++ } ++ return nullptr; ++ case GGML_OP_RMS_NORM: ++ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { ++ return ctx->device->pipeline_rms_norm_f32; ++ } ++ return nullptr; ++ case GGML_OP_UNARY: ++ switch (ggml_get_unary_op(dst)) { ++ case GGML_UNARY_OP_SILU: ++ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { ++ return ctx->device->pipeline_silu_f32; ++ } ++ break; ++ case GGML_UNARY_OP_GELU: ++ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { ++ return ctx->device->pipeline_gelu_f32; ++ } ++ break; ++ case GGML_UNARY_OP_GELU_QUICK: ++ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { ++ return ctx->device->pipeline_gelu_quick_f32; ++ } ++ break; ++ case GGML_UNARY_OP_RELU: ++ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { ++ return ctx->device->pipeline_relu_f32; ++ } ++ break; ++ case GGML_UNARY_OP_TANH: ++ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { ++ return ctx->device->pipeline_tanh_f32; ++ } ++ break; ++ default: ++ break; ++ } ++ return nullptr; ++ case GGML_OP_DIAG_MASK_INF: ++ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { ++ return ctx->device->pipeline_diag_mask_inf_f32; ++ } ++ return nullptr; ++ case GGML_OP_SOFT_MAX: ++ GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); ++ ++ if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) { ++ return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_wg512 : ctx->device->pipeline_soft_max_f32; ++ } ++ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { ++ return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_f16_wg512 : ctx->device->pipeline_soft_max_f32_f16; ++ } ++ return nullptr; ++ case GGML_OP_ROPE: ++ { ++ const int mode = ((const int32_t *) dst->op_params)[2]; ++ const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; ++ ++ if (is_neox) { ++ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { ++ return ctx->device->pipeline_rope_neox_f32; ++ } ++ if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { ++ return ctx->device->pipeline_rope_neox_f16; ++ } ++ } else { ++ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { ++ return ctx->device->pipeline_rope_norm_f32; ++ } ++ if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { ++ return ctx->device->pipeline_rope_norm_f16; ++ } ++ } ++ return nullptr; ++ } ++ case GGML_OP_ARGSORT: ++ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) { ++ return ctx->device->pipeline_argsort_f32; ++ } ++ return nullptr; ++ case GGML_OP_SUM_ROWS: ++ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { ++ return ctx->device->pipeline_sum_rows_f32; ++ } ++ return nullptr; ++ case GGML_OP_IM2COL: ++ if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { ++ return ctx->device->pipeline_im2col_f32; ++ } ++ if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) { ++ return ctx->device->pipeline_im2col_f32_f16; ++ } ++ return nullptr; ++ case GGML_OP_TIMESTEP_EMBEDDING: ++ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { ++ return ctx->device->pipeline_timestep_embedding_f32; ++ } ++ return nullptr; ++ case GGML_OP_POOL_2D: ++ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { ++ return ctx->device->pipeline_pool2d_f32; ++ } ++ return nullptr; ++ case GGML_OP_RWKV_WKV6: ++ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { ++ return ctx->device->pipeline_rwkv_wkv6_f32; ++ } ++ return nullptr; ++ case GGML_OP_LEAKY_RELU: ++ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { ++ return ctx->device->pipeline_leaky_relu_f32; ++ } ++ return nullptr; ++ default: ++ return nullptr; ++ } ++ ++ GGML_UNUSED(src2); ++} ++ ++static bool ggml_vk_op_supports_incontiguous(ggml_op op) { ++ switch (op) { ++ case GGML_OP_CPY: ++ case GGML_OP_GET_ROWS: ++ case GGML_OP_ADD: ++ case GGML_OP_MUL: ++ case GGML_OP_DIV: ++ case GGML_OP_CONCAT: ++ case GGML_OP_UPSCALE: ++ case GGML_OP_SQR: ++ case GGML_OP_SIN: ++ case GGML_OP_COS: ++ case GGML_OP_CLAMP: ++ case GGML_OP_PAD: ++ case GGML_OP_REPEAT: ++ return true; ++ default: ++ return false; ++ } ++} ++ ++static uint32_t get_misalign_bytes(ggml_backend_vk_context * ctx, const ggml_tensor * t) ++{ ++ return ((vk_tensor_offset(t) + t->view_offs) & (ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1));; ++} ++ ++template void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, T &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { ++ GGML_UNUSED(p); ++ GGML_UNUSED(src0); ++ GGML_UNUSED(src1); ++ GGML_UNUSED(src2); ++ GGML_UNUSED(dst); ++ static_assert(!std::is_const::value, "unexpected type"); ++ GGML_ASSERT(!src0 || get_misalign_bytes(ctx, src0) == 0); ++ GGML_ASSERT(!src1 || get_misalign_bytes(ctx, src1) == 0); ++ GGML_ASSERT(!src2 || get_misalign_bytes(ctx, src2) == 0); ++ GGML_ASSERT(!dst || get_misalign_bytes(ctx, dst) == 0); ++} ++ ++template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_unary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { ++ const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); ++ const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); ++ ++ p.misalign_offsets = (a_offset << 16) | d_offset; ++ ++ GGML_UNUSED(src1); ++ GGML_UNUSED(src2); ++} ++ ++template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_binary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { ++ const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); ++ const uint32_t b_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type); ++ const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); ++ ++ GGML_ASSERT(dst->op != GGML_OP_GET_ROWS || (a_offset == 0 && b_offset == 0 && d_offset == 0)); ++ ++ p.misalign_offsets = (a_offset << 16) | (b_offset << 8) | d_offset; ++ ++ GGML_UNUSED(src2); ++} ++ ++template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_upscale_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { ++ const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); ++ const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); ++ ++ p.a_offset = a_offset; ++ p.d_offset = d_offset; ++ ++ GGML_UNUSED(src1); ++ GGML_UNUSED(src2); ++} ++ ++template ++static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op, PC&& pc, bool dryrun = false) { ++ VK_LOG_DEBUG("ggml_vk_op_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; ++ if (src1 != nullptr) { ++ std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; ++ } ++ if (src2 != nullptr) { ++ std::cerr << "), (" << src2 << ", name=" << src2->name << ", type=" << src2->type << ", ne0=" << src2->ne[0] << ", ne1=" << src2->ne[1] << ", ne2=" << src2->ne[2] << ", ne3=" << src2->ne[3] << ", nb0=" << src2->nb[0] << ", nb1=" << src2->nb[1] << ", nb2=" << src2->nb[2] << ", nb3=" << src2->nb[3]; ++ } ++ std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; ++ std::cerr << "), " << ggml_op_name(op) << ", " << (dryrun ? "dryrun" : "") << ")"); ++ GGML_ASSERT(op == GGML_OP_GET_ROWS || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type)))); // NOLINT ++ GGML_ASSERT(ggml_vk_op_supports_incontiguous(op) || ggml_vk_dim01_contiguous(src0)); // NOLINT ++ GGML_ASSERT(dst->buffer != nullptr); ++ const uint64_t ne00 = src0->ne[0]; ++ const uint64_t ne01 = src0->ne[1]; ++ const uint64_t ne02 = src0->ne[2]; ++ const uint64_t ne03 = src0->ne[3]; ++ const uint64_t ne0 = ne00 * ne01; ++ ++ const bool use_src1 = src1 != nullptr; ++ const uint64_t ne10 = use_src1 ? src1->ne[0] : 0; ++ const uint64_t ne11 = use_src1 ? src1->ne[1] : 0; ++ const uint64_t ne12 = use_src1 ? src1->ne[2] : 0; ++ const uint64_t ne13 = use_src1 ? src1->ne[3] : 0; ++ const uint64_t ne1 = ne10 * ne11; ++ // const uint64_t nb10 = use_src1 ? src1->nb[0] : 0; ++ ++ const bool use_src2 = src2 != nullptr; ++ const uint64_t ne20 = use_src2 ? src2->ne[0] : 0; ++ const uint64_t ne21 = use_src2 ? src2->ne[1] : 0; ++ const uint64_t ne22 = use_src2 ? src2->ne[2] : 0; ++ const uint64_t ne23 = use_src2 ? src2->ne[3] : 0; ++ const uint64_t ne2 = ne20 * ne21; ++ ++ const uint64_t ned0 = dst->ne[0]; ++ const uint64_t ned1 = dst->ne[1]; ++ const uint64_t ned2 = dst->ne[2]; ++ const uint64_t ned3 = dst->ne[3]; ++ const uint64_t ned = ned0 * ned1; ++ ++ init_pushconst_fastdiv(pc); ++ ++ vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, dst, op); ++ ++ if (pipeline == nullptr) { ++ std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(op) << " for " << ggml_type_name(src0->type); ++ if (src1 != nullptr) { ++ std::cerr << " and " << ggml_type_name(src1->type); ++ } ++ std::cerr << " to " << ggml_type_name(dst->type) << std::endl; ++ GGML_ABORT("fatal error"); ++ } ++ ++ if (dryrun) { ++ ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); ++ return; ++ } ++ ++ const bool op_supports_incontiguous = ggml_vk_op_supports_incontiguous(op); ++ ++ ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; ++ ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; ++ ggml_backend_vk_buffer_context * src1_buf_ctx = use_src1 ? (ggml_backend_vk_buffer_context *)src1->buffer->context : nullptr; ++ ggml_backend_vk_buffer_context * src2_buf_ctx = use_src2 ? (ggml_backend_vk_buffer_context *)src2->buffer->context : nullptr; ++ ++ vk_buffer d_X = nullptr; ++ size_t x_buf_offset = 0; ++ vk_buffer d_Y = nullptr; ++ size_t y_buf_offset = 0; ++ vk_buffer d_Z = nullptr; ++ size_t z_buf_offset = 0; ++ ++ bool src0_uma = false; ++ bool src1_uma = false; ++ bool src2_uma = false; ++ ++ if (ctx->device->uma) { ++ ggml_vk_host_get(ctx->device, src0->data, d_X, x_buf_offset); ++ src0_uma = d_X != nullptr; ++ if (use_src1) { ++ ggml_vk_host_get(ctx->device, src1->data, d_Y, y_buf_offset); ++ src1_uma = d_Y != nullptr; ++ } ++ if (use_src2) { ++ ggml_vk_host_get(ctx->device, src2->data, d_Z, z_buf_offset); ++ src2_uma = d_Z != nullptr; ++ } ++ } ++ ++ uint64_t x_sz = ggml_type_size(src0->type)/ggml_blck_size(src0->type) * ne0; ++ uint64_t y_sz = use_src1 ? ggml_type_size(src1->type) * ne1 : 0; ++ uint64_t z_sz = use_src2 ? ggml_type_size(src2->type) * ne2 : 0; ++ uint64_t d_sz = ggml_type_size(dst->type) * ned; ++ ++ vk_buffer d_D = dst_buf_ctx->dev_buffer; ++ ++ // Workaround for tiny tensor inputs on ROPE ++ if (op == GGML_OP_ROPE && use_src1 && y_sz > d_D->size) { ++ y_sz = VK_WHOLE_SIZE; ++ } ++ ++ GGML_ASSERT(d_D != nullptr); ++ uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; ++ if(!src0_uma) { ++ d_X = src0_buf_ctx->dev_buffer; ++ x_buf_offset = vk_tensor_offset(src0) + src0->view_offs; ++ GGML_ASSERT(d_X != nullptr); ++ } ++ if (use_src1 && !src1_uma) { ++ d_Y = src1_buf_ctx->dev_buffer; ++ y_buf_offset = vk_tensor_offset(src1) + src1->view_offs; ++ GGML_ASSERT(d_Y != nullptr); ++ } ++ if (use_src2 && !src2_uma) { ++ d_Z = src2_buf_ctx->dev_buffer; ++ z_buf_offset = vk_tensor_offset(src2) + src2->view_offs; ++ GGML_ASSERT(d_Z != nullptr); ++ } ++ // Compute misalignment offset for descriptors and store it in in push constants, then align the descriptor offsets. ++ init_pushconst_tensor_offsets(ctx, pc, src0, src1, src2, dst); ++ x_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); ++ y_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); ++ z_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); ++ d_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); ++ ++ if (op_supports_incontiguous) { ++ x_sz = ggml_nbytes(src0); ++ y_sz = use_src1 ? ggml_nbytes(src1) : 0; ++ z_sz = use_src2 ? ggml_nbytes(src2) : 0; ++ d_sz = ggml_nbytes(dst); ++ ++ if (x_buf_offset + x_sz >= d_X->size) { ++ x_sz = VK_WHOLE_SIZE; ++ } ++ if (use_src1 && y_buf_offset + y_sz >= d_Y->size) { ++ y_sz = VK_WHOLE_SIZE; ++ } ++ if (use_src2 && z_buf_offset + z_sz >= d_Z->size) { ++ z_sz = VK_WHOLE_SIZE; ++ } ++ if (d_buf_offset + d_sz >= d_D->size) { ++ d_sz = VK_WHOLE_SIZE; ++ } ++ } ++ ++ std::array elements; ++ ++ // Single call if dimension 2 is contiguous ++ GGML_ASSERT(op_supports_incontiguous || (ggml_is_contiguous(src0) && (src1 == nullptr || ggml_is_contiguous(src1)))); ++ ++ switch (op) { ++ case GGML_OP_NORM: ++ case GGML_OP_RMS_NORM: ++ case GGML_OP_SOFT_MAX: ++ case GGML_OP_SUM_ROWS: ++ { ++ const uint32_t nr = ggml_nrows(src0); ++ if (nr > 262144) { ++ elements = { 512, 512, CEIL_DIV(nr, 262144) }; ++ } else if (nr > 512) { ++ elements = { 512, CEIL_DIV(nr, 512), 1 }; ++ } else { ++ elements = { nr, 1, 1 }; ++ } ++ } break; ++ case GGML_OP_GROUP_NORM: ++ { ++ const uint32_t num_groups = dst->op_params[0]; ++ elements = { num_groups * (uint32_t)src0->ne[3], 1, 1 }; ++ } break; ++ case GGML_OP_DIAG_MASK_INF: ++ case GGML_OP_ROPE: ++ elements = { (uint32_t)ggml_nrows(src0), (uint32_t)ne00, 1 }; ++ break; ++ case GGML_OP_GET_ROWS: ++ elements = { (uint32_t)ne00, (uint32_t)ne10, (uint32_t)(ne11 * ne12) }; ++ break; ++ case GGML_OP_ARGSORT: ++ elements = { (uint32_t)ne00, (uint32_t)ggml_nrows(src0), 1 }; ++ break; ++ case GGML_OP_IM2COL: ++ { ++ const bool is_2D = dst->op_params[6] == 1; ++ ++ const uint32_t IC = src1->ne[is_2D ? 2 : 1]; ++ ++ const uint32_t KH = is_2D ? src0->ne[1] : 1; ++ const uint32_t KW = src0->ne[0]; ++ ++ const uint32_t OH = is_2D ? dst->ne[2] : 1; ++ const uint32_t OW = dst->ne[1]; ++ ++ const uint32_t batch = src1->ne[is_2D ? 3 : 2]; ++ ++ elements = { OW * KW * KH, OH, batch * IC }; ++ } break; ++ case GGML_OP_TIMESTEP_EMBEDDING: ++ { ++ const uint32_t dim = dst->op_params[0]; ++ uint32_t half_ceil = (dim + 1) / 2; ++ elements = { half_ceil, (uint32_t)src0->ne[0], 1 }; ++ } break; ++ case GGML_OP_POOL_2D: ++ { ++ const uint32_t N = dst->ne[3]; ++ const uint32_t OC = dst->ne[2]; ++ const uint32_t OH = dst->ne[1]; ++ const uint32_t OW = dst->ne[0]; ++ elements = { N * OC * OH * OW, 1, 1}; ++ } break; ++ case GGML_OP_ADD: ++ case GGML_OP_DIV: ++ case GGML_OP_MUL: ++ case GGML_OP_SCALE: ++ case GGML_OP_SQR: ++ case GGML_OP_SIN: ++ case GGML_OP_COS: ++ case GGML_OP_CLAMP: ++ case GGML_OP_PAD: ++ case GGML_OP_REPEAT: ++ case GGML_OP_CPY: ++ case GGML_OP_CONCAT: ++ case GGML_OP_UPSCALE: ++ case GGML_OP_UNARY: ++ { ++ const uint32_t ne = ggml_nelements(dst); ++ if (ne > 262144) { ++ elements = { 512, 512, CEIL_DIV(ne, 262144) }; ++ } else if (ne > 512) { ++ elements = { 512, CEIL_DIV(ne, 512), 1 }; ++ } else { ++ elements = { ne, 1, 1 }; ++ } ++ } break; ++ default: ++ elements = { (uint32_t)ggml_nelements(src0), 1, 1 }; ++ break; ++ } ++ ++ if (!op_supports_incontiguous) { ++ if (x_sz != VK_WHOLE_SIZE) { ++ x_sz *= ne02 * ne03; ++ } ++ if (use_src1 && y_sz != VK_WHOLE_SIZE) { ++ y_sz *= ne12 * ne13; ++ } ++ if (use_src2 && z_sz != VK_WHOLE_SIZE) { ++ z_sz *= ne22 * ne23; ++ } ++ if (d_sz != VK_WHOLE_SIZE) { ++ d_sz *= ned2 * ned3; ++ } ++ } ++ ++ if (op == GGML_OP_SOFT_MAX) { ++ // Empty src1 is possible in soft_max, but the shader needs a buffer ++ vk_subbuffer subbuf_y; ++ if (use_src1) { ++ subbuf_y = { d_Y, y_buf_offset, y_sz }; ++ } else { ++ subbuf_y = { d_X, 0, x_sz }; ++ } ++ ++ ggml_vk_sync_buffers(subctx); ++ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); ++ } else if (op == GGML_OP_ROPE) { ++ // Empty src2 is possible in rope, but the shader needs a buffer ++ vk_subbuffer subbuf_z; ++ if (use_src2) { ++ subbuf_z = { d_Z, z_buf_offset, z_sz }; ++ } else { ++ subbuf_z = { d_X, 0, x_sz }; ++ } ++ ++ ggml_vk_sync_buffers(subctx); ++ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); ++ } else if (op == GGML_OP_IM2COL) { ++ // im2col uses only src1 and dst buffers ++ ggml_vk_sync_buffers(subctx); ++ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); ++ } else if (use_src2) { ++ ggml_vk_sync_buffers(subctx); ++ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); ++ } else if (use_src1) { ++ ggml_vk_sync_buffers(subctx); ++ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); ++ } else { ++ ggml_vk_sync_buffers(subctx); ++ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); ++ } ++} ++ ++static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { ++ const uint32_t src0_type_size = ggml_type_size(src0->type); ++ const uint32_t src1_type_size = ggml_type_size(src1->type); ++ const uint32_t dst_type_size = ggml_type_size(dst->type); ++ ++ ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GET_ROWS, { ++ (uint32_t)ggml_nelements(src0), ++ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, ++ (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, ++ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, ++ 0, ++ 0.0f, 0.0f, 0, ++ }, dryrun); ++} ++ ++static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { ++ const uint32_t src0_type_size = ggml_type_size(src0->type); ++ const uint32_t src1_type_size = ggml_type_size(src1->type); ++ const uint32_t dst_type_size = ggml_type_size(dst->type); ++ ++ int nb1 = dst->op_params[0] / 4; // 4 bytes of float32 ++ int nb2 = dst->op_params[1] / 4; // 4 bytes of float32 ++ // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused ++ int offset = dst->op_params[3] / 4; // offset in bytes ++ ++ ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_ACC, { ++ (uint32_t)ggml_nelements(src0), ++ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)src0->nb[3] / src0_type_size, ++ (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, ++ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t) dst->nb[3] / dst_type_size, ++ 0, ++ 0.0f, 0.0f, offset, ++ }, dryrun); ++} ++ ++static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { ++ const uint32_t src0_type_size = ggml_type_size(src0->type); ++ const uint32_t src1_type_size = ggml_type_size(src1->type); ++ const uint32_t dst_type_size = ggml_type_size(dst->type); ++ ++ ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_ADD, { ++ (uint32_t)ggml_nelements(src0), ++ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, ++ (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, ++ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, ++ 0, ++ 0.0f, 0.0f, 0, ++ }, dryrun); ++} ++ ++static void ggml_vk_mul(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { ++ const uint32_t src0_type_size = ggml_type_size(src0->type); ++ const uint32_t src1_type_size = ggml_type_size(src1->type); ++ const uint32_t dst_type_size = ggml_type_size(dst->type); ++ ++ ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_MUL, { ++ (uint32_t)ggml_nelements(src0), ++ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, ++ (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, ++ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, ++ 0, ++ 0.0f, 0.0f, 0, ++ }, dryrun); ++} ++ ++static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { ++ const uint32_t src0_type_size = ggml_type_size(src0->type); ++ const uint32_t src1_type_size = ggml_type_size(src1->type); ++ const uint32_t dst_type_size = ggml_type_size(dst->type); ++ ++ ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_DIV, { ++ (uint32_t)ggml_nelements(src0), ++ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, ++ (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, ++ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, ++ 0, ++ 0.0f, 0.0f, 0, ++ }, dryrun); ++} ++ ++static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, bool dryrun = false) { ++ const ggml_tensor * k = dst->src[0]; ++ const ggml_tensor * v = dst->src[1]; ++ const ggml_tensor * r = dst->src[2]; ++ const ggml_tensor * tf = dst->src[3]; ++ const ggml_tensor * td = dst->src[4]; ++ const ggml_tensor * state = dst->src[5]; ++ ++ GGML_ASSERT(!ggml_is_quantized(k->type)); ++ GGML_ASSERT(!ggml_is_quantized(v->type)); ++ GGML_ASSERT(!ggml_is_quantized(r->type)); ++ GGML_ASSERT(!ggml_is_quantized(tf->type)); ++ GGML_ASSERT(!ggml_is_quantized(td->type)); ++ GGML_ASSERT(!ggml_is_quantized(state->type)); ++ GGML_ASSERT(dst->buffer != nullptr); ++ ++ vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, k, v, r, dst, GGML_OP_RWKV_WKV6); ++ GGML_ASSERT(pipeline != nullptr); ++ ++ if (dryrun) { ++ ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); ++ return; ++ } ++ ++ ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; ++ ggml_backend_vk_buffer_context * k_buf_ctx = (ggml_backend_vk_buffer_context *)k->buffer->context; ++ ggml_backend_vk_buffer_context * v_buf_ctx = (ggml_backend_vk_buffer_context *)v->buffer->context; ++ ggml_backend_vk_buffer_context * r_buf_ctx = (ggml_backend_vk_buffer_context *)r->buffer->context; ++ ggml_backend_vk_buffer_context * tf_buf_ctx = (ggml_backend_vk_buffer_context *)tf->buffer->context; ++ ggml_backend_vk_buffer_context * td_buf_ctx = (ggml_backend_vk_buffer_context *)td->buffer->context; ++ ggml_backend_vk_buffer_context * state_buf_ctx = (ggml_backend_vk_buffer_context *)state->buffer->context; ++ ++ ggml_vk_sync_buffers(subctx); ++ ++ vk_buffer d_D = nullptr, d_K = nullptr, d_V = nullptr, d_R = nullptr, d_TF = nullptr, d_TD = nullptr, d_State = nullptr; ++ size_t k_offset = 0, v_offset = 0, r_offset = 0, tf_offset = 0, td_offset = 0, state_offset = 0, dst_offset = 0; ++ bool K_uma = false, V_uma = false, R_uma = false, TF_uma = false, TD_uma = false, STATE_uma = false, DST_uma = false; ++ ++ if (ctx->device->uma) { ++ ggml_vk_host_get(ctx->device, k->data, d_K, k_offset); ++ ggml_vk_host_get(ctx->device, v->data, d_V, v_offset); ++ ggml_vk_host_get(ctx->device, r->data, d_R, r_offset); ++ ggml_vk_host_get(ctx->device, tf->data, d_TF, tf_offset); ++ ggml_vk_host_get(ctx->device, td->data, d_TD, td_offset); ++ ggml_vk_host_get(ctx->device, state->data, d_State, state_offset); ++ ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset); ++ ++ K_uma = d_K != nullptr; ++ V_uma = d_V != nullptr; ++ R_uma = d_R != nullptr; ++ TF_uma = d_TF != nullptr; ++ TD_uma = d_TD != nullptr; ++ STATE_uma = d_State != nullptr; ++ DST_uma = d_D != nullptr; ++ } ++ ++ if (!K_uma) { ++ d_K = k_buf_ctx->dev_buffer; ++ k_offset = vk_tensor_offset(k) + k->view_offs; ++ } ++ if (!V_uma) { ++ d_V = v_buf_ctx->dev_buffer; ++ v_offset = vk_tensor_offset(v) + v->view_offs; ++ } ++ if (!R_uma) { ++ d_R = r_buf_ctx->dev_buffer; ++ r_offset = vk_tensor_offset(r) + r->view_offs; ++ } ++ if (!TF_uma) { ++ d_TF = tf_buf_ctx->dev_buffer; ++ tf_offset = vk_tensor_offset(tf) + tf->view_offs; ++ } ++ if (!TD_uma) { ++ d_TD = td_buf_ctx->dev_buffer; ++ td_offset = vk_tensor_offset(td) + td->view_offs; ++ } ++ if (!STATE_uma) { ++ d_State = state_buf_ctx->dev_buffer; ++ state_offset = vk_tensor_offset(state) + state->view_offs; ++ } ++ if (!DST_uma) { ++ d_D = dst_buf_ctx->dev_buffer; ++ dst_offset = vk_tensor_offset(dst) + dst->view_offs; ++ } ++ ++ const uint64_t k_size = ggml_nbytes(k); ++ const uint64_t v_size = ggml_nbytes(v); ++ const uint64_t r_size = ggml_nbytes(r); ++ const uint64_t tf_size = ggml_nbytes(tf); ++ const uint64_t td_size = ggml_nbytes(td); ++ const uint64_t state_size = ggml_nbytes(state); ++ const uint64_t dst_size = ggml_nbytes(dst); ++ ++ std::array elements = { ++ (uint32_t)(pc.B * pc.H), ++ 1, ++ 1 ++ }; ++ ++ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { ++ vk_subbuffer{ d_K, k_offset, k_size }, ++ vk_subbuffer{ d_V, v_offset, v_size }, ++ vk_subbuffer{ d_R, r_offset, r_size }, ++ vk_subbuffer{ d_TF, tf_offset, tf_size }, ++ vk_subbuffer{ d_TD, td_offset, td_size }, ++ vk_subbuffer{ d_State, state_offset, state_size }, ++ vk_subbuffer{ d_D, dst_offset, dst_size } ++ }, sizeof(vk_op_rwkv_wkv6_push_constants), &pc, elements); ++} ++ ++static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) { ++ const size_t seq_length = dst->src[0]->ne[3]; ++ const size_t n_embed = dst->ne[0]; ++ const size_t n_heads = dst->src[0]->ne[2]; ++ const size_t n_seqs = dst->src[5]->ne[1]; ++ ++ ggml_vk_op_f32_rwkv6( ++ ctx, subctx, dst, ++ { ++ (uint32_t)n_seqs, ++ (uint32_t)seq_length, ++ (uint32_t)n_embed, ++ (uint32_t)n_heads, ++ }, ++ dryrun ++ ); ++} ++ ++static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { ++ int * op_params = (int *)dst->op_params; ++ ++ const uint32_t src0_type_size = ggml_type_size(src0->type); ++ const uint32_t src1_type_size = ggml_type_size(src1->type); ++ const uint32_t dst_type_size = ggml_type_size(dst->type); ++ ++ ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONCAT, { ++ (uint32_t)ggml_nelements(dst), ++ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, ++ (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, ++ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, ++ 0, ++ 0.0f, 0.0f, op_params[0], ++ }, dryrun); ++} ++ ++static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { ++ const uint32_t src0_type_size = ggml_type_size(src0->type); ++ ++ const float sf0 = (float)dst->ne[0] / src0->ne[0]; ++ const float sf1 = (float)dst->ne[1] / src0->ne[1]; ++ const float sf2 = (float)dst->ne[2] / src0->ne[2]; ++ const float sf3 = (float)dst->ne[3] / src0->ne[3]; ++ ++ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UPSCALE, { ++ (uint32_t)ggml_nelements(dst), 0, 0, ++ (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, ++ (uint32_t)dst->ne[0], (uint32_t)dst->ne[1], (uint32_t)dst->ne[2],(uint32_t)dst->ne[3], ++ sf0, sf1, sf2, sf3, ++ }, dryrun); ++} ++ ++static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { ++ float * op_params = (float *)dst->op_params; ++ const uint32_t src0_type_size = ggml_type_size(src0->type); ++ const uint32_t dst_type_size = ggml_type_size(dst->type); ++ ++ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SCALE, { ++ (uint32_t)ggml_nelements(src0), ++ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, ++ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, ++ 0, ++ op_params[0], 0.0f, ++ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ++ }, dryrun); ++} ++ ++static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { ++ const uint32_t src0_type_size = ggml_type_size(src0->type); ++ const uint32_t dst_type_size = ggml_type_size(dst->type); ++ ++ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQR, { ++ (uint32_t)ggml_nelements(src0), ++ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, ++ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, ++ 0, ++ 0.0f, 0.0f, ++ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ++ }, dryrun); ++} ++ ++static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { ++ const uint32_t src0_type_size = ggml_type_size(src0->type); ++ const uint32_t dst_type_size = ggml_type_size(dst->type); ++ ++ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SIN, { ++ (uint32_t)ggml_nelements(src0), ++ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, ++ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, ++ 0, ++ 0.0f, 0.0f, ++ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ++ }, dryrun); ++} ++ ++static void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { ++ const uint32_t src0_type_size = ggml_type_size(src0->type); ++ const uint32_t dst_type_size = ggml_type_size(dst->type); ++ ++ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_COS, { ++ (uint32_t)ggml_nelements(src0), ++ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, ++ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, ++ 0, ++ 0.0f, 0.0f, ++ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ++ }, dryrun); ++} ++ ++static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { ++ float * op_params = (float *)dst->op_params; ++ const uint32_t src0_type_size = ggml_type_size(src0->type); ++ const uint32_t dst_type_size = ggml_type_size(dst->type); ++ ++ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CLAMP, { ++ (uint32_t)ggml_nelements(src0), ++ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, ++ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, ++ 0, ++ op_params[0], op_params[1], ++ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ++ }, dryrun); ++} ++ ++static void ggml_vk_pad(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { ++ const uint32_t src0_type_size = ggml_type_size(src0->type); ++ const uint32_t dst_type_size = ggml_type_size(dst->type); ++ ++ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_PAD, { ++ (uint32_t)ggml_nelements(dst), ++ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, ++ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, ++ 0, ++ 0.0f, 0.0f, ++ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ++ }, dryrun); ++} ++ ++static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { ++ const uint32_t src0_type_size = ggml_type_size(src0->type); ++ const uint32_t dst_type_size = ggml_type_size(dst->type); ++ ++ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT, { ++ (uint32_t)ggml_nelements(dst), ++ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, ++ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, ++ 0, ++ 0.0f, 0.0f, ++ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ++ }, dryrun); ++} ++ ++static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { ++ const uint32_t src0_type_size = ggml_type_size(src0->type); ++ const uint32_t dst_type_size = ggml_type_size(dst->type); ++ ++ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, { ++ (uint32_t)ggml_nelements(src0), ++ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, ++ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, ++ 0, ++ 0.0f, 0.0f, ++ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ++ }, dryrun); ++} ++ ++static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { ++ float * op_params = (float *)dst->op_params; ++ ++ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun); ++} ++ ++static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { ++ const int * int_op_params = (const int *)dst->op_params; ++ const float * float_op_params = (const float *)dst->op_params; ++ ++ const uint32_t num_groups = int_op_params[0]; ++ const float eps = float_op_params[1]; ++ const uint32_t group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups); ++ ++ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun); ++} ++ ++static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { ++ float * op_params = (float *)dst->op_params; ++ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun); ++} ++ ++static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { ++ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun); ++} ++ ++static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { ++ int32_t * op_params = (int32_t *)dst->op_params; ++ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }, dryrun); ++} ++ ++static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { ++ float * op_params = (float *)dst->op_params; ++ ++ float scale = op_params[0]; ++ float max_bias = op_params[1]; ++ ++ const uint32_t ncols = (uint32_t)src0->ne[0]; ++ const uint32_t nrows_x = (uint32_t)ggml_nrows(src0); ++ const uint32_t nrows_y = (uint32_t)src0->ne[1]; ++ ++ const uint32_t n_head_kv = nrows_x/nrows_y; ++ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv)); ++ ++ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); ++ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); ++ ++ ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, { ++ ncols, ++ src1 != nullptr ? nrows_y : (uint32_t)0, ++ scale, max_bias, ++ m0, m1, ++ n_head_log2, ++ nrows_x, ++ }, dryrun); ++} ++ ++static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) { ++ const int n_dims = ((int32_t *) dst->op_params)[1]; ++ // const int mode = ((int32_t *) dst->op_params)[2]; ++ // const int n_ctx = ((int32_t *) dst->op_params)[3]; ++ const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; ++ const float freq_base = ((float *) dst->op_params)[5]; ++ const float freq_scale = ((float *) dst->op_params)[6]; ++ const float ext_factor = ((float *) dst->op_params)[7]; ++ const float attn_factor = ((float *) dst->op_params)[8]; ++ const float beta_fast = ((float *) dst->op_params)[9]; ++ const float beta_slow = ((float *) dst->op_params)[10]; ++ ++ float corr_dims[2]; ++ ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); ++ ++ const float theta_scale = powf(freq_base, -2.0f/n_dims); ++ ++ ggml_vk_op_f32(ctx, subctx, src0, src1, src2, dst, GGML_OP_ROPE, { ++ (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1], ++ freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale, ++ src2 != nullptr, ++ }, dryrun); ++} ++ ++static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { ++ int32_t * op_params = (int32_t *)dst->op_params; ++ ++ uint32_t ncols = src0->ne[0]; ++ ++ uint32_t ncols_pad = 1; ++ while (ncols_pad < ncols) { ++ ncols_pad *= 2; ++ } ++ ++ GGML_ASSERT(ncols_pad <= 1024); ++ ++ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ARGSORT, { ++ ncols, ++ ncols_pad, ++ op_params[0], ++ }, dryrun); ++} ++ ++static void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { ++ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, { (uint32_t)src0->ne[0], 0, 0.0f, 0.0f }, dryrun); ++} ++ ++static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { ++ const int32_t s0 = dst->op_params[0]; ++ const int32_t s1 = dst->op_params[1]; ++ const int32_t p0 = dst->op_params[2]; ++ const int32_t p1 = dst->op_params[3]; ++ const int32_t d0 = dst->op_params[4]; ++ const int32_t d1 = dst->op_params[5]; ++ ++ const bool is_2D = dst->op_params[6] == 1; ++ ++ const uint32_t IC = src1->ne[is_2D ? 2 : 1]; ++ const uint32_t IH = is_2D ? src1->ne[1] : 1; ++ const uint32_t IW = src1->ne[0]; ++ ++ const uint32_t KH = is_2D ? src0->ne[1] : 1; ++ const uint32_t KW = src0->ne[0]; ++ ++ const uint32_t OH = is_2D ? dst->ne[2] : 1; ++ const uint32_t OW = dst->ne[1]; ++ ++ const uint32_t offset_delta = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32 ++ const uint32_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32 ++ ++ const uint32_t pelements = OW * KW * KH; ++ ++ ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_IM2COL, { ++ batch_offset, offset_delta, ++ IC, IW, IH, OW, OH, KW, KH, ++ pelements, ++ IC * KH * KW, ++ s0, s1, p0, p1, d0, d1, ++ }, dryrun); ++} ++ ++static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { ++ const uint32_t dim = dst->op_params[0]; ++ const uint32_t max_period = dst->op_params[1]; ++ const uint32_t nb1 = dst->nb[1] / ggml_type_size(dst->type); ++ ++ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_TIMESTEP_EMBEDDING, { ++ nb1, dim, max_period, ++ }, dryrun); ++} ++ ++static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { ++ uint32_t op = static_cast(dst->op_params[0]); ++ const int32_t k1 = dst->op_params[1]; ++ const int32_t k0 = dst->op_params[2]; ++ const int32_t s1 = dst->op_params[3]; ++ const int32_t s0 = dst->op_params[4]; ++ const int32_t p1 = dst->op_params[5]; ++ const int32_t p0 = dst->op_params[6]; ++ ++ const uint32_t IH = src0->ne[1]; ++ const uint32_t IW = src0->ne[0]; ++ ++ const uint32_t N = dst->ne[3]; ++ ++ const uint32_t OC = dst->ne[2]; ++ const uint32_t OH = dst->ne[1]; ++ const uint32_t OW = dst->ne[0]; ++ ++ const uint32_t parallel_elements = N * OC * OH * OW; ++ ++ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_POOL_2D, { ++ IW, IH, OW, OH, OC, ++ parallel_elements, ++ op, ++ k0, k1, s0, s1, p0, p1, ++ }, dryrun); ++} ++ ++static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { ++ const float * op_params = (const float *)dst->op_params; ++ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f }, dryrun); ++} ++ ++#ifdef GGML_VULKAN_RUN_TESTS ++static void ggml_vk_print_matrix_area(const void * data, ggml_type type, int ne0, int ne1, int i0, int i1, int i2) { ++ if (type != GGML_TYPE_F32 && type != GGML_TYPE_F16) { ++ return; ++ } ++ i0 = std::max(i0, 5); ++ i1 = std::max(i1, 5); ++ i2 = std::max(i2, 0); ++ fprintf(stderr, " "); ++ for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) { ++ fprintf(stderr, "%7d ", idx1); ++ } ++ fprintf(stderr, "\n"); ++ for (int idx0 = i0 - 5; idx0 < i0 + 5; idx0++) { ++ fprintf(stderr, "%7d: ", idx0); ++ for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) { ++ if (idx0 >= 0 && idx0 < ne0 && idx1 >= 0 && idx1 < ne1) { ++ float val; ++ if (type == GGML_TYPE_F32) { ++ val = *((const float *) data + i2*ne1*ne0 + idx1*ne0 + idx0); ++ } else if (type == GGML_TYPE_F16) { ++ val = ggml_fp16_to_fp32(*((const ggml_fp16_t *) data + i2*ne1*ne0 + idx1*ne0 + idx0)); ++ } else { ++ GGML_ABORT("fatal error"); ++ } ++ fprintf(stderr, "% 7.2f ", val); ++ } else { ++ fprintf(stderr, " "); ++ } ++ } ++ fprintf(stderr, "\n"); ++ } ++} ++ ++template ++static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, int split_k, int shader_size) { ++ VK_LOG_DEBUG("ggml_vk_test_matmul(" << m << ", " << n << ", " << k << ", " << batch << ", " << num_it << ", " << split_k << ", " << shader_size << ")"); ++ const size_t x_ne = m * k * batch; ++ const size_t y_ne = k * n * batch; ++ const size_t d_ne = m * n * batch; ++ ++ vk_pipeline p; ++ std::string shname; ++ if (shader_size == 0) { ++ if (std::is_same() && std::is_same()) { ++ p = ctx->device->pipeline_matmul_f32->a_s; ++ shname = "F32_ALIGNED_S"; ++ } else if (std::is_same() && std::is_same()) { ++ p = ctx->device->pipeline_matmul_f32_f16->a_s; ++ shname = "F32_F16_ALIGNED_S"; ++ } else if (std::is_same() && std::is_same()) { ++ p = ctx->device->pipeline_matmul_f16_f32.f32acc->a_s; ++ shname = "F16_F32_ALIGNED_S"; ++ } else if (std::is_same() && std::is_same()) { ++ p = ctx->device->pipeline_matmul_f16.f32acc->a_s; ++ shname = "F16_ALIGNED_S"; ++ } else { ++ GGML_ABORT("fatal error"); ++ } ++ } else if (shader_size == 1) { ++ if (std::is_same() && std::is_same()) { ++ p = ctx->device->pipeline_matmul_f32->a_m; ++ shname = "F32_ALIGNED_M"; ++ } else if (std::is_same() && std::is_same()) { ++ p = ctx->device->pipeline_matmul_f32_f16->a_m; ++ shname = "F32_F16_ALIGNED_M"; ++ } else if (std::is_same() && std::is_same()) { ++ p = ctx->device->pipeline_matmul_f16_f32.f32acc->a_m; ++ shname = "F16_F32_ALIGNED_M"; ++ } else if (std::is_same() && std::is_same()) { ++ p = ctx->device->pipeline_matmul_f16.f32acc->a_m; ++ shname = "F16_ALIGNED_M"; ++ } else { ++ GGML_ABORT("fatal error"); ++ } ++ } else if (shader_size == 2) { ++ if (std::is_same() && std::is_same()) { ++ p = ctx->device->pipeline_matmul_f32->a_l; ++ shname = "F32_ALIGNED_L"; ++ } else if (std::is_same() && std::is_same()) { ++ p = ctx->device->pipeline_matmul_f32_f16->a_l; ++ shname = "F32_F16_ALIGNED_L"; ++ } else if (std::is_same() && std::is_same()) { ++ p = ctx->device->pipeline_matmul_f16_f32.f32acc->a_l; ++ shname = "F16_F32_ALIGNED_L"; ++ } else if (std::is_same() && std::is_same()) { ++ p = ctx->device->pipeline_matmul_f16.f32acc->a_l; ++ shname = "F16_ALIGNED_L"; ++ } else { ++ GGML_ABORT("fatal error"); ++ } ++ } else { ++ GGML_ASSERT(0); ++ } ++ ++ const size_t kpad = ggml_vk_align_size(k, p->align); ++ ++ if (k != kpad) { ++ if (shader_size == 0) { ++ if (std::is_same() && std::is_same()) { ++ p = ctx->device->pipeline_matmul_f32->s; ++ shname = "F32_S"; ++ } else if (std::is_same() && std::is_same()) { ++ p = ctx->device->pipeline_matmul_f32_f16->s; ++ shname = "F32_F16_S"; ++ } else if (std::is_same() && std::is_same()) { ++ p = ctx->device->pipeline_matmul_f16_f32.f32acc->s; ++ shname = "F16_F32_S"; ++ } else if (std::is_same() && std::is_same()) { ++ p = ctx->device->pipeline_matmul_f16.f32acc->s; ++ shname = "F16_S"; ++ } ++ } else if (shader_size == 1) { ++ if (std::is_same() && std::is_same()) { ++ p = ctx->device->pipeline_matmul_f32->m; ++ shname = "F32_M"; ++ } else if (std::is_same() && std::is_same()) { ++ p = ctx->device->pipeline_matmul_f32_f16->m; ++ shname = "F32_F16_M"; ++ } else if (std::is_same() && std::is_same()) { ++ p = ctx->device->pipeline_matmul_f16_f32.f32acc->m; ++ shname = "F16_F32_M"; ++ } else if (std::is_same() && std::is_same()) { ++ p = ctx->device->pipeline_matmul_f16.f32acc->m; ++ shname = "F16_M"; ++ } ++ } else if (shader_size == 2) { ++ if (std::is_same() && std::is_same()) { ++ p = ctx->device->pipeline_matmul_f32->l; ++ shname = "F32_L"; ++ } else if (std::is_same() && std::is_same()) { ++ p = ctx->device->pipeline_matmul_f32_f16->l; ++ shname = "F32_F16_L"; ++ } else if (std::is_same() && std::is_same()) { ++ p = ctx->device->pipeline_matmul_f16_f32.f32acc->l; ++ shname = "F16_F32_L"; ++ } else if (std::is_same() && std::is_same()) { ++ p = ctx->device->pipeline_matmul_f16.f32acc->l; ++ shname = "F16_L"; ++ } ++ } ++ } ++ ++ ggml_pipeline_request_descriptor_sets(ctx->device, p, num_it); ++ if (split_k > 1) { ++ ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, num_it); ++ ++ if (ctx->prealloc_split_k == nullptr || ctx->prealloc_split_k->size < sizeof(float) * d_ne * split_k) { ++ // Resize buffer ++ if (ctx->prealloc_split_k != nullptr) { ++ ggml_vk_destroy_buffer(ctx->prealloc_split_k); ++ } ++ ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, vk::MemoryPropertyFlagBits::eDeviceLocal); ++ } ++ } ++ ++ ggml_pipeline_allocate_descriptor_sets(ctx->device); ++ ++ vk_buffer d_X = ggml_vk_create_buffer_check(ctx->device, sizeof(X_TYPE) * x_ne, vk::MemoryPropertyFlagBits::eDeviceLocal); ++ vk_buffer d_Y = ggml_vk_create_buffer_check(ctx->device, sizeof(Y_TYPE) * y_ne, vk::MemoryPropertyFlagBits::eDeviceLocal); ++ vk_buffer d_D = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne, vk::MemoryPropertyFlagBits::eDeviceLocal); ++ ++ X_TYPE* x = (X_TYPE *) malloc(sizeof(X_TYPE) * x_ne); ++ Y_TYPE* y = (Y_TYPE *) malloc(sizeof(Y_TYPE) * y_ne); ++ float* d = (float *) malloc(sizeof(float) * d_ne); ++ ++ for (size_t i = 0; i < x_ne; i++) { ++ if (std::is_same()) { ++ x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f; ++ // x[i] = 1.0f; ++ // x[i] = i + 1; ++ // x[i] = (i % k == i / k) ? 1.0f : 0.0f; ++ } else if (std::is_same()) { ++ x[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f); ++ // x[i] = ggml_fp32_to_fp16(1.0f); ++ // x[i] = ggml_fp32_to_fp16(i + 1); ++ // x[i] = ggml_fp32_to_fp16((i % k == i / k) ? 1.0f : 0.0f); ++ } else { ++ GGML_ABORT("fatal error"); ++ } ++ } ++ for (size_t i = 0; i < y_ne; i++) { ++ if (std::is_same()) { ++ y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f; ++ // y[i] = (i % k == i / k) ? 1.0f : 0.0f; ++ // y[i] = i + 1; ++ } else if (std::is_same()) { ++ y[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f); ++ // y[i] = ggml_fp32_to_fp16((i % k == i / k) ? 1.0f : 0.0f); ++ // y[i] = ggml_fp32_to_fp16(i + 1); ++ } else { ++ GGML_ABORT("fatal error"); ++ } ++ } ++ ++ ggml_vk_buffer_write(d_X, 0, x, sizeof(X_TYPE) * k * m * batch); ++ ggml_vk_buffer_write(d_Y, 0, y, sizeof(Y_TYPE) * k * n * batch); ++ ++ vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue); ++ ggml_vk_ctx_begin(ctx->device, subctx); ++ for (size_t i = 0; i < num_it; i++) { ++ ggml_vk_matmul( ++ ctx, subctx, p, ggml_vk_subbuffer(d_X), ggml_vk_subbuffer(d_Y), ggml_vk_subbuffer(d_D), ggml_vk_subbuffer(ctx->prealloc_split_k), ++ m, n, k, ++ k, k, m, k*m, k*n, m*n, ++ split_k, batch, batch, batch, 1, 1 ++ ); ++ } ++ ggml_vk_ctx_end(subctx); ++ ++ auto begin = std::chrono::high_resolution_clock::now(); ++ ggml_vk_submit(subctx, ctx->fence); ++ VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_matmul waitForFences"); ++ ctx->device->device.resetFences({ ctx->fence }); ++ ++ auto end = std::chrono::high_resolution_clock::now(); ++ double time = std::chrono::duration_cast(end-begin).count() / 1000.0; ++ ++ // copy dst to host ++ ggml_vk_buffer_read(d_D, 0, d, sizeof(float) * d_ne); ++ ++ float * d_chk = (float *) malloc(sizeof(float) * d_ne); ++ ++ ggml_init_params iparams = { ++ /*.mem_size =*/ 1024*1024*1024, ++ /*.mem_buffer =*/ NULL, ++ /*.no_alloc =*/ true, ++ }; ++ ++ ggml_context * ggml_ctx = ggml_init(iparams); ++ ++ ggml_type src0_type; ++ ggml_type src1_type; ++ ++ if (std::is_same()) { ++ src0_type = GGML_TYPE_F32; ++ } else if (std::is_same()) { ++ src0_type = GGML_TYPE_F16; ++ } else { ++ GGML_ABORT("fatal error"); ++ } ++ if (std::is_same()) { ++ src1_type = GGML_TYPE_F32; ++ } else if (std::is_same()) { ++ src1_type = GGML_TYPE_F16; ++ } else { ++ GGML_ABORT("fatal error"); ++ } ++ ++ ggml_tensor * src0_ggml = ggml_new_tensor_3d(ggml_ctx, src0_type, k, m, batch); ++ ggml_tensor * src1_ggml = ggml_new_tensor_3d(ggml_ctx, src1_type, k, n, batch); ++ ggml_tensor * tensor_ggml = ggml_mul_mat(ggml_ctx, src0_ggml, src1_ggml); ++ ++ src0_ggml->data = x; ++ src1_ggml->data = y; ++ tensor_ggml->data = d_chk; ++ ++ ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx); ++ ggml_build_forward_expand(cgraph, tensor_ggml); ++ ++ ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 1); ++ ++ ggml_free(ggml_ctx); ++ ++ double avg_err = 0.0; ++ int first_err_n = -1; ++ int first_err_m = -1; ++ int first_err_b = -1; ++ ++ for (size_t i = 0; i < m*n*batch; i++) { ++ double err = std::fabs(d[i] - d_chk[i]); ++ avg_err += err; ++ ++ if ((err > 0.05f || std::isnan(err)) && first_err_n == -1) { ++ first_err_b = i / (m * n); ++ first_err_n = (i % (m * n)) / m; ++ first_err_m = (i % (m * n)) % m; ++ } ++ } ++ ++ avg_err /= m * n; ++ ++ double tflops = 2.0*m*n*k*batch*num_it / (time / 1000.0) / (1000.0*1000.0*1000.0*1000.0); ++ ++ std::cerr << "TEST " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl; ++ ++ if (avg_err > 0.1 || std::isnan(avg_err)) { ++ std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl; ++ std::cerr << "Actual result: " << std::endl << std::endl; ++ ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); ++ std::cerr << "Expected result: " << std::endl << std::endl; ++ ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); ++ ++ if (split_k > 1) { ++ float * split_k_buf = (float *) malloc(sizeof(float) * d_ne * split_k); ++ ggml_vk_buffer_read(ctx->prealloc_split_k, 0, split_k_buf, sizeof(float) * d_ne * split_k); ++ ++ std::cerr << "d_buf0: " << std::endl << std::endl; ++ ggml_vk_print_matrix_area(split_k_buf, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); ++ ++ std::cerr << "d_buf1: " << std::endl << std::endl; ++ ggml_vk_print_matrix_area(split_k_buf + d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); ++ ++ std::cerr << "d_buf2: " << std::endl << std::endl; ++ ggml_vk_print_matrix_area(split_k_buf + 2 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); ++ ++ std::cerr << "d_buf3: " << std::endl << std::endl; ++ ggml_vk_print_matrix_area(split_k_buf + 3 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); ++ ++ free(split_k_buf); ++ } ++ } ++ ++ free(d_chk); ++ ++ ggml_vk_queue_cleanup(ctx->device, ctx->device->transfer_queue); ++ ggml_vk_queue_cleanup(ctx->device, ctx->device->compute_queue); ++ ++ ggml_vk_destroy_buffer(d_X); ++ ggml_vk_destroy_buffer(d_Y); ++ ggml_vk_destroy_buffer(d_D); ++ ++ ggml_pipeline_cleanup(p); ++ ggml_pipeline_cleanup(ctx->device->pipeline_matmul_split_k_reduce); ++ ++ free(x); ++ free(y); ++ free(d); ++} ++ ++static void ggml_vk_print_tensor_area(const ggml_tensor * tensor, int i0, int i1, int i2, int i3) { ++ if (tensor->type != GGML_TYPE_F32 && tensor->type != GGML_TYPE_F16) { ++ return; ++ } ++ i0 = std::max(i0, 5); ++ i1 = std::max(i1, 5); ++ i2 = std::max(i2, 0); ++ i3 = std::max(i3, 0); ++ fprintf(stderr, " "); ++ for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) { ++ fprintf(stderr, "%7d ", idx1); ++ } ++ fprintf(stderr, "\n"); ++ for (int idx0 = i0 - 5; idx0 < i0 + 5; idx0++) { ++ fprintf(stderr, "%7d: ", idx0); ++ for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) { ++ if (idx0 >= 0 && idx0 < tensor->ne[0] && idx1 >= 0 && idx1 < tensor->ne[1] && i2 >= 0 && i2 < tensor->ne[2] && i3 >= 0 && i3 < tensor->ne[3]) { ++ float val; ++ if (tensor->type == GGML_TYPE_F32) { ++ val = *(float *) ((char *) tensor->data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]); ++ } else if (tensor->type == GGML_TYPE_F16) { ++ val = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) tensor->data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0])); ++ } else { ++ GGML_ABORT("fatal error"); ++ } ++ fprintf(stderr, "% 7.2f ", val); ++ } else { ++ fprintf(stderr, " "); ++ } ++ } ++ fprintf(stderr, "\n"); ++ } ++} ++ ++static void ggml_vk_quantize_data(const float * from, void * to, size_t ne, ggml_type quant) { ++ ggml_quantize_chunk(quant, from, to, 0, 1, ne, nullptr); ++} ++ ++static void ggml_vk_dequantize_data(const void * from, float * to, size_t ne, ggml_type quant) { ++ if (quant == GGML_TYPE_F32) { ++ memcpy(to, from, sizeof(float) * ne); ++ return; ++ } ++ ++ const auto * tt = ggml_get_type_traits(quant); ++ ++ ggml_to_float_t dequant_fn = tt->to_float; ++ ++ dequant_fn(from, to, ne); ++} ++ ++static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_type quant) { ++ VK_LOG_DEBUG("ggml_vk_test_dequant(" << ne << ")"); ++ const size_t x_sz = sizeof(float) * ne; ++ const size_t x_sz_f16 = sizeof(ggml_fp16_t) * ne; ++ const size_t qx_sz = ne * ggml_type_size(quant)/ggml_blck_size(quant); ++ float * x = (float *) malloc(x_sz); ++ void * qx = malloc(qx_sz); ++ vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal); ++ vk_buffer x_buf = ggml_vk_create_buffer_check(ctx->device, x_sz_f16, vk::MemoryPropertyFlagBits::eDeviceLocal); ++ float * x_ref = (float *) malloc(x_sz); ++ ggml_fp16_t * x_chk = (ggml_fp16_t *) malloc(x_sz_f16); ++ ++ for (size_t i = 0; i < ne; i++) { ++ x[i] = rand() / (float)RAND_MAX; ++ } ++ ++ vk_pipeline p = ggml_vk_get_to_fp16(ctx, quant); ++ ++ ggml_vk_quantize_data(x, qx, ne, quant); ++ ggml_vk_dequantize_data(qx, x_ref, ne, quant); ++ ++ ggml_pipeline_request_descriptor_sets(ctx->device, p, 1); ++ ++ ggml_pipeline_allocate_descriptor_sets(ctx->device); ++ ++ ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz); ++ ++ vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue); ++ ggml_vk_ctx_begin(ctx->device, subctx); ++ const std::vector pc = { 1, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne }; ++ ggml_vk_dispatch_pipeline(ctx, subctx, p, { vk_subbuffer{ qx_buf, 0, qx_sz }, vk_subbuffer{ x_buf, 0, x_sz_f16 } }, pc.size() * sizeof(int), pc.data(), { (uint32_t)ne, 1, 1}); ++ ggml_vk_ctx_end(subctx); ++ ++ auto begin = std::chrono::high_resolution_clock::now(); ++ ++ ggml_vk_submit(subctx, ctx->fence); ++ VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_dequant waitForFences"); ++ ctx->device->device.resetFences({ ctx->fence }); ++ ++ auto end = std::chrono::high_resolution_clock::now(); ++ ++ double ms_dequant = std::chrono::duration_cast(end-begin).count() / 1000.0; ++ ggml_vk_buffer_read(x_buf, 0, x_chk, x_sz_f16); ++ ++ int first_err = -1; ++ ++ double avg_err = 0.0; ++ for (size_t i = 0; i < ne; i++) { ++ double error = std::fabs(x_ref[i] - ggml_fp16_to_fp32(x_chk[i])); ++ avg_err += error; ++ ++ if (first_err < 0 && error > 0.05) { ++ first_err = i; ++ } ++ } ++ ++ avg_err /= ne; ++ ++ std::cerr << "TEST DEQUANT " << ggml_type_name(quant) << " time=" << ms_dequant << "ms avg_err=" << avg_err << std::endl; ++ ++ if (avg_err > 0.1) { ++ std::cerr << "first_error = " << first_err << std::endl; ++ std::cerr << "Actual result: " << std::endl << std::endl; ++ for (int i = std::max(0, first_err - 5); i < std::min((int)ne, first_err + 5); i++) { ++ std::cerr << ggml_fp16_to_fp32(x_chk[i]) << ", "; ++ } ++ std::cerr << std::endl << "Expected result: " << std::endl << std::endl; ++ for (int i = std::max(0, first_err - 5); i < std::min((int)ne, first_err + 5); i++) { ++ std::cerr << x_ref[i] << ", "; ++ } ++ std::cerr << std::endl; ++ } ++ ++ ggml_vk_destroy_buffer(x_buf); ++ ggml_vk_destroy_buffer(qx_buf); ++ ++ free(x); ++ free(qx); ++ free(x_ref); ++ free(x_chk); ++} ++ ++static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, size_t split_k, size_t shader_size, ggml_type quant) { ++ VK_LOG_DEBUG("ggml_vk_test_dequant_matmul(" << m << ", " << n << ", " << k << ", " << batch << ", " << num_it << ", " << split_k << ", " << ggml_type_name(quant) << ")"); ++ const size_t x_ne = m * k * batch; ++ const size_t y_ne = k * n * batch; ++ const size_t d_ne = m * n * batch; ++ ++ vk_pipeline p; ++ std::string shname; ++ if (shader_size == 0) { ++ p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_s : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_s; ++ shname = std::string(ggml_type_name(quant)) + "_ALIGNED_S"; ++ } else if (shader_size == 1) { ++ p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_m : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_m; ++ shname = std::string(ggml_type_name(quant)) + "_ALIGNED_M"; ++ } else if (shader_size == 2) { ++ p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_l : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_l; ++ shname = std::string(ggml_type_name(quant)) + "_ALIGNED_L"; ++ } else { ++ GGML_ASSERT(0); ++ } ++ ++ const size_t kpad = ggml_vk_align_size(k, p->align); ++ ++ if (k != kpad) { ++ if (shader_size == 0) { ++ p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->s : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->s; ++ shname = std::string(ggml_type_name(quant)) + "_S"; ++ } else if (shader_size == 1) { ++ p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->m : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->m; ++ shname = std::string(ggml_type_name(quant)) + "_M"; ++ } else if (shader_size == 2) { ++ p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->l : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->l; ++ shname = std::string(ggml_type_name(quant)) + "_L"; ++ } else { ++ GGML_ASSERT(0); ++ } ++ } ++ ++ const size_t x_sz = sizeof(float) * x_ne; ++ const size_t y_sz = sizeof(float) * y_ne; ++ const size_t qx_sz = x_ne * ggml_type_size(quant)/ggml_blck_size(quant); ++ const size_t d_sz = sizeof(float) * d_ne; ++ float * x = (float *) malloc(x_sz); ++ float * y = (float *) malloc(y_sz); ++ void * qx = malloc(qx_sz); ++ vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal); ++ vk_buffer y_buf = ggml_vk_create_buffer_check(ctx->device, y_sz, vk::MemoryPropertyFlagBits::eDeviceLocal); ++ vk_buffer d_buf = ggml_vk_create_buffer_check(ctx->device, d_sz, vk::MemoryPropertyFlagBits::eDeviceLocal); ++ float * d = (float *) malloc(d_sz); ++ float * d_chk = (float *) malloc(d_sz); ++ ++ for (size_t i = 0; i < x_ne; i++) { ++ x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f; ++ } ++ ++ ggml_vk_quantize_data(x, qx, x_ne, quant); ++ ++ for (size_t i = 0; i < y_ne; i++) { ++ // y[i] = rand() / (float)RAND_MAX; ++ y[i] = (i % k == i / k) ? 1.0f : 0.0f; ++ } ++ ++ ggml_pipeline_request_descriptor_sets(ctx->device, p, num_it); ++ if (split_k > 1) { ++ ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, num_it); ++ ++ if (ctx->prealloc_split_k == nullptr || ctx->prealloc_split_k->size < sizeof(float) * d_ne * split_k) { ++ // Resize buffer ++ if (ctx->prealloc_split_k != nullptr) { ++ ggml_vk_destroy_buffer(ctx->prealloc_split_k); ++ } ++ ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, vk::MemoryPropertyFlagBits::eDeviceLocal); ++ } ++ } ++ ++ ggml_pipeline_allocate_descriptor_sets(ctx->device); ++ ++ ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz); ++ ggml_vk_buffer_write(y_buf, 0, y, y_sz); ++ ++ vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue); ++ ggml_vk_ctx_begin(ctx->device, subctx); ++ for (size_t i = 0; i < num_it; i++) { ++ ggml_vk_matmul( ++ ctx, subctx, p, ggml_vk_subbuffer(qx_buf), ggml_vk_subbuffer(y_buf), ggml_vk_subbuffer(d_buf), ggml_vk_subbuffer(ctx->prealloc_split_k), ++ m, n, k, ++ k, k, m, k*m, k*n, m*n, ++ split_k, batch, batch, batch, 1, 1 ++ ); ++ } ++ ggml_vk_ctx_end(subctx); ++ ++ auto begin = std::chrono::high_resolution_clock::now(); ++ ++ ggml_vk_submit(subctx, ctx->fence); ++ VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_dequant waitForFences"); ++ ctx->device->device.resetFences({ ctx->fence }); ++ ++ auto end = std::chrono::high_resolution_clock::now(); ++ ++ double time_ms = std::chrono::duration_cast(end-begin).count() / 1000.0; ++ ggml_vk_buffer_read(d_buf, 0, d, d_sz); ++ ++ ggml_init_params iparams = { ++ /*.mem_size =*/ 1024*1024*1024, ++ /*.mem_buffer =*/ NULL, ++ /*.no_alloc =*/ true, ++ }; ++ ++ ggml_context * ggml_ctx = ggml_init(iparams); ++ ++ ggml_tensor * src0_ggml = ggml_new_tensor_3d(ggml_ctx, quant, k, m, batch); ++ ggml_tensor * src1_ggml = ggml_new_tensor_3d(ggml_ctx, GGML_TYPE_F32, k, n, batch); ++ ggml_tensor * tensor_ggml = ggml_mul_mat(ggml_ctx, src0_ggml, src1_ggml); ++ ++ src0_ggml->data = qx; ++ src1_ggml->data = y; ++ tensor_ggml->data = d_chk; ++ ++ ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx); ++ ggml_build_forward_expand(cgraph, tensor_ggml); ++ ++ ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 1); ++ ++ ggml_free(ggml_ctx); ++ ++ double avg_err = 0.0; ++ int first_err_n = -1; ++ int first_err_m = -1; ++ int first_err_b = -1; ++ ++ for (size_t i = 0; i < m*n*batch; i++) { ++ double err = std::fabs(d[i] - d_chk[i]); ++ avg_err += err; ++ ++ if ((err > 0.05f || std::isnan(err)) && first_err_n == -1) { ++ first_err_b = i / (m * n); ++ first_err_n = (i % (m * n)) / m; ++ first_err_m = (i % (m * n)) % m; ++ } ++ } ++ ++ avg_err /= m * n; ++ ++ double tflops = 2.0*m*n*k*batch*num_it / (time_ms / 1000.0) / (1000.0*1000.0*1000.0*1000.0); ++ ++ std::cerr << "TEST MMQ " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time_ms / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl; ++ ++ if (avg_err > 0.01 || std::isnan(avg_err)) { ++ std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl; ++ std::cerr << "Actual result: " << std::endl << std::endl; ++ ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); ++ std::cerr << std::endl; ++ std::cerr << "Expected result: " << std::endl << std::endl; ++ ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); ++ ++ if (split_k > 1) { ++ float * split_k_buf = (float *) malloc(sizeof(float) * d_ne * split_k); ++ ggml_vk_buffer_read(ctx->prealloc_split_k, 0, split_k_buf, sizeof(float) * d_ne * split_k); ++ ++ std::cerr << "d_buf0: " << std::endl << std::endl; ++ ggml_vk_print_matrix_area(split_k_buf, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); ++ ++ std::cerr << "d_buf1: " << std::endl << std::endl; ++ ggml_vk_print_matrix_area(split_k_buf + d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); ++ ++ std::cerr << "d_buf2: " << std::endl << std::endl; ++ ggml_vk_print_matrix_area(split_k_buf + 2 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); ++ ++ std::cerr << "d_buf3: " << std::endl << std::endl; ++ ggml_vk_print_matrix_area(split_k_buf + 3 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); ++ ++ free(split_k_buf); ++ } ++ } ++ ++ ggml_vk_destroy_buffer(qx_buf); ++ ggml_vk_destroy_buffer(y_buf); ++ ggml_vk_destroy_buffer(d_buf); ++ ++ free(x); ++ free(qx); ++ free(y); ++ free(d); ++ free(d_chk); ++} ++#endif ++ ++static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) { ++#if defined(GGML_VULKAN_RUN_TESTS) ++ const std::vector vals { ++ 512, 512, 128, ++ 128, 512, 512, ++ 4096, 512, 4096, ++ 11008, 512, 4096, ++ 4096, 512, 11008, ++ 32000, 512, 4096, ++ 8, 8, 8, ++ 100, 46, 576, ++ 623, 111, 128, ++ 100, 46, 558, ++ 512, 1, 256, ++ 128, 110, 622, ++ 511, 511, 127, ++ 511, 511, 7, ++ 511, 511, 17, ++ 49, 49, 128, ++ 128, 49, 49, ++ 4096, 49, 4096, ++ }; ++ const size_t num_it = 100; ++ ++ for (size_t i = 0; i < vals.size(); i += 3) { ++ ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0); ++ ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1); ++ ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2); ++ std::cerr << '\n'; ++ ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 0); ++ ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 1); ++ ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 2); ++ std::cerr << '\n'; ++ ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0); ++ ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1); ++ ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2); ++ std::cerr << '\n' << std::endl; ++ ++ if (vals[i + 2] % 32 == 0) { ++ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0, GGML_TYPE_Q4_0); ++ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1, GGML_TYPE_Q4_0); ++ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2, GGML_TYPE_Q4_0); ++ std::cerr << '\n'; ++ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 0, GGML_TYPE_Q4_0); ++ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 1, GGML_TYPE_Q4_0); ++ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 2, GGML_TYPE_Q4_0); ++ std::cerr << '\n'; ++ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0, GGML_TYPE_Q4_0); ++ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1, GGML_TYPE_Q4_0); ++ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2, GGML_TYPE_Q4_0); ++ std::cerr << '\n' << std::endl; ++ } ++ ++ if (vals[i + 2] % 256 == 0) { ++ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0, GGML_TYPE_Q4_K); ++ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1, GGML_TYPE_Q4_K); ++ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2, GGML_TYPE_Q4_K); ++ std::cerr << '\n'; ++ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 0, GGML_TYPE_Q4_K); ++ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 1, GGML_TYPE_Q4_K); ++ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 2, GGML_TYPE_Q4_K); ++ std::cerr << '\n'; ++ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0, GGML_TYPE_Q4_K); ++ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1, GGML_TYPE_Q4_K); ++ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2, GGML_TYPE_Q4_K); ++ std::cerr << '\n' << std::endl; ++ } ++ } ++ ++ GGML_ABORT("fatal error"); ++#endif ++ ++ if (ctx->prealloc_x == nullptr || (ctx->prealloc_size_x > 0 && ctx->prealloc_x->size < ctx->prealloc_size_x)) { ++ VK_LOG_MEMORY("ggml_vk_preallocate_buffers(x_size: " << ctx->prealloc_size_x << ")"); ++ // Resize buffer ++ if (ctx->prealloc_x != nullptr) { ++ ggml_vk_destroy_buffer(ctx->prealloc_x); ++ } ++ ctx->prealloc_x = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_x); ++ } ++ if (ctx->prealloc_y == nullptr || (ctx->prealloc_size_y > 0 && ctx->prealloc_y->size < ctx->prealloc_size_y)) { ++ VK_LOG_MEMORY("ggml_vk_preallocate_buffers(y_size: " << ctx->prealloc_size_y << ")"); ++ // Resize buffer ++ if (ctx->prealloc_y != nullptr) { ++ ggml_vk_destroy_buffer(ctx->prealloc_y); ++ } ++ ctx->prealloc_y = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_y); ++ } ++ if (ctx->prealloc_split_k == nullptr || (ctx->prealloc_size_split_k > 0 && ctx->prealloc_split_k->size < ctx->prealloc_size_split_k)) { ++ VK_LOG_MEMORY("ggml_vk_preallocate_buffers(split_k_size: " << ctx->prealloc_size_split_k << ")"); ++ // Resize buffer ++ if (ctx->prealloc_split_k != nullptr) { ++ ggml_vk_destroy_buffer(ctx->prealloc_split_k); ++ } ++ ctx->prealloc_split_k = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_split_k); ++ } ++} ++ ++static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_tensor* tensor, int tensor_idx, bool use_fence); ++ ++// Returns true if node has enqueued work into the queue, false otherwise ++// If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution. ++static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * node, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool submit){ ++ if (ggml_is_empty(node) || !node->buffer) { ++ return false; ++ } ++ ++ VK_LOG_DEBUG("ggml_vk_build_graph(" << node << ", " << ggml_op_name(node->op) << ")"); ++ ctx->semaphore_idx = 0; ++ ++ const ggml_tensor * src0 = node->src[0]; ++ const ggml_tensor * src1 = node->src[1]; ++ const ggml_tensor * src2 = node->src[2]; ++ const ggml_tensor * src3 = node->src[3]; ++ ++ switch (node->op) { ++ // Return on empty ops to avoid generating a compute_ctx and setting exit_tensor ++ case GGML_OP_RESHAPE: ++ case GGML_OP_VIEW: ++ case GGML_OP_PERMUTE: ++ case GGML_OP_TRANSPOSE: ++ case GGML_OP_NONE: ++ return false; ++ case GGML_OP_UNARY: ++ switch (ggml_get_unary_op(node)) { ++ case GGML_UNARY_OP_SILU: ++ case GGML_UNARY_OP_GELU: ++ case GGML_UNARY_OP_GELU_QUICK: ++ case GGML_UNARY_OP_RELU: ++ case GGML_UNARY_OP_TANH: ++ break; ++ default: ++ return false; ++ } ++ break; ++ case GGML_OP_REPEAT: ++ case GGML_OP_GET_ROWS: ++ case GGML_OP_ADD: ++ case GGML_OP_ACC: ++ case GGML_OP_MUL: ++ case GGML_OP_DIV: ++ case GGML_OP_CONCAT: ++ case GGML_OP_UPSCALE: ++ case GGML_OP_SCALE: ++ case GGML_OP_SQR: ++ case GGML_OP_SIN: ++ case GGML_OP_COS: ++ case GGML_OP_CLAMP: ++ case GGML_OP_PAD: ++ case GGML_OP_CPY: ++ case GGML_OP_CONT: ++ case GGML_OP_DUP: ++ case GGML_OP_NORM: ++ case GGML_OP_GROUP_NORM: ++ case GGML_OP_RMS_NORM: ++ case GGML_OP_DIAG_MASK_INF: ++ case GGML_OP_SOFT_MAX: ++ case GGML_OP_ROPE: ++ case GGML_OP_MUL_MAT: ++ case GGML_OP_MUL_MAT_ID: ++ case GGML_OP_ARGSORT: ++ case GGML_OP_SUM_ROWS: ++ case GGML_OP_IM2COL: ++ case GGML_OP_TIMESTEP_EMBEDDING: ++ case GGML_OP_POOL_2D: ++ case GGML_OP_RWKV_WKV6: ++ case GGML_OP_LEAKY_RELU: ++ case GGML_OP_FLASH_ATTN_EXT: ++ break; ++ default: ++ std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl; ++ GGML_ABORT("fatal error"); ++ return false; ++ } ++ ++ vk_context compute_ctx; ++ ++ if (!dryrun) { ++ if (ctx->compute_ctx.expired()) { ++ compute_ctx = ggml_vk_create_context(ctx, ctx->device->compute_queue); ++ ctx->compute_ctx = compute_ctx; ++ ggml_vk_ctx_begin(ctx->device, compute_ctx); ++ } else { ++ compute_ctx = ctx->compute_ctx.lock(); ++ } ++ } else { ++ switch (node->op) { ++ case GGML_OP_REPEAT: ++ case GGML_OP_ACC: ++ case GGML_OP_GET_ROWS: ++ case GGML_OP_ADD: ++ case GGML_OP_MUL: ++ case GGML_OP_DIV: ++ case GGML_OP_CONCAT: ++ case GGML_OP_UPSCALE: ++ case GGML_OP_SCALE: ++ case GGML_OP_SQR: ++ case GGML_OP_SIN: ++ case GGML_OP_COS: ++ case GGML_OP_CLAMP: ++ case GGML_OP_PAD: ++ case GGML_OP_CPY: ++ case GGML_OP_CONT: ++ case GGML_OP_DUP: ++ case GGML_OP_NORM: ++ case GGML_OP_GROUP_NORM: ++ case GGML_OP_RMS_NORM: ++ case GGML_OP_UNARY: ++ case GGML_OP_DIAG_MASK_INF: ++ case GGML_OP_SOFT_MAX: ++ case GGML_OP_ROPE: ++ case GGML_OP_ARGSORT: ++ case GGML_OP_SUM_ROWS: ++ case GGML_OP_IM2COL: ++ case GGML_OP_TIMESTEP_EMBEDDING: ++ case GGML_OP_POOL_2D: ++ case GGML_OP_LEAKY_RELU: ++ { ++ // These operations all go through ggml_vk_op_f32, so short-circuit and ++ // do the only thing needed for the dryrun. ++ vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, node, node->op); ++ ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); ++ return false; ++ } ++ default: ++ break; ++ } ++ } ++ ++ switch (node->op) { ++ case GGML_OP_REPEAT: ++ ggml_vk_repeat(ctx, compute_ctx, src0, node, dryrun); ++ ++ break; ++ case GGML_OP_ACC: ++ ggml_vk_acc(ctx, compute_ctx, src0, src1, node, dryrun); ++ ++ break; ++ case GGML_OP_GET_ROWS: ++ ggml_vk_get_rows(ctx, compute_ctx, src0, src1, node, dryrun); ++ ++ break; ++ case GGML_OP_ADD: ++ ggml_vk_add(ctx, compute_ctx, src0, src1, node, dryrun); ++ ++ break; ++ case GGML_OP_MUL: ++ ggml_vk_mul(ctx, compute_ctx, src0, src1, node, dryrun); ++ ++ break; ++ case GGML_OP_DIV: ++ ggml_vk_div(ctx, compute_ctx, src0, src1, node, dryrun); ++ ++ break; ++ case GGML_OP_CONCAT: ++ ggml_vk_concat(ctx, compute_ctx, src0, src1, node, dryrun); ++ ++ break; ++ case GGML_OP_UPSCALE: ++ ggml_vk_upscale(ctx, compute_ctx, src0, node, dryrun); ++ ++ break; ++ case GGML_OP_SCALE: ++ ggml_vk_scale(ctx, compute_ctx, src0, node, dryrun); ++ ++ break; ++ case GGML_OP_SQR: ++ ggml_vk_sqr(ctx, compute_ctx, src0, node, dryrun); ++ ++ break; ++ case GGML_OP_SIN: ++ ggml_vk_sin(ctx, compute_ctx, src0, node, dryrun); ++ ++ break; ++ case GGML_OP_COS: ++ ggml_vk_cos(ctx, compute_ctx, src0, node, dryrun); ++ ++ break; ++ case GGML_OP_CLAMP: ++ ggml_vk_clamp(ctx, compute_ctx, src0, node, dryrun); ++ ++ break; ++ case GGML_OP_PAD: ++ ggml_vk_pad(ctx, compute_ctx, src0, node, dryrun); ++ ++ break; ++ case GGML_OP_CPY: ++ case GGML_OP_CONT: ++ case GGML_OP_DUP: ++ ggml_vk_cpy(ctx, compute_ctx, src0, node, dryrun); ++ ++ break; ++ case GGML_OP_NORM: ++ ggml_vk_norm(ctx, compute_ctx, src0, node, dryrun); ++ ++ break; ++ case GGML_OP_GROUP_NORM: ++ ggml_vk_group_norm(ctx, compute_ctx, src0, node, dryrun); ++ ++ break; ++ case GGML_OP_RMS_NORM: ++ ggml_vk_rms_norm(ctx, compute_ctx, src0, node, dryrun); ++ ++ break; ++ case GGML_OP_UNARY: ++ switch (ggml_get_unary_op(node)) { ++ case GGML_UNARY_OP_SILU: ++ case GGML_UNARY_OP_GELU: ++ case GGML_UNARY_OP_GELU_QUICK: ++ case GGML_UNARY_OP_RELU: ++ case GGML_UNARY_OP_TANH: ++ ggml_vk_unary(ctx, compute_ctx, src0, node, dryrun); ++ break; ++ default: ++ return false; ++ } ++ break; ++ case GGML_OP_DIAG_MASK_INF: ++ ggml_vk_diag_mask_inf(ctx, compute_ctx, src0, node, dryrun); ++ ++ break; ++ case GGML_OP_SOFT_MAX: ++ ggml_vk_soft_max(ctx, compute_ctx, src0, src1, node, dryrun); ++ ++ break; ++ case GGML_OP_ROPE: ++ ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, dryrun); ++ ++ break; ++ case GGML_OP_ARGSORT: ++ ggml_vk_argsort(ctx, compute_ctx, src0, node, dryrun); ++ ++ break; ++ case GGML_OP_SUM_ROWS: ++ ggml_vk_sum_rows(ctx, compute_ctx, src0, node, dryrun); ++ ++ break; ++ case GGML_OP_IM2COL: ++ ggml_vk_im2col(ctx, compute_ctx, src0, src1, node, dryrun); ++ ++ break; ++ case GGML_OP_TIMESTEP_EMBEDDING: ++ ggml_vk_timestep_embedding(ctx, compute_ctx, src0, node, dryrun); ++ ++ break; ++ case GGML_OP_POOL_2D: ++ ggml_vk_pool_2d(ctx, compute_ctx, src0, node, dryrun); ++ ++ break; ++ case GGML_OP_LEAKY_RELU: ++ ggml_vk_leaky_relu(ctx, compute_ctx, src0, node, dryrun); ++ ++ break; ++ case GGML_OP_MUL_MAT: ++ ggml_vk_mul_mat(ctx, compute_ctx, src0, src1, node, dryrun); ++ ++ break; ++ case GGML_OP_MUL_MAT_ID: ++ ggml_vk_mul_mat_id(ctx, compute_ctx, src0, src1, src2, node, dryrun); ++ ++ break; ++ ++ case GGML_OP_FLASH_ATTN_EXT: ++ ggml_vk_flash_attn(ctx, compute_ctx, src0, src1, src2, src3, node, dryrun); ++ ++ break; ++ ++ case GGML_OP_RWKV_WKV6: ++ ggml_vk_rwkv_wkv6(ctx, compute_ctx, node, dryrun); ++ ++ break; ++ default: ++ return false; ++ } ++ ++ if (dryrun) { ++ return false; ++ } ++ ++ ctx->tensor_ctxs[node_idx] = compute_ctx; ++ ++#if defined(GGML_VULKAN_CHECK_RESULTS) || defined(GGML_VULKAN_PERF) ++ // Force context reset on each node so that each tensor ends up in its own context ++ // and can be run and compared to its CPU equivalent separately ++ last_node = true; ++#endif ++ ++ if (submit || last_node) { ++ ggml_vk_ctx_end(compute_ctx); ++ ++ // TODO probably it'd be better to pass a exit_node flag to ggml_vk_compute_forward ++ if (last_node) { ++ compute_ctx->exit_tensor_idx = node_idx_begin; ++ } ++ else { ++ compute_ctx->exit_tensor_idx = -1; ++ } ++ ++ ctx->compute_ctx.reset(); ++ ++ bool ok = ggml_vk_compute_forward(ctx, node_begin, node_idx_begin, false); ++ if (!ok) { ++ if (node->op == GGML_OP_UNARY) { ++ std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast(node->op_params[0])) << ")" << std::endl; ++ } ++ else { ++ std::cerr << __func__ << ": error: op not supported " << node->name << " (" << ggml_op_name(node->op) << ")" << std::endl; ++ } ++ } ++ ++ } ++ return true; ++} ++ ++static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * tensor, int tensor_idx, bool use_fence = true){ ++ ggml_backend_buffer * buf = nullptr; ++ ++ switch (tensor->op) { ++ case GGML_OP_ADD: ++ case GGML_OP_ACC: ++ case GGML_OP_GET_ROWS: ++ case GGML_OP_MUL: ++ case GGML_OP_DIV: ++ case GGML_OP_CONCAT: ++ case GGML_OP_UPSCALE: ++ case GGML_OP_SCALE: ++ case GGML_OP_SQR: ++ case GGML_OP_SIN: ++ case GGML_OP_COS: ++ case GGML_OP_CLAMP: ++ case GGML_OP_PAD: ++ case GGML_OP_CPY: ++ case GGML_OP_CONT: ++ case GGML_OP_DUP: ++ case GGML_OP_NORM: ++ case GGML_OP_GROUP_NORM: ++ case GGML_OP_RMS_NORM: ++ case GGML_OP_DIAG_MASK_INF: ++ case GGML_OP_SOFT_MAX: ++ case GGML_OP_ROPE: ++ case GGML_OP_RESHAPE: ++ case GGML_OP_VIEW: ++ case GGML_OP_PERMUTE: ++ case GGML_OP_TRANSPOSE: ++ case GGML_OP_NONE: ++ case GGML_OP_ARGSORT: ++ case GGML_OP_SUM_ROWS: ++ case GGML_OP_IM2COL: ++ case GGML_OP_TIMESTEP_EMBEDDING: ++ case GGML_OP_POOL_2D: ++ case GGML_OP_RWKV_WKV6: ++ case GGML_OP_LEAKY_RELU: ++ case GGML_OP_REPEAT: ++ buf = tensor->buffer; ++ ++ break; ++ case GGML_OP_UNARY: ++ switch (ggml_get_unary_op(tensor)) { ++ case GGML_UNARY_OP_SILU: ++ case GGML_UNARY_OP_GELU: ++ case GGML_UNARY_OP_GELU_QUICK: ++ case GGML_UNARY_OP_RELU: ++ case GGML_UNARY_OP_TANH: ++ buf = tensor->buffer; ++ break; ++ default: ++ return false; ++ } ++ break; ++ case GGML_OP_MUL_MAT: ++ case GGML_OP_MUL_MAT_ID: ++ case GGML_OP_FLASH_ATTN_EXT: ++ buf = tensor->buffer; ++ ++ break; ++ default: ++ return false; ++ } ++ ++ if (buf == nullptr) { ++ return false; ++ } ++ ++ VK_LOG_DEBUG("ggml_vk_compute_forward(" << tensor << ", name=" << tensor->name << ", op=" << ggml_op_name(tensor->op) << ", type=" << tensor->type << ", ne0=" << tensor->ne[0] << ", ne1=" << tensor->ne[1] << ", ne2=" << tensor->ne[2] << ", ne3=" << tensor->ne[3] << ", nb0=" << tensor->nb[0] << ", nb1=" << tensor->nb[1] << ", nb2=" << tensor->nb[2] << ", nb3=" << tensor->nb[3] << ", view_src=" << tensor->view_src << ", view_offs=" << tensor->view_offs << ")"); ++ ++ vk_context subctx = ctx->tensor_ctxs[tensor_idx].lock(); ++ ++ // always wait for the GPU work to be done for the last submit ++ if (tensor_idx == subctx->exit_tensor_idx) { ++ use_fence = true; ++ } ++ ++ // Only run if ctx hasn't been submitted yet ++ if (!subctx->seqs.empty()) { ++#ifdef GGML_VULKAN_CHECK_RESULTS ++ ggml_vk_check_results_0(tensor); ++ use_fence = true; ++#endif ++ ++ // Do staging buffer copies ++ for (auto& cpy : subctx->in_memcpys) { ++ memcpy(cpy.dst, cpy.src, cpy.n); ++ } ++ ++ ggml_vk_submit(subctx, use_fence ? ctx->fence : vk::Fence{}); ++ ++ if (use_fence) { ++ VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_compute_forward waitForFences"); ++ ++ ctx->device->device.resetFences({ ctx->fence }); ++ } ++#ifdef GGML_VULKAN_CHECK_RESULTS ++ ggml_vk_check_results_1(tensor); ++#endif ++ } ++ ++ if (tensor_idx == subctx->exit_tensor_idx) { ++ // Do staging buffer copies ++ for (auto& cpy : subctx->out_memcpys) { ++ memcpy(cpy.dst, cpy.src, cpy.n); ++ } ++ subctx->in_memcpys.clear(); ++ subctx->out_memcpys.clear(); ++ } ++ ++ return true; ++} ++ ++// Clean up after graph processing is done ++static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) { ++ VK_LOG_DEBUG("ggml_vk_graph_cleanup()"); ++ for (auto& buffer : ctx->gc.temp_buffers) { ++ ggml_vk_pool_free(ctx, buffer); ++ } ++ ctx->gc.temp_buffers.clear(); ++ ++ for (auto& dsr : ctx->device->pipeline_descriptor_set_requirements) { ++ vk_pipeline_ref plr = ctx->device->pipelines[dsr.first]; ++ ++ if (plr.expired()) { ++ continue; ++ } ++ ++ vk_pipeline pl = plr.lock(); ++ ggml_pipeline_cleanup(pl); ++ } ++ ++ ggml_vk_queue_cleanup(ctx->device, ctx->device->compute_queue); ++ ggml_vk_queue_cleanup(ctx->device, ctx->device->transfer_queue); ++ ++ for (size_t i = 0; i < ctx->gc.semaphores.size(); i++) { ++ ctx->device->device.destroySemaphore({ ctx->gc.semaphores[i].s }); ++ } ++ ctx->gc.semaphores.clear(); ++ ++ for (size_t i = 0; i < ctx->gc.tl_semaphores.size(); i++) { ++ ctx->device->device.destroySemaphore({ ctx->gc.tl_semaphores[i].s }); ++ } ++ ctx->gc.tl_semaphores.clear(); ++ ctx->semaphore_idx = 0; ++ ++ ctx->event_idx = 0; ++ ++ for (auto& event : ctx->gc.events) { ++ ctx->device->device.resetEvent(event); ++ } ++ ++ ctx->tensor_ctxs.clear(); ++ ctx->gc.contexts.clear(); ++ ctx->device->pipeline_descriptor_set_requirements.clear(); ++} ++ ++// Clean up on backend free ++static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) { ++ VK_LOG_DEBUG("ggml_vk_cleanup(" << ctx->name << ")"); ++ ggml_vk_graph_cleanup(ctx); ++ ++ ggml_vk_destroy_buffer(ctx->prealloc_x); ++ ggml_vk_destroy_buffer(ctx->prealloc_y); ++ ggml_vk_destroy_buffer(ctx->prealloc_split_k); ++ ++ for (auto& buffer : ctx->buffer_pool) { ++ ggml_vk_destroy_buffer(buffer); ++ } ++ ++ ctx->prealloc_size_x = 0; ++ ctx->prealloc_size_y = 0; ++ ctx->prealloc_size_split_k = 0; ++ ++ for (auto& event : ctx->gc.events) { ++ ctx->device->device.destroyEvent(event); ++ } ++ ctx->gc.events.clear(); ++ ++ ctx->device->device.destroyFence(ctx->fence); ++} ++ ++static int ggml_vk_get_device_count() { ++ ggml_vk_instance_init(); ++ ++ return vk_instance.device_indices.size(); ++} ++ ++static void ggml_vk_get_device_description(int device, char * description, size_t description_size) { ++ ggml_vk_instance_init(); ++ ++ std::vector devices = vk_instance.instance.enumeratePhysicalDevices(); ++ ++ vk::PhysicalDeviceProperties props; ++ devices[device].getProperties(&props); ++ ++ snprintf(description, description_size, "%s", props.deviceName.data()); ++} ++ ++// backend interface ++ ++#define UNUSED GGML_UNUSED ++ ++// device backend ++ ++static bool ggml_backend_buffer_is_vk(ggml_backend_buffer_t buffer) { ++ return buffer->buft->iface.get_name == ggml_backend_vk_buffer_type_name; ++} ++ ++static void ggml_backend_vk_buffer_free_buffer(ggml_backend_buffer_t buffer) { ++ VK_LOG_MEMORY("ggml_backend_vk_buffer_free_buffer()"); ++ ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context; ++ ggml_vk_destroy_buffer(ctx->dev_buffer); ++ delete ctx; ++} ++ ++static void * ggml_backend_vk_buffer_get_base(ggml_backend_buffer_t buffer) { ++ return vk_ptr_base; ++ ++ UNUSED(buffer); ++} ++ ++static void ggml_backend_vk_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { ++ VK_LOG_DEBUG("ggml_backend_vk_buffer_init_tensor(" << buffer << " (" << buffer->context << "), " << tensor << ")"); ++ if (tensor->view_src != nullptr) { ++ GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft); ++ } ++} ++ ++static void ggml_backend_vk_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { ++ VK_LOG_DEBUG("ggml_backend_vk_buffer_set_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")"); ++ ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context; ++ vk_buffer buf = buf_ctx->dev_buffer; ++ ++ ggml_vk_buffer_write(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size); ++} ++ ++static void ggml_backend_vk_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { ++ VK_LOG_DEBUG("ggml_backend_vk_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")"); ++ ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context; ++ ++ vk_buffer buf = buf_ctx->dev_buffer; ++ ++ ggml_vk_buffer_read(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size); ++} ++ ++static bool ggml_backend_vk_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) { ++ if (ggml_backend_buffer_is_vk(src->buffer)) { ++ ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context; ++ ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; ++ ++ vk_buffer src_buf = src_buf_ctx->dev_buffer; ++ vk_buffer dst_buf = dst_buf_ctx->dev_buffer; ++ ++ ggml_vk_buffer_copy(dst_buf, vk_tensor_offset(dst) + dst->view_offs, src_buf, vk_tensor_offset(src) + src->view_offs, ggml_nbytes(src)); ++ ++ return true; ++ } ++ return false; ++ ++ UNUSED(buffer); ++} ++ ++static void ggml_backend_vk_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { ++ ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context; ++ ++ ggml_vk_buffer_memset(ctx->dev_buffer, 0, value, buffer->size); ++} ++ ++static ggml_backend_buffer_i ggml_backend_vk_buffer_interface = { ++ /* .free_buffer = */ ggml_backend_vk_buffer_free_buffer, ++ /* .get_base = */ ggml_backend_vk_buffer_get_base, ++ /* .init_tensor = */ ggml_backend_vk_buffer_init_tensor, ++ /* .memset_tensor = */ NULL, ++ /* .set_tensor = */ ggml_backend_vk_buffer_set_tensor, ++ /* .get_tensor = */ ggml_backend_vk_buffer_get_tensor, ++ /* .cpy_tensor = */ ggml_backend_vk_buffer_cpy_tensor, ++ /* .clear = */ ggml_backend_vk_buffer_clear, ++ /* .reset = */ NULL, ++}; ++ ++// vk buffer type ++static const char * ggml_backend_vk_buffer_type_name(ggml_backend_buffer_type_t buft) { ++ ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *)buft->context; ++ ++ return ctx->name.c_str(); ++} ++ ++static ggml_backend_buffer_t ggml_backend_vk_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { ++ VK_LOG_MEMORY("ggml_backend_vk_buffer_type_alloc_buffer(" << size << ")"); ++ ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context; ++ ++ vk_buffer dev_buffer = nullptr; ++ try { ++ dev_buffer = ggml_vk_create_buffer_device(ctx->device, size); ++ } catch (const vk::SystemError& e) { ++ return nullptr; ++ } ++ ++ ggml_backend_vk_buffer_context * bufctx = new ggml_backend_vk_buffer_context(ctx->device, std::move(dev_buffer), ctx->name); ++ ++ return ggml_backend_buffer_init(buft, ggml_backend_vk_buffer_interface, bufctx, size); ++} ++ ++static size_t ggml_backend_vk_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { ++ ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context; ++ return ctx->device->properties.limits.minStorageBufferOffsetAlignment; ++} ++ ++static size_t ggml_backend_vk_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) { ++ ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context; ++ return ctx->device->max_memory_allocation_size; ++} ++ ++static size_t ggml_backend_vk_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { ++ return ggml_nbytes(tensor); ++ ++ UNUSED(buft); ++} ++ ++ggml_backend_buffer_type_t ggml_backend_vk_buffer_type(size_t dev_num) { ++ ggml_vk_instance_init(); ++ ++ VK_LOG_DEBUG("ggml_backend_vk_buffer_type(" << dev_num << ")"); ++ ++ vk_device dev = ggml_vk_get_device(dev_num); ++ ++ return &dev->buffer_type; ++} ++ ++// host buffer type ++ ++static const char * ggml_backend_vk_host_buffer_type_name(ggml_backend_buffer_type_t buft) { ++ return GGML_VK_NAME "_Host"; ++ ++ UNUSED(buft); ++} ++ ++static const char * ggml_backend_vk_host_buffer_name(ggml_backend_buffer_t buffer) { ++ return GGML_VK_NAME "_Host"; ++ ++ UNUSED(buffer); ++} ++ ++static void ggml_backend_vk_host_buffer_free_buffer(ggml_backend_buffer_t buffer) { ++ VK_LOG_MEMORY("ggml_backend_vk_host_buffer_free_buffer()"); ++ ggml_vk_host_free(vk_instance.devices[0], buffer->context); ++} ++ ++static ggml_backend_buffer_t ggml_backend_vk_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { ++ VK_LOG_MEMORY("ggml_backend_vk_host_buffer_type_alloc_buffer(" << size << ")"); ++ ++ size += 32; // Behave like the CPU buffer type ++ void * ptr = nullptr; ++ try { ++ ptr = ggml_vk_host_malloc(vk_instance.devices[0], size); ++ } catch (vk::SystemError& e) { ++ std::cerr << "ggml_vulkan: Failed to allocate pinned memory." << std::endl; ++ std::cerr << "ggml_vulkan: " << e.what() << std::endl; ++ // fallback to cpu buffer ++ return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size); ++ } ++ ++ ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size); ++ buffer->buft = buft; ++ buffer->iface.free_buffer = ggml_backend_vk_host_buffer_free_buffer; ++ ++ return buffer; ++ ++ UNUSED(buft); ++} ++ ++static size_t ggml_backend_vk_host_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { ++ return vk_instance.devices[0]->properties.limits.minMemoryMapAlignment; ++ ++ UNUSED(buft); ++} ++ ++// Should be changed to return device-specific host buffer type ++// but that probably requires changes in llama.cpp ++ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type() { ++ static struct ggml_backend_buffer_type ggml_backend_vk_buffer_type_host = { ++ /* .iface = */ { ++ /* .get_name = */ ggml_backend_vk_host_buffer_type_name, ++ /* .alloc_buffer = */ ggml_backend_vk_host_buffer_type_alloc_buffer, ++ /* .get_alignment = */ ggml_backend_vk_host_buffer_type_get_alignment, ++ /* .get_max_size = */ NULL, // defaults to SIZE_MAX ++ /* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size, ++ /* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host, ++ }, ++ /* .device = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), 0), ++ /* .context = */ nullptr, ++ }; ++ ++ // Make sure device 0 is initialized ++ ggml_vk_instance_init(); ++ ggml_vk_get_device(0); ++ ++ return &ggml_backend_vk_buffer_type_host; ++} ++ ++ ++// backend ++ ++static const char * ggml_backend_vk_name(ggml_backend_t backend) { ++ ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; ++ ++ return ctx->name.c_str(); ++} ++ ++static void ggml_backend_vk_free(ggml_backend_t backend) { ++ ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; ++ VK_LOG_DEBUG("ggml_backend_vk_free(" << ctx->name << ")"); ++ ++ ggml_vk_cleanup(ctx); ++ ++ delete ctx; ++ delete backend; ++} ++ ++static ggml_backend_buffer_type_t ggml_backend_vk_get_default_buffer_type(ggml_backend_t backend) { ++ ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; ++ ++ return &ctx->device->buffer_type; ++} ++ ++static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { ++ VK_LOG_DEBUG("ggml_backend_vk_set_tensor_async(" << size << ")"); ++ ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; ++ GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type"); ++ ++ ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context; ++ ++ vk_context transfer_ctx; ++ ++ if (ctx->transfer_ctx.expired()) { ++ // Initialize new transfer context ++ transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue); ++ ctx->transfer_ctx = transfer_ctx; ++ ggml_vk_ctx_begin(ctx->device, transfer_ctx); ++ } else { ++ transfer_ctx = ctx->transfer_ctx.lock(); ++ } ++ ++ vk_buffer buf = buf_ctx->dev_buffer; ++ ++ ggml_vk_buffer_write_async(transfer_ctx, buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size); ++} ++ ++static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { ++ VK_LOG_DEBUG("ggml_backend_vk_get_tensor_async(" << size << ")"); ++ ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; ++ GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type"); ++ ++ ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context; ++ ++ vk_context transfer_ctx; ++ ++ if (ctx->transfer_ctx.expired()) { ++ // Initialize new transfer context ++ transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue); ++ ctx->transfer_ctx = transfer_ctx; ++ ggml_vk_ctx_begin(ctx->device, transfer_ctx); ++ } else { ++ transfer_ctx = ctx->transfer_ctx.lock(); ++ } ++ ++ vk_buffer buf = buf_ctx->dev_buffer; ++ ++ ggml_vk_buffer_read_async(transfer_ctx, buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size); ++} ++ ++static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend, const ggml_tensor * src, ggml_tensor * dst) { ++ VK_LOG_DEBUG("ggml_backend_vk_cpy_tensor_async()"); ++ ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; ++ if ((dst->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || dst->buffer->buft == ggml_backend_vk_host_buffer_type()) && ggml_backend_buffer_is_vk(src->buffer)) { ++ ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context; ++ ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; ++ ++ vk_context transfer_ctx; ++ ++ if (ctx->transfer_ctx.expired()) { ++ // Initialize new transfer context ++ transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue); ++ ctx->transfer_ctx = transfer_ctx; ++ ggml_vk_ctx_begin(ctx->device, transfer_ctx); ++ } else { ++ transfer_ctx = ctx->transfer_ctx.lock(); ++ } ++ ++ vk_buffer src_buf = src_buf_ctx->dev_buffer; ++ vk_buffer dst_buf = dst_buf_ctx->dev_buffer; ++ ++ ggml_vk_buffer_copy_async(transfer_ctx, dst_buf, vk_tensor_offset(dst) + dst->view_offs, src_buf, vk_tensor_offset(src) + src->view_offs, ggml_nbytes(src)); ++ return true; ++ } ++ ++ return false; ++} ++ ++static void ggml_backend_vk_synchronize(ggml_backend_t backend) { ++ VK_LOG_DEBUG("ggml_backend_vk_synchronize()"); ++ ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; ++ if(ctx->transfer_ctx.expired()) { ++ return; ++ } ++ ++ vk_context transfer_ctx = ctx->transfer_ctx.lock(); ++ ++ ggml_vk_ctx_end(transfer_ctx); ++ ++ for (auto& cpy : transfer_ctx->in_memcpys) { ++ memcpy(cpy.dst, cpy.src, cpy.n); ++ } ++ ++ ggml_vk_submit(transfer_ctx, ctx->fence); ++ VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_backend_vk_synchronize waitForFences"); ++ ctx->device->device.resetFences({ ctx->fence }); ++ ++ for (auto& cpy : transfer_ctx->out_memcpys) { ++ memcpy(cpy.dst, cpy.src, cpy.n); ++ } ++ ++ ctx->transfer_ctx.reset(); ++} ++ ++static bool ggml_vk_is_empty(ggml_tensor * node) { ++ return ggml_is_empty(node) || node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE; ++} ++ ++static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { ++ VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)"); ++ ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; ++ ++ for (int i = 0; i < cgraph->n_nodes; i++) { ++ ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false); ++ } ++ ggml_vk_preallocate_buffers(ctx); ++ ggml_pipeline_allocate_descriptor_sets(ctx->device); ++ ++ int last_node = cgraph->n_nodes - 1; ++ ++ // If the last op in the cgraph isn't backend GPU, the command buffer doesn't get closed properly ++ while (last_node > 0 && ggml_vk_is_empty(cgraph->nodes[last_node])) { ++ last_node -= 1; ++ } ++ ++ // Reserve tensor context space for all nodes ++ ctx->tensor_ctxs.resize(cgraph->n_nodes); ++ ++ bool first_node_in_batch = true; // true if next node will be first node in a batch ++ int submit_node_idx = 0; // index to first node in a batch ++ ++ // Submit work every nodes_per_submit nodes to overlap CPU cmdbuffer generation with GPU execution. ++ // Start with a smaller count to get work submitted right away, and increase it after each submit. ++ int nodes_per_submit = 20; ++ int submitted_nodes = 0; ++ int submit_count = 0; ++ for (int i = 0; i < cgraph->n_nodes; i++) { ++ if (first_node_in_batch) { ++ submit_node_idx = i; ++ } ++ ++ bool submit = (submitted_nodes >= nodes_per_submit) || (i == last_node); ++ ++ bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, submit); ++ ++ if (enqueued) { ++ ++submitted_nodes; ++ ++#ifndef GGML_VULKAN_CHECK_RESULTS ++ if (first_node_in_batch) { ++ first_node_in_batch = false; ++ } ++#endif ++ } ++ ++ if (submit) { ++ first_node_in_batch = true; ++ submitted_nodes = 0; ++ switch (submit_count) { ++ case 0: ++ nodes_per_submit = 50; ++ break; ++ default: ++ nodes_per_submit = 100; ++ break; ++ } ++ submit_count++; ++ } ++ } ++ ++#ifdef GGML_VULKAN_PERF ++ ctx->device->perf_logger->print_timings(); ++#endif ++ ++ ggml_vk_graph_cleanup(ctx); ++ ++ return GGML_STATUS_SUCCESS; ++ ++ UNUSED(backend); ++} ++ ++// TODO: enable async and synchronize ++static ggml_backend_i ggml_backend_vk_interface = { ++ /* .get_name = */ ggml_backend_vk_name, ++ /* .free = */ ggml_backend_vk_free, ++ /* .set_tensor_async = */ NULL, // ggml_backend_vk_set_tensor_async, ++ /* .get_tensor_async = */ NULL, // ggml_backend_vk_get_tensor_async, ++ /* .cpy_tensor_async = */ NULL, // ggml_backend_vk_cpy_tensor_async, ++ /* .synchronize = */ NULL, // ggml_backend_vk_synchronize, ++ /* .graph_plan_create = */ NULL, ++ /* .graph_plan_free = */ NULL, ++ /* .graph_plan_update = */ NULL, ++ /* .graph_plan_compute = */ NULL, ++ /* .graph_compute = */ ggml_backend_vk_graph_compute, ++ /* .event_record = */ NULL, ++ /* .event_wait = */ NULL, ++}; ++ ++static ggml_guid_t ggml_backend_vk_guid() { ++ static ggml_guid guid = { 0xb8, 0xf7, 0x4f, 0x86, 0x40, 0x3c, 0xe1, 0x02, 0x91, 0xc8, 0xdd, 0xe9, 0x02, 0x3f, 0xc0, 0x2b }; ++ return &guid; ++} ++ ++ggml_backend_t ggml_backend_vk_init(size_t dev_num) { ++ VK_LOG_DEBUG("ggml_backend_vk_init(" << dev_num << ")"); ++ ++ ggml_backend_vk_context * ctx = new ggml_backend_vk_context; ++ ggml_vk_init(ctx, dev_num); ++ ++ ggml_backend_t vk_backend = new ggml_backend { ++ /* .guid = */ ggml_backend_vk_guid(), ++ /* .interface = */ ggml_backend_vk_interface, ++ /* .device = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), dev_num), ++ /* .context = */ ctx, ++ }; ++ ++ return vk_backend; ++} ++ ++bool ggml_backend_is_vk(ggml_backend_t backend) { ++ return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_vk_guid()); ++} ++ ++int ggml_backend_vk_get_device_count() { ++ return ggml_vk_get_device_count(); ++} ++ ++void ggml_backend_vk_get_device_description(int device, char * description, size_t description_size) { ++ GGML_ASSERT(device < (int) vk_instance.device_indices.size()); ++ int dev_idx = vk_instance.device_indices[device]; ++ ggml_vk_get_device_description(dev_idx, description, description_size); ++} ++ ++void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total) { ++ GGML_ASSERT(device < (int) vk_instance.device_indices.size()); ++ ++ vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device]]; ++ ++ vk::PhysicalDeviceMemoryProperties memprops = vkdev.getMemoryProperties(); ++ ++ for (const vk::MemoryHeap& heap : memprops.memoryHeaps) { ++ if (heap.flags & vk::MemoryHeapFlagBits::eDeviceLocal) { ++ *total = heap.size; ++ *free = heap.size; ++ break; ++ } ++ } ++} ++ ++////////////////////////// ++ ++struct ggml_backend_vk_device_context { ++ size_t device; ++ std::string name; ++ std::string description; ++}; ++ ++static const char * ggml_backend_vk_device_get_name(ggml_backend_dev_t dev) { ++ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; ++ return ctx->name.c_str(); ++} ++ ++static const char * ggml_backend_vk_device_get_description(ggml_backend_dev_t dev) { ++ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; ++ return ctx->description.c_str(); ++} ++ ++static void ggml_backend_vk_device_get_memory(ggml_backend_dev_t device, size_t * free, size_t * total) { ++ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)device->context; ++ ggml_backend_vk_get_device_memory(ctx->device, free, total); ++} ++ ++static ggml_backend_buffer_type_t ggml_backend_vk_device_get_buffer_type(ggml_backend_dev_t dev) { ++ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; ++ return ggml_backend_vk_buffer_type(ctx->device); ++} ++ ++static ggml_backend_buffer_type_t ggml_backend_vk_device_get_host_buffer_type(ggml_backend_dev_t dev) { ++ UNUSED(dev); ++ return ggml_backend_vk_host_buffer_type(); ++} ++ ++static enum ggml_backend_dev_type ggml_backend_vk_device_get_type(ggml_backend_dev_t dev) { ++ UNUSED(dev); ++ return GGML_BACKEND_DEVICE_TYPE_GPU; ++} ++ ++static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) { ++ props->name = ggml_backend_vk_device_get_name(dev); ++ props->description = ggml_backend_vk_device_get_description(dev); ++ props->type = ggml_backend_vk_device_get_type(dev); ++ ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total); ++ props->caps = { ++ /* .async = */ false, ++ /* .host_buffer = */ true, ++ /* .buffer_from_host_ptr = */ false, ++ /* .events = */ false, ++ }; ++} ++ ++static ggml_backend_t ggml_backend_vk_device_init(ggml_backend_dev_t dev, const char * params) { ++ UNUSED(params); ++ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; ++ return ggml_backend_vk_init(ctx->device); ++} ++ ++static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { ++ switch (op->op) { ++ case GGML_OP_UNARY: ++ switch (ggml_get_unary_op(op)) { ++ case GGML_UNARY_OP_GELU: ++ case GGML_UNARY_OP_GELU_QUICK: ++ case GGML_UNARY_OP_SILU: ++ case GGML_UNARY_OP_RELU: ++ case GGML_UNARY_OP_TANH: ++ return ggml_is_contiguous(op->src[0]); ++ default: ++ return false; ++ } ++ break; ++ case GGML_OP_MUL_MAT: ++ case GGML_OP_MUL_MAT_ID: ++ { ++ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; ++ const vk_device& device = ggml_vk_get_device(ctx->device); ++ if (op->op == GGML_OP_MUL_MAT_ID && !device->mul_mat_id_s && !device->mul_mat_id_m && !device->mul_mat_id_l) { ++ // If there's not enough shared memory for row_ids and the result tile, fallback to CPU ++ return false; ++ } ++ switch (op->src[0]->type) { ++ case GGML_TYPE_F32: ++ case GGML_TYPE_F16: ++ case GGML_TYPE_Q4_0: ++ case GGML_TYPE_Q4_1: ++ case GGML_TYPE_Q5_0: ++ case GGML_TYPE_Q5_1: ++ case GGML_TYPE_Q8_0: ++ case GGML_TYPE_Q2_K: ++ case GGML_TYPE_Q3_K: ++ case GGML_TYPE_Q4_K: ++ case GGML_TYPE_Q5_K: ++ case GGML_TYPE_Q6_K: ++ case GGML_TYPE_IQ4_NL: ++ break; ++ default: ++ return false; ++ } ++ struct ggml_tensor * a; ++ struct ggml_tensor * b; ++ if (op->op == GGML_OP_MUL_MAT) { ++ a = op->src[0]; ++ b = op->src[1]; ++ } else { ++ a = op->src[2]; ++ b = op->src[1]; ++ } ++ if (a->ne[3] != b->ne[3]) { ++ return false; ++ } ++ if (!(ggml_vk_dim01_contiguous(op->src[0]) || op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) || ++ !(ggml_vk_dim01_contiguous(op->src[1]) || op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16)) { ++ return false; ++ } ++ ++ return true; ++ } break; ++ case GGML_OP_FLASH_ATTN_EXT: ++ { ++ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; ++ if (!ggml_vk_get_device(ctx->device)->coopmat2) { ++ return false; ++ } ++ switch (op->src[0]->ne[0]) { ++ case 64: ++ case 80: ++ case 96: ++ case 112: ++ case 128: ++ case 256: ++ break; ++ default: ++ return false; ++ } ++ if (op->src[0]->type != GGML_TYPE_F32) { ++ return false; ++ } ++ if (op->type != GGML_TYPE_F32) { ++ return false; ++ } ++ if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) { ++ return false; ++ } ++ // It's straightforward to support different K/V dequant, but would ++ // significantly increase the number of pipelines ++ if (op->src[1]->type != op->src[2]->type) { ++ return false; ++ } ++ switch (op->src[1]->type) { ++ case GGML_TYPE_F16: ++ case GGML_TYPE_Q4_0: ++ case GGML_TYPE_Q4_1: ++ case GGML_TYPE_Q5_0: ++ case GGML_TYPE_Q5_1: ++ case GGML_TYPE_Q8_0: ++ // K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently ++ //case GGML_TYPE_Q2_K: ++ //case GGML_TYPE_Q3_K: ++ //case GGML_TYPE_Q4_K: ++ //case GGML_TYPE_Q5_K: ++ //case GGML_TYPE_Q6_K: ++ case GGML_TYPE_IQ4_NL: ++ break; ++ default: ++ return false; ++ } ++ return true; ++ } ++ case GGML_OP_GET_ROWS: ++ { ++ switch (op->src[0]->type) { ++ case GGML_TYPE_F32: ++ case GGML_TYPE_F16: ++ case GGML_TYPE_Q4_0: ++ case GGML_TYPE_Q4_1: ++ case GGML_TYPE_Q5_0: ++ case GGML_TYPE_Q5_1: ++ case GGML_TYPE_Q8_0: ++ case GGML_TYPE_IQ4_NL: ++ return true; ++ default: ++ return false; ++ } ++ } break; ++ case GGML_OP_CONT: ++ case GGML_OP_CPY: ++ case GGML_OP_DUP: ++ { ++ ggml_type src0_type = op->src[0]->type; ++ ggml_type src1_type = op->src[1] != nullptr ? op->src[1]->type : src0_type; ++ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) { ++ return true; ++ } ++ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) { ++ return true; ++ } ++ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { ++ return true; ++ } ++ return false; ++ } break; ++ case GGML_OP_REPEAT: ++ return ggml_type_size(op->type) == sizeof(float) && ggml_type_size(op->src[0]->type) == sizeof(float); ++ case GGML_OP_ROPE: ++ { ++ const int mode = ((const int32_t *) op->op_params)[2]; ++ if (mode & GGML_ROPE_TYPE_MROPE) { ++ return false; ++ } ++ if (mode & GGML_ROPE_TYPE_VISION) { ++ return false; ++ } ++ return ggml_is_contiguous(op->src[0]); ++ } ++ case GGML_OP_NONE: ++ case GGML_OP_RESHAPE: ++ case GGML_OP_VIEW: ++ case GGML_OP_PERMUTE: ++ case GGML_OP_TRANSPOSE: ++ case GGML_OP_NORM: ++ case GGML_OP_GROUP_NORM: ++ case GGML_OP_RMS_NORM: ++ case GGML_OP_ADD: ++ case GGML_OP_ACC: ++ case GGML_OP_MUL: ++ case GGML_OP_DIV: ++ case GGML_OP_CONCAT: ++ case GGML_OP_UPSCALE: ++ case GGML_OP_SCALE: ++ case GGML_OP_SQR: ++ case GGML_OP_SIN: ++ case GGML_OP_COS: ++ case GGML_OP_CLAMP: ++ case GGML_OP_PAD: ++ case GGML_OP_DIAG_MASK_INF: ++ case GGML_OP_SOFT_MAX: ++ case GGML_OP_ARGSORT: ++ case GGML_OP_SUM_ROWS: ++ case GGML_OP_IM2COL: ++ case GGML_OP_TIMESTEP_EMBEDDING: ++ case GGML_OP_POOL_2D: ++ case GGML_OP_RWKV_WKV6: ++ case GGML_OP_LEAKY_RELU: ++ return true; ++ default: ++ return false; ++ } ++ ++ UNUSED(dev); ++} ++ ++static bool ggml_backend_vk_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { ++ if (buft->iface.get_name != ggml_backend_vk_buffer_type_name) { ++ return false; ++ } ++ ++ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; ++ ggml_backend_vk_buffer_type_context * buft_ctx = (ggml_backend_vk_buffer_type_context *)buft->context; ++ ++ return buft_ctx->device->idx == ctx->device; ++} ++ ++static bool ggml_backend_vk_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) { ++ const int min_batch_size = 32; ++ ++ return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) || ++ (op->ne[2] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID); ++ ++ UNUSED(dev); ++} ++ ++static const struct ggml_backend_device_i ggml_backend_vk_device_i = { ++ /* .get_name = */ ggml_backend_vk_device_get_name, ++ /* .get_description = */ ggml_backend_vk_device_get_description, ++ /* .get_memory = */ ggml_backend_vk_device_get_memory, ++ /* .get_type = */ ggml_backend_vk_device_get_type, ++ /* .get_props = */ ggml_backend_vk_device_get_props, ++ /* .init_backend = */ ggml_backend_vk_device_init, ++ /* .get_buffer_type = */ ggml_backend_vk_device_get_buffer_type, ++ /* .get_host_buffer_type = */ ggml_backend_vk_device_get_host_buffer_type, ++ /* .buffer_from_host_ptr = */ NULL, ++ /* .supports_op = */ ggml_backend_vk_device_supports_op, ++ /* .supports_buft = */ ggml_backend_vk_device_supports_buft, ++ /* .offload_op = */ ggml_backend_vk_device_offload_op, ++ /* .event_new = */ NULL, ++ /* .event_free = */ NULL, ++ /* .event_synchronize = */ NULL, ++}; ++ ++static const char * ggml_backend_vk_reg_get_name(ggml_backend_reg_t reg) { ++ UNUSED(reg); ++ return GGML_VK_NAME; ++} ++ ++static size_t ggml_backend_vk_reg_get_device_count(ggml_backend_reg_t reg) { ++ UNUSED(reg); ++ return ggml_backend_vk_get_device_count(); ++} ++ ++static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, size_t device) { ++ static std::vector devices; ++ ++ static bool initialized = false; ++ ++ { ++ static std::mutex mutex; ++ std::lock_guard lock(mutex); ++ if (!initialized) { ++ for (int i = 0; i < ggml_backend_vk_get_device_count(); i++) { ++ ggml_backend_vk_device_context * ctx = new ggml_backend_vk_device_context; ++ char desc[256]; ++ ggml_backend_vk_get_device_description(i, desc, sizeof(desc)); ++ ctx->device = i; ++ ctx->name = GGML_VK_NAME + std::to_string(i); ++ ctx->description = desc; ++ devices.push_back(new ggml_backend_device { ++ /* .iface = */ ggml_backend_vk_device_i, ++ /* .reg = */ reg, ++ /* .context = */ ctx, ++ }); ++ } ++ initialized = true; ++ } ++ } ++ ++ GGML_ASSERT(device < devices.size()); ++ return devices[device]; ++} ++ ++static const struct ggml_backend_reg_i ggml_backend_vk_reg_i = { ++ /* .get_name = */ ggml_backend_vk_reg_get_name, ++ /* .get_device_count = */ ggml_backend_vk_reg_get_device_count, ++ /* .get_device = */ ggml_backend_vk_reg_get_device, ++ /* .get_proc_address = */ NULL, ++}; ++ ++ggml_backend_reg_t ggml_backend_vk_reg() { ++ static ggml_backend_reg reg = { ++ /* .api_version = */ GGML_BACKEND_API_VERSION, ++ /* .iface = */ ggml_backend_vk_reg_i, ++ /* .context = */ nullptr, ++ }; ++ ++ return ® ++} ++ ++// Extension availability ++static bool ggml_vk_instance_validation_ext_available(const std::vector& instance_extensions) { ++#ifdef GGML_VULKAN_VALIDATE ++ bool portability_enumeration_ext = false; ++ // Check for portability enumeration extension for MoltenVK support ++ for (const auto& properties : instance_extensions) { ++ if (strcmp("VK_KHR_portability_enumeration", properties.extensionName) == 0) { ++ return true; ++ } ++ } ++ if (!portability_enumeration_ext) { ++ std::cerr << "ggml_vulkan: WARNING: Instance extension VK_KHR_portability_enumeration not found." << std::endl; ++ } ++#endif ++ return false; ++ ++ UNUSED(instance_extensions); ++} ++static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector& instance_extensions) { ++#ifdef __APPLE__ ++ bool portability_enumeration_ext = false; ++ // Check for portability enumeration extension for MoltenVK support ++ for (const auto& properties : instance_extensions) { ++ if (strcmp("VK_KHR_portability_enumeration", properties.extensionName) == 0) { ++ return true; ++ } ++ } ++ if (!portability_enumeration_ext) { ++ std::cerr << "ggml_vulkan: WARNING: Instance extension VK_KHR_portability_enumeration not found." << std::endl; ++ } ++#endif ++ return false; ++ ++ UNUSED(instance_extensions); ++} ++ ++static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props) { ++ switch (props.vendorID) { ++ case VK_VENDOR_ID_INTEL: ++ // Intel drivers don't support coopmat properly yet ++ return false; ++ case VK_VENDOR_ID_AMD: ++ if (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource) { ++ // Workaround for AMD proprietary driver reporting support on all GPUs ++ const std::string name = props.deviceName; ++ return name.rfind("AMD Radeon RX 7", 0) == 0 || name.rfind("AMD Radeon(TM) RX 7", 0) == 0 || // RDNA 3 consumer GPUs ++ name.rfind("AMD Radeon PRO W7", 0) == 0 || name.rfind("AMD Radeon(TM) PRO W7", 0) == 0 || // RDNA 3 workstation GPUs ++ name.rfind("AMD Radeon 7", 0) == 0 || name.rfind("AMD Radeon(TM) 7", 0) == 0; // RDNA 3 APUs ++ } ++ return true; ++ default: ++ return true; ++ } ++} ++ ++// checks ++ ++#ifdef GGML_VULKAN_CHECK_RESULTS ++static void ggml_vk_print_graph_origin(const ggml_tensor * tensor, std::vector& done, int level = 0) { ++ if (std::find(done.begin(), done.end(), tensor) != done.end() || level > 10) { ++ return; ++ } ++ for (int j = 0; j < level; j++) { ++ std::cerr << " "; ++ } ++ std::cerr << ggml_op_name(tensor->op) << " gpu=" << (tensor->extra != nullptr) << std::endl; ++ ++ done.push_back(tensor); ++ ++ for (int i = 0; i < GGML_MAX_SRC; i++) { ++ if (tensor->src[i] != nullptr) { ++ ggml_vk_print_graph_origin(tensor->src[i], done, level + 1); ++ } ++ } ++} ++ ++static void ggml_vk_print_tensor_area(const ggml_tensor * tensor, const void * data, int i0, int i1, int i2, int i3) { ++ if (tensor->type != GGML_TYPE_F32 && tensor->type != GGML_TYPE_F16 && tensor->type != GGML_TYPE_I32) { ++ return; ++ } ++ i0 = std::max(i0, 5); ++ i1 = std::max(i1, 5); ++ i2 = std::max(i2, 0); ++ i3 = std::max(i3, 0); ++ fprintf(stderr, " "); ++ for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) { ++ fprintf(stderr, "%7d ", idx1); ++ } ++ fprintf(stderr, "\n"); ++ for (int idx0 = i0 - 5; idx0 < i0 + 5; idx0++) { ++ fprintf(stderr, "%7d: ", idx0); ++ for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) { ++ if (idx0 >= 0 && idx0 < tensor->ne[0] && idx1 >= 0 && idx1 < tensor->ne[1] && i2 >= 0 && i2 < tensor->ne[2] && i3 >= 0 && i3 < tensor->ne[3]) { ++ float val; ++ if (tensor->type == GGML_TYPE_F32) { ++ val = *(const float *) ((const char *) data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]); ++ } else if (tensor->type == GGML_TYPE_F16) { ++ val = ggml_fp16_to_fp32(*(const ggml_fp16_t *) ((const char *) data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0])); ++ } else if (tensor->type == GGML_TYPE_I32) { ++ val = *(const int32_t *) ((const char *) data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]); ++ } else { ++ GGML_ABORT("fatal error"); ++ } ++ fprintf(stderr, "% 7.2f ", val); ++ } else { ++ fprintf(stderr, " "); ++ } ++ } ++ fprintf(stderr, "\n"); ++ } ++} ++ ++static void ggml_vk_print_tensor(const ggml_tensor * tensor, const char * name) { ++ void * tensor_data = tensor->data; ++ ++ const bool is_gpu = tensor->buffer != nullptr && ggml_backend_buffer_is_vk(tensor->buffer); ++ ++ if (is_gpu) { ++ const size_t tensor_size = ggml_nbytes(tensor); ++ tensor_data = malloc(tensor_size); ++ ++ ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context; ++ ++ vk_buffer buffer_gpu = buf_ctx->dev_buffer; ++ ggml_vk_buffer_read(buffer_gpu, vk_tensor_offset(tensor) + tensor->view_offs, tensor_data, tensor_size); ++ } ++ ++ std::cerr << "TENSOR CHECK " << name << " (" << tensor->name << "): " << ggml_op_name(tensor->op) << std::endl; ++ std::cerr << "tensor=" << tensor << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << std::endl; ++ if (tensor->src[0] != nullptr) { ++ std::cerr << "tensor->src[0]=" << tensor->src[0] << " name=" << tensor->src[0]->name << " op=" << ggml_op_name(tensor->src[0]->op) << " type=" << ggml_type_name(tensor->src[0]->type) << " ne0=" << tensor->src[0]->ne[0] << " nb0=" << tensor->src[0]->nb[0] << " ne1=" << tensor->src[0]->ne[1] << " nb1=" << tensor->src[0]->nb[1] << " ne2=" << tensor->src[0]->ne[2] << " nb2=" << tensor->src[0]->nb[2] << " ne3=" << tensor->src[0]->ne[3] << " nb3=" << tensor->src[0]->nb[3] << std::endl; ++ } ++ if (tensor->src[1] != nullptr) { ++ std::cerr << "tensor->src[1]=" << tensor->src[1] << " name=" << tensor->src[1]->name << " op=" << ggml_op_name(tensor->src[1]->op) << " type=" << ggml_type_name(tensor->src[1]->type) << " ne0=" << tensor->src[1]->ne[0] << " nb0=" << tensor->src[1]->nb[0] << " ne1=" << tensor->src[1]->ne[1] << " nb1=" << tensor->src[1]->nb[1] << " ne2=" << tensor->src[1]->ne[2] << " nb2=" << tensor->src[1]->nb[2] << " ne3=" << tensor->src[1]->ne[3] << " nb3=" << tensor->src[1]->nb[3] << std::endl; ++ } ++ std::cerr << std::endl << "Result:" << std::endl; ++ ggml_vk_print_tensor_area(tensor, tensor_data, 5, 5, 0, 0); ++ std::cerr << std::endl; ++ std::vector done; ++ ggml_vk_print_graph_origin(tensor, done); ++ ++ if (is_gpu) { ++ free(tensor_data); ++ } ++} ++ ++void * comp_result; ++size_t comp_size; ++size_t comp_nb[GGML_MAX_DIMS]; ++size_t check_counter = 0; ++static void ggml_vk_check_results_0(ggml_tensor * tensor) { ++ if (tensor->op == GGML_OP_TRANSPOSE) { ++ return; ++ } ++ ++ check_counter++; ++ if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) { ++ return; ++ } ++ ++ VK_LOG_DEBUG("ggml_vk_check_results_0(" << tensor->name << ")"); ++ ++ ggml_tensor * src0 = tensor->src[0]; ++ ggml_tensor * src1 = tensor->src[1]; ++ ggml_tensor * src2 = tensor->src[2]; ++ ggml_tensor * src3 = tensor->src[3]; ++ ++ struct ggml_init_params iparams = { ++ /*.mem_size =*/ 2ul*1024ul*1024ul*1024ul, ++ /*.mem_buffer =*/ NULL, ++ /*.no_alloc =*/ false, ++ }; ++ ++ struct ggml_context * ggml_ctx = ggml_init(iparams); ++ ++ struct ggml_tensor * src0_clone = nullptr; ++ struct ggml_tensor * src1_clone = nullptr; ++ struct ggml_tensor * src2_clone = nullptr; ++ struct ggml_tensor * src3_clone = nullptr; ++ struct ggml_tensor * tensor_clone = nullptr; ++ ++ size_t src0_size; ++ size_t src1_size; ++ size_t src2_size; ++ size_t src3_size; ++ ++ void * src0_buffer = nullptr; ++ void * src1_buffer = nullptr; ++ void * src2_buffer = nullptr; ++ void * src3_buffer = nullptr; ++ ++ if (src0 != nullptr) { ++ src0_clone = ggml_dup_tensor(ggml_ctx, src0); ++ ++ src0_size = ggml_nbytes(src0); ++ ++ src0_buffer = malloc(src0_size); ++ src0_clone->data = src0_buffer; ++ if (ggml_backend_buffer_is_host(src0->buffer)) { ++ memcpy(src0_clone->data, src0->data, src0_size); ++ memcpy(src0_clone->nb, src0->nb, sizeof(size_t) * GGML_MAX_DIMS); ++ } else if (ggml_backend_buffer_is_vk(src0->buffer)) { ++ ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; ++ vk_buffer& buffer_gpu = buf_ctx->dev_buffer; ++ uint64_t offset = vk_tensor_offset(src0) + src0->view_offs; ++ if (!ggml_is_contiguous(src0) && ggml_vk_dim01_contiguous(src0)) { ++ for (int i3 = 0; i3 < src0->ne[3]; i3++) { ++ for (int i2 = 0; i2 < src0->ne[2]; i2++) { ++ const int idx = i3*src0->ne[2] + i2; ++ ggml_vk_buffer_read(buffer_gpu, offset + idx * src0->nb[2], ((char *)src0_clone->data + idx * src0_clone->nb[2]), src0->ne[1] * src0->nb[1]); ++ } ++ } ++ ++ src0_clone->nb[0] = src0->nb[0]; ++ src0_clone->nb[1] = src0->nb[1]; ++ for (int i = 2; i < GGML_MAX_DIMS; i++) { ++ src0_clone->nb[i] = src0_clone->nb[i - 1]*src0_clone->ne[i - 1]; ++ } ++ } else { ++ if (offset + src0_size >= buffer_gpu->size) { ++ src0_size = buffer_gpu->size - offset; ++ } ++ ggml_vk_buffer_read(buffer_gpu, offset, src0_clone->data, src0_size); ++ memcpy(src0_clone->nb, src0->nb, sizeof(size_t) * GGML_MAX_DIMS); ++ } ++ } else { ++ GGML_ABORT("fatal error"); ++ } ++ ++ if (vk_output_tensor > 0 && vk_output_tensor == check_counter) { ++ ggml_vk_print_tensor(src0, "src0"); ++ } ++ } ++ if (src1 != nullptr) { ++ src1_clone = ggml_dup_tensor(ggml_ctx, src1); ++ ++ src1_size = ggml_nbytes(src1); ++ ++ src1_buffer = malloc(src1_size); ++ src1_clone->data = src1_buffer; ++ if (ggml_backend_buffer_is_host(src1->buffer)) { ++ memcpy(src1_clone->data, src1->data, src1_size); ++ memcpy(src1_clone->nb, src1->nb, sizeof(size_t) * GGML_MAX_DIMS); ++ } else if (ggml_backend_buffer_is_vk(src1->buffer)) { ++ ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; ++ vk_buffer& buffer_gpu = buf_ctx->dev_buffer; ++ uint64_t offset = vk_tensor_offset(src1) + src1->view_offs; ++ if (!ggml_is_contiguous(src1) && ggml_vk_dim01_contiguous(src1)) { ++ for (int i3 = 0; i3 < src1->ne[3]; i3++) { ++ for (int i2 = 0; i2 < src1->ne[2]; i2++) { ++ const int idx = i3*src1->ne[2] + i2; ++ ggml_vk_buffer_read(buffer_gpu, offset + idx * src1->nb[2], ((char *)src1_clone->data + idx * src1_clone->nb[2]), src1->ne[1] * src1->nb[1]); ++ } ++ } ++ ++ src1_clone->nb[0] = src1->nb[0]; ++ src1_clone->nb[1] = src1->nb[1]; ++ for (int i = 2; i < GGML_MAX_DIMS; i++) { ++ src1_clone->nb[i] = src1_clone->nb[i - 1]*src1_clone->ne[i - 1]; ++ } ++ } else { ++ if (offset + src1_size >= buffer_gpu->size) { ++ src1_size = buffer_gpu->size - offset; ++ } ++ ggml_vk_buffer_read(buffer_gpu, offset, src1_clone->data, src1_size); ++ memcpy(src1_clone->nb, src1->nb, sizeof(size_t) * GGML_MAX_DIMS); ++ } ++ } else { ++ GGML_ABORT("fatal error"); ++ } ++ ++ if (vk_output_tensor > 0 && vk_output_tensor == check_counter) { ++ ggml_vk_print_tensor(src1, "src1"); ++ } ++ } ++ if (src2 != nullptr) { ++ src2_clone = ggml_dup_tensor(ggml_ctx, src2); ++ ++ src2_size = ggml_nbytes(src2); ++ ++ src2_buffer = malloc(src2_size); ++ src2_clone->data = src2_buffer; ++ if (ggml_backend_buffer_is_host(src2->buffer)) { ++ memcpy(src2_clone->data, src2->data, src2_size); ++ memcpy(src2_clone->nb, src2->nb, sizeof(size_t) * GGML_MAX_DIMS); ++ } else if (ggml_backend_buffer_is_vk(src2->buffer)) { ++ ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src2->buffer->context; ++ vk_buffer& buffer_gpu = buf_ctx->dev_buffer; ++ uint64_t offset = vk_tensor_offset(src2) + src2->view_offs; ++ if (!ggml_is_contiguous(src2) && ggml_vk_dim01_contiguous(src2)) { ++ for (int i3 = 0; i3 < src2->ne[3]; i3++) { ++ for (int i2 = 0; i2 < src2->ne[2]; i2++) { ++ const int idx = i3*src2->ne[2] + i2; ++ ggml_vk_buffer_read(buffer_gpu, offset + idx * src2->nb[2], ((char *)src2_clone->data + idx * src2_clone->nb[2]), src2->ne[1] * src2->nb[1]); ++ } ++ } ++ ++ src2_clone->nb[0] = src2->nb[0]; ++ src2_clone->nb[1] = src2->nb[1]; ++ for (int i = 2; i < GGML_MAX_DIMS; i++) { ++ src2_clone->nb[i] = src2_clone->nb[i - 1]*src2_clone->ne[i - 1]; ++ } ++ } else { ++ if (offset + src2_size >= buffer_gpu->size) { ++ src2_size = buffer_gpu->size - offset; ++ } ++ ggml_vk_buffer_read(buffer_gpu, offset, src2_clone->data, src2_size); ++ memcpy(src2_clone->nb, src2->nb, sizeof(size_t) * GGML_MAX_DIMS); ++ } ++ } else { ++ GGML_ABORT("fatal error"); ++ } ++ ++ if (vk_output_tensor > 0 && vk_output_tensor == check_counter) { ++ ggml_vk_print_tensor(src2, "src2"); ++ } ++ } ++ if (src3 != nullptr) { ++ src3_clone = ggml_dup_tensor(ggml_ctx, src3); ++ ++ src3_size = ggml_nbytes(src3); ++ ++ src3_buffer = malloc(src3_size); ++ src3_clone->data = src3_buffer; ++ if (ggml_backend_buffer_is_host(src3->buffer)) { ++ memcpy(src3_clone->data, src3->data, src3_size); ++ memcpy(src3_clone->nb, src3->nb, sizeof(size_t) * GGML_MAX_DIMS); ++ } else if (ggml_backend_buffer_is_vk(src3->buffer)) { ++ ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src3->buffer->context; ++ vk_buffer& buffer_gpu = buf_ctx->dev_buffer; ++ uint64_t offset = vk_tensor_offset(src3) + src3->view_offs; ++ if (!ggml_is_contiguous(src3) && ggml_vk_dim01_contiguous(src3)) { ++ for (int i3 = 0; i3 < src3->ne[3]; i3++) { ++ for (int i2 = 0; i2 < src3->ne[2]; i2++) { ++ const int idx = i3*src3->ne[2] + i2; ++ ggml_vk_buffer_read(buffer_gpu, offset + idx * src3->nb[2], ((char *)src3_clone->data + idx * src3_clone->nb[2]), src3->ne[1] * src3->nb[1]); ++ } ++ } ++ ++ src3_clone->nb[0] = src3->nb[0]; ++ src3_clone->nb[1] = src3->nb[1]; ++ for (int i = 2; i < GGML_MAX_DIMS; i++) { ++ src3_clone->nb[i] = src3_clone->nb[i - 1]*src3_clone->ne[i - 1]; ++ } ++ } else { ++ if (offset + src3_size >= buffer_gpu->size) { ++ src3_size = buffer_gpu->size - offset; ++ } ++ ggml_vk_buffer_read(buffer_gpu, offset, src3_clone->data, src3_size); ++ memcpy(src3_clone->nb, src3->nb, sizeof(size_t) * GGML_MAX_DIMS); ++ } ++ } else { ++ GGML_ABORT("fatal error"); ++ } ++ ++ if (vk_output_tensor > 0 && vk_output_tensor == check_counter) { ++ ggml_vk_print_tensor(src3, "src3"); ++ } ++ } ++ ++ if (tensor->op == GGML_OP_FLASH_ATTN_EXT) { ++ const float *params = (const float *)tensor->op_params; ++ tensor_clone = ggml_flash_attn_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, src3_clone, params[0], params[1], params[2]); ++ } else if (tensor->op == GGML_OP_MUL_MAT) { ++ tensor_clone = ggml_mul_mat(ggml_ctx, src0_clone, src1_clone); ++ } else if (tensor->op == GGML_OP_MUL_MAT_ID) { ++ tensor_clone = ggml_mul_mat_id(ggml_ctx, src0_clone, src1_clone, src2_clone); ++ } else if (tensor->op == GGML_OP_MUL) { ++ tensor_clone = ggml_mul(ggml_ctx, src0_clone, src1_clone); ++ } else if (tensor->op == GGML_OP_DIV) { ++ tensor_clone = ggml_div(ggml_ctx, src0_clone, src1_clone); ++ } else if (tensor->op == GGML_OP_CONCAT) { ++ tensor_clone = ggml_concat(ggml_ctx, src0_clone, src1_clone, *(int *)tensor->op_params); ++ } else if (tensor->op == GGML_OP_UPSCALE) { ++ tensor_clone = ggml_upscale_ext(ggml_ctx, src0_clone, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); ++ } else if (tensor->op == GGML_OP_SCALE) { ++ tensor_clone = ggml_scale(ggml_ctx, src0_clone, ((float *)tensor->op_params)[0]); ++ } else if (tensor->op == GGML_OP_SQR) { ++ tensor_clone = ggml_sqr(ggml_ctx, src0_clone); ++ } else if (tensor->op == GGML_OP_SIN) { ++ tensor_clone = ggml_sin(ggml_ctx, src0_clone); ++ } else if (tensor->op == GGML_OP_COS) { ++ tensor_clone = ggml_cos(ggml_ctx, src0_clone); ++ } else if (tensor->op == GGML_OP_CLAMP) { ++ tensor_clone = ggml_clamp(ggml_ctx, src0_clone, ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]); ++ } else if (tensor->op == GGML_OP_PAD) { ++ tensor_clone = ggml_pad(ggml_ctx, src0_clone, tensor->ne[0] - src0_clone->ne[0], tensor->ne[1] - src0_clone->ne[1], tensor->ne[2] - src0_clone->ne[2], tensor->ne[3] - src0_clone->ne[3]); ++ } else if (tensor->op == GGML_OP_REPEAT) { ++ tensor_clone = ggml_repeat(ggml_ctx, src0_clone, tensor); ++ } else if (tensor->op == GGML_OP_ADD) { ++ tensor_clone = ggml_add(ggml_ctx, src0_clone, src1_clone); ++ } else if (tensor->op == GGML_OP_ACC) { ++ tensor_clone = ggml_acc(ggml_ctx, src0_clone, src1_clone, tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]); ++ } else if (tensor->op == GGML_OP_NORM) { ++ tensor_clone = ggml_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params); ++ } else if (tensor->op == GGML_OP_GROUP_NORM) { ++ tensor_clone = ggml_group_norm(ggml_ctx, src0_clone, *(int *)tensor->op_params, ((float *)tensor->op_params)[1]); ++ } else if (tensor->op == GGML_OP_RMS_NORM) { ++ tensor_clone = ggml_rms_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params); ++ } else if (tensor->op == GGML_OP_SOFT_MAX) { ++ if (src1 != nullptr) { ++ tensor_clone = ggml_soft_max_ext(ggml_ctx, src0_clone, src1_clone, ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]); ++ } else { ++ tensor_clone = ggml_soft_max(ggml_ctx, src0_clone); ++ } ++ } else if (tensor->op == GGML_OP_DIAG_MASK_INF) { ++ tensor_clone = ggml_diag_mask_inf(ggml_ctx, src0_clone, *(int *)tensor->op_params); ++ } else if (tensor->op == GGML_OP_ROPE) { ++ const int n_dims = ((int32_t *) tensor->op_params)[1]; ++ const int mode = ((int32_t *) tensor->op_params)[2]; ++ //const int n_ctx_ggml = ((int32_t *) tensor->op_params)[3]; ++ const int n_ctx_orig_ggml = ((int32_t *) tensor->op_params)[4]; ++ const float freq_base = ((float *) tensor->op_params)[5]; ++ const float freq_scale = ((float *) tensor->op_params)[6]; ++ const float ext_factor = ((float *) tensor->op_params)[7]; ++ const float attn_factor = ((float *) tensor->op_params)[8]; ++ const float beta_fast = ((float *) tensor->op_params)[9]; ++ const float beta_slow = ((float *) tensor->op_params)[10]; ++ tensor_clone = ggml_rope_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); ++ } else if (tensor->op == GGML_OP_UNARY) { ++ switch (ggml_get_unary_op(tensor)) { ++ case GGML_UNARY_OP_SILU: ++ tensor_clone = ggml_silu(ggml_ctx, src0_clone); ++ break; ++ case GGML_UNARY_OP_GELU: ++ tensor_clone = ggml_gelu(ggml_ctx, src0_clone); ++ break; ++ case GGML_UNARY_OP_GELU_QUICK: ++ tensor_clone = ggml_gelu_quick(ggml_ctx, src0_clone); ++ break; ++ case GGML_UNARY_OP_RELU: ++ tensor_clone = ggml_relu(ggml_ctx, src0_clone); ++ break; ++ case GGML_UNARY_OP_TANH: ++ tensor_clone = ggml_tanh(ggml_ctx, src0_clone); ++ break; ++ default: ++ std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl; ++ GGML_ABORT("fatal error"); ++ } ++ } else if (tensor->op == GGML_OP_CPY || tensor->op == GGML_OP_DUP) { ++ if (src1 == nullptr) { ++ tensor_clone = ggml_dup(ggml_ctx, src0_clone); ++ tensor_clone->type = tensor->type; ++ } else { ++ tensor_clone = ggml_cpy(ggml_ctx, src0_clone, src1_clone); ++ } ++ } else if (tensor->op == GGML_OP_CONT) { ++ tensor_clone = ggml_cont_4d(ggml_ctx, src0_clone, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); ++ } else if (tensor->op == GGML_OP_RESHAPE) { ++ tensor_clone = ggml_reshape_4d(ggml_ctx, src0_clone, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); ++ } else if (tensor->op == GGML_OP_VIEW) { ++ tensor_clone = ggml_view_4d(ggml_ctx, src0_clone, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], tensor->nb[1], tensor->nb[2], tensor->nb[3], ((int32_t *) tensor->op_params)[0]); ++ } else if (tensor->op == GGML_OP_PERMUTE) { ++ int32_t * params = (int32_t *)tensor->op_params; ++ tensor_clone = ggml_permute(ggml_ctx, src0_clone, params[0], params[1], params[2], params[3]); ++ } else if (tensor->op == GGML_OP_TRANSPOSE) { ++ tensor_clone = ggml_transpose(ggml_ctx, src0_clone); ++ } else if (tensor->op == GGML_OP_GET_ROWS) { ++ tensor_clone = ggml_get_rows(ggml_ctx, src0_clone, src1_clone); ++ } else if (tensor->op == GGML_OP_ARGSORT) { ++ tensor_clone = ggml_argsort(ggml_ctx, src0_clone, (ggml_sort_order) *(int *)tensor->op_params); ++ } else if (tensor->op == GGML_OP_SUM_ROWS) { ++ tensor_clone = ggml_sum_rows(ggml_ctx, src0_clone); ++ } else if (tensor->op == GGML_OP_IM2COL) { ++ const int32_t s0 = tensor->op_params[0]; ++ const int32_t s1 = tensor->op_params[1]; ++ const int32_t p0 = tensor->op_params[2]; ++ const int32_t p1 = tensor->op_params[3]; ++ const int32_t d0 = tensor->op_params[4]; ++ const int32_t d1 = tensor->op_params[5]; ++ ++ const bool is_2D = tensor->op_params[6] == 1; ++ tensor_clone = ggml_im2col(ggml_ctx, src0_clone, src1_clone, s0, s1, p0, p1, d0, d1, is_2D, tensor->type); ++ } else if (tensor->op == GGML_OP_TIMESTEP_EMBEDDING) { ++ const int32_t dim = tensor->op_params[0]; ++ const int32_t max_period = tensor->op_params[1]; ++ tensor_clone = ggml_timestep_embedding(ggml_ctx, src0_clone, dim, max_period); ++ } else if (tensor->op == GGML_OP_POOL_2D) { ++ enum ggml_op_pool op = static_cast(tensor->op_params[0]); ++ const int32_t k0 = tensor->op_params[1]; ++ const int32_t k1 = tensor->op_params[2]; ++ const int32_t s0 = tensor->op_params[3]; ++ const int32_t s1 = tensor->op_params[4]; ++ const int32_t p0 = tensor->op_params[5]; ++ const int32_t p1 = tensor->op_params[6]; ++ ++ tensor_clone = ggml_pool_2d(ggml_ctx, src0_clone, op, k0, k1, s0, s1, p0, p1); ++ } else if (tensor->op == GGML_OP_LEAKY_RELU) { ++ const float * op_params = (const float *)tensor->op_params; ++ tensor_clone = ggml_leaky_relu(ggml_ctx, src0_clone, op_params[0], false); ++ } else if (tensor->op == GGML_OP_RWKV_WKV6) { ++ tensor_clone = ggml_rwkv_wkv6(ggml_ctx, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], ++ tensor->src[4], tensor->src[5]); ++ } ++ else { ++ std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl; ++ GGML_ABORT("fatal error"); ++ } ++ ++ ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx); ++ ggml_build_forward_expand(cgraph, tensor_clone); ++ ++ ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 8); ++ ++ if (vk_output_tensor > 0 && vk_output_tensor == check_counter) { ++ ggml_vk_print_tensor(tensor_clone, "tensor_clone"); ++ } ++ ++ comp_size = ggml_nbytes(tensor_clone); ++ ++ comp_result = malloc(comp_size); ++ memcpy(comp_result, tensor_clone->data, comp_size); ++ memcpy(comp_nb, tensor_clone->nb, sizeof(size_t) * GGML_MAX_DIMS); ++ ++ if (src0 != nullptr) { ++ free(src0_buffer); ++ } ++ if (src1 != nullptr) { ++ free(src1_buffer); ++ } ++ ++ ggml_free(ggml_ctx); ++ ++ VK_LOG_DEBUG("END ggml_vk_check_results_0(" << tensor->name << ")"); ++} ++ ++static void ggml_vk_check_results_1(ggml_tensor * tensor) { ++ if (tensor->op == GGML_OP_TRANSPOSE) { ++ return; ++ } ++ if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) { ++ return; ++ } ++ ++ VK_LOG_DEBUG("ggml_vk_check_results_1(" << tensor->name << ")"); ++ ++ ggml_tensor * src0 = tensor->src[0]; ++ ggml_tensor * src1 = tensor->src[1]; ++ ggml_tensor * src2 = tensor->src[2]; ++ ++ void * tensor_data = tensor->data; ++ ++ if (ggml_backend_buffer_is_vk(tensor->buffer)) { ++ size_t tensor_size = ggml_nbytes(tensor); ++ tensor_data = malloc(tensor_size); ++ ++ ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context; ++ ++ vk_buffer& buffer_gpu = buf_ctx->dev_buffer; ++ uint64_t offset = vk_tensor_offset(tensor) + tensor->view_offs; ++ if (offset + tensor_size >= buffer_gpu->size) { ++ tensor_size = buffer_gpu->size - offset; ++ } ++ ++ ggml_vk_buffer_read(buffer_gpu, offset, tensor_data, tensor_size); ++ } ++ ++ float first_error_result = -1.0f; ++ float first_error_correct = -1.0f; ++ std::array first_error = { -1, -1, -1, -1 }; ++ double avg_err = 0.0; ++ size_t counter = 0; ++ ++ for (int i3 = 0; i3 < tensor->ne[3]; i3++) { ++ for (int i2 = 0; i2 < tensor->ne[2]; i2++) { ++ for (int i1 = 0; i1 < tensor->ne[1]; i1++) { ++ for (int i0 = 0; i0 < tensor->ne[0]; i0++) { ++ const bool buffer_size_fit = i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0] < comp_size; ++ float correct = 0.0f; ++ float result = 0.0f; ++ ++ if (buffer_size_fit) { ++ if (tensor->type == GGML_TYPE_F32) { ++ correct = *(float *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]); ++ result = *(float *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]); ++ } else if (tensor->type == GGML_TYPE_F16) { ++ correct = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0])); ++ result = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0])); ++ } else if (tensor->type == GGML_TYPE_I32) { ++ correct = *(int32_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]); ++ result = *(int32_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]); ++ } else { ++ std::cerr << "Results check not implemented for type " << ggml_type_name(tensor->type) << std::endl; ++ } ++ } else { ++ std::cerr << "Missing debug code for type " << ggml_type_name(tensor->type) << std::endl; ++ GGML_ABORT("fatal error"); ++ } ++ ++ if ((std::isnan(correct) != std::isnan(result)) || (std::isinf(correct) != std::isinf(result)) || !buffer_size_fit) { ++ std::cerr << "ERROR: Invalid value in " << ggml_op_name(tensor->op) << " i3=" << i3 << " i2=" << i2 << " i1=" << i1 << " i0=" << i0 << " result=" << result << " correct=" << correct << " avg_err=" << (avg_err / counter) << std::endl; ++ std::cerr << "tensor=" << tensor << " tensor->name=" << tensor->name << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << " offset=" << tensor->view_offs << std::endl; ++ if (src0 != nullptr) { ++ std::cerr << "src0=" << src0 << " src0->name=" << src0->name << " op=" << ggml_op_name(src0->op) << " type=" << ggml_type_name(src0->type) << " ne0=" << src0->ne[0] << " nb0=" << src0->nb[0] << " ne1=" << src0->ne[1] << " nb1=" << src0->nb[1] << " ne2=" << src0->ne[2] << " nb2=" << src0->nb[2] << " ne3=" << src0->ne[3] << " nb3=" << src0->nb[3] << " offset=" << src0->view_offs << std::endl; ++ } ++ if (src1 != nullptr) { ++ std::cerr << "src1=" << src1 << " src1->name=" << src1->name << " op=" << ggml_op_name(src1->op) << " type=" << ggml_type_name(src1->type) << " ne0=" << src1->ne[0] << " nb0=" << src1->nb[0] << " ne1=" << src1->ne[1] << " nb1=" << src1->nb[1] << " ne2=" << src1->ne[2] << " nb2=" << src1->nb[2] << " ne3=" << src1->ne[3] << " nb3=" << src1->nb[3] << " offset=" << src1->view_offs << std::endl; ++ } ++ if (src2 != nullptr) { ++ std::cerr << "src2=" << src2 << " src2->name=" << src2->name << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl; ++ } ++ std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl; ++ std::cerr << std::endl << "Result:" << std::endl; ++ ggml_vk_print_tensor_area(tensor, tensor_data, i0, i1, i2, i3); ++ std::cerr << std::endl << "Correct:" << std::endl; ++ ggml_vk_print_tensor_area(tensor, comp_result, i0, i1, i2, i3); ++ std::cerr << std::endl; ++ std::vector done; ++ ggml_vk_print_graph_origin(tensor, done); ++ GGML_ABORT("fatal error"); ++ } ++ if (first_error[0] == -1 && std::fabs(correct - result) > 0.1f) { ++ first_error[0] = i0; ++ first_error[1] = i1; ++ first_error[2] = i2; ++ first_error[3] = i3; ++ first_error_result = result; ++ first_error_correct = correct; ++ } ++ ++ // Special case, value is infinite, avoid NaN result in avg_err ++ // NaN also appears in results, if both are nan error is 0 ++ if (!std::isinf(correct) && !std::isinf(result) && !std::isnan(correct) && !std::isnan(result)) { ++ avg_err += std::fabs(correct - result); ++ } ++ counter++; ++ } ++ } ++ } ++ } ++ ++ avg_err /= counter; ++ ++ if (vk_output_tensor > 0 && vk_output_tensor == check_counter) { ++ std::cerr << "TENSOR CHECK: avg_err=" << avg_err << " in " << ggml_op_name(tensor->op) << " (check " << check_counter << ")" << std::endl; ++ std::cerr << "tensor=" << tensor << " tensor->name=" << tensor->name << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << " offset=" << tensor->view_offs << std::endl; ++ if (src0 != nullptr) { ++ std::cerr << "src0=" << src0 << " op=" << ggml_op_name(src0->op) << " type=" << ggml_type_name(src0->type) << " ne0=" << src0->ne[0] << " nb0=" << src0->nb[0] << " ne1=" << src0->ne[1] << " nb1=" << src0->nb[1] << " ne2=" << src0->ne[2] << " nb2=" << src0->nb[2] << " ne3=" << src0->ne[3] << " nb3=" << src0->nb[3] << " offset=" << src0->view_offs << std::endl; ++ } ++ if (src1 != nullptr) { ++ std::cerr << "src1=" << src1 << " op=" << ggml_op_name(src1->op) << " type=" << ggml_type_name(src1->type) << " ne0=" << src1->ne[0] << " nb0=" << src1->nb[0] << " ne1=" << src1->ne[1] << " nb1=" << src1->nb[1] << " ne2=" << src1->ne[2] << " nb2=" << src1->nb[2] << " ne3=" << src1->ne[3] << " nb3=" << src1->nb[3] << " offset=" << src1->view_offs << std::endl; ++ } ++ if (src2 != nullptr) { ++ std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl; ++ } ++ std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl; ++ std::cerr << std::endl << "Result:" << std::endl; ++ ggml_vk_print_tensor_area(tensor, tensor_data, 5, 5, 0, 0); ++ std::cerr << std::endl << "Correct:" << std::endl; ++ ggml_vk_print_tensor_area(tensor, comp_result, 5, 5, 0, 0); ++ std::cerr << std::endl; ++ std::vector done; ++ ggml_vk_print_graph_origin(tensor, done); ++ } ++ ++ if (avg_err > 0.05 || std::isnan(avg_err)) { ++ std::cerr << "ERROR: avg_err=" << avg_err << " in " << ggml_op_name(tensor->op) << " (check " << check_counter << ")" << std::endl; ++ std::cerr << "tensor=" << tensor << " tensor->name=" << tensor->name << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << " offset=" << tensor->view_offs << std::endl; ++ if (src0 != nullptr) { ++ std::cerr << "src0=" << src0 << " op=" << ggml_op_name(src0->op) << " type=" << ggml_type_name(src0->type) << " ne0=" << src0->ne[0] << " nb0=" << src0->nb[0] << " ne1=" << src0->ne[1] << " nb1=" << src0->nb[1] << " ne2=" << src0->ne[2] << " nb2=" << src0->nb[2] << " ne3=" << src0->ne[3] << " nb3=" << src0->nb[3] << " offset=" << src0->view_offs << std::endl; ++ } ++ if (src1 != nullptr) { ++ std::cerr << "src1=" << src1 << " op=" << ggml_op_name(src1->op) << " type=" << ggml_type_name(src1->type) << " ne0=" << src1->ne[0] << " nb0=" << src1->nb[0] << " ne1=" << src1->ne[1] << " nb1=" << src1->nb[1] << " ne2=" << src1->ne[2] << " nb2=" << src1->nb[2] << " ne3=" << src1->ne[3] << " nb3=" << src1->nb[3] << " offset=" << src1->view_offs << std::endl; ++ } ++ if (src2 != nullptr) { ++ std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl; ++ } ++ std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl; ++ std::cerr << std::endl << "Result:" << std::endl; ++ ggml_vk_print_tensor_area(tensor, tensor_data, first_error[0], first_error[1], first_error[2], first_error[3]); ++ std::cerr << std::endl << "Correct:" << std::endl; ++ ggml_vk_print_tensor_area(tensor, comp_result, first_error[0], first_error[1], first_error[2], first_error[3]); ++ std::cerr << std::endl; ++ std::vector done; ++ ggml_vk_print_graph_origin(tensor, done); ++ GGML_ABORT("fatal error"); ++ } else { ++ std::cerr << check_counter << " " << tensor->name << " op=" << ggml_op_name(tensor->op) << " avg_err=" << avg_err << std::endl; ++ } ++ ++ free(comp_result); ++ comp_result = nullptr; ++ comp_size = 0; ++ ++ if (ggml_backend_buffer_is_vk(tensor->buffer)) { ++ free(tensor_data); ++ } ++ ++ VK_LOG_DEBUG("END ggml_vk_check_results_1(" << tensor->name << ")"); ++} ++#endif ++ ++GGML_BACKEND_DL_IMPL(ggml_backend_vk_reg) +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +new file mode 100644 +index 00000000..bd0c74cb +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +@@ -0,0 +1,9 @@ ++find_package (Threads REQUIRED) ++find_package(Vulkan COMPONENTS glslc REQUIRED) ++ ++set(TARGET vulkan-shaders-gen) ++add_executable(${TARGET} vulkan-shaders-gen.cpp) ++install(TARGETS ${TARGET} RUNTIME) ++target_compile_features(${TARGET} PRIVATE cxx_std_17) ++target_link_libraries(vulkan-shaders-gen PUBLIC Threads::Threads) ++target_link_libraries(vulkan-shaders-gen PRIVATE Vulkan::Vulkan) +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +new file mode 100644 +index 00000000..d896f1ef +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +@@ -0,0 +1,29 @@ ++#version 450 ++ ++#include "types.comp" ++#include "generic_binary_head.comp" ++ ++layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; ++ ++void main() { ++ const uint idx = gl_GlobalInvocationID.x; ++ if (idx >= p.ne) { ++ return; ++ } ++ ++ const uint offset = p.param3; ++ const uint src1_i = idx - offset; ++ const uint oz = src1_i / p.nb02; ++ const uint oy = (src1_i - (oz * p.nb02)) / p.nb01; ++ const uint ox = src1_i % p.nb01; ++ ++ uint i00, i01, i02, i03; ++ get_indices(idx, i00, i01, i02, i03); ++ ++ if (ox < p.ne10 && oy < p.ne11 && oz < p.ne12) { ++ data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + ox + oy * p.ne10 + oz * p.ne10 * p.ne11])); ++ } else { ++ data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)])); ++ } ++} ++ +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/add.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +new file mode 100644 +index 00000000..2b4085c4 +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +@@ -0,0 +1,29 @@ ++#version 450 ++ ++#extension GL_EXT_shader_16bit_storage : require ++ ++#include "types.comp" ++#include "generic_binary_head.comp" ++ ++const uint num_threads = 256; ++ ++layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; ++ ++void main() { ++ uint idx = get_idx(); ++ ++ // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation ++ const uint num_iter = 2; ++ ++ [[unroll]] for (uint i = 0; i < num_iter; ++i) { ++ if (idx >= p.ne) { ++ continue; ++ } ++ uint i00, i01, i02, i03; ++ get_indices(idx, i00, i01, i02, i03); ++ ++ data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)])); ++ ++ idx += num_threads; ++ } ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +new file mode 100644 +index 00000000..d4fa45b1 +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +@@ -0,0 +1,69 @@ ++#version 450 ++ ++#include "types.comp" ++ ++#define BLOCK_SIZE 1024 ++#define ASC 0 ++ ++layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; ++ ++layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; ++layout (binding = 1) buffer D {int data_d[];}; ++ ++layout (push_constant) uniform parameter { ++ uint ncols; ++ uint ncols_pad; ++ uint order; ++} p; ++ ++shared int dst_row[BLOCK_SIZE]; ++ ++void swap(uint idx0, uint idx1) { ++ int tmp = dst_row[idx0]; ++ dst_row[idx0] = dst_row[idx1]; ++ dst_row[idx1] = tmp; ++} ++ ++void main() { ++ // bitonic sort ++ const int col = int(gl_LocalInvocationID.x); ++ const uint row = gl_WorkGroupID.y; ++ ++ const uint row_offset = row * p.ncols; ++ ++ // initialize indices ++ if (col < p.ncols_pad) { ++ dst_row[col] = col; ++ } ++ barrier(); ++ ++ for (uint k = 2; k <= p.ncols_pad; k *= 2) { ++ for (uint j = k / 2; j > 0; j /= 2) { ++ const uint ixj = col ^ j; ++ if (col < p.ncols_pad && ixj > col) { ++ if ((col & k) == 0) { ++ if (dst_row[col] >= p.ncols || ++ (dst_row[ixj] < p.ncols && (p.order == ASC ? ++ data_a[row_offset + dst_row[col]] > data_a[row_offset + dst_row[ixj]] : ++ data_a[row_offset + dst_row[col]] < data_a[row_offset + dst_row[ixj]])) ++ ) { ++ swap(col, ixj); ++ } ++ } else { ++ if (dst_row[ixj] >= p.ncols || ++ (dst_row[col] < p.ncols && (p.order == ASC ? ++ data_a[row_offset + dst_row[col]] < data_a[row_offset + dst_row[ixj]] : ++ data_a[row_offset + dst_row[col]] > data_a[row_offset + dst_row[ixj]])) ++ ) { ++ swap(col, ixj); ++ } ++ } ++ } ++ barrier(); ++ } ++ } ++ ++ if (col < p.ncols) { ++ data_d[row_offset + col] = dst_row[col]; ++ } ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +new file mode 100644 +index 00000000..1e5cb8da +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +@@ -0,0 +1,17 @@ ++#version 450 ++ ++#include "types.comp" ++#include "generic_unary_head.comp" ++ ++layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; ++ ++void main() { ++ const uint idx = get_idx(); ++ ++ if (idx >= p.ne) { ++ return; ++ } ++ ++ const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]); ++ data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val < p.param1 ? p.param1 : (val > p.param2 ? p.param2 : val)); ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +new file mode 100644 +index 00000000..9ee2f1fa +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +@@ -0,0 +1,41 @@ ++#version 450 ++ ++#include "types.comp" ++#include "generic_binary_head.comp" ++ ++layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; ++ ++void main() { ++ const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; ++ const int dim = p.param3; ++ ++ if (idx >= p.ne) { ++ return; ++ } ++ ++ const uint i3 = idx / (p.ne22*p.ne21*p.ne20); ++ const uint i3_offset = i3 * p.ne22*p.ne21*p.ne20; ++ const uint i2 = (idx - i3_offset) / (p.ne21*p.ne20); ++ const uint i2_offset = i2*p.ne21*p.ne20; ++ const uint i1 = (idx - i3_offset - i2_offset) / p.ne20; ++ const uint i0 = idx - i3_offset - i2_offset - i1*p.ne20; ++ ++ uint o[4] = {0, 0, 0, 0}; ++ o[dim] = dim == 0 ? p.ne00 : (dim == 1 ? p.ne01 : (dim == 2 ? p.ne02 : p.ne03)); ++ ++ const uint src0_idx = i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0*p.nb00; ++ const uint src1_idx = (i3 - o[3])*p.nb13 + (i2 - o[2])*p.nb12 + (i1 - o[1])*p.nb11 + (i0 - o[0])*p.nb10; ++ const uint dst_idx = i3*p.nb23 + i2*p.nb22 + i1*p.nb21 + i0*p.nb20; ++ ++ const bool is_src0 = i0 < p.ne00 && i1 < p.ne01 && i2 < p.ne02 && i3 < p.ne03; ++ ++#ifndef OPTIMIZATION_ERROR_WORKAROUND ++ data_d[get_doffset() + dst_idx] = D_TYPE(is_src0 ? data_a[get_aoffset() + src0_idx] : data_b[get_boffset() + src1_idx]); ++#else ++ if (is_src0) { ++ data_d[get_doffset() + dst_idx] = data_a[get_aoffset() + src0_idx]; ++ } else { ++ data_d[get_doffset() + dst_idx] = data_b[get_boffset() + src1_idx]; ++ } ++#endif ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +new file mode 100644 +index 00000000..dd828c23 +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +@@ -0,0 +1,42 @@ ++#version 450 ++ ++#include "types.comp" ++#include "generic_unary_head.comp" ++ ++#extension GL_EXT_control_flow_attributes : require ++ ++const uint num_threads = 128; ++ ++layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; ++ ++void main() { ++ uint idx = get_idx(); ++ ++ // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation ++ const uint num_iter = 4; ++ ++ // fast path for when all four iterations are in-bounds ++ if (idx + (num_iter-1)*num_threads < p.ne) { ++ [[unroll]] for (uint i = 0; i < num_iter; ++i) { ++#ifndef OPTIMIZATION_ERROR_WORKAROUND ++ data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]); ++#else ++ data_d[get_doffset() + idx] = data_a[get_aoffset() + idx]; ++#endif ++ idx += num_threads; ++ } ++ } else { ++ [[unroll]] for (uint i = 0; i < num_iter; ++i) { ++ if (idx >= p.ne) { ++ continue; ++ } ++ ++#ifndef OPTIMIZATION_ERROR_WORKAROUND ++ data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]); ++#else ++ data_d[get_doffset() + idx] = data_a[get_aoffset() + idx]; ++#endif ++ idx += num_threads; ++ } ++ } ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +new file mode 100644 +index 00000000..29c90649 +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +@@ -0,0 +1,20 @@ ++#version 450 ++ ++#include "types.comp" ++#include "generic_unary_head.comp" ++ ++layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; ++ ++void main() { ++ const uint idx = get_idx(); ++ ++ if (idx >= p.ne) { ++ return; ++ } ++ ++#ifndef OPTIMIZATION_ERROR_WORKAROUND ++ data_d[get_doffset() + dst_idx(idx)] = D_TYPE(data_a[get_aoffset() + src0_idx(idx)]); ++#else ++ data_d[get_doffset() + dst_idx(idx)] = data_a[get_aoffset() + src0_idx(idx)]; ++#endif ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +new file mode 100644 +index 00000000..0b8d02f5 +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +@@ -0,0 +1,17 @@ ++#version 450 ++ ++#include "types.comp" ++#include "generic_unary_head.comp" ++ ++layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; ++ ++void main() { ++ const uint idx = get_idx(); ++ ++ if (idx >= p.ne) { ++ return; ++ } ++ ++ const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]); ++ data_d[get_doffset() + dst_idx(idx)] = D_TYPE(cos(val)); ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +new file mode 100644 +index 00000000..a4d3fca5 +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +@@ -0,0 +1,20 @@ ++#version 450 ++ ++#include "dequant_head.comp" ++ ++layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; ++ ++layout (binding = 0) readonly buffer A {float data_a[];}; ++layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; ++ ++void main() { ++ const uint i = gl_GlobalInvocationID.x * 16; ++ ++ if (i >= p.nel) { ++ return; ++ } ++ ++ [[unroll]] for (uint l = 0; l < 16; l++) { ++ data_b[i + l] = D_TYPE(data_a[i + l]); ++ } ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +new file mode 100644 +index 00000000..91bb8f8d +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +@@ -0,0 +1,118 @@ ++#if !defined(DATA_A_F32) && !defined(DATA_A_F16) ++#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require ++#endif ++ ++#include "types.comp" ++ ++#if defined(A_TYPE_PACKED16) ++layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];}; ++#endif ++#if defined(A_TYPE_PACKED32) ++layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];}; ++#endif ++ ++#if defined(DATA_A_F32) ++vec2 dequantize(uint ib, uint iqs, uint a_offset) { ++ return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]); ++} ++#endif ++ ++#if defined(DATA_A_F16) ++vec2 dequantize(uint ib, uint iqs, uint a_offset) { ++ return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]); ++} ++#endif ++ ++#if defined(DATA_A_Q4_0) ++vec2 dequantize(uint ib, uint iqs, uint a_offset) { ++ const uint vui = uint(data_a[a_offset + ib].qs[iqs]); ++ return (vec2(vui & 0xF, vui >> 4) - 8.0f); ++} ++vec4 dequantize4(uint ib, uint iqs, uint a_offset) { ++ const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]); ++ return (vec4(vui & 0xF, (vui >> 4) & 0xF, (vui >> 8) & 0xF, vui >> 12) - 8.0f); ++} ++#endif ++ ++#if defined(DATA_A_Q4_1) ++vec2 dequantize(uint ib, uint iqs, uint a_offset) { ++ const uint vui = uint(data_a[a_offset + ib].qs[iqs]); ++ return vec2(vui & 0xF, vui >> 4); ++} ++vec4 dequantize4(uint ib, uint iqs, uint a_offset) { ++ const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]); ++ return vec4(vui & 0xF, (vui >> 4) & 0xF, (vui >> 8) & 0xF, vui >> 12); ++} ++#endif ++ ++#if defined(DATA_A_Q5_0) ++vec2 dequantize(uint ib, uint iqs, uint a_offset) { ++ const uint uint_qh = uint(data_a[a_offset + ib].qh[1]) << 16 | data_a[a_offset + ib].qh[0]; ++ const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); ++ const uint vui = uint(data_a[a_offset + ib].qs[iqs]); ++ return (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f); ++} ++vec4 dequantize4(uint ib, uint iqs, uint a_offset) { ++ const uint uint_qh = uint(data_a_packed16[a_offset + ib].qh[1]) << 16 | data_a_packed16[a_offset + ib].qh[0]; ++ const ivec2 qh0 = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); ++ const ivec2 qh1 = ivec2(((uint_qh >> (iqs + 1)) << 4) & 0x10, (uint_qh >> (iqs + 13)) & 0x10); ++ const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]); ++ return (vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) - 16.0f); ++} ++#endif ++ ++#if defined(DATA_A_Q5_1) ++vec2 dequantize(uint ib, uint iqs, uint a_offset) { ++ const uint uint_qh = data_a[a_offset + ib].qh; ++ const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); ++ const uint vui = uint(data_a[a_offset + ib].qs[iqs]); ++ return vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y); ++} ++vec4 dequantize4(uint ib, uint iqs, uint a_offset) { ++ const uint uint_qh = data_a_packed16[a_offset + ib].qh; ++ const ivec2 qh0 = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); ++ const ivec2 qh1 = ivec2(((uint_qh >> (iqs + 1)) << 4) & 0x10, (uint_qh >> (iqs + 13)) & 0x10); ++ const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]); ++ return vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y); ++} ++#endif ++ ++#if defined(DATA_A_Q8_0) ++vec2 dequantize(uint ib, uint iqs, uint a_offset) { ++ return vec2(int(data_a[a_offset + ib].qs[iqs]), int(data_a[a_offset + ib].qs[iqs + 1])); ++} ++vec4 dequantize4(uint ib, uint iqs, uint a_offset) { ++ uint32_t v0 = data_a_packed16[a_offset + ib].qs[iqs/2]; ++ uint32_t v1 = data_a_packed16[a_offset + ib].qs[iqs/2 + 1]; ++ return vec4(int8_t(v0 & 0xFF), int8_t(v0 >> 8), int8_t(v1 & 0xFF), int8_t(v1 >> 8)); ++} ++#endif ++ ++#if defined(DATA_A_IQ4_NL) ++vec2 dequantize(uint ib, uint iqs, uint a_offset) { ++ const uint vui = uint(data_a[a_offset + ib].qs[iqs]); ++ return vec2(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[vui >> 4]); ++} ++vec4 dequantize4(uint ib, uint iqs, uint a_offset) { ++ const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]); ++ return vec4(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[(vui >> 4) & 0xF], kvalues_iq4nl[(vui >> 8) & 0xF], kvalues_iq4nl[vui >> 12]); ++} ++#endif ++ ++#if defined(DATA_A_F32) || defined(DATA_A_F16) ++vec2 get_dm(uint ib, uint a_offset) { ++ return vec2(0, 0); ++} ++#endif ++ ++#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ4_NL) ++vec2 get_dm(uint ib, uint a_offset) { ++ return vec2(float(data_a[a_offset + ib].d), 0); ++} ++#endif ++ ++#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1) ++vec2 get_dm(uint ib, uint a_offset) { ++ return vec2(float(data_a[a_offset + ib].d), float(data_a[a_offset + ib].m)); ++} ++#endif +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +new file mode 100644 +index 00000000..94b78598 +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +@@ -0,0 +1,325 @@ ++ ++#include "types.comp" ++ ++layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ4_0 { ++ block_q4_0_packed16 block; ++}; ++ ++float16_t dequantFuncQ4_0(const in decodeBufQ4_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) ++{ ++ const float16_t d = bl.block.d; ++ const uint idx = coordInBlock[1]; ++ const uint shift = (idx & 0x10) >> 2; ++ uint32_t qs = uint32_t(bl.block.qs[(idx & 0xE) >> 1]); ++ qs >>= shift; ++ qs &= 0x0F0F; ++ qs = unpack8(qs)[idx & 1]; ++ float16_t ret = (float16_t(qs) - float16_t(8)) * d; ++ return ret; ++} ++ ++layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ4_1 { ++ block_q4_1 block; ++}; ++ ++float16_t dequantFuncQ4_1(const in decodeBufQ4_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) ++{ ++ const float16_t d = bl.block.d; ++ const float16_t m = bl.block.m; ++ const uint idx = coordInBlock[1]; ++ const uint iqs = idx & 0xF; ++ const uint shift = (idx & 0x10) >> 2; ++ uint32_t qs = bl.block.qs[iqs]; ++ qs >>= shift; ++ qs &= 0xF; ++ float16_t ret = float16_t(qs) * d + m; ++ return ret; ++} ++ ++layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ5_0 { ++ block_q5_0 block; ++}; ++ ++float16_t dequantFuncQ5_0(const in decodeBufQ5_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) ++{ ++ const float16_t d = bl.block.d; ++ const uint idx = coordInBlock[1]; ++ const uint iqs = idx & 0xF; ++ ++ const uint uint_qh = uint(bl.block.qh[1]) << 16 | bl.block.qh[0]; ++ const uint qh = ((uint_qh >> idx) << 4) & 0x10; ++ ++ const uint shift = (idx & 0x10) >> 2; ++ uint32_t qs = bl.block.qs[iqs]; ++ qs >>= shift; ++ qs &= 0xF; ++ ++ float16_t ret = (float16_t(qs | qh) - float16_t(16)) * d; ++ return ret; ++} ++ ++layout(buffer_reference, std430, buffer_reference_align = 8) buffer decodeBufQ5_1 { ++ block_q5_1 block; ++}; ++ ++float16_t dequantFuncQ5_1(const in decodeBufQ5_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) ++{ ++ const float16_t d = bl.block.d; ++ const float16_t m = bl.block.m; ++ const uint idx = coordInBlock[1]; ++ const uint iqs = idx & 0xF; ++ ++ const uint uint_qh = bl.block.qh; ++ const uint qh = ((uint_qh >> idx) << 4) & 0x10; ++ ++ const uint shift = (idx & 0x10) >> 2; ++ uint32_t qs = bl.block.qs[iqs]; ++ qs >>= shift; ++ qs &= 0xF; ++ ++ float16_t ret = float16_t(qs | qh) * d + m; ++ return ret; ++} ++ ++layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ8_0 { ++ block_q8_0_packed16 block; ++}; ++ ++float16_t dequantFuncQ8_0(const in decodeBufQ8_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) ++{ ++ const float16_t d = bl.block.d; ++ const uint idx = coordInBlock[1]; ++ const uint iqs = idx; ++ ++ // Load 16b and select the byte for this element ++ int32_t qs = unpack8(int32_t(bl.block.qs[(iqs & 0x1E) >> 1]))[iqs & 1]; ++ float16_t ret = float16_t(qs) * d; ++ return ret; ++} ++ ++layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ2_K { ++ block_q2_K block; ++}; ++ ++float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) ++{ ++ const f16vec2 d = bl.block.d; ++ const uint idx = coordInBlock[1]; ++ const uint iqs = idx; ++ ++ const uint qsi = (iqs / 128) * 32 + (iqs % 32); // 0..31 ++ const uint scalesi = iqs / 16; // 0..15 ++ const uint qsshift = ((iqs % 128) / 32) * 2; // 0,2,4,6 ++ ++ uint32_t qs = bl.block.qs[qsi]; ++ const uint scales = bl.block.scales[scalesi]; ++ float16_t ret = d.x * float16_t(scales & 0xF) * float16_t((qs >> qsshift) & 3) - d.y * float16_t(scales >> 4); ++ return ret; ++} ++ ++layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ3_K { ++ block_q3_K block; ++}; ++ ++float16_t dequantFuncQ3_K(const in decodeBufQ3_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) ++{ ++ const uint idx = coordInBlock[1]; ++ const uint iqs = idx; ++ ++ const uint n = iqs / 128; // 0,1 ++ const uint qsi = n * 32 + (iqs % 32); // 0..63 ++ const uint hmi = (iqs % 32); // 0..31 ++ const uint j = (iqs % 128) / 8; // 0..15 ++ const uint is = iqs / 16; // 0..15 ++ const uint halfsplit = ((iqs % 128) / 32); // 0,1,2,3 ++ const uint qsshift = halfsplit * 2; // 0,2,4,6 ++ const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128 ++ ++ uint32_t scaleidx0 = (is < 8) ? is : (is-8); ++ uint32_t scaleidx0shift = (is < 8) ? 0 : 4; ++ uint32_t scaleidx1 = is + 8 - (is/4)*4; ++ uint32_t scaleidx1shift = (is/4)*2; ++ ++ const int8_t us = int8_t(((bl.block.scales[scaleidx0] >> scaleidx0shift) & 0xF) | (((bl.block.scales[scaleidx1] >> scaleidx1shift) & 3) << 4)); ++ ++ const float16_t dl = bl.block.d * float16_t(us - 32); ++ ++ float16_t ret = dl * float16_t(int8_t((bl.block.qs[qsi ] >> qsshift) & 3) - (((bl.block.hmask[hmi ] & m) != 0) ? 0 : 4)); ++ ++ return ret; ++} ++ ++layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K { ++ block_q4_K block; ++}; ++ ++layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K_packed16 { ++ block_q4_K_packed16 block; ++}; ++ ++float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) ++{ ++ decodeBufQ4_K_packed16 bl16 = decodeBufQ4_K_packed16(bl); ++ const uint idx = coordInBlock[1]; ++ ++ const uint b = (idx & 0x20) >> 5; // 0,1 ++ const uint is = (idx & 0xE0) >> 5; // 0..7 ++ ++ const f16vec2 loadd = bl.block.d; ++ ++ uint32_t sc; ++ uint32_t mbyte; ++ ++ uint32_t scidx0 = (is < 4) ? is : (is + 4); ++ uint32_t scidx1 = (is < 4) ? is : (is - 4); ++ uint32_t scidxmask1 = (is < 4) ? 0x30 : 0xC0; ++ uint32_t scidxshift1 = (is < 4) ? 0 : 2; ++ uint32_t mbidx0 = is + 4; ++ uint32_t mbidx1 = (is < 4) ? is + 4 : is; ++ uint32_t mbidxmask0 = (is < 4) ? 0xF : 0xF0; ++ uint32_t mbidxshift0 = (is < 4) ? 0 : 4; ++ uint32_t mbidxmask1 = (is < 4) ? 0x30 : 0xC0; ++ uint32_t mbidxshift1 = (is < 4) ? 0 : 2; ++ ++ sc = uint8_t((bl.block.scales[scidx0] & 0xF) | ((bl.block.scales[scidx1] & scidxmask1) >> scidxshift1)); ++ mbyte = uint8_t(((bl.block.scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((bl.block.scales[mbidx1] & mbidxmask1) >> mbidxshift1)); ++ ++ const float16_t d = loadd.x * float16_t(sc); ++ const float16_t m = loadd.y * float16_t(mbyte); ++ ++ uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]); ++ qs = (qs >> (b * 4)) & 0x0F0F; ++ qs = unpack8(qs)[idx & 1]; ++ ++ float16_t ret = d * float16_t(qs) - m; ++ ++ return ret; ++} ++ ++layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K { ++ block_q5_K block; ++}; ++ ++layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K_packed16 { ++ block_q5_K_packed16 block; ++}; ++ ++float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) ++{ ++ decodeBufQ5_K_packed16 bl16 = decodeBufQ5_K_packed16(bl); ++ const uint idx = coordInBlock[1]; ++ ++ const uint b = (idx & 0x20) >> 5; // 0,1 ++ const uint is = (idx & 0xE0) >> 5; // 0..7 ++ ++ const uint32_t hm = 0x0101 << is; ++ ++ const f16vec2 loadd = bl.block.d; ++ ++ uint32_t sc; ++ uint32_t mbyte; ++ ++ uint32_t scidx0 = (is < 4) ? is : (is + 4); ++ uint32_t scidx1 = (is < 4) ? is : (is - 4); ++ uint32_t scidxmask1 = (is < 4) ? 0x30 : 0xC0; ++ uint32_t scidxshift1 = (is < 4) ? 0 : 2; ++ uint32_t mbidx0 = is + 4; ++ uint32_t mbidx1 = (is < 4) ? is + 4 : is; ++ uint32_t mbidxmask0 = (is < 4) ? 0xF : 0xF0; ++ uint32_t mbidxshift0 = (is < 4) ? 0 : 4; ++ uint32_t mbidxmask1 = (is < 4) ? 0x30 : 0xC0; ++ uint32_t mbidxshift1 = (is < 4) ? 0 : 2; ++ ++ sc = uint8_t((bl.block.scales[scidx0] & 0xF) | ((bl.block.scales[scidx1] & scidxmask1) >> scidxshift1)); ++ mbyte = uint8_t(((bl.block.scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((bl.block.scales[mbidx1] & mbidxmask1) >> mbidxshift1)); ++ ++ const float16_t d = loadd.x * float16_t(sc); ++ const float16_t m = loadd.y * float16_t(mbyte); ++ ++ uint qh = uint32_t(bl16.block.qh[(idx & 0x1E) >> 1]); ++ qh = qh & hm; ++ qh = unpack8(qh)[idx & 1]; ++ ++ uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]); ++ qs = (qs >> (b * 4)) & 0x0F0F; ++ qs = unpack8(qs)[idx & 1]; ++ ++ float16_t ret = d * (float16_t(qs) + (qh != 0 ? float16_t(16) : float16_t(0))) - m; ++ ++ return ret; ++} ++ ++layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ6_K { ++ block_q6_K block; ++}; ++ ++layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ6_K_packed16 { ++ block_q6_K_packed16 block; ++}; ++ ++float16_t dequantFuncQ6_K(const in decodeBufQ6_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) ++{ ++ decodeBufQ6_K_packed16 bl16 = decodeBufQ6_K_packed16(bl); ++ const uint idx = coordInBlock[1]; ++ ++ const uint b = (idx & 0x40) >> 6; // 0,1 ++ const uint qhshift = (idx & 0x60) >> 4; // 0,2,4,6 ++ const uint is = (idx & 0xF0) >> 4; // 0..15 ++ ++ const float16_t dscale = bl.block.d * float16_t(bl.block.scales[is]); ++ ++ uint ql = uint32_t(bl16.block.ql[((idx & 0x80) >> 2) + ((idx & 0x3E) >> 1)]); ++ ql = (ql >> (b * 4)) & 0x0F0F; ++ ++ uint qh = uint32_t(bl16.block.qh[((idx & 0x80) >> 3) + ((idx & 0x1E) >> 1)]); ++ qh = ((qh >> qhshift) & 0x0303) << 4; ++ ++ int q = unpack8(ql | qh)[idx & 1]; ++ ++ float16_t ret = dscale * float16_t(q - 32); ++ ++ return ret; ++} ++ ++#if defined(DATA_A_IQ4_NL) ++layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4_NL { ++ block_iq4_nl block; ++}; ++ ++float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoords[2], const in uint coordInBlock[2]) ++{ ++ const float16_t d = bl.block.d; ++ const uint idx = coordInBlock[1]; ++ const uint iqs = idx & 0xF; ++ const uint shift = (idx & 0x10) >> 2; ++ uint32_t qs = bl.block.qs[iqs]; ++ qs >>= shift; ++ qs &= 0xF; ++ float16_t ret = float16_t(kvalues_iq4nl[qs]) * d; ++ return ret; ++} ++#endif ++ ++#if defined(DATA_A_Q4_0) ++#define dequantFuncA dequantFuncQ4_0 ++#elif defined(DATA_A_Q4_1) ++#define dequantFuncA dequantFuncQ4_1 ++#elif defined(DATA_A_Q5_0) ++#define dequantFuncA dequantFuncQ5_0 ++#elif defined(DATA_A_Q5_1) ++#define dequantFuncA dequantFuncQ5_1 ++#elif defined(DATA_A_Q8_0) ++#define dequantFuncA dequantFuncQ8_0 ++#elif defined(DATA_A_Q2_K) ++#define dequantFuncA dequantFuncQ2_K ++#elif defined(DATA_A_Q3_K) ++#define dequantFuncA dequantFuncQ3_K ++#elif defined(DATA_A_Q4_K) ++#define dequantFuncA dequantFuncQ4_K ++#elif defined(DATA_A_Q5_K) ++#define dequantFuncA dequantFuncQ5_K ++#elif defined(DATA_A_Q6_K) ++#define dequantFuncA dequantFuncQ6_K ++#elif defined(DATA_A_IQ4_NL) ++#define dequantFuncA dequantFuncIQ4_NL ++#endif +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.comp +new file mode 100644 +index 00000000..8d806435 +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.comp +@@ -0,0 +1,13 @@ ++#extension GL_EXT_control_flow_attributes : require ++#extension GL_EXT_shader_16bit_storage : require ++ ++layout (push_constant) uniform parameter ++{ ++ uint M; ++ uint K; ++ uint stride_a; ++ uint stride_b; ++ uint nel; ++} p; ++ ++#include "types.comp" +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +new file mode 100644 +index 00000000..8de14fc0 +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +@@ -0,0 +1,32 @@ ++#version 450 ++ ++#include "dequant_head.comp" ++ ++layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; ++ ++layout (binding = 0) readonly buffer A {block_iq4_nl data_a[];}; ++layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; ++ ++void main() { ++ const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; ++ ++ init_iq4nl_shmem(); ++ ++ const uint tid = gl_LocalInvocationID.x % 64; ++ const uint il = tid/32; ++ const uint ir = tid%32; ++ const uint ib = 32*i + ir; ++ if (ib >= p.nel / 32) { ++ return; ++ } ++ ++ const uint q_idx = 8*il; ++ const uint b_idx = 1024*i + 32*ir + q_idx; ++ ++ const float d = float(data_a[ib].d); ++ ++ [[unroll]] for (uint l = 0; l < 8; ++l) { ++ data_b[b_idx + l + 0] = D_TYPE(d * kvalues_iq4nl[data_a[ib].qs[q_idx + l] & 0xF]); ++ data_b[b_idx + l + 16] = D_TYPE(d * kvalues_iq4nl[data_a[ib].qs[q_idx + l] >> 4]); ++ } ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +new file mode 100644 +index 00000000..157154af +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +@@ -0,0 +1,34 @@ ++#version 450 ++ ++#include "dequant_head.comp" ++ ++layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; ++ ++layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; ++layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; ++ ++void main() { ++ [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { ++ const uint i = gl_WorkGroupID.x * 256 + wgy; ++ if (i >= p.M * p.K / QUANT_K) { ++ return; ++ } ++ ++ const uint tid = gl_LocalInvocationID.x; ++ const uint ip = tid / 32; ++ const uint il = tid - 32 * ip; ++ const uint is = 8 * ip + il / 16; ++ ++ const uint y_idx = i * QUANT_K + 128 * ip + il; ++ ++ const uint ql_idx = 32 * ip + il; ++ const uint8_t qs = data_a[i].qs[32 * ip + il]; ++ ++ FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x); ++ FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y); ++ data_b[y_idx + 0] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+0] & 0xF) * ((qs >> 0) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+0] >> 4)); ++ data_b[y_idx + 32] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+2] & 0xF) * ((qs >> 2) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+2] >> 4)); ++ data_b[y_idx + 64] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+4] & 0xF) * ((qs >> 4) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+4] >> 4)); ++ data_b[y_idx + 96] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+6] & 0xF) * ((qs >> 6) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+6] >> 4)); ++ } ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +new file mode 100644 +index 00000000..c17dd0d9 +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +@@ -0,0 +1,42 @@ ++#version 450 ++ ++#include "dequant_head.comp" ++ ++layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; ++ ++layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; ++layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; ++ ++void main() { ++ [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { ++ const uint i = uint(gl_WorkGroupID.x * 256 + wgy); ++ if (i >= p.M * p.K / QUANT_K) { ++ return; ++ } ++ ++ const uint r = gl_LocalInvocationID.x / 4; ++ const uint tid = r / 2; ++ const uint is0 = r % 2; ++ const uint l0 = 16 * is0 + 4 * (gl_LocalInvocationID.x % 4); ++ const uint n = tid / 4; ++ const uint j = tid - 4*n; ++ ++ const uint8_t m = uint8_t(1 << (4*n + j)); ++ const uint is = 8*n + 2*j + is0; ++ const uint shift = 2*j; ++ ++ const int8_t us = int8_t(is < 4 ? (data_a[i].scales[is-0] & 0xF) | (((data_a[i].scales[is+8] >> 0) & 3) << 4) : ++ is < 8 ? (data_a[i].scales[is-0] & 0xF) | (((data_a[i].scales[is+4] >> 2) & 3) << 4) : ++ is < 12 ? (data_a[i].scales[is-8] >> 4) | (((data_a[i].scales[is+0] >> 4) & 3) << 4) : ++ (data_a[i].scales[is-8] >> 4) | (((data_a[i].scales[is-4] >> 6) & 3) << 4)); ++ const FLOAT_TYPE d_all = FLOAT_TYPE(data_a[i].d); ++ const FLOAT_TYPE dl = d_all * FLOAT_TYPE(us - 32); ++ ++ const uint y_idx = i * QUANT_K + 128 * n + 32 * j; ++ const uint qs_idx = 32*n; ++ ++ for (uint l = l0; l < l0 + 4; ++l) { ++ data_b[y_idx + l] = D_TYPE(dl * FLOAT_TYPE(int8_t((data_a[i].qs[qs_idx + l] >> shift) & 3) - (((data_a[i].hmask[l] & m) != 0) ? 0 : 4))); ++ } ++ } ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +new file mode 100644 +index 00000000..40818532 +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +@@ -0,0 +1,30 @@ ++#version 450 ++ ++#include "dequant_head.comp" ++ ++layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; ++ ++layout (binding = 0) readonly buffer A {block_q4_0 data_a[];}; ++layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; ++ ++void main() { ++ const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; ++ ++ const uint tid = gl_LocalInvocationID.x % 64; ++ const uint il = tid/32; ++ const uint ir = tid%32; ++ const uint ib = 32*i + ir; ++ if (ib >= p.nel / 32) { ++ return; ++ } ++ ++ const uint q_idx = 8*il; ++ const uint b_idx = 1024*i + 32*ir + q_idx; ++ ++ const float d = float(data_a[ib].d); ++ ++ [[unroll]] for (uint l = 0; l < 8; ++l) { ++ data_b[b_idx + l + 0] = D_TYPE(d * ((data_a[ib].qs[q_idx + l] & 0xF) - 8.0f)); ++ data_b[b_idx + l + 16] = D_TYPE(d * ((data_a[ib].qs[q_idx + l] >> 4) - 8.0f)); ++ } ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +new file mode 100644 +index 00000000..2f27eee6 +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +@@ -0,0 +1,32 @@ ++#version 450 ++ ++#include "dequant_head.comp" ++ ++layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; ++ ++layout (binding = 0) readonly buffer A {block_q4_1 data_a[];}; ++layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; ++ ++void main() { ++ const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; ++ ++ const uint tid = gl_LocalInvocationID.x % 64; ++ const uint il = tid/32; ++ const uint ir = tid%32; ++ const uint ib = 32*i + ir; ++ if (ib >= p.nel / 32) { ++ return; ++ } ++ ++ const uint b_idx = 1024*i + 32*ir + 8*il; ++ ++ const float d = float(data_a[ib].d); ++ const float m = float(data_a[ib].m); ++ ++ const uint q_idx = 8*il; ++ ++ [[unroll]] for (uint l = 0; l < 8; ++l) { ++ data_b[b_idx + l + 0] = D_TYPE(d * (data_a[ib].qs[q_idx + l] & 0xF) + m); ++ data_b[b_idx + l + 16] = D_TYPE(d * (data_a[ib].qs[q_idx + l] >> 4) + m); ++ } ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +new file mode 100644 +index 00000000..987f113a +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +@@ -0,0 +1,68 @@ ++#version 450 ++ ++#include "dequant_head.comp" ++ ++layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; ++ ++layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; ++layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; ++ ++void main() { ++ [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { ++ const uint ib = gl_WorkGroupID.x * 256 + wgy; ++ if (ib >= p.M * p.K / QUANT_K) { ++ return; ++ } ++ ++ const uint tid = gl_LocalInvocationID.x; ++ const uint il = tid / 8; ++ const uint ir = tid % 8; ++ const uint is = 2 * il; ++ const uint n = 4; ++ ++ const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].d.x); ++ const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].d.y); ++ ++ const uint y_idx = ib * QUANT_K + 64 * il + n * ir; ++ const uint qs_idx = 32*il + n * ir; ++ ++ uint scidx0 = (is < 4) ? is : (is + 4); ++ uint scidx1 = (is < 4) ? is : (is - 4); ++ uint scidxmask1 = (is < 4) ? 0x30 : 0xC0; ++ uint scidxshift1 = (is < 4) ? 0 : 2; ++ uint mbidx0 = is + 4; ++ uint mbidx1 = (is < 4) ? is + 4 : is; ++ uint mbidxmask0 = (is < 4) ? 0xF : 0xF0; ++ uint mbidxshift0 = (is < 4) ? 0 : 4; ++ uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0; ++ uint mbidxshift1 = (is < 4) ? 0 : 2; ++ ++ uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); ++ uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); ++ ++ const FLOAT_TYPE d1 = dall * sc; ++ const FLOAT_TYPE m1 = dmin * mbyte; ++ ++ scidx0 = (is < 4) ? is + 1 : (is + 5); ++ scidx1 = (is < 4) ? is + 1 : (is - 3); ++ scidxmask1 = (is < 4) ? 0x30 : 0xC0; ++ scidxshift1 = (is < 4) ? 0 : 2; ++ mbidx0 = is + 5; ++ mbidx1 = (is < 4) ? is + 5 : is + 1; ++ mbidxmask0 = (is < 4) ? 0xF : 0xF0; ++ mbidxshift0 = (is < 4) ? 0 : 4; ++ mbidxmask1 = (is < 4) ? 0x30 : 0xC0; ++ mbidxshift1 = (is < 4) ? 0 : 2; ++ ++ sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); ++ mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); ++ ++ const FLOAT_TYPE d2 = dall * sc; ++ const FLOAT_TYPE m2 = dmin * mbyte; ++ ++ [[unroll]] for (uint l = 0; l < n; ++l) { ++ data_b[y_idx + l ] = D_TYPE(d1 * FLOAT_TYPE(data_a[ib].qs[qs_idx + l] & 0xF) - m1); ++ data_b[y_idx + l + 32] = D_TYPE(d2 * FLOAT_TYPE(data_a[ib].qs[qs_idx + l] >> 4) - m2); ++ } ++ } ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +new file mode 100644 +index 00000000..b20b8052 +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +@@ -0,0 +1,34 @@ ++#version 450 ++ ++#include "dequant_head.comp" ++ ++layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; ++ ++layout (binding = 0) readonly buffer A {block_q5_0 data_a[];}; ++layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; ++ ++void main() { ++ const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; ++ ++ const uint tid = gl_LocalInvocationID.x % 64; ++ const uint il = tid/32; ++ const uint ir = tid%32; ++ const uint ib = 32*i + ir; ++ if (ib >= p.nel / 32) { ++ return; ++ } ++ ++ const uint b_idx = 1024*i + 32*ir + 8*il; ++ ++ const float d = float(data_a[ib].d); ++ const uint qh = uint(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]; ++ ++ const uint q_idx = 8*il; ++ ++ [[unroll]] for (uint l = 0; l < 8; ++l) { ++ const uint iqs = q_idx + l; ++ const uint vui = uint(data_a[ib].qs[iqs]); ++ data_b[b_idx + l + 0] = D_TYPE(d * (((vui & 0xF) | (((qh >> iqs) << 4) & 0x10)) - 16.0f)); ++ data_b[b_idx + l + 16] = D_TYPE(d * (((vui >> 4) | ((qh >> (iqs + 12)) & 0x10)) - 16.0f)); ++ } ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +new file mode 100644 +index 00000000..dc59fe3b +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +@@ -0,0 +1,35 @@ ++#version 450 ++ ++#include "dequant_head.comp" ++ ++layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; ++ ++layout (binding = 0) readonly buffer A {block_q5_1 data_a[];}; ++layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; ++ ++void main() { ++ const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; ++ ++ const uint tid = gl_LocalInvocationID.x % 64; ++ const uint il = tid/32; ++ const uint ir = tid%32; ++ const uint ib = 32*i + ir; ++ if (ib >= p.nel / 32) { ++ return; ++ } ++ ++ const uint b_idx = 1024*i + 32*ir + 8*il; ++ ++ const float d = float(data_a[ib].d); ++ const float m = float(data_a[ib].m); ++ const uint qh = data_a[ib].qh; ++ ++ const uint q_idx = 8*il; ++ ++ [[unroll]] for (uint l = 0; l < 8; ++l) { ++ const uint iqs = q_idx + l; ++ const uint vui = uint(data_a[ib].qs[iqs]); ++ data_b[b_idx + l + 0] = D_TYPE(d * (((vui & 0xF) | (((qh >> iqs) << 4) & 0x10))) + m); ++ data_b[b_idx + l + 16] = D_TYPE(d * (((vui >> 4) | ((qh >> (iqs + 12)) & 0x10))) + m); ++ } ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +new file mode 100644 +index 00000000..6db5403b +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +@@ -0,0 +1,70 @@ ++#version 450 ++ ++#include "dequant_head.comp" ++ ++layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; ++ ++layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; ++layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; ++ ++void main() { ++ [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { ++ const uint ib = gl_WorkGroupID.x * 256 + wgy; ++ if (ib >= p.M * p.K / QUANT_K) { ++ return; ++ } ++ ++ const uint tid = gl_LocalInvocationID.x; ++ const uint il = tid / 16; ++ const uint ir = tid % 16; ++ const uint is = 2 * il; ++ ++ const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].d.x); ++ const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].d.y); ++ ++ const uint y_idx = ib * QUANT_K + 64 * il + 2 * ir; ++ const uint qs_idx = 32*il + 2 * ir; ++ const uint qh_idx = 2 * ir; ++ ++ uint scidx0 = (is < 4) ? is : (is + 4); ++ uint scidx1 = (is < 4) ? is : (is - 4); ++ uint scidxmask1 = (is < 4) ? 0x30 : 0xC0; ++ uint scidxshift1 = (is < 4) ? 0 : 2; ++ uint mbidx0 = is + 4; ++ uint mbidx1 = (is < 4) ? is + 4 : is; ++ uint mbidxmask0 = (is < 4) ? 0xF : 0xF0; ++ uint mbidxshift0 = (is < 4) ? 0 : 4; ++ uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0; ++ uint mbidxshift1 = (is < 4) ? 0 : 2; ++ ++ uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); ++ uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); ++ ++ const FLOAT_TYPE d1 = dall * sc; ++ const FLOAT_TYPE m1 = dmin * mbyte; ++ ++ scidx0 = (is < 4) ? is + 1 : (is + 5); ++ scidx1 = (is < 4) ? is + 1 : (is - 3); ++ scidxmask1 = (is < 4) ? 0x30 : 0xC0; ++ scidxshift1 = (is < 4) ? 0 : 2; ++ mbidx0 = is + 5; ++ mbidx1 = (is < 4) ? is + 5 : is + 1; ++ mbidxmask0 = (is < 4) ? 0xF : 0xF0; ++ mbidxshift0 = (is < 4) ? 0 : 4; ++ mbidxmask1 = (is < 4) ? 0x30 : 0xC0; ++ mbidxshift1 = (is < 4) ? 0 : 2; ++ ++ sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); ++ mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); ++ ++ const FLOAT_TYPE d2 = dall * sc; ++ const FLOAT_TYPE m2 = dmin * mbyte; ++ ++ const uint8_t hm1 = uint8_t(1 << (2 * il )); ++ const uint8_t hm2 = uint8_t(1 << (2 * il + 1)); ++ data_b[y_idx ] = D_TYPE(d1 * FLOAT_TYPE((data_a[ib].qs[qs_idx ] & 0xF) + (((data_a[ib].qh[qh_idx ] & hm1) != 0) ? 16 : 0)) - m1); ++ data_b[y_idx + 1] = D_TYPE(d1 * FLOAT_TYPE((data_a[ib].qs[qs_idx + 1] & 0xF) + (((data_a[ib].qh[qh_idx + 1] & hm1) != 0) ? 16 : 0)) - m1); ++ data_b[y_idx + 32] = D_TYPE(d2 * FLOAT_TYPE((data_a[ib].qs[qs_idx ] >> 4) + (((data_a[ib].qh[qh_idx ] & hm2) != 0) ? 16 : 0)) - m2); ++ data_b[y_idx + 33] = D_TYPE(d2 * FLOAT_TYPE((data_a[ib].qs[qs_idx + 1] >> 4) + (((data_a[ib].qh[qh_idx + 1] & hm2) != 0) ? 16 : 0)) - m2); ++ } ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +new file mode 100644 +index 00000000..0b913175 +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +@@ -0,0 +1,33 @@ ++#version 450 ++ ++#include "dequant_head.comp" ++ ++layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; ++ ++layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; ++layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; ++ ++void main() { ++ [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { ++ const uint i = gl_WorkGroupID.x * 256 + wgy; ++ if (i >= p.M * p.K / QUANT_K) { ++ return; ++ } ++ const uint tid = gl_LocalInvocationID.x; ++ const uint ip = tid / 32; ++ const uint il = tid - 32 * ip; ++ const uint is = 8 * ip + il / 16; ++ ++ const uint y_idx = i * QUANT_K + 128 * ip + il; ++ ++ const uint ql_idx = 64 * ip + il; ++ const uint8_t qh = data_a[i].qh[32 * ip + il]; ++ ++ const FLOAT_TYPE d = FLOAT_TYPE(data_a[i].d); ++ ++ data_b[y_idx + 0] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 0] * (int8_t((data_a[i].ql[ql_idx + 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32))); ++ data_b[y_idx + 32] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 2] * (int8_t((data_a[i].ql[ql_idx + 32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32))); ++ data_b[y_idx + 64] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 4] * (int8_t((data_a[i].ql[ql_idx + 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32))); ++ data_b[y_idx + 96] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 6] * (int8_t((data_a[i].ql[ql_idx + 32] >> 4) | (((qh >> 6) & 3) << 4)) - 32))); ++ } ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +new file mode 100644 +index 00000000..bd1344a8 +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +@@ -0,0 +1,31 @@ ++#version 450 ++ ++#include "dequant_head.comp" ++ ++layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; ++ ++layout (binding = 0) readonly buffer A {block_q8_0 data_a[];}; ++layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; ++ ++void main() { ++ const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; ++ ++ const uint tid = gl_LocalInvocationID.x % 64; ++ const uint il = tid/32; ++ const uint ir = tid%32; ++ const uint ib = 32*i + ir; ++ if (ib >= p.nel / 32) { ++ return; ++ } ++ ++ const uint b_idx = 1024*i + 32*ir + 16*il; ++ ++ const float d = float(data_a[ib].d); ++ ++ const uint q_idx = 16*il; ++ ++ [[unroll]] for (uint l = 0; l < 16; l += 2) { ++ data_b[b_idx + l ] = D_TYPE(d * data_a[ib].qs[q_idx + l ]); ++ data_b[b_idx + l + 1] = D_TYPE(d * data_a[ib].qs[q_idx + l + 1]); ++ } ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +new file mode 100644 +index 00000000..4e68742b +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +@@ -0,0 +1,34 @@ ++#version 450 ++ ++#extension GL_EXT_shader_16bit_storage : require ++#extension GL_EXT_control_flow_attributes : enable ++ ++layout (push_constant) uniform parameter ++{ ++ uint ncols; ++ uint rows_per_channel; ++ uint n_past; ++} p; ++ ++#include "types.comp" ++ ++layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; ++ ++layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; ++layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; ++ ++void main() { ++ const uint col = gl_GlobalInvocationID.y; ++ const uint row = gl_GlobalInvocationID.x; ++ ++ if (col >= p.ncols) { ++ return; ++ } ++ ++ const uint i = row*p.ncols + col; ++ if (col > p.n_past + row % p.rows_per_channel) { ++ data_d[i] = D_TYPE(uintBitsToFloat(0xFF800000)); ++ } else { ++ data_d[i] = D_TYPE(data_a[i]); ++ } ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/div.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +new file mode 100644 +index 00000000..9fb69c6c +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +@@ -0,0 +1,27 @@ ++#version 450 ++ ++#include "types.comp" ++#include "generic_binary_head.comp" ++ ++const uint num_threads = 256; ++ ++layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; ++ ++void main() { ++ uint idx = get_idx(); ++ ++ // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation ++ const uint num_iter = 2; ++ ++ [[unroll]] for (uint i = 0; i < num_iter; ++i) { ++ if (idx >= p.ne) { ++ continue; ++ } ++ uint i00, i01, i02, i03; ++ get_indices(idx, i00, i01, i02, i03); ++ ++ data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) / FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)])); ++ ++ idx += num_threads; ++ } ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +new file mode 100644 +index 00000000..c5be8131 +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +@@ -0,0 +1,289 @@ ++#version 450 ++ ++#extension GL_EXT_control_flow_attributes : enable ++#extension GL_EXT_shader_16bit_storage : require ++ ++#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require ++#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require ++#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require ++#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require ++ ++#extension GL_KHR_memory_scope_semantics : enable ++#extension GL_KHR_cooperative_matrix : enable ++#extension GL_NV_cooperative_matrix2 : enable ++#extension GL_EXT_buffer_reference : enable ++#extension GL_KHR_shader_subgroup_ballot : enable ++#extension GL_KHR_shader_subgroup_vote : enable ++#extension GL_EXT_null_initializer : enable ++ ++#include "types.comp" ++#include "dequant_funcs_cm2.comp" ++ ++layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; ++ ++layout (constant_id = 1) const uint32_t Br = 32; ++layout (constant_id = 2) const uint32_t Bc = 32; ++layout (constant_id = 3) const uint32_t D = 32; ++layout (constant_id = 4) const uint32_t Clamp = gl_CooperativeMatrixClampModeConstantNV; ++ ++layout (push_constant) uniform parameter { ++ uint32_t N; ++ uint32_t KV; ++ ++ uint32_t ne1; ++ uint32_t ne2; ++ uint32_t ne3; ++ ++ uint32_t neq2; ++ uint32_t neq3; ++ uint32_t nek2; ++ uint32_t nek3; ++ uint32_t nev2; ++ uint32_t nev3; ++ uint32_t nem1; ++ ++ uint32_t nb02; ++ uint32_t nb03; ++ uint32_t nb12; ++ uint32_t nb13; ++ uint32_t nb22; ++ uint32_t nb23; ++ uint32_t nb31; ++ ++ float scale; ++ float max_bias; ++ float logit_softcap; ++ ++ uint32_t mask; ++ uint32_t n_head_log2; ++ float m0; ++ float m1; ++} p; ++ ++layout (binding = 0) readonly buffer Q {uint8_t data_q[];}; ++layout (binding = 1) readonly buffer K {uint8_t data_k[];}; ++layout (binding = 2) readonly buffer V {uint8_t data_v[];}; ++layout (binding = 3) readonly buffer M {uint8_t data_m[];}; ++layout (binding = 4) writeonly buffer O {D_TYPE data_o[];}; ++ ++#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) ++ ++ACC_TYPE maxReduce(const in ACC_TYPE x, const in ACC_TYPE y) { ++ return max(x, y); ++} ++ ++ACC_TYPE smearReduce(const in ACC_TYPE x, const in ACC_TYPE y) { ++ return x; ++} ++ ++// Replace matrix elements >= numRows or numCols with 'replace' ++ACC_TYPE replacePadding(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem, const in ACC_TYPE replace, const in uint32_t numRows, const in uint32_t numCols) { ++ if (row >= numRows || col >= numCols) { ++ return replace; ++ } ++ return elem; ++} ++ ++ACC_TYPE Exp(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem) ++{ ++ return exp(elem); ++} ++ ++ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem0, const in ACC_TYPE elem1) ++{ ++ return max(elem0, elem1); ++} ++ ++#if defined(BLOCK_SIZE) ++#define DECODEFUNC , DEQUANTFUNC ++#else ++#define DECODEFUNC ++#endif ++ ++void main() { ++#if defined(DATA_A_IQ4_NL) ++ init_iq4nl_shmem(); ++#endif ++ ++ const uint32_t N = p.N; ++ const uint32_t KV = p.KV; ++ ++ const uint32_t Tr = CEIL_DIV(N, Br); ++ const uint32_t Tc = CEIL_DIV(KV, Bc); ++ ++ const uint32_t i = gl_WorkGroupID.x; ++ ++ const uint32_t iq2 = gl_WorkGroupID.y; ++ const uint32_t iq3 = gl_WorkGroupID.z; ++ ++ // broadcast factors ++ const uint32_t rk2 = p.neq2/p.nek2; ++ const uint32_t rk3 = p.neq3/p.nek3; ++ ++ const uint32_t rv2 = p.neq2/p.nev2; ++ const uint32_t rv3 = p.neq3/p.nev3; ++ ++ // k indices ++ const uint32_t ik3 = iq3 / rk3; ++ const uint32_t ik2 = iq2 / rk2; ++ ++ // v indices ++ const uint32_t iv3 = iq3 / rv3; ++ const uint32_t iv2 = iq2 / rv2; ++ ++ tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutQ = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); ++ tensorLayoutNV<2, Clamp> tensorLayoutK = createTensorLayoutNV(2, Clamp); ++ tensorLayoutNV<2, Clamp> tensorLayoutV = createTensorLayoutNV(2, Clamp); ++ ++ tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0); ++ ++#if defined(BLOCK_SIZE) ++ tensorLayoutK = setTensorLayoutBlockSizeNV(tensorLayoutK, 1, BLOCK_SIZE); ++ tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, BLOCK_SIZE); ++#endif ++ ++ tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N, D); ++ tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D); ++ tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D); ++ ++ coopmat Q; ++ coopmat Qf16; ++ ++ uint32_t q_offset = iq2*p.nb02+iq3*p.nb03; ++ coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, D)); ++ ++ Qf16 = coopmat(Q); ++ Qf16 *= float16_t(p.scale); ++ ++ coopmat O = coopmat(0); ++ ++ coopmat L, M; ++ ++ L = coopmat(0); ++ M = coopmat(-1.0/0.0); ++ ++ ACC_TYPE slope = ACC_TYPE(1.0); ++ ++ // ALiBi ++ if (p.max_bias > 0.0f) { ++ const uint32_t h = iq2; ++ ++ const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1); ++ const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1); ++ ++ slope = pow(base, ACC_TYPE(exph)); ++ } ++ ++ [[dont_unroll]] ++ for (uint32_t j = 0; j < Tc; ++j) { ++ ++ coopmat S = coopmat(0); ++ ++ coopmat K_T; ++ ++ uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13; ++ coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, D), tensorViewTranspose DECODEFUNC); ++ S = coopMatMulAdd(Qf16, K_T, S); ++ ++ if (p.logit_softcap != 0.0f) { ++ [[unroll]] ++ for (int k = 0; k < S.length(); ++k) { ++ S[k] = ACC_TYPE(p.logit_softcap)*tanh(S[k]); ++ } ++ } ++ ++ if (p.mask != 0) { ++ tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); ++ tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV); ++ ++ coopmat mv; ++ ++ coopMatLoadTensorNV(mv, data_m, 0, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); ++ ++ S += slope*coopmat(mv); ++ } ++ ++ // Clear padding elements to -inf, so they don't contribute to rowmax ++ if (Clamp != 0 && ++ ((j + 1) * Bc > KV || ++ (i + 1) * Br > N)) { ++ ++ uint R = ((i + 1) * Br > N) ? (N % Br) : Br; ++ uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc; ++ ++ coopMatPerElementNV(S, S, replacePadding, ACC_TYPE(-1.0/0.0), R, C); ++ } ++ ++ coopmat rowmax, P, rowsum, eM; ++ ++ coopMatReduceNV(rowmax, S, gl_CooperativeMatrixReduceRowNV, maxReduce); ++ ++ coopmat Mold = M; ++ ++ // M = max(rowmax, Mold) ++ // P = e^(S - M) ++ // eM = e^(Mold - M) ++ coopMatPerElementNV(M, rowmax, Max, Mold); ++ coopMatPerElementNV(P, S - M, Exp); ++ coopMatPerElementNV(eM, Mold - M, Exp); ++ ++ // Clear padding elements to 0, so they don't contribute to rowsum ++ if (Clamp != 0 && ++ ((j + 1) * Bc > KV || ++ (i + 1) * Br > N)) { ++ ++ uint R = ((i + 1) * Br > N) ? (N % Br) : Br; ++ uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc; ++ ++ coopMatPerElementNV(P, P, replacePadding, ACC_TYPE(0.0), R, C); ++ } ++ ++ coopmat P_A = coopmat(P); ++ ++ // compute rowsum by multiplying by matrix of all ones. ++ coopmat One = coopmat(1.0); ++ ++ rowsum = coopmat(0.0); ++ rowsum = coopMatMulAdd(P_A, One, rowsum); ++ ++ coopmat V; ++ uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23; ++ coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, D) DECODEFUNC); ++ ++ L = eM*L + rowsum; ++ ++ // This is the "diagonal" matrix in the paper, but since we do componentwise ++ // multiply rather than matrix multiply it has the diagonal element smeared ++ // across the row ++ coopmat eMdiag; ++ ++ // resize eM by using smear/reduce ++ coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce); ++ ++ O = eMdiag * O; ++ ++ O = coopMatMulAdd(P_A, V, O); ++ } ++ ++ coopmat Ldiag; ++ ++ // resize L by using smear/reduce ++ coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce); ++ ++ [[unroll]] ++ for (int k = 0; k < Ldiag.length(); ++k) { ++ Ldiag[k] = ACC_TYPE(1.0) / Ldiag[k]; ++ } ++ ++ O = Ldiag*O; ++ ++ tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV); ++ tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, D); ++ ++ // permute dimensions ++ tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2); ++ uint32_t o_offset = iq3*p.ne2*p.ne1; ++ ++ coopmat O_D = coopmat(O); ++ coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, 1, 0, D), tensorViewPermute); ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +new file mode 100644 +index 00000000..4cc7a68c +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +@@ -0,0 +1,25 @@ ++#version 450 ++ ++#include "generic_head.comp" ++#include "types.comp" ++ ++#extension GL_EXT_control_flow_attributes : enable ++ ++layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; ++ ++layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; ++layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; ++ ++void main() { ++ const float GELU_COEF_A = 0.044715f; ++ const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; ++ const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; ++ ++ if (i >= p.KX) { ++ return; ++ } ++ ++ const float xi = float(data_a[i]); ++ const float val = SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi); ++ data_d[i] = D_TYPE(0.5f*xi*(2.0f - 2.0f / (exp(2 * val) + 1))); ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +new file mode 100644 +index 00000000..e6e6fcfd +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +@@ -0,0 +1,23 @@ ++#version 450 ++ ++#include "generic_head.comp" ++#include "types.comp" ++ ++#extension GL_EXT_control_flow_attributes : enable ++ ++layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; ++ ++layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; ++layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; ++ ++void main() { ++ const float GELU_QUICK_COEF = -1.702f; ++ const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; ++ ++ if (i >= p.KX) { ++ return; ++ } ++ ++ const float x = float(data_a[i]); ++ data_d[i] = D_TYPE(x * (1.0f / (1.0f + exp(GELU_QUICK_COEF * x)))); ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +new file mode 100644 +index 00000000..062e2a4c +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +@@ -0,0 +1,64 @@ ++#extension GL_EXT_shader_16bit_storage : require ++#extension GL_EXT_control_flow_attributes : require ++ ++layout (push_constant) uniform parameter ++{ ++ uint ne; ++ uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03; ++ uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13; ++ uint ne20; uint ne21; uint ne22; uint ne23; uint nb20; uint nb21; uint nb22; uint nb23; ++ uint misalign_offsets; ++ float param1; float param2; int param3; ++} p; ++ ++layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; ++layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; ++layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; ++ ++// true if src0/src1 are the same shape and the indices can be reused without additional modulus ++layout(constant_id = 0) const bool norepeat = false; ++ ++uint get_idx() { ++ return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; ++} ++ ++uint get_aoffset() { return p.misalign_offsets >> 16; } ++uint get_boffset() { return (p.misalign_offsets >> 8) & 0xFF; } ++uint get_doffset() { return p.misalign_offsets & 0xFF; } ++ ++// mod and div are expensive and coordinates/dimensions are often power of 2 or equal to 1 ++uint fastmod(uint a, uint b) { ++ if ((b & (b-1)) == 0) { ++ return a & (b-1); ++ } ++ return a % b; ++} ++ ++uint fastdiv(uint a, uint b) { ++ return (a < b) ? 0 : (a / b); ++} ++ ++void get_indices(uint idx, out uint i00, out uint i01, out uint i02, out uint i03) { ++ i03 = fastdiv(idx, (p.ne02*p.ne01*p.ne00)); ++ const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00; ++ i02 = fastdiv((idx - i03_offset), (p.ne01*p.ne00)); ++ const uint i02_offset = i02*p.ne01*p.ne00; ++ i01 = (idx - i03_offset - i02_offset) / p.ne00; ++ i00 = idx - i03_offset - i02_offset - i01*p.ne00; ++} ++ ++uint src0_idx(uint i00, uint i01, uint i02, uint i03) { ++ return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00; ++} ++ ++uint src1_idx(uint i00, uint i01, uint i02, uint i03) { ++ if (norepeat) { ++ return i03*p.nb13 + i02*p.nb12 + i01*p.nb11 + i00*p.nb10; ++ } else { ++ return fastmod(i03, p.ne13)*p.nb13 + fastmod(i02, p.ne12)*p.nb12 + fastmod(i01, p.ne11)*p.nb11 + fastmod(i00, p.ne10)*p.nb10; ++ } ++} ++ ++uint dst_idx(uint i00, uint i01, uint i02, uint i03) { ++ return i03*p.nb23 + i02*p.nb22 + i01*p.nb21 + i00*p.nb20; ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.comp +new file mode 100644 +index 00000000..66e46ae6 +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.comp +@@ -0,0 +1,9 @@ ++#extension GL_EXT_shader_16bit_storage : require ++ ++layout (push_constant) uniform parameter ++{ ++ uint KX; ++ uint KY; ++ float param1; ++ float param2; ++} p; +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp +new file mode 100644 +index 00000000..68d1bc9f +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp +@@ -0,0 +1,56 @@ ++#extension GL_EXT_shader_16bit_storage : require ++#extension GL_EXT_control_flow_attributes : require ++ ++layout (push_constant) uniform parameter ++{ ++ uint ne; ++ uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03; ++ uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13; ++ uint misalign_offsets; ++ float param1; float param2; ++ ++ uint ne0_012mp; uint ne0_012L; ++ uint ne0_01mp; uint ne0_01L; ++ uint ne0_0mp; uint ne0_0L; ++ uint ne1_012mp; uint ne1_012L; ++ uint ne1_01mp; uint ne1_01L; ++ uint ne1_0mp; uint ne1_0L; ++} p; ++ ++layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; ++layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; ++ ++uint get_idx() { ++ return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; ++} ++ ++uint get_aoffset() { return p.misalign_offsets >> 16; } ++uint get_doffset() { return p.misalign_offsets & 0xFFFF; } ++ ++// see init_fastdiv_values in ggml-vulkan.cpp ++uint fastdiv(uint n, uint mp, uint L) { ++ uint msbs, lsbs; ++ // msbs = mulhi(n, mp) ++ umulExtended(n, mp, msbs, lsbs); ++ return (msbs + n) >> L; ++} ++ ++uint src0_idx(uint idx) { ++ const uint i03 = fastdiv(idx, p.ne0_012mp, p.ne0_012L); ++ const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00; ++ const uint i02 = fastdiv(idx - i03_offset, p.ne0_01mp, p.ne0_01L); ++ const uint i02_offset = i02*p.ne01*p.ne00; ++ const uint i01 = fastdiv(idx - i03_offset - i02_offset, p.ne0_0mp, p.ne0_0L); ++ const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00; ++ return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00; ++} ++ ++uint dst_idx(uint idx) { ++ const uint i13 = fastdiv(idx, p.ne1_012mp, p.ne1_012L); ++ const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10; ++ const uint i12 = fastdiv(idx - i13_offset, p.ne1_01mp, p.ne1_01L); ++ const uint i12_offset = i12*p.ne11*p.ne10; ++ const uint i11 = fastdiv(idx - i13_offset - i12_offset, p.ne1_0mp, p.ne1_0L); ++ const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10; ++ return i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + i10*p.nb10; ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +new file mode 100644 +index 00000000..e877ed77 +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +@@ -0,0 +1,28 @@ ++#version 450 ++ ++#include "types.comp" ++#include "generic_binary_head.comp" ++ ++layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; ++ ++void main() { ++ const uint i00 = gl_GlobalInvocationID.x; ++ const uint i10 = gl_GlobalInvocationID.y; ++ const uint i11 = (gl_GlobalInvocationID.z)/p.ne12; ++ const uint i12 = (gl_GlobalInvocationID.z)%p.ne12; ++ ++ if (i00 >= p.ne00) { ++ return; ++ } ++ ++ const uint i01 = data_b[get_boffset() + i10*p.nb10 + i11*p.nb11 + i12*p.nb12]; ++ ++ const uint a_offset = get_aoffset() + i01*p.nb01 + i11*p.nb02 + i12*p.nb03; ++ const uint d_offset = get_doffset() + i10*p.nb21 + i11*p.nb22 + i12*p.nb23; ++ ++#ifndef OPTIMIZATION_ERROR_WORKAROUND ++ data_d[d_offset + i00] = D_TYPE(data_a[a_offset + i00]); ++#else ++ data_d[d_offset + i00] = data_a[a_offset + i00]; ++#endif ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +new file mode 100644 +index 00000000..1426fde6 +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +@@ -0,0 +1,39 @@ ++#version 450 ++ ++#include "types.comp" ++#include "generic_binary_head.comp" ++#include "dequant_funcs.comp" ++ ++layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; ++ ++void main() { ++ const uint i00 = (gl_GlobalInvocationID.x)*2; ++ const uint i10 = gl_GlobalInvocationID.y; ++ const uint i11 = (gl_GlobalInvocationID.z)/p.ne12; ++ const uint i12 = (gl_GlobalInvocationID.z)%p.ne12; ++ ++#if defined(DATA_A_IQ4_NL) ++ init_iq4nl_shmem(); ++#endif ++ ++ if (i00 >= p.ne00) { ++ return; ++ } ++ ++ const uint i01 = data_b[i10*p.nb10 + i11*p.nb11 + i12*p.nb12]; ++ ++ const uint a_offset = i01*p.nb01 + i11*p.nb02 + i12*p.nb03; ++ const uint d_offset = i10*p.nb21 + i11*p.nb22 + i12*p.nb23; ++ ++ const uint ib = a_offset + i00/QUANT_K; // block index ++ const uint iqs = (i00%QUANT_K)/QUANT_R; // quant index ++ const uint iybs = i00 - i00%QUANT_K; // dst block start index ++ const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2; ++ ++ vec2 v = dequantize(ib, iqs, 0); ++ const vec2 dm = get_dm(ib, 0); ++ v = v * dm.x + dm.y; ++ ++ data_d[d_offset + iybs + iqs ] = D_TYPE(v.x); ++ data_d[d_offset + iybs + iqs + y_offset] = D_TYPE(v.y); ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +new file mode 100644 +index 00000000..b6a0d564 +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +@@ -0,0 +1,66 @@ ++#version 450 ++ ++#include "generic_head.comp" ++#include "types.comp" ++ ++#extension GL_EXT_control_flow_attributes : enable ++#define BLOCK_SIZE 512 ++ ++layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; ++ ++layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; ++layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; ++ ++shared float tmp[BLOCK_SIZE]; ++ ++void main() { ++ const uint group_size = p.KX; ++ const float eps = p.param1; ++ ++ const uint tid = gl_LocalInvocationID.x; ++ const uint start = gl_WorkGroupID.x * group_size + tid; ++ const uint end = (gl_WorkGroupID.x + 1) * group_size; ++ ++ tmp[tid] = 0.0f; ++ ++ // Calculate mean ++ [[unroll]] for (uint col = start; col < end; col += BLOCK_SIZE) { ++ tmp[tid] += float(data_a[col]); ++ } ++ ++ // tmp up partial tmps and write back result ++ barrier(); ++ [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { ++ if (tid < s) { ++ tmp[tid] += tmp[tid + s]; ++ } ++ barrier(); ++ } ++ ++ const float mean = tmp[0] / group_size; ++ barrier(); ++ tmp[tid] = 0.0f; ++ ++ // Calculate variance ++ [[unroll]] for (uint col = start; col < end; col += BLOCK_SIZE) { ++ const float xi = float(data_a[col]) - mean; ++ data_d[col] = D_TYPE(xi); ++ tmp[tid] += xi * xi; ++ } ++ ++ // sum up partial sums and write back result ++ barrier(); ++ [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { ++ if (tid < s) { ++ tmp[tid] += tmp[tid + s]; ++ } ++ barrier(); ++ } ++ ++ const float variance = tmp[0] / group_size; ++ const float scale = inversesqrt(variance + eps); ++ ++ [[unroll]] for (uint col = start; col < end; col += BLOCK_SIZE) { ++ data_d[col] *= D_TYPE(scale); ++ } ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +new file mode 100644 +index 00000000..122b1e93 +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +@@ -0,0 +1,87 @@ ++#version 450 ++ ++#extension GL_EXT_shader_16bit_storage : require ++#extension GL_EXT_spirv_intrinsics: enable ++#extension GL_EXT_control_flow_attributes : require ++ ++#if RTE16 ++spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits ++#endif ++ ++layout (push_constant) uniform parameter ++{ ++ uint batch_offset; uint offset_delta; ++ uint IC; ++ uint IW; uint IH; ++ uint OW; uint OH; ++ uint KW; uint KH; ++ uint pelements; ++ uint CHW; ++ int s0; int s1; ++ int p0; int p1; ++ int d0; int d1; ++} p; ++ ++#include "types.comp" ++ ++layout(constant_id = 0) const uint BLOCK_SIZE = 32; ++ ++const uint NUM_ITER = 512 / BLOCK_SIZE; ++ ++layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; ++ ++layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; ++layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; ++ ++void main() { ++ const uint gidx = gl_GlobalInvocationID.x; ++ ++ const uint oh = gl_GlobalInvocationID.y; ++ const uint batch = gl_GlobalInvocationID.z / p.IC; ++ const uint ic = gl_GlobalInvocationID.z % p.IC; ++ ++ A_TYPE values[NUM_ITER]; ++ uint offset_dst[NUM_ITER]; ++ [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { ++ values[idx] = A_TYPE(0); ++ } ++ ++ [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { ++ ++ const uint i = gidx * NUM_ITER + idx; ++ ++ const uint ksize = p.OW * (p.KH > 1 ? p.KW : 1); ++ const uint kx = i / ksize; ++ const uint kd = kx * ksize; ++ const uint ky = (i - kd) / p.OW; ++ const uint ix = i % p.OW; ++ ++ const uint iiw = ix * p.s0 + kx * p.d0 - p.p0; ++ const uint iih = oh * p.s1 + ky * p.d1 - p.p1; ++ ++ offset_dst[idx] = ++ ((batch * p.OH + oh) * p.OW + ix) * p.CHW + ++ (ic * (p.KW * p.KH) + ky * p.KW + kx); ++ ++ if (i >= p.pelements) { ++ continue; ++ } ++ ++ if (iih < p.IH && iiw < p.IW) { ++ const uint offset_src = ic * p.offset_delta + batch * p.batch_offset; ++ values[idx] = data_a[offset_src + iih * p.IW + iiw]; ++ } ++ } ++ ++ [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { ++ ++ const uint i = gidx * NUM_ITER + idx; ++ ++ if (i >= p.pelements) { ++ continue; ++ } ++ ++ data_d[offset_dst[idx]] = D_TYPE(values[idx]); ++ } ++ ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +new file mode 100644 +index 00000000..d90a99ae +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +@@ -0,0 +1,22 @@ ++#version 450 ++ ++#include "generic_head.comp" ++#include "types.comp" ++ ++#extension GL_EXT_control_flow_attributes : enable ++ ++layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; ++ ++layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; ++layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; ++ ++void main() { ++ const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; ++ ++ if (i >= p.KX) { ++ return; ++ } ++ ++ const float val = float(data_a[i]); ++ data_d[i] = D_TYPE(max(val, 0.0f) + min(val, 0.0f) * p.param1); ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +new file mode 100644 +index 00000000..43de19df +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +@@ -0,0 +1,27 @@ ++#version 450 ++ ++#include "types.comp" ++#include "generic_binary_head.comp" ++ ++const uint num_threads = 256; ++ ++layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; ++ ++void main() { ++ uint idx = get_idx(); ++ ++ // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation ++ const uint num_iter = 2; ++ ++ [[unroll]] for (uint i = 0; i < num_iter; ++i) { ++ if (idx >= p.ne) { ++ continue; ++ } ++ uint i00, i01, i02, i03; ++ get_indices(idx, i00, i01, i02, i03); ++ ++ data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) * FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)])); ++ ++ idx += num_threads; ++ } ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_split_k_reduce.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_split_k_reduce.comp +new file mode 100644 +index 00000000..4c64fd47 +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_split_k_reduce.comp +@@ -0,0 +1,48 @@ ++#version 450 ++ ++#extension GL_EXT_control_flow_attributes : enable ++ ++layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; ++ ++layout (binding = 0) readonly buffer A {float data_a[];}; ++layout (binding = 0) readonly buffer A4 {vec4 data_a4[];}; ++layout (binding = 1) writeonly buffer D {float data_d[];}; ++layout (binding = 1) writeonly buffer D4 {vec4 data_d4[];}; ++ ++layout (push_constant) uniform parameter { ++ uint ne; ++ uint k_num; ++} p; ++ ++void main() { ++ // Each invocation handles four consecutive components ++ const uint idx = gl_GlobalInvocationID.x * 4; ++ ++ if (idx >= p.ne) { ++ return; ++ } ++ ++ // Check if all four components are in bounds and aligned, ++ // then use vector loads ++ if (idx + 3 < p.ne && (p.ne % 4) == 0) { ++ vec4 result = vec4(0.0f); ++ ++ [[unroll]] for (uint i = 0; i < p.k_num; i++) { ++ result += data_a4[(i * p.ne + idx) / 4]; ++ } ++ ++ data_d4[idx / 4] = result; ++ } else { ++ [[unroll]] for (uint j = 0; j < 4; ++j) { ++ if (idx + j < p.ne) { ++ float result = 0.0f; ++ ++ [[unroll]] for (uint i = 0; i < p.k_num; i++) { ++ result += data_a[i * p.ne + idx + j]; ++ } ++ ++ data_d[idx + j] = result; ++ } ++ } ++ } ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +new file mode 100644 +index 00000000..24875cdc +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +@@ -0,0 +1,152 @@ ++#version 450 ++ ++#ifdef FLOAT16 ++#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require ++#endif ++#extension GL_EXT_shader_explicit_arithmetic_types : require ++ ++#include "mul_mat_vec_base.comp" ++ ++layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; ++ ++#if !defined(DATA_A_F32) && !defined(DATA_A_F16) ++#define K_PER_ITER 8 ++#else ++#define K_PER_ITER 2 ++#endif ++ ++ ++uint a_offset, b_offset, d_offset, y_offset; ++ ++void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i, bool lastiter) ++{ ++ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { ++ const uint col = i*BLOCK_SIZE + K_PER_ITER*tid; ++ const uint iqs = (col%QUANT_K)/QUANT_R; // quant index ++ const uint iybs = col - col%QUANT_K; // y block start index ++ ++#if K_PER_ITER == 8 ++#if QUANT_R == 2 ++ const B_TYPE_VEC4 bv02 = data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]; ++ const B_TYPE_VEC4 bv13 = data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs + y_offset) / 4]; ++ const vec4 bv0 = vec4(bv02.x, bv13.x, bv02.y, bv13.y); ++ const vec4 bv1 = vec4(bv02.z, bv13.z, bv02.w, bv13.w); ++#else ++ const vec4 bv0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]); ++ const vec4 bv1 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4 + 1]); ++#endif ++#else ++ // Check if the second of the pair of elements is OOB, and don't fetch B or ++ // accumulate it. We still fetch a pair of elements for A, which is fine for ++ // quantized formats since they'll be within the same block. We should ++ // probably skip fetching the second element for F16/F32, but as of now we ++ // still do. ++ const bool OOB = lastiter && (iybs + iqs + y_offset >= p.ncols); ++ ++ FLOAT_TYPE b0 = 0, b1 = 0; ++ b0 = FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]); ++ if (!OOB) { ++ b1 = FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset]); ++ } ++#endif ++ uint ibi = first_row*p.ncols; ++ [[unroll]] for (uint n = 0; n < num_rows; ++n) { ++ const uint ib = (ibi + col)/QUANT_K; // block index ++ ibi += p.ncols; ++ ++#if K_PER_ITER == 8 ++ vec4 v = dequantize4(ib, iqs, a_offset); ++ vec4 v2 = dequantize4(ib, iqs+(4/QUANT_R), a_offset); ++ ++ const vec2 dm = get_dm(ib, a_offset); ++ if (dm.y != 0) { // quant has min component ++ v = v * dm.x + dm.y; ++ v2 = v2 * dm.x + dm.y; ++ } ++ ++ // matrix multiplication ++ FLOAT_TYPE rowtmp = dot(bv0, v); ++ rowtmp += dot(bv1, v2); ++ ++ if (dm.y == 0) ++ rowtmp *= dm.x; ++ ++ temp[j][n] += rowtmp; ++#else ++ const vec2 v = dequantize(ib, iqs, a_offset); ++ ++ // matrix multiplication ++ temp[j][n] = fma(FLOAT_TYPE(v.x), b0, temp[j][n]); ++ if (!OOB) { ++ temp[j][n] = fma(FLOAT_TYPE(v.y), b1, temp[j][n]); ++ } ++#endif ++ } ++ } ++} ++ ++void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { ++ const uint tid = gl_LocalInvocationID.x; ++ ++ get_offsets(a_offset, b_offset, d_offset); ++ a_offset /= QUANT_K; ++ ++ y_offset = QUANT_R == 1 ? 1 : QUANT_K/2; ++ ++ FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; ++ ++ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { ++ [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { ++ temp[j][i] = FLOAT_TYPE(0); ++ } ++ } ++ ++ uint num_iters = p.ncols / (K_PER_ITER * BLOCK_SIZE); ++ if (num_iters * K_PER_ITER * BLOCK_SIZE + K_PER_ITER*tid < p.ncols) { ++ num_iters++; ++ } ++ int unroll_count = 4; ++ uint unrolled_iters = num_iters & ~(unroll_count - 1); ++ ++ uint i = 0; ++ while (i < unrolled_iters) { ++ // Manually partially unroll the loop ++ [[unroll]] for (uint k = 0; k < unroll_count; ++k) { ++ iter(temp, first_row, num_rows, tid, i*K_PER_ITER, false); ++ i++; ++ } ++ } ++ unroll_count = 2; ++ unrolled_iters = num_iters & ~(unroll_count - 1); ++ while (i < unrolled_iters) { ++ // Manually partially unroll the loop ++ [[unroll]] for (uint k = 0; k < unroll_count; ++k) { ++ iter(temp, first_row, num_rows, tid, i*K_PER_ITER, false); ++ i++; ++ } ++ } ++ while (i < num_iters) { ++ iter(temp, first_row, num_rows, tid, i*K_PER_ITER, true); ++ i++; ++ } ++ ++ reduce_result(temp, d_offset, first_row, num_rows, tid); ++} ++ ++void main() { ++ const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); ++ ++#if defined(DATA_A_IQ4_NL) ++ init_iq4nl_shmem(); ++#endif ++ ++ // do NUM_ROWS at a time, unless there aren't enough remaining rows ++ if (first_row + NUM_ROWS <= p.stride_d) { ++ compute_outputs(first_row, NUM_ROWS); ++ } else { ++ if (first_row >= p.stride_d) { ++ return; ++ } ++ compute_outputs(first_row, p.stride_d - first_row); ++ } ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +new file mode 100644 +index 00000000..903753c7 +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +@@ -0,0 +1,118 @@ ++#extension GL_EXT_control_flow_attributes : enable ++#extension GL_EXT_shader_16bit_storage : require ++#extension GL_EXT_shader_8bit_storage : require ++ ++#ifdef MUL_MAT_ID ++#define EXPERT_COUNT 8 ++#endif ++ ++#include "types.comp" ++ ++layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; ++layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; ++layout (binding = 1) readonly buffer BV2 {B_TYPE_VEC2 data_b_v2[];}; ++layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];}; ++ ++layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; ++#ifdef MUL_MAT_ID ++layout (binding = 3) readonly buffer IDS {int data_ids[];}; ++#endif ++ ++#include "dequant_funcs.comp" ++ ++layout (push_constant) uniform parameter ++{ ++ uint ncols; ++ uint stride_a; ++ uint stride_b; ++ uint stride_d; ++ ++ uint batch_stride_a; ++ uint batch_stride_b; ++ uint batch_stride_d; ++ ++#ifdef MUL_MAT_ID ++ uint nei0; ++ uint ne11; ++#else ++ uint ne02; ++ uint ne12; ++ uint broadcast2; ++ uint broadcast3; ++#endif ++} p; ++ ++void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) { ++#ifdef MUL_MAT_ID ++ const uint expert_idx = gl_GlobalInvocationID.y; ++#else ++ const uint batch_idx = gl_GlobalInvocationID.y; ++#endif ++ ++#ifndef MUL_MAT_ID ++ uint batch_idx_a = 0; ++ if (batch_idx != 0) { ++ const uint i13 = batch_idx / p.ne12; ++ const uint i12 = batch_idx % p.ne12; ++ ++ const uint i03 = i13 / p.broadcast3; ++ const uint i02 = i12 / p.broadcast2; ++ ++ batch_idx_a = i03 * p.ne02 + i02; ++ } ++#else ++ const uint expert_id = data_ids[expert_idx]; ++#endif ++ ++ a_offset = ++#ifdef MUL_MAT_ID ++ expert_id * p.batch_stride_a; ++#else ++ batch_idx_a * p.batch_stride_a; ++#endif ++ b_offset = ++#ifdef MUL_MAT_ID ++ (expert_idx % p.ne11) * p.stride_b; ++#else ++ batch_idx * p.batch_stride_b; ++#endif ++ d_offset = ++#ifdef MUL_MAT_ID ++ expert_idx * p.stride_d; ++#else ++ batch_idx * p.batch_stride_d; ++#endif ++} ++ ++layout (constant_id = 0) const uint BLOCK_SIZE = 32; ++layout (constant_id = 1) const uint NUM_ROWS = 1; ++layout (constant_id = 2) const uint NUM_COLS = 1; ++ ++shared FLOAT_TYPE tmpsh[NUM_COLS][NUM_ROWS][BLOCK_SIZE]; ++ ++void reduce_result(const in FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid) { ++ // sum up partial sums and write back result ++ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { ++ [[unroll]] for (uint n = 0; n < num_rows; ++n) { ++ tmpsh[j][n][tid] = temp[j][n]; ++ } ++ } ++ barrier(); ++ [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) { ++ if (tid < s) { ++ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { ++ [[unroll]] for (uint n = 0; n < num_rows; ++n) { ++ tmpsh[j][n][tid] += tmpsh[j][n][tid + s]; ++ } ++ } ++ } ++ barrier(); ++ } ++ if (tid == 0) { ++ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { ++ [[unroll]] for (uint n = 0; n < num_rows; ++n) { ++ data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(tmpsh[j][n][0]); ++ } ++ } ++ } ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +new file mode 100644 +index 00000000..1cc4996d +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +@@ -0,0 +1,71 @@ ++#version 450 ++ ++#extension GL_EXT_control_flow_attributes : enable ++#extension GL_EXT_shader_16bit_storage : require ++ ++#define BLOCK_SIZE 32 ++#define FLOAT_TYPE float ++ ++layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; ++ ++layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; ++layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; ++layout (binding = 2) writeonly buffer D {D_TYPE dst[];}; ++ ++layout (push_constant) uniform parameter ++{ ++ uint ncols_x; ++ uint nrows_x; ++ uint row_stride_x; ++ uint channel_stride_x; ++ uint channel_x_divisor; ++ uint b_offset; ++ uint d_offset; ++} p; ++ ++shared FLOAT_TYPE tmp[BLOCK_SIZE]; ++ ++void main() { ++ const uint tid = gl_LocalInvocationID.x; ++ const uint row_x = gl_GlobalInvocationID.y; ++ const uint channel = gl_GlobalInvocationID.z; ++ const uint channel_x = channel / p.channel_x_divisor; ++ ++ const uint nrows_y = p.ncols_x; ++ const uint nrows_dst = p.nrows_x; ++ const uint row_dst = row_x; ++ ++ const uint idst = channel*nrows_dst + row_dst; ++ ++ tmp[tid] = 0.0f; ++ ++ for (uint col_x0 = 0; col_x0 < p.ncols_x; col_x0 += BLOCK_SIZE) { ++ const uint col_x = col_x0 + tid; ++ ++ if (col_x >= p.ncols_x) { ++ break; ++ } ++ ++ const uint row_y = col_x; ++ ++ const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x; ++ const uint iy = channel*nrows_y + row_y; ++ ++ const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]); ++ ++ tmp[tid] = fma(xi, FLOAT_TYPE(data_b[iy]), tmp[tid]); ++ } ++ ++ // sum up partial sums and write back result ++ barrier(); ++ [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { ++ if (tid < s) { ++ tmp[tid] += tmp[tid + s]; ++ } ++ barrier(); ++ } ++ ++ if (tid == 0) { ++ dst[idst] = tmp[0]; ++ } ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +new file mode 100644 +index 00000000..9b443807 +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +@@ -0,0 +1,73 @@ ++#version 450 ++ ++#extension GL_EXT_control_flow_attributes : enable ++#extension GL_EXT_shader_16bit_storage : require ++ ++#define BLOCK_SIZE 32 ++#define FLOAT_TYPE float ++ ++layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; ++ ++layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; ++layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; ++layout (binding = 2) writeonly buffer D {D_TYPE dst[];}; ++ ++layout (push_constant) uniform parameter ++{ ++ uint ncols_x; ++ uint nrows_x; ++ uint nchannels_x; ++ uint nchannels_y; ++ uint b_offset; ++ uint d_offset; ++} p; ++ ++shared FLOAT_TYPE tmp[BLOCK_SIZE]; ++ ++void main() { ++ const uint tid = gl_LocalInvocationID.x; ++ const uint row_x = gl_GlobalInvocationID.y; ++ const uint channel = gl_GlobalInvocationID.z; ++ const uint channel_x = channel / (p.nchannels_y / p.nchannels_x); ++ ++ const uint nrows_y = p.ncols_x; ++ const uint nrows_dst = p.nrows_x; ++ const uint row_dst = row_x; ++ ++ tmp[tid] = FLOAT_TYPE(0.0f); ++ ++ for (uint col_x0 = 0; col_x0 < p.ncols_x; col_x0 += BLOCK_SIZE) { ++ const uint col_x = col_x0 + tid; ++ ++ if (col_x >= p.ncols_x) { ++ break; ++ } ++ ++ // x is transposed and permuted ++ const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x; ++ const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]); ++ ++ const uint row_y = col_x; ++ ++ // y is not transposed but permuted ++ const uint iy = channel*nrows_y + row_y; ++ ++ tmp[tid] = fma(xi, FLOAT_TYPE(data_b[iy]), tmp[tid]); ++ } ++ ++ // dst is not transposed and not permuted ++ const uint idst = channel*nrows_dst + row_dst; ++ ++ // sum up partial sums and write back result ++ barrier(); ++ [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { ++ if (tid < s) { ++ tmp[tid] += tmp[tid + s]; ++ } ++ barrier(); ++ } ++ ++ if (tid == 0) { ++ dst[idst] = tmp[0]; ++ } ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +new file mode 100644 +index 00000000..93421344 +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +@@ -0,0 +1,115 @@ ++#version 450 ++#extension GL_EXT_shader_explicit_arithmetic_types : require ++ ++#include "mul_mat_vec_base.comp" ++ ++layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; ++ ++void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { ++ uint a_offset, b_offset, d_offset; ++ get_offsets(a_offset, b_offset, d_offset); ++ ++ const uint num_blocks_per_row = p.ncols / QUANT_K; ++ ++ // 16 threads are used to process each block ++ const uint it_size = gl_WorkGroupSize.x/16; ++ const uint tid = gl_LocalInvocationID.x; ++ const uint itid = tid%16; // 0...16 ++ const uint ix = tid/16; ++ ++ const uint step = 8; ++ ++ const uint v_im = itid/step; // 0 or 1. 0 computes 0..., 1 computes 128... ++ const uint v_in = itid - step*v_im; // 0...15 or 0...7 ++ ++ const uint l0 = 2*v_in; // 0...15 ++ const uint q_offset = 32*v_im + l0; ++ const uint s_offset = 8*v_im; ++ const uint y_offset = 128*v_im + l0; ++ ++ FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; ++ ++ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { ++ [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { ++ temp[j][i] = FLOAT_TYPE(0); ++ } ++ } ++ ++ [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { ++ const uint y_idx = i * QUANT_K + y_offset; ++ ++ [[unroll]] for (uint n = 0; n < num_rows; ++n) { ++ const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; ++ f16vec2 d = data_a[ib0 + i].d; ++ const FLOAT_TYPE dall = d.x; ++ const FLOAT_TYPE dmin = d.y; ++ ++ uint32_t s0_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 0]; ++ uint32_t s4_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 1]; ++ ++ uint32_t s0_lo4_u32 = s0_u32 & 0x0F0F0F0F; ++ uint32_t s0_hi4_u32 = (s0_u32 >> 4) & 0x0F0F0F0F; ++ uint32_t s4_lo4_u32 = s4_u32 & 0x0F0F0F0F; ++ uint32_t s4_hi4_u32 = (s4_u32 >> 4) & 0x0F0F0F0F; ++ ++ uvec4 s0_lo4 = uvec4(unpack8(s0_lo4_u32)); ++ uvec4 s4_lo4 = uvec4(unpack8(s4_lo4_u32)); ++ uvec4 s0_hi4 = uvec4(unpack8(s0_hi4_u32)); ++ uvec4 s4_hi4 = uvec4(unpack8(s4_hi4_u32)); ++ ++ uint16_t qs0_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2 + 0]; ++ uint16_t qs16_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]; ++ uvec2 qs0 = uvec2(unpack8(qs0_u16)); ++ uvec2 qs16 = uvec2(unpack8(qs16_u16)); ++ ++ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { ++ B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]; ++ B_TYPE_VEC2 b16 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]; ++ B_TYPE_VEC2 b32 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]; ++ B_TYPE_VEC2 b48 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]; ++ B_TYPE_VEC2 b64 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]; ++ B_TYPE_VEC2 b80 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]; ++ B_TYPE_VEC2 b96 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]; ++ B_TYPE_VEC2 b112 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]; ++ ++ FLOAT_TYPE sum1 = FLOAT_TYPE(0.0); ++ FLOAT_TYPE sum2 = FLOAT_TYPE(0.0); ++ [[unroll]] for (int l = 0; l < 2; ++l) { ++ sum1 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_lo4[0]) * FLOAT_TYPE((qs0[l] >> 0) & 3), ++ fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_lo4[1]) * FLOAT_TYPE((qs16[l] >> 0) & 3), ++ fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_lo4[2]) * FLOAT_TYPE((qs0[l] >> 2) & 3), ++ fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_lo4[3]) * FLOAT_TYPE((qs16[l] >> 2) & 3), ++ fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_lo4[0]) * FLOAT_TYPE((qs0[l] >> 4) & 3), ++ fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_lo4[1]) * FLOAT_TYPE((qs16[l] >> 4) & 3), ++ fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_lo4[2]) * FLOAT_TYPE((qs0[l] >> 6) & 3), ++ fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_lo4[3]) * FLOAT_TYPE((qs16[l] >> 6) & 3), sum1)))))))); ++ sum2 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_hi4[0]), ++ fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_hi4[1]), ++ fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_hi4[2]), ++ fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_hi4[3]), ++ fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_hi4[0]), ++ fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_hi4[1]), ++ fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_hi4[2]), ++ fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_hi4[3]), sum2)))))))); ++ } ++ temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n])); ++ } ++ } ++ } ++ ++ reduce_result(temp, d_offset, first_row, num_rows, tid); ++} ++ ++void main() { ++ const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); ++ ++ // do NUM_ROWS at a time, unless there aren't enough remaining rows ++ if (first_row + NUM_ROWS <= p.stride_d) { ++ compute_outputs(first_row, NUM_ROWS); ++ } else { ++ if (first_row >= p.stride_d) { ++ return; ++ } ++ compute_outputs(first_row, p.stride_d - first_row); ++ } ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +new file mode 100644 +index 00000000..86b0159d +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +@@ -0,0 +1,103 @@ ++#version 450 ++#extension GL_EXT_shader_explicit_arithmetic_types : require ++ ++#include "mul_mat_vec_base.comp" ++ ++layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; ++ ++void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { ++ uint a_offset, b_offset, d_offset; ++ get_offsets(a_offset, b_offset, d_offset); ++ ++ const uint num_blocks_per_row = p.ncols / QUANT_K; ++ ++ // 16 threads are used to process each block ++ const uint it_size = gl_WorkGroupSize.x/16; ++ const uint tid = gl_LocalInvocationID.x; ++ const uint itid = tid%16; // 0...16 ++ const uint ix = tid/16; ++ ++ const uint step = 8; ++ ++ const uint v_im = itid/step; // 0 or 1. 0 computes 0..., 1 computes 128... ++ const uint v_in = itid - step*v_im; // 0...15 or 0...7 ++ ++ const uint8_t m = uint8_t(1 << (4 * v_im)); ++ ++ const uint l0 = 2*v_in; // 0...15 ++ const uint q_offset = 32*v_im + l0; ++ const uint y_offset = 128*v_im + l0; ++ ++ FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; ++ ++ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { ++ [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { ++ temp[j][i] = FLOAT_TYPE(0); ++ } ++ } ++ ++ const uint s_shift = 4 * v_im; ++ ++ [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { ++ const uint y_idx = i * QUANT_K + y_offset; ++ ++ [[unroll]] for (uint n = 0; n < num_rows; ++n) { ++ const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; ++ const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); ++ ++ uint16_t s0_16 = data_a_packed16[ib0 + i].scales[0]; ++ uint16_t s2_16 = data_a_packed16[ib0 + i].scales[1]; ++ uint16_t s4_16 = data_a_packed16[ib0 + i].scales[2]; ++ uint16_t s6_16 = data_a_packed16[ib0 + i].scales[3]; ++ uint16_t s8_16 = data_a_packed16[ib0 + i].scales[4]; ++ uint16_t s10_16 = data_a_packed16[ib0 + i].scales[5]; ++ u8vec2 s0 = unpack8(s0_16); ++ u8vec2 s2 = unpack8(s2_16); ++ u8vec2 s4 = unpack8(s4_16); ++ u8vec2 s6 = unpack8(s6_16); ++ u8vec2 s8 = unpack8(s8_16); ++ u8vec2 s10 = unpack8(s10_16); ++ ++ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { ++ ++ B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]; ++ B_TYPE_VEC2 b16 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]; ++ B_TYPE_VEC2 b32 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]; ++ B_TYPE_VEC2 b48 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]; ++ B_TYPE_VEC2 b64 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]; ++ B_TYPE_VEC2 b80 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]; ++ B_TYPE_VEC2 b96 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]; ++ B_TYPE_VEC2 b112 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]; ++ ++ FLOAT_TYPE sum = FLOAT_TYPE(0.0); ++ [[unroll]] for (int l = 0; l < 2; ++l) { ++ sum = fma(FLOAT_TYPE(b0[l]) * FLOAT_TYPE(int8_t(((s0[0] >> s_shift) & 0xF) | ((s8[0] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 0)) != 0) ? 0 : 4)), ++ fma(FLOAT_TYPE(b32[l]) * FLOAT_TYPE(int8_t(((s2[0] >> s_shift) & 0xF) | ((s10[0] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 1)) != 0) ? 0 : 4)), ++ fma(FLOAT_TYPE(b64[l]) * FLOAT_TYPE(int8_t(((s4[0] >> s_shift) & 0xF) | ((s8[0] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 2)) != 0) ? 0 : 4)), ++ fma(FLOAT_TYPE(b96[l]) * FLOAT_TYPE(int8_t(((s6[0] >> s_shift) & 0xF) | ((s10[0] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 3)) != 0) ? 0 : 4)), ++ fma(FLOAT_TYPE(b16[l]) * FLOAT_TYPE(int8_t(((s0[1] >> s_shift) & 0xF) | ((s8[1] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 0)) != 0) ? 0 : 4)), ++ fma(FLOAT_TYPE(b48[l]) * FLOAT_TYPE(int8_t(((s2[1] >> s_shift) & 0xF) | ((s10[1] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 1)) != 0) ? 0 : 4)), ++ fma(FLOAT_TYPE(b80[l]) * FLOAT_TYPE(int8_t(((s4[1] >> s_shift) & 0xF) | ((s8[1] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 2)) != 0) ? 0 : 4)), ++ fma(FLOAT_TYPE(b112[l]) * FLOAT_TYPE(int8_t(((s6[1] >> s_shift) & 0xF) | ((s10[1] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 3)) != 0) ? 0 : 4)), sum)))))))); ++ } ++ temp[j][n] = fma(d, sum, temp[j][n]); ++ } ++ } ++ } ++ ++ reduce_result(temp, d_offset, first_row, num_rows, tid); ++} ++ ++void main() { ++ const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); ++ ++ // do NUM_ROWS at a time, unless there aren't enough remaining rows ++ if (first_row + NUM_ROWS <= p.stride_d) { ++ compute_outputs(first_row, NUM_ROWS); ++ } else { ++ if (first_row >= p.stride_d) { ++ return; ++ } ++ compute_outputs(first_row, p.stride_d - first_row); ++ } ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +new file mode 100644 +index 00000000..cd1dd8e8 +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +@@ -0,0 +1,133 @@ ++#version 450 ++ ++#extension GL_EXT_shader_explicit_arithmetic_types : require ++ ++#include "mul_mat_vec_base.comp" ++ ++layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; ++ ++void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { ++ uint a_offset, b_offset, d_offset; ++ get_offsets(a_offset, b_offset, d_offset); ++ ++ const uint num_blocks_per_row = p.ncols / QUANT_K; ++ ++ // 16 threads are used to process each block ++ const uint it_size = gl_WorkGroupSize.x/16; ++ const uint tid = gl_LocalInvocationID.x; ++ const uint itid = tid%16; // 0...16 ++ const uint ix = tid/16; ++ ++ const uint step = 4; ++ ++ const uint il = itid/step; // 0...3 ++ const uint ir = itid - step*il; // 0...7 or 0...3 ++ const uint n = 4; ++ ++ const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 ++ const uint v_in = il % 2; ++ ++ const uint l0 = n * (2 * ir + v_in); // 0...15 ++ const uint q_offset = 32*v_im + l0; ++ const uint y_offset = 64*v_im + l0; ++ ++ FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; ++ ++ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { ++ [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { ++ temp[j][i] = FLOAT_TYPE(0); ++ } ++ } ++ ++ [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { ++ const uint y1_idx = i * QUANT_K + y_offset; ++ const uint y2_idx = y1_idx + 128; ++ ++ [[unroll]] for (uint n = 0; n < num_rows; ++n) { ++ const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; ++ f16vec2 d = data_a[ib0 + i].d; ++ const FLOAT_TYPE dall = FLOAT_TYPE(d.x); ++ const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); ++ ++ uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; ++ uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; ++ uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4]; ++ uvec4 scale0 = uvec4(unpack8(scale0_u32)); ++ uvec4 scale4 = uvec4(unpack8(scale4_u32)); ++ uvec4 scale8 = uvec4(unpack8(scale8_u32)); ++ ++ const uint32_t sc0 = ( scale0.x & 0x3f); ++ const uint32_t sc1 = ( scale0.y & 0x3f); ++ const uint32_t sc2 = ( scale4.x & 0x3f); ++ const uint32_t sc3 = ( scale4.y & 0x3f); ++ const uint32_t sc4 = (( scale8.x & 0x0f) | ((scale0.x & 0xc0) >> 2)); ++ const uint32_t sc5 = (( scale8.y & 0x0f) | ((scale0.y & 0xc0) >> 2)); ++ const uint32_t sc6 = (((scale8.x >> 4) & 0x0f) | ((scale4.x & 0xc0) >> 2)); ++ const uint32_t sc7 = (((scale8.y >> 4) & 0x0f) | ((scale4.y & 0xc0) >> 2)); ++ ++ uint32_t qs0_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4]; ++ uint32_t qs64_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4 + 16]; ++ ++ uint32_t qs0_u32_lo4 = qs0_u32 & 0x0F0F0F0F; ++ uint32_t qs0_u32_hi4 = (qs0_u32 >> 4) & 0x0F0F0F0F; ++ uint32_t qs64_u32_lo4 = qs64_u32 & 0x0F0F0F0F; ++ uint32_t qs64_u32_hi4 = (qs64_u32 >> 4) & 0x0F0F0F0F; ++ ++ uvec4 qs0_lo4 = uvec4(unpack8(qs0_u32_lo4)); ++ uvec4 qs64_lo4 = uvec4(unpack8(qs64_u32_lo4)); ++ uvec4 qs0_hi4 = uvec4(unpack8(qs0_u32_hi4)); ++ uvec4 qs64_hi4 = uvec4(unpack8(qs64_u32_hi4)); ++ ++ const uint32_t q4_0 = qs0_lo4.x; ++ const uint32_t q4_1 = qs0_lo4.y; ++ const uint32_t q4_2 = qs0_lo4.z; ++ const uint32_t q4_3 = qs0_lo4.w; ++ const uint32_t q4_4 = qs0_hi4.x; ++ const uint32_t q4_5 = qs0_hi4.y; ++ const uint32_t q4_6 = qs0_hi4.z; ++ const uint32_t q4_7 = qs0_hi4.w; ++ const uint32_t q4_8 = qs64_lo4.x; ++ const uint32_t q4_9 = qs64_lo4.y; ++ const uint32_t q4_10 = qs64_lo4.z; ++ const uint32_t q4_11 = qs64_lo4.w; ++ const uint32_t q4_12 = qs64_hi4.x; ++ const uint32_t q4_13 = qs64_hi4.y; ++ const uint32_t q4_14 = qs64_hi4.z; ++ const uint32_t q4_15 = qs64_hi4.w; ++ ++ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { ++ B_TYPE_VEC4 by10 = data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4]; ++ B_TYPE_VEC4 by132 = data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 + 8]; ++ B_TYPE_VEC4 by20 = data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4]; ++ B_TYPE_VEC4 by232 = data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 + 8]; ++ ++ const FLOAT_TYPE sx = fma(FLOAT_TYPE(by10.x), q4_0, fma(FLOAT_TYPE(by10.y), q4_1, fma(FLOAT_TYPE(by10.z), q4_2, FLOAT_TYPE(by10.w) * q4_3))); ++ const FLOAT_TYPE sy = fma(FLOAT_TYPE(by132.x), q4_4, fma(FLOAT_TYPE(by132.y), q4_5, fma(FLOAT_TYPE(by132.z), q4_6, FLOAT_TYPE(by132.w) * q4_7))); ++ const FLOAT_TYPE sz = fma(FLOAT_TYPE(by20.x), q4_8, fma(FLOAT_TYPE(by20.y), q4_9, fma(FLOAT_TYPE(by20.z), q4_10, FLOAT_TYPE(by20.w) * q4_11))); ++ const FLOAT_TYPE sw = fma(FLOAT_TYPE(by232.x), q4_12, fma(FLOAT_TYPE(by232.y), q4_13, fma(FLOAT_TYPE(by232.z), q4_14, FLOAT_TYPE(by232.w) * q4_15))); ++ const FLOAT_TYPE smin = ++ fma(FLOAT_TYPE(by10.x), sc2, fma(FLOAT_TYPE(by132.x), sc3, fma(FLOAT_TYPE(by20.x), sc6, fma(FLOAT_TYPE(by232.x), sc7, ++ fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7, ++ fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7, ++ fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6, FLOAT_TYPE(by232.w) * sc7))))))))))))))); ++ temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n])); ++ } ++ } ++ } ++ ++ reduce_result(temp, d_offset, first_row, num_rows, tid); ++} ++ ++void main() { ++ const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); ++ ++ // do NUM_ROWS at a time, unless there aren't enough remaining rows ++ if (first_row + NUM_ROWS <= p.stride_d) { ++ compute_outputs(first_row, NUM_ROWS); ++ } else { ++ if (first_row >= p.stride_d) { ++ return; ++ } ++ compute_outputs(first_row, p.stride_d - first_row); ++ } ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +new file mode 100644 +index 00000000..0a68891c +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +@@ -0,0 +1,162 @@ ++#version 450 ++ ++#extension GL_EXT_shader_explicit_arithmetic_types : require ++ ++#include "mul_mat_vec_base.comp" ++ ++layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; ++ ++void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { ++ uint a_offset, b_offset, d_offset; ++ get_offsets(a_offset, b_offset, d_offset); ++ ++ const uint num_blocks_per_row = p.ncols / QUANT_K; ++ ++ // 16 threads are used to process each block ++ const uint it_size = gl_WorkGroupSize.x/16; ++ const uint tid = gl_LocalInvocationID.x; ++ const uint itid = tid%16; // 0...16 ++ const uint ix = tid/16; ++ ++ const uint il = itid/4; // 0...3 ++ const uint ir = itid - 4*il; // 0...7 or 0...3 ++ ++ const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 ++ const uint v_in = il % 2; ++ ++ const uint l0 = 4*ir + 2*v_in; // 0...15 ++ const uint q_offset = 32*v_im + l0; ++ const uint y_offset = 64*v_im + l0; ++ ++ FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; ++ ++ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { ++ [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { ++ temp[j][i] = FLOAT_TYPE(0); ++ } ++ } ++ ++ [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { ++ const uint y1_idx = i * QUANT_K + y_offset; ++ const uint y2_idx = y1_idx + 128; ++ ++ [[unroll]] for (uint n = 0; n < num_rows; ++n) { ++ const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; ++ f16vec2 d = data_a[ib0 + i].d; ++ const FLOAT_TYPE dall = FLOAT_TYPE(d.x); ++ const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); ++ ++ uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; ++ uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; ++ uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4]; ++ uvec4 scale0 = uvec4(unpack8(scale0_u32)); ++ uvec4 scale4 = uvec4(unpack8(scale4_u32)); ++ uvec4 scale8 = uvec4(unpack8(scale8_u32)); ++ ++ const uint32_t sc0 = ( scale0.x & 0x3f); ++ const uint32_t sc1 = ( scale0.y & 0x3f); ++ const uint32_t sc2 = ( scale4.x & 0x3f); ++ const uint32_t sc3 = ( scale4.y & 0x3f); ++ const uint32_t sc4 = (( scale8.x & 0x0f) | ((scale0.x & 0xc0) >> 2)); ++ const uint32_t sc5 = (( scale8.y & 0x0f) | ((scale0.y & 0xc0) >> 2)); ++ const uint32_t sc6 = (((scale8.x >> 4) & 0x0f) | ((scale4.x & 0xc0) >> 2)); ++ const uint32_t sc7 = (((scale8.y >> 4) & 0x0f) | ((scale4.y & 0xc0) >> 2)); ++ ++ uint32_t qs0_16_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16); ++ uint32_t qs64_80_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 32]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 40]) << 16); ++ ++ uint32_t qs0_16_u32_lo4 = qs0_16_u32 & 0x0F0F0F0F; ++ uint32_t qs0_16_u32_hi4 = (qs0_16_u32 >> 4) & 0x0F0F0F0F; ++ uint32_t qs64_80_u32_lo4 = qs64_80_u32 & 0x0F0F0F0F; ++ uint32_t qs64_80_u32_hi4 = (qs64_80_u32 >> 4) & 0x0F0F0F0F; ++ ++ uint32_t qh = pack32(u16vec2(data_a_packed16[ib0 + i].qh[l0 / 2], data_a_packed16[ib0 + i].qh[l0 / 2 + 8])); ++ ++ uint32_t qs0_16_lo4_offset16 = ((qh >> (2*v_im)) & 0x01010101) << 4; ++ uint32_t qs0_16_hi4_offset16 = ((qh >> (2*v_im)) & 0x02020202) << 3; ++ uint32_t qs64_80_lo4_offset16 = ((qh >> (2*v_im)) & 0x10101010) << 0; ++ uint32_t qs64_80_hi4_offset16 = ((qh >> (2*v_im)) & 0x20202020) >> 1; ++ ++ qs0_16_u32_lo4 += qs0_16_lo4_offset16; ++ qs0_16_u32_hi4 += qs0_16_hi4_offset16; ++ qs64_80_u32_lo4 += qs64_80_lo4_offset16; ++ qs64_80_u32_hi4 += qs64_80_hi4_offset16; ++ ++ uvec4 qs0_16_lo4 = uvec4(unpack8(qs0_16_u32_lo4)); ++ uvec4 qs64_80_lo4 = uvec4(unpack8(qs64_80_u32_lo4)); ++ uvec4 qs0_16_hi4 = uvec4(unpack8(qs0_16_u32_hi4)); ++ uvec4 qs64_80_hi4 = uvec4(unpack8(qs64_80_u32_hi4)); ++ ++ const uint32_t q4_0 = qs0_16_lo4.x; ++ const uint32_t q4_1 = qs0_16_lo4.y; ++ const uint32_t q4_2 = qs0_16_lo4.z; ++ const uint32_t q4_3 = qs0_16_lo4.w; ++ const uint32_t q4_4 = qs0_16_hi4.x; ++ const uint32_t q4_5 = qs0_16_hi4.y; ++ const uint32_t q4_6 = qs0_16_hi4.z; ++ const uint32_t q4_7 = qs0_16_hi4.w; ++ const uint32_t q4_8 = qs64_80_lo4.x; ++ const uint32_t q4_9 = qs64_80_lo4.y; ++ const uint32_t q4_10 = qs64_80_lo4.z; ++ const uint32_t q4_11 = qs64_80_lo4.w; ++ const uint32_t q4_12 = qs64_80_hi4.x; ++ const uint32_t q4_13 = qs64_80_hi4.y; ++ const uint32_t q4_14 = qs64_80_hi4.z; ++ const uint32_t q4_15 = qs64_80_hi4.w; ++ ++ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { ++ B_TYPE_VEC2 by10 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2]; ++ B_TYPE_VEC2 by116 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 8]; ++ B_TYPE_VEC2 by132 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 16]; ++ B_TYPE_VEC2 by148 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 24]; ++ B_TYPE_VEC2 by20 = data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2]; ++ B_TYPE_VEC2 by216 = data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 8]; ++ B_TYPE_VEC2 by232 = data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 16]; ++ B_TYPE_VEC2 by248 = data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 24]; ++ ++ const FLOAT_TYPE sx = ++ fma(FLOAT_TYPE(by10.x), q4_0, ++ fma(FLOAT_TYPE(by10.y), q4_1, ++ fma(FLOAT_TYPE(by116.x), q4_2, ++ FLOAT_TYPE(by116.y) * q4_3))); ++ const FLOAT_TYPE sy = ++ fma(FLOAT_TYPE(by132.x), q4_4, ++ fma(FLOAT_TYPE(by132.y), q4_5, ++ fma(FLOAT_TYPE(by148.x), q4_6, ++ FLOAT_TYPE(by148.y) * q4_7))); ++ const FLOAT_TYPE sz = ++ fma(FLOAT_TYPE(by20.x), q4_8, ++ fma(FLOAT_TYPE(by20.y), q4_9, ++ fma(FLOAT_TYPE(by216.x), q4_10, ++ FLOAT_TYPE(by216.y) * q4_11))); ++ const FLOAT_TYPE sw = ++ fma(FLOAT_TYPE(by232.x), q4_12, ++ fma(FLOAT_TYPE(by232.y), q4_13, ++ fma(FLOAT_TYPE(by248.x), q4_14, ++ FLOAT_TYPE(by248.y) * q4_15))); ++ const FLOAT_TYPE smin = ++ fma(FLOAT_TYPE(by10.x) + FLOAT_TYPE(by10.y) + FLOAT_TYPE(by116.x) + FLOAT_TYPE(by116.y), sc2, ++ fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3, ++ fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6, ++ (FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7))); ++ temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n])); ++ } ++ } ++ } ++ ++ reduce_result(temp, d_offset, first_row, num_rows, tid); ++} ++ ++void main() { ++ const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); ++ ++ // do NUM_ROWS at a time, unless there aren't enough remaining rows ++ if (first_row + NUM_ROWS <= p.stride_d) { ++ compute_outputs(first_row, NUM_ROWS); ++ } else { ++ if (first_row >= p.stride_d) { ++ return; ++ } ++ compute_outputs(first_row, p.stride_d - first_row); ++ } ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +new file mode 100644 +index 00000000..70e13a56 +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +@@ -0,0 +1,112 @@ ++#version 450 ++ ++#extension GL_EXT_shader_explicit_arithmetic_types : require ++ ++#include "mul_mat_vec_base.comp" ++ ++layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; ++ ++void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { ++ uint a_offset, b_offset, d_offset; ++ get_offsets(a_offset, b_offset, d_offset); ++ ++ const uint num_blocks_per_row = p.ncols / QUANT_K; ++ ++ // 16 threads are used to process each block ++ const uint it_size = gl_WorkGroupSize.x/16; ++ const uint tid = gl_LocalInvocationID.x; ++ const uint itid = tid%16; // 0...16 ++ const uint ix = tid/16; ++ ++ const uint step = 8; ++ ++ const uint v_im = itid/step; // 0 or 1. 0 computes 0..., 1 computes 128... ++ const uint v_in = itid - step*v_im; // 0...15 or 0...7 ++ ++ const uint l0 = 4 * v_in; // 0, 4, 8, ..., 28 ++ const uint is = v_in / 4; ++ ++ const uint ql_offset = 64*v_im + l0; ++ const uint qh_offset = 32*v_im + l0; ++ const uint s_offset = 8*v_im + is; ++ const uint y_offset = 128*v_im + l0; ++ ++ FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; ++ ++ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { ++ [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { ++ temp[j][i] = FLOAT_TYPE(0); ++ } ++ } ++ ++ [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { ++ const uint y_idx = i * QUANT_K + y_offset; ++ ++ [[unroll]] for (uint n = 0; n < num_rows; ++n) { ++ const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; ++ const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); ++ ++ FLOAT_TYPE scales[4]; ++ scales[0] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]); ++ scales[1] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]); ++ scales[2] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]); ++ scales[3] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]); ++ ++ uint32_t ql0_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 1]) << 16); ++ uint32_t ql32_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 16]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 17]) << 16); ++ ++ uint32_t ql0_u32_lo4 = ql0_u32 & 0x0F0F0F0F; ++ uint32_t ql0_u32_hi4 = (ql0_u32 >> 4) & 0x0F0F0F0F; ++ uint32_t ql32_u32_lo4 = ql32_u32 & 0x0F0F0F0F; ++ uint32_t ql32_u32_hi4 = (ql32_u32 >> 4) & 0x0F0F0F0F; ++ ++ uint32_t qh_u32 = uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2 + 1]) << 16); ++ uint32_t qh0_u32 = (qh_u32 & 0x03030303) << 4; ++ uint32_t qh2_u32 = (qh_u32 & 0x0C0C0C0C) << 2; ++ uint32_t qh4_u32 = (qh_u32 & 0x30303030) << 0; ++ uint32_t qh6_u32 = (qh_u32 & 0xC0C0C0C0) >> 2; ++ ++ uint32_t q0_u32 = ql0_u32_lo4 | qh0_u32; ++ uint32_t q1_u32 = ql32_u32_lo4 | qh2_u32; ++ uint32_t q2_u32 = ql0_u32_hi4 | qh4_u32; ++ uint32_t q3_u32 = ql32_u32_hi4 | qh6_u32; ++ ++ uvec4 q0 = uvec4(unpack8(q0_u32)); ++ uvec4 q1 = uvec4(unpack8(q1_u32)); ++ uvec4 q2 = uvec4(unpack8(q2_u32)); ++ uvec4 q3 = uvec4(unpack8(q3_u32)); ++ ++ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { ++ B_TYPE_VEC4 by0 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4]; ++ B_TYPE_VEC4 by32 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 8]; ++ B_TYPE_VEC4 by64 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 16]; ++ B_TYPE_VEC4 by96 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 24]; ++ ++ FLOAT_TYPE sum = FLOAT_TYPE(0.0); ++ [[unroll]] for (int l = 0; l < 4; ++l) { ++ sum = fma(FLOAT_TYPE(by0[l]) * scales[0], FLOAT_TYPE(int8_t(q0[l]) - 32), ++ fma(FLOAT_TYPE(by32[l]) * scales[1], FLOAT_TYPE(int8_t(q1[l]) - 32), ++ fma(FLOAT_TYPE(by64[l]) * scales[2], FLOAT_TYPE(int8_t(q2[l]) - 32), ++ fma(FLOAT_TYPE(by96[l]) * scales[3], FLOAT_TYPE(int8_t(q3[l]) - 32), sum)))); ++ } ++ temp[j][n] += sum * d; ++ } ++ } ++ } ++ ++ reduce_result(temp, d_offset, first_row, num_rows, tid); ++} ++ ++void main() { ++ const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); ++ ++ // do NUM_ROWS at a time, unless there aren't enough remaining rows ++ if (first_row + NUM_ROWS <= p.stride_d) { ++ compute_outputs(first_row, NUM_ROWS); ++ } else { ++ if (first_row >= p.stride_d) { ++ return; ++ } ++ compute_outputs(first_row, p.stride_d - first_row); ++ } ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +new file mode 100644 +index 00000000..48122cbe +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +@@ -0,0 +1,631 @@ ++#version 450 ++ ++#extension GL_EXT_control_flow_attributes : enable ++#extension GL_EXT_shader_16bit_storage : require ++ ++#ifdef FLOAT16 ++#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require ++#endif ++ ++#ifdef COOPMAT ++#extension GL_KHR_cooperative_matrix : enable ++#extension GL_KHR_memory_scope_semantics : enable ++#extension GL_KHR_shader_subgroup_basic : enable ++#endif ++ ++#ifdef MUL_MAT_ID ++#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require ++#endif ++ ++#include "types.comp" ++ ++#ifndef LOAD_VEC_A ++#define LOAD_VEC_A 1 ++#endif ++#ifndef LOAD_VEC_B ++#define LOAD_VEC_B 1 ++#endif ++ ++layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; ++ ++layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; ++layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; ++layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; ++ ++#ifdef MUL_MAT_ID ++layout (binding = 3) readonly buffer IDS {int data_ids[];}; ++#endif ++ ++layout (push_constant) uniform parameter ++{ ++ uint M; ++ uint N; ++ uint K; ++ uint stride_a; ++ uint stride_b; ++ uint stride_d; ++ ++ uint batch_stride_a; ++ uint batch_stride_b; ++ uint batch_stride_d; ++ ++#ifdef MUL_MAT_ID ++ uint nei0; ++ uint nei1; ++ uint nbi1; ++ uint ne11; ++#else ++ uint k_split; ++ uint ne02; ++ uint ne12; ++ uint broadcast2; ++ uint broadcast3; ++#endif ++} p; ++ ++layout (constant_id = 0) const uint BLOCK_SIZE = 64; ++layout (constant_id = 1) const uint BM = 64; ++layout (constant_id = 2) const uint BN = 64; ++layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant ++layout (constant_id = 4) const uint WM = 32; ++layout (constant_id = 5) const uint WN = 32; ++layout (constant_id = 6) const uint WMITER = 2; ++layout (constant_id = 7) const uint TM = 4; ++layout (constant_id = 8) const uint TN = 2; ++layout (constant_id = 9) const uint TK = 1; // Only needed for coopmat ++layout (constant_id = 10) const uint WARP = 32; ++ ++#ifdef COOPMAT ++#define SHMEM_STRIDE (BK + 8) ++#else ++#define SHMEM_STRIDE (BK + 1) ++#endif ++ ++shared FLOAT_TYPE buf_a[BM * SHMEM_STRIDE]; ++shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE]; ++ ++#ifdef MUL_MAT_ID ++shared u16vec2 row_ids[3072]; ++#endif // MUL_MAT_ID ++ ++#define NUM_WARPS (BLOCK_SIZE / WARP) ++ ++#ifdef COOPMAT ++shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS]; ++#endif ++ ++void main() { ++#if defined(DATA_A_IQ4_NL) ++ init_iq4nl_shmem(); ++#endif ++ ++#ifdef MUL_MAT_ID ++ const uint expert_idx = gl_GlobalInvocationID.z; ++#else ++ const uint batch_idx = gl_GlobalInvocationID.z; ++ ++ const uint i13 = batch_idx / p.ne12; ++ const uint i12 = batch_idx % p.ne12; ++ ++ const uint i03 = i13 / p.broadcast3; ++ const uint i02 = i12 / p.broadcast2; ++ ++ const uint batch_idx_a = i03 * p.ne02 + i02; ++#endif ++ ++ const uint blocks_m = (p.M + BM - 1) / BM; ++ const uint ir = gl_WorkGroupID.x % blocks_m; ++ const uint ik = gl_WorkGroupID.x / blocks_m; ++ const uint ic = gl_WorkGroupID.y; ++ ++ const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER); ++ const uint WSUBM = WM / WMITER; ++ const uint WSUBN = WN / WNITER; ++ ++#ifdef COOPMAT ++ const uint warp_i = gl_SubgroupID; ++ ++ const uint tiw = gl_SubgroupInvocationID; ++ ++ const uint cms_per_row = WM / TM; ++ const uint cms_per_col = WN / TN; ++ ++ const uint storestride = WARP / TM; ++ const uint store_r = tiw % TM; ++ const uint store_c = tiw / TM; ++#else ++ const uint warp_i = gl_LocalInvocationID.x / WARP; ++ ++ const uint tiw = gl_LocalInvocationID.x % WARP; ++ ++ const uint tiwr = tiw % (WSUBM / TM); ++ const uint tiwc = tiw / (WSUBM / TM); ++#endif ++ ++ const uint warp_r = warp_i % (BM / WM); ++ const uint warp_c = warp_i / (BM / WM); ++ ++ const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A); ++ const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A); ++ const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B); ++ const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B); ++ ++ const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A / BK; ++ const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B / BK; ++ ++#ifdef MUL_MAT_ID ++ uint _ne1 = 0; ++ for (uint ii1 = 0; ii1 < p.nei1; ii1++) { ++ for (uint ii0 = 0; ii0 < p.nei0; ii0++) { ++ if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) { ++ row_ids[_ne1] = u16vec2(ii0, ii1); ++ _ne1++; ++ } ++ } ++ } ++ ++ barrier(); ++ ++ // Workgroup has no work ++ if (ic * BN >= _ne1) return; ++#endif ++ ++#ifdef MUL_MAT_ID ++ const uint start_k = 0; ++ const uint end_k = p.K; ++#else ++ const uint start_k = ik * p.k_split; ++ const uint end_k = min(p.K, (ik + 1) * p.k_split); ++#endif ++ ++ uint pos_a = ( ++#ifdef MUL_MAT_ID ++ expert_idx * p.batch_stride_a + ++#else ++ batch_idx_a * p.batch_stride_a + ++#endif ++ ir * BM * p.stride_a + start_k) / LOAD_VEC_A; ++#ifdef MUL_MAT_ID ++ uint pos_b = 0; ++#else ++ uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B; ++#endif ++ ++#ifdef COOPMAT ++ coopmat cache_a; ++ coopmat cache_b; ++ coopmat sums[cms_per_row * cms_per_col]; ++ ++ [[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) { ++ sums[i] = coopmat(0.0f); ++ } ++#else ++ ACC_TYPE sums[WMITER * TM * WNITER * TN]; ++ FLOAT_TYPE cache_a[WMITER * TM]; ++ FLOAT_TYPE cache_b[WNITER * TN]; ++ ++ [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) { ++ sums[i] = ACC_TYPE(0.0f); ++ } ++#endif ++ ++ for (uint block = start_k; block < end_k; block += BK) { ++ [[unroll]] for (uint l = 0; l < BM; l += loadstride_a) { ++ ++#if defined(DATA_A_F32) || defined(DATA_A_F16) ++#if LOAD_VEC_A == 8 ++ const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; ++ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; ++ buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx][0].x); ++ buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx][0].y); ++ buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx][0].z); ++ buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx][0].w); ++ buf_a[buf_idx + 4] = FLOAT_TYPE(data_a[idx][1].x); ++ buf_a[buf_idx + 5] = FLOAT_TYPE(data_a[idx][1].y); ++ buf_a[buf_idx + 6] = FLOAT_TYPE(data_a[idx][1].z); ++ buf_a[buf_idx + 7] = FLOAT_TYPE(data_a[idx][1].w); ++#elif LOAD_VEC_A == 4 ++ const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; ++ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; ++ buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx].x); ++ buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx].y); ++ buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx].z); ++ buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx].w); ++#else ++ if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) { ++ buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]); ++ } else { ++ buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(0.0f); ++ } ++#endif ++#elif defined(DATA_A_Q4_0) ++ const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; ++ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a; ++ ++ const uint ib = idx / 16; ++ const uint iqs = idx & 0xF; ++ ++ const float d = float(data_a[ib].d); ++ const uint vui = uint(data_a[ib].qs[iqs]); ++ const vec2 v = (vec2(vui & 0xF, vui >> 4) - 8.0f) * d; ++ ++ buf_a[buf_idx ] = FLOAT_TYPE(v.x); ++ buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); ++#elif defined(DATA_A_Q4_1) ++ const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; ++ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a; ++ ++ const uint ib = idx / 16; ++ const uint iqs = idx & 0xF; ++ ++ const float d = float(data_a[ib].d); ++ const float m = float(data_a[ib].m); ++ const uint vui = uint(data_a[ib].qs[iqs]); ++ const vec2 v = vec2(vui & 0xF, vui >> 4) * d + m; ++ ++ buf_a[buf_idx ] = FLOAT_TYPE(v.x); ++ buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); ++#elif defined(DATA_A_Q5_0) ++ const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; ++ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a; ++ ++ const uint ib = idx / 16; ++ const uint iqs = idx & 0xF; ++ ++ const float d = float(data_a[ib].d); ++ const uint uint_qh = uint(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]; ++ const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); ++ const uint vui = uint(data_a[ib].qs[iqs]); ++ const vec2 v = (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f) * d; ++ ++ buf_a[buf_idx ] = FLOAT_TYPE(v.x); ++ buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); ++#elif defined(DATA_A_Q5_1) ++ const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; ++ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a; ++ ++ const uint ib = idx / 16; ++ const uint iqs = idx & 0xF; ++ ++ const float d = float(data_a[ib].d); ++ const float m = float(data_a[ib].m); ++ const uint uint_qh = data_a[ib].qh; ++ const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); ++ const uint vui = uint(data_a[ib].qs[iqs]); ++ const vec2 v = vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) * d + m; ++ ++ buf_a[buf_idx ] = FLOAT_TYPE(v.x); ++ buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); ++#elif defined(DATA_A_Q8_0) ++ const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; ++ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; ++ ++ const uint ib = idx / 16; ++ const uint iqs = (idx & 0xF) * 2; ++ ++ const float d = float(data_a[ib].d); ++ const vec2 v = vec2(int(data_a[ib].qs[iqs]), int(data_a[ib].qs[iqs + 1])) * d; ++ ++ buf_a[buf_idx ] = FLOAT_TYPE(v.x); ++ buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); ++#elif defined(DATA_A_Q2_K) ++ const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; ++ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; ++ ++ const uint ib = idx / 128; // 2 values per idx ++ const uint iqs = idx % 128; // 0..127 ++ ++ const uint qsi = (iqs / 64) * 32 + (iqs % 16) * 2; // 0,2,4..30 ++ const uint scalesi = iqs / 8; // 0..15 ++ const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6 ++ ++ const uvec2 qs = uvec2(data_a[ib].qs[qsi], data_a[ib].qs[qsi + 1]); ++ const uint scales = data_a[ib].scales[scalesi]; ++ const vec2 d = vec2(data_a[ib].d); ++ ++ const vec2 v = d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4); ++ ++ buf_a[buf_idx ] = FLOAT_TYPE(v.x); ++ buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); ++#elif defined(DATA_A_Q3_K) ++ const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; ++ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; ++ ++ const uint ib = idx / 128; // 2 values per idx ++ const uint iqs = idx % 128; // 0..127 ++ ++ const uint n = iqs / 64; // 0,1 ++ const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..62 ++ const uint hmi = (iqs % 16) * 2; // 0,2,4..30 ++ const uint j = (iqs % 64) / 4; // 0..3 ++ const uint is = iqs / 8; // 0..15 ++ const uint halfsplit = ((iqs % 64) / 16); // 0,1,2,3 ++ const uint qsshift = halfsplit * 2; // 0,2,4,6 ++ const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128 ++ ++ const int8_t us = int8_t(is < 4 ? (data_a[ib].scales[is-0] & 0xF) | (((data_a[ib].scales[is+8] >> 0) & 3) << 4) : ++ is < 8 ? (data_a[ib].scales[is-0] & 0xF) | (((data_a[ib].scales[is+4] >> 2) & 3) << 4) : ++ is < 12 ? (data_a[ib].scales[is-8] >> 4) | (((data_a[ib].scales[is+0] >> 4) & 3) << 4) : ++ (data_a[ib].scales[is-8] >> 4) | (((data_a[ib].scales[is-4] >> 6) & 3) << 4)); ++ const float dl = float(data_a[ib].d) * float(us - 32); ++ ++ buf_a[buf_idx ] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi ] >> qsshift) & 3) - (((data_a[ib].hmask[hmi ] & m) != 0) ? 0 : 4))); ++ buf_a[buf_idx + 1] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4))); ++#elif defined(DATA_A_Q4_K) ++ const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; ++ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; ++ ++ const uint ib = idx / 128; // 2 values per idx ++ const uint iqs = idx % 128; // 0..127 ++ ++ const uint n = iqs / 32; // 0,1,2,3 ++ const uint b = (iqs % 32) / 16; // 0,1 ++ const uint is = 2 * n + b; // 0..7 ++ const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126 ++ ++ const vec2 loadd = vec2(data_a[ib].d); ++ ++ const uint scidx0 = (is < 4) ? is : (is + 4); ++ const uint scidx1 = (is < 4) ? is : (is - 4); ++ const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0; ++ const uint scidxshift1 = (is < 4) ? 0 : 2; ++ const uint mbidx0 = is + 4; ++ const uint mbidx1 = (is < 4) ? is + 4 : is; ++ const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0; ++ const uint mbidxshift0 = (is < 4) ? 0 : 4; ++ const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0; ++ const uint mbidxshift1 = (is < 4) ? 0 : 2; ++ ++ const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); ++ const uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); ++ ++ const float d = loadd.x * sc; ++ const float m = -loadd.y * mbyte; ++ ++ buf_a[buf_idx ] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF), m)); ++ buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF), m)); ++#elif defined(DATA_A_Q5_K) ++ const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; ++ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; ++ ++ const uint ib = idx / 128; // 2 values per idx ++ const uint iqs = idx % 128; // 0..127 ++ ++ const uint n = iqs / 32; // 0,1,2,3 ++ const uint b = (iqs % 32) / 16; // 0,1 ++ const uint is = 2 * n + b; // 0..7 ++ const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126 ++ const uint qhi = (iqs % 16) * 2; // 0,2,4..30 ++ ++ const uint8_t hm = uint8_t(1 << (iqs / 16)); ++ ++ const vec2 loadd = vec2(data_a[ib].d); ++ ++ const uint scidx0 = (is < 4) ? is : (is + 4); ++ const uint scidx1 = (is < 4) ? is : (is - 4); ++ const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0; ++ const uint scidxshift1 = (is < 4) ? 0 : 2; ++ const uint mbidx0 = is + 4; ++ const uint mbidx1 = (is < 4) ? is + 4 : is; ++ const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0; ++ const uint mbidxshift0 = (is < 4) ? 0 : 4; ++ const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0; ++ const uint mbidxshift1 = (is < 4) ? 0 : 2; ++ ++ const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); ++ const uint8_t mbyte = uint8_t(((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); ++ ++ const float d = loadd.x * sc; ++ const float m = -loadd.y * mbyte; ++ ++ buf_a[buf_idx ] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi ] & hm) != 0 ? 16 : 0), m)); ++ buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m)); ++#elif defined(DATA_A_Q6_K) ++ const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; ++ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; ++ ++ const uint ib = idx / 128; // 2 values per idx ++ const uint iqs = idx % 128; // 0..127 ++ ++ const uint n = iqs / 64; // 0,1 ++ const uint b = (iqs % 64) / 32; // 0,1 ++ const uint is_b = (iqs % 16) / 8; // 0,1 ++ const uint qhshift = ((iqs % 64) / 16) * 2; // 0,2,4,6 ++ const uint is = 8 * n + qhshift + is_b; // 0..15 ++ const uint qsi = n * 64 + (iqs % 32) * 2; // 0,2,4..126 ++ const uint qhi = n * 32 + (iqs % 16) * 2; // 0,2,4..62 ++ ++ const float dscale = float(data_a[ib].d) * float(data_a[ib].scales[is]); ++ ++ buf_a[buf_idx ] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi ] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi ] >> qhshift) & 3) << 4)) - 32)); ++ buf_a[buf_idx + 1] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32)); ++#elif defined(DATA_A_IQ4_NL) ++ const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; ++ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a; ++ ++ const uint ib = idx / 16; ++ const uint iqs = idx & 0xF; ++ ++ const float d = float(data_a[ib].d); ++ const uint vui = uint(data_a[ib].qs[iqs]); ++ const vec2 v = vec2(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[vui >> 4]) * d; ++ ++ buf_a[buf_idx ] = FLOAT_TYPE(v.x); ++ buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); ++#endif ++ } ++ [[unroll]] for (uint l = 0; l < BN; l += loadstride_b) { ++#if LOAD_VEC_B == 8 ++#ifdef MUL_MAT_ID ++ const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l]; ++ const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b; ++#else ++ const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b; ++#endif ++ const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B; ++ buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx][0].x); ++ buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx][0].y); ++ buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx][0].z); ++ buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx][0].w); ++ buf_b[buf_idx + 4] = FLOAT_TYPE(data_b[idx][1].x); ++ buf_b[buf_idx + 5] = FLOAT_TYPE(data_b[idx][1].y); ++ buf_b[buf_idx + 6] = FLOAT_TYPE(data_b[idx][1].z); ++ buf_b[buf_idx + 7] = FLOAT_TYPE(data_b[idx][1].w); ++#elif LOAD_VEC_B == 4 ++#ifdef MUL_MAT_ID ++ const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l]; ++ const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b; ++#else ++ const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b; ++#endif ++ const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B; ++ buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx].x); ++ buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx].y); ++ buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx].z); ++ buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx].w); ++#elif !MUL_MAT_ID ++ if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) { ++ buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]); ++ } else { ++ buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f); ++ } ++#else ++ const uint row_i = ic * BN + loadc_b + l; ++ if (row_i < _ne1) { ++ const u16vec2 row_idx = row_ids[row_i]; ++ buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]); ++ } else { ++ buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f); ++ } ++#endif ++ } ++ ++ barrier(); ++ ++ pos_a += BK / LOAD_VEC_A; ++ pos_b += BK / LOAD_VEC_B; ++ ++#ifdef COOPMAT ++ [[unroll]] for (uint i = 0; i < BK; i += TK) { ++ [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { ++ // Load from shared into cache ++ coopMatLoad(cache_a, buf_a, (warp_r * WM + cm_row * TM) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor); ++ ++ [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { ++ coopMatLoad(cache_b, buf_b, (warp_c * WN + cm_col * TN) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor); ++ ++ sums[cm_col * cms_per_row + cm_row] = coopMatMulAdd(cache_a, cache_b, sums[cm_col * cms_per_row + cm_row]); ++ } ++ } ++ } ++#else ++ [[unroll]] for (uint i = 0; i < BK; i++) { ++ // Load from shared into cache ++ [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { ++ [[unroll]] for (uint j = 0; j < TM; j++) { ++ cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + i]; ++ } ++ } ++ [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { ++ [[unroll]] for (uint j = 0; j < TN; j++) { ++ cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * SHMEM_STRIDE + i]; ++ } ++ } ++ ++ [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { ++ [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { ++ [[unroll]] for (uint cc = 0; cc < TN; cc++) { ++ [[unroll]] for (uint cr = 0; cr < TM; cr++) { ++ const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr; ++ sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr]), ACC_TYPE(cache_b[wsic * TN + cc]), sums[sums_idx]); ++ } ++ } ++ } ++ } ++ } ++#endif ++ ++ barrier(); ++ } ++ ++ const uint dr = ir * BM + warp_r * WM; ++ const uint dc = ic * BN + warp_c * WN; ++ ++#ifndef MUL_MAT_ID ++ const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; ++#endif ++ ++#ifdef COOPMAT ++#ifdef MUL_MAT_ID ++ [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { ++ [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { ++ coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); ++ ++ [[unroll]] for (uint col = 0; col < BN; col += storestride) { ++ const uint row_i = dc + cm_col * TN + col + store_c; ++ if (row_i >= _ne1) break; ++ ++ const u16vec2 row_idx = row_ids[row_i]; ++ ++ data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); ++ } ++ } ++ } ++#else ++ const bool is_aligned = p.stride_d % 4 == 0; // Assumption: D_TYPE == float ++ ++ [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { ++ [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { ++ const bool is_in_bounds = dr + (cm_row + 1) * TM <= p.M && dc + (cm_col + 1) * TN <= p.N; ++ ++ if (is_aligned && is_in_bounds) { ++ // Full coopMat is within bounds and stride_d is aligned with 16B ++ coopmat cm_dtype = coopmat(sums[cm_col * cms_per_row + cm_row]); ++ coopMatStore(cm_dtype, data_d, offsets + (dc + cm_col * TN) * p.stride_d + dr + cm_row * TM, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor); ++ } else if (is_in_bounds) { ++ // Full coopMat is within bounds, but stride_d is not aligned ++ coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); ++ ++ [[unroll]] for (uint col = 0; col < TN; col += storestride) { ++ data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); ++ } ++ } else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) { ++ // Partial coopMat is within bounds ++ coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); ++ ++ [[unroll]] for (uint col = 0; col < TN; col += storestride) { ++ if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) { ++ data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); ++ } ++ } ++ } ++ } ++ } ++#endif // MUL_MAT_ID ++#else ++ [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { ++ [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { ++ ++ const uint dr_warp = dr + wsir * WSUBM + tiwr * TM; ++ const uint dc_warp = dc + wsic * WSUBN + tiwc * TN; ++ [[unroll]] for (uint cc = 0; cc < TN; cc++) { ++#ifdef MUL_MAT_ID ++ const uint row_i = dc_warp + cc; ++ if (row_i >= _ne1) break; ++ ++ const u16vec2 row_idx = row_ids[row_i]; ++#endif // MUL_MAT_ID ++ [[unroll]] for (uint cr = 0; cr < TM; cr++) { ++#ifdef MUL_MAT_ID ++ data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); ++#else ++ if (dr_warp + cr < p.M && dc_warp + cc < p.N) { ++ data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); ++ } ++#endif // MUL_MAT_ID ++ } ++ } ++ } ++ } ++#endif // COOPMAT ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +new file mode 100644 +index 00000000..cbfa5dce +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +@@ -0,0 +1,328 @@ ++#version 450 ++ ++#extension GL_EXT_control_flow_attributes : enable ++#extension GL_EXT_shader_16bit_storage : require ++ ++#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require ++#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require ++#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require ++#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require ++ ++#extension GL_KHR_memory_scope_semantics : enable ++#extension GL_KHR_cooperative_matrix : enable ++#extension GL_NV_cooperative_matrix2 : enable ++#extension GL_EXT_buffer_reference : enable ++#extension GL_KHR_shader_subgroup_ballot : enable ++#extension GL_KHR_shader_subgroup_vote : enable ++ ++#include "types.comp" ++ ++layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; ++ ++layout (constant_id = 1) const uint BM = 64; ++layout (constant_id = 2) const uint BN = 64; ++layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant ++ ++layout (push_constant) uniform parameter ++{ ++ uint M; ++ uint N; ++ uint K; ++ uint stride_a; ++ uint stride_b; ++ uint stride_d; ++ ++ uint batch_stride_a; ++ uint batch_stride_b; ++ uint batch_stride_d; ++ ++#ifdef MUL_MAT_ID ++ uint nei0; ++ uint nei1; ++ uint nbi1; ++ uint ne11; ++#else ++ uint k_split; ++ uint ne02; ++ uint ne12; ++ uint broadcast2; ++ uint broadcast3; ++#endif ++} p; ++ ++ ++layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; ++layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; ++layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; ++ ++#if QUANT_K > 1 ++#define DECODEFUNCA , dequantFuncA ++#define MAT_A_TYPE float16_t ++ ++#include "dequant_funcs_cm2.comp" ++ ++#else ++#define DECODEFUNCA ++#define MAT_A_TYPE A_TYPE ++#endif ++ ++#define MAT_B_TYPE B_TYPE ++ ++#ifdef MUL_MAT_ID ++layout (binding = 3) readonly buffer IDS {int data_ids[];}; ++ ++shared u16vec4 row_ids[3072]; ++ ++layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufB { ++ B_TYPE b[]; ++}; ++ ++uint _ne1; ++shared uint _ne1_sh; ++ ++B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const in uint coordInBlock[2]) ++{ ++ const uint row_i = blockCoords[0]; ++ ++ if (row_i >= _ne1) { ++ return B_TYPE(0.0); ++ } ++ ++ const u16vec4 row_idx = row_ids[row_i]; ++ B_TYPE ret = data_b[row_idx.y * p.batch_stride_b + row_idx.x * p.stride_b + blockCoords[1]]; ++ ++ return ret; ++} ++ ++D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t ir, const in uint32_t ic) ++{ ++ uint dr = ir * BM + r; ++ uint dc = ic * BN + c; ++ ++ if (dr < p.M && dc < _ne1) { ++ uint row_i = dc; ++ const u16vec4 row_idx = row_ids[row_i]; ++ data_d[row_idx.y * p.batch_stride_d + row_idx.z * p.stride_d + dr] = elem; ++ } ++ return elem; ++} ++ ++#endif ++ ++void main() { ++#if defined(DATA_A_IQ4_NL) ++ init_iq4nl_shmem(); ++#endif ++ ++#ifdef MUL_MAT_ID ++ const uint expert_idx = gl_GlobalInvocationID.z; ++#else ++ const uint batch_idx = gl_GlobalInvocationID.z; ++ ++ const uint i13 = batch_idx / p.ne12; ++ const uint i12 = batch_idx % p.ne12; ++ ++ const uint i03 = i13 / p.broadcast3; ++ const uint i02 = i12 / p.broadcast2; ++ ++ const uint batch_idx_a = i03 * p.ne02 + i02; ++#endif ++ ++ const uint blocks_m = (p.M + BM - 1) / BM; ++ const uint ir = gl_WorkGroupID.x % blocks_m; ++ const uint ik = gl_WorkGroupID.x / blocks_m; ++ const uint ic = gl_WorkGroupID.y; ++ ++#ifdef MUL_MAT_ID ++ // Spread the search across all elements in the first subgroup ++ if (gl_SubgroupID == 0) { ++ _ne1 = 0; ++ uint num_elements = p.nei1 * p.nei0; ++ ++ for (uint i = gl_SubgroupInvocationID; subgroupAny(i < num_elements); i += gl_SubgroupSize) { ++ bool in_range = i < num_elements; ++ uint ii0 = i % p.nei0; ++ uint ii1 = i / p.nei0; ++ uint id = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0; ++ uvec4 ballot = subgroupBallot(in_range && id == expert_idx); ++ uint idx = subgroupBallotExclusiveBitCount(ballot); ++ if (in_range && id == expert_idx) { ++ row_ids[_ne1 + idx] = u16vec4(ii0 % p.ne11, ii1, ii0, 0); ++ } ++ _ne1 += subgroupBallotBitCount(ballot); ++ } ++ _ne1_sh = _ne1; ++ } ++ ++ barrier(); ++ ++ _ne1 = _ne1_sh; ++ ++ // Workgroup has no work ++ if (ic * BN >= _ne1) return; ++#endif ++ ++#ifdef MUL_MAT_ID ++ uint start_k = 0; ++ const uint end_k = p.K; ++#else ++ uint start_k = ik * p.k_split; ++ const uint end_k = min(p.K, (ik + 1) * p.k_split); ++#endif ++ ++ coopmat sum; ++ sum = coopmat(0.0); ++ ++#ifdef MUL_MAT_ID ++ uint pos_a = (expert_idx * p.batch_stride_a) / QUANT_K; ++ uint pos_b = 0; ++#else ++ uint pos_a = (batch_idx_a * p.batch_stride_a) / QUANT_K; ++ uint pos_b = batch_idx * p.batch_stride_b; ++#endif ++ ++ uint stride_a = p.stride_a / QUANT_K; ++ uint stride_b = p.stride_b; ++ ++ // Hint to the compiler that values are aligned (want 16B alignment). ++ // Quants are always block-aligned, no alignment needed. ++#if ALIGNED ++#if QUANT_K == 1 ++ stride_a &= ~7; ++#endif ++ stride_b &= ~7; ++#endif ++ ++ // Create layouts for both clamped and unclamped accesses ++ tensorLayoutNV<2> tensorLayoutA = createTensorLayoutNV(2); ++ tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutAClamp = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); ++ tensorLayoutNV<2> tensorLayoutB = createTensorLayoutNV(2); ++ tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutBClamp = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); ++ tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); ++ ++#if QUANT_K > 1 ++ tensorLayoutA = setTensorLayoutBlockSizeNV(tensorLayoutA, 1, QUANT_K); ++ tensorLayoutAClamp = setTensorLayoutBlockSizeNV(tensorLayoutAClamp, 1, QUANT_K); ++#endif ++ ++ // Use end_k rather than p.K as the dimension because that's what ++ // we need to bound check against when using split_k ++ tensorLayoutA = setTensorLayoutDimensionNV(tensorLayoutA, p.M, end_k); ++ tensorLayoutB = setTensorLayoutDimensionNV(tensorLayoutB, p.N, end_k); ++ tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.N, p.M); ++ tensorLayoutAClamp = setTensorLayoutDimensionNV(tensorLayoutAClamp, p.M, end_k); ++ tensorLayoutBClamp = setTensorLayoutDimensionNV(tensorLayoutBClamp, p.N, end_k); ++ ++ tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0); ++ ++#if !defined(MUL_MAT_ID) ++ // Detect a fast path where all loads are entirely in bounds and no clamping is required ++ if ((ir + 1) * BM <= p.M && (ic + 1) * BN <= p.N && (start_k % BK) == 0 && (end_k % BK) == 0 && ++#if QUANT_K == 1 ++ (stride_a % 8) == 0 && ++#endif ++ (stride_b % 8) == 0 && (start_k % 8) == 0) { ++ // Hint to the compiler that values are aligned (want 16B alignment) ++ start_k &= ~7; ++ stride_b &= ~7; ++#if QUANT_K == 1 ++ stride_a &= ~7; ++#endif ++ ++ tensorLayoutA = setTensorLayoutStrideNV(tensorLayoutA, stride_a, 1); ++ tensorLayoutB = setTensorLayoutStrideNV(tensorLayoutB, stride_b, 1); ++ ++ uint k_iters = (end_k - start_k + BK - 1) / BK; ++ ++ for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) { ++ ++ coopmat mat_a; ++ coopmat mat_b; ++ ++ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); ++ coopmat mat_a_ft = coopmat(mat_a); ++ ++ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose); ++ coopmat mat_b_ft = coopmat(mat_b); ++ ++ sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum); ++ } ++ } else ++#endif // !defined(MUL_MAT_ID) ++ { ++ tensorLayoutA = setTensorLayoutStrideNV(tensorLayoutA, stride_a, 1); ++ ++ tensorLayoutAClamp = setTensorLayoutStrideNV(tensorLayoutAClamp, stride_a, 1); ++ ++ tensorLayoutB = setTensorLayoutStrideNV(tensorLayoutB, stride_b, 1); ++ ++ tensorLayoutBClamp = setTensorLayoutStrideNV(tensorLayoutBClamp, stride_b, 1); ++ ++ [[dont_unroll]] ++ for (uint block_k = start_k; block_k < end_k; block_k += BK) { ++ ++ coopmat mat_a; ++ coopmat mat_b; ++ coopmat mat_a_ft; ++ coopmat mat_b_ft; ++ ++ // Clamping is expensive, so detect different code paths for each combination ++ // of A and B needing clamping. ++ bool unclampedA = (ir + 1) * BM <= p.M && block_k + BK <= end_k && (block_k % 8) == 0; ++#ifdef MUL_MAT_ID ++ bool unclampedB = true; ++#else ++ bool unclampedB = (ic + 1) * BN <= p.N && block_k + BK <= end_k && (block_k % 8) == 0; ++#endif ++ if (unclampedA && unclampedB) { ++ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, (block_k & ~7), BK) DECODEFUNCA); ++#ifdef MUL_MAT_ID ++ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB); ++#else ++ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, (block_k & ~7), BK), tensorViewTranspose); ++#endif ++ mat_a_ft = coopmat(mat_a); ++ mat_b_ft = coopmat(mat_b); ++ sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum); ++ } else if (unclampedA && !unclampedB) { ++ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, (block_k & ~7), BK) DECODEFUNCA); ++ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose); ++ ++ mat_a_ft = coopmat(mat_a); ++ mat_b_ft = coopmat(mat_b); ++ sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum); ++ } else if (!unclampedA && unclampedB) { ++ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); ++#ifdef MUL_MAT_ID ++ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB); ++#else ++ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, (block_k & ~7), BK), tensorViewTranspose); ++#endif ++ mat_a_ft = coopmat(mat_a); ++ mat_b_ft = coopmat(mat_b); ++ sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum); ++ } else if (!unclampedA && !unclampedB) { ++ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); ++ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose); ++ ++ mat_a_ft = coopmat(mat_a); ++ mat_b_ft = coopmat(mat_b); ++ sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum); ++ } ++ } ++ } ++ ++ // Convert from ACC_TYPE to D_TYPE ++ coopmat mat_d; ++ mat_d = coopmat(sum); ++ ++#ifdef MUL_MAT_ID ++ // Call callback to store each element, remapping row through shared memory ++ coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic); ++#else ++ tensorLayoutD = setTensorLayoutStrideNV(tensorLayoutD, p.stride_d, 1); ++ ++ uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; ++ coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose); ++#endif ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +new file mode 100644 +index 00000000..6627a50b +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +@@ -0,0 +1,44 @@ ++#version 450 ++ ++#include "generic_head.comp" ++#include "types.comp" ++ ++#extension GL_EXT_control_flow_attributes : enable ++#define BLOCK_SIZE 512 ++ ++layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; ++ ++layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; ++layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; ++ ++shared vec2 sum[BLOCK_SIZE]; ++ ++void main() { ++ const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; ++ const uint tid = gl_LocalInvocationID.x; ++ ++ sum[tid] = vec2(0.0f, 0.0f); ++ ++ [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { ++ const float xi = float(data_a[row*p.KX + col]); ++ sum[tid].x += xi; ++ sum[tid].y += xi * xi; ++ } ++ ++ // sum up partial sums and write back result ++ barrier(); ++ [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { ++ if (tid < s) { ++ sum[tid] += sum[tid + s]; ++ } ++ barrier(); ++ } ++ ++ const float mean = sum[0].x / p.KX; ++ const float var = sum[0].y / p.KX - mean * mean; ++ const float inv_std = inversesqrt(var + p.param1); ++ ++ [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { ++ data_d[row*p.KX + col] = D_TYPE((float(data_a[row*p.KX + col]) - mean) * inv_std); ++ } ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +new file mode 100644 +index 00000000..450b67fc +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +@@ -0,0 +1,28 @@ ++#version 450 ++ ++#include "types.comp" ++#include "generic_unary_head.comp" ++ ++layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; ++ ++void main() { ++ const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; ++ ++ if (idx >= p.ne) { ++ return; ++ } ++ ++ const uint i3 = idx / (p.ne12*p.ne11*p.ne10); ++ const uint i3_offset = i3 * p.ne12*p.ne11*p.ne10; ++ const uint i2 = (idx - i3_offset) / (p.ne11*p.ne10); ++ const uint i2_offset = i2*p.ne11*p.ne10; ++ const uint i1 = (idx - i3_offset - i2_offset) / p.ne10; ++ const uint i0 = idx - i3_offset - i2_offset - i1*p.ne10; ++ ++ const uint src0_idx = i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0*p.nb00; ++ const uint dst_idx = i3*p.nb13 + i2*p.nb12 + i1*p.nb11 + i0*p.nb10; ++ ++ const bool is_src0 = i0 < p.ne00 && i1 < p.ne01 && i2 < p.ne02 && i3 < p.ne03; ++ ++ data_d[get_doffset() + dst_idx] = D_TYPE(is_src0 ? data_a[get_aoffset() + src0_idx] : 0.0f); ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +new file mode 100644 +index 00000000..b6124411 +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +@@ -0,0 +1,74 @@ ++#version 450 ++ ++#include "types.comp" ++ ++#extension GL_EXT_shader_16bit_storage : require ++ ++layout(push_constant) uniform parameter { ++ uint IW; uint IH; ++ uint OW; uint OH; ++ uint OC; ++ uint pelements; ++ uint op; ++ int k0; int k1; ++ int s0; int s1; ++ int p0; int p1; ++} p; ++ ++#define BLOCK_SIZE 512 ++#define FLT_MAX 3.402823466e+38F ++#define OP_POOL_MAX 0u ++#define OP_POOL_AVG 1u ++ ++layout (local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; ++ ++layout(binding = 0) readonly buffer X {A_TYPE data_a[];}; ++layout(binding = 1) writeonly buffer D {D_TYPE data_d[];}; ++ ++void main() { ++ const uint idx = gl_GlobalInvocationID.x; ++ if (idx >= p.pelements) { ++ return; ++ } ++ ++ const uint O_HW = p.OW * p.OH; ++ ++ const uint nc = idx / O_HW; ++ const uint cur_oh = (idx % O_HW) / p.OW; ++ const uint cur_ow = (idx % O_HW) % p.OW; ++ ++ const int start_h = int(cur_oh) * p.s0 - p.p0; ++ const uint bh = max(start_h, 0); ++ const uint eh = min(start_h + p.k0, p.IH); ++ ++ const int start_w = int(cur_ow) * p.s1 - p.p1; ++ const uint bw = max(start_w, 0); ++ const uint ew = min(start_w + p.k1, p.IW); ++ ++ const float scale = 1.0 / float(p.k0 * p.k1); ++ float res; ++ ++ if (p.op == OP_POOL_AVG) { ++ res = 0.0; ++ } else if (p.op == OP_POOL_MAX) { ++ res = -FLT_MAX; ++ } else { ++ return; ++ } ++ ++ #pragma unroll ++ for (uint i = bh; i < eh; i++) { ++ #pragma unroll ++ for (uint j = bw; j < ew; j++) { ++ const float cur = D_TYPE(data_a[nc * p.IH * p.IW + i * p.IW + j]); ++ ++ if (p.op == OP_POOL_AVG) { ++ res += cur * scale; ++ } else if (p.op == OP_POOL_MAX) { ++ res = max(res, cur); ++ } ++ } ++ } ++ ++ data_d[nc * O_HW + cur_oh * p.OW + cur_ow] = res; ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +new file mode 100644 +index 00000000..52a19b62 +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +@@ -0,0 +1,21 @@ ++#version 450 ++ ++#include "generic_head.comp" ++#include "types.comp" ++ ++#extension GL_EXT_control_flow_attributes : enable ++ ++layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; ++ ++layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; ++layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; ++ ++void main() { ++ const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; ++ ++ if (i >= p.KX) { ++ return; ++ } ++ ++ data_d[i] = max(float(data_a[i]), 0); ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +new file mode 100644 +index 00000000..1568b141 +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +@@ -0,0 +1,26 @@ ++#version 450 ++ ++#include "types.comp" ++#include "generic_unary_head.comp" ++ ++layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; ++ ++uint src0_idx_mod(uint idx) { ++ const uint i13 = idx / (p.ne12*p.ne11*p.ne10); ++ const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10; ++ const uint i12 = (idx - i13_offset) / (p.ne11*p.ne10); ++ const uint i12_offset = i12*p.ne11*p.ne10; ++ const uint i11 = (idx - i13_offset - i12_offset) / p.ne10; ++ const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10; ++ return (i13 % p.ne03)*p.nb03 + (i12 % p.ne02)*p.nb02 + (i11 % p.ne01)*p.nb01 + (i10 % p.ne00)*p.nb00; ++} ++ ++void main() { ++ const uint idx = get_idx(); ++ ++ if (idx >= p.ne) { ++ return; ++ } ++ ++ data_d[get_doffset() + dst_idx(idx)] = D_TYPE(data_a[get_aoffset() + src0_idx_mod(idx)]); ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +new file mode 100644 +index 00000000..b554400b +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +@@ -0,0 +1,42 @@ ++#version 450 ++ ++#include "generic_head.comp" ++#include "types.comp" ++ ++#extension GL_EXT_control_flow_attributes : enable ++#define BLOCK_SIZE 512 ++ ++layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; ++ ++layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; ++layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; ++ ++shared FLOAT_TYPE sum[BLOCK_SIZE]; ++ ++void main() { ++ const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; ++ const uint tid = gl_LocalInvocationID.x; ++ ++ sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp ++ ++ [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { ++ const FLOAT_TYPE xi = FLOAT_TYPE(data_a[row*p.KX + col]); ++ sum[tid] += xi * xi; ++ } ++ ++ // sum up partial sums and write back result ++ barrier(); ++ [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { ++ if (tid < s) { ++ sum[tid] += sum[tid + s]; ++ } ++ barrier(); ++ } ++ ++ const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(p.KX); ++ const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1)); ++ ++ [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { ++ data_d[row*p.KX + col] = D_TYPE(scale * FLOAT_TYPE(data_a[row*p.KX + col])); ++ } ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +new file mode 100644 +index 00000000..574b51ca +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +@@ -0,0 +1,49 @@ ++#include "types.comp" ++ ++#extension GL_EXT_shader_16bit_storage : require ++#extension GL_EXT_spirv_intrinsics: enable ++ ++#if RTE16 ++spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits ++#endif ++ ++layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in; ++ ++layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; ++layout (binding = 1) readonly buffer Y {int data_pos[];}; ++layout (binding = 2) readonly buffer Z {float data_ff[];}; ++layout (binding = 3) writeonly buffer D {D_TYPE data_d[];}; ++ ++layout (push_constant) uniform parameter { ++ uint ncols; ++ uint n_dims; ++ float freq_scale; ++ uint p_delta_rows; ++ float freq_base; ++ float ext_factor; ++ float attn_factor; ++ float corr_dims[2]; ++ float theta_scale; ++ uint has_ff; ++} p; ++ ++float rope_yarn_ramp(const float low, const float high, const uint i0) { ++ const float y = (i0 / 2 - low) / max(0.001f, high - low); ++ return 1.0f - min(1.0f, max(0.0f, y)); ++} ++ ++void rope_yarn(const float theta_extrap, const uint i0, out float cos_theta, out float sin_theta) { ++ float mscale = p.attn_factor; ++ // Get n-d rotational scaling corrected for extrapolation ++ float theta_interp = p.freq_scale * theta_extrap; ++ float theta = theta_interp; ++ if (p.ext_factor != 0.0f) { ++ float ramp_mix = rope_yarn_ramp(p.corr_dims[0], p.corr_dims[1], i0) * p.ext_factor; ++ theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; ++ ++ // Get n-d magnitude scaling corrected for interpolation ++ mscale *= 1.0f + 0.1f * log(1.0f / p.freq_scale); ++ } ++ cos_theta = cos(theta) * mscale; ++ sin_theta = sin(theta) * mscale; ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +new file mode 100644 +index 00000000..83b46b69 +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +@@ -0,0 +1,37 @@ ++#version 450 ++ ++#include "rope_head.comp" ++ ++void main() { ++ const uint col = gl_GlobalInvocationID.y * 2; ++ const uint row = gl_GlobalInvocationID.x; ++ ++ if (col >= p.ncols) { ++ return; ++ } ++ ++ if (col >= p.n_dims) { ++ const uint i = row*p.ncols + col; ++ ++ data_d[i + 0] = data_a[i + 0]; ++ data_d[i + 1] = data_a[i + 1]; ++ ++ return; ++ } ++ ++ const uint i = row*p.ncols + col/2; ++ const uint i2 = row/p.p_delta_rows; ++ ++ const float theta_base = data_pos[i2] * pow(p.theta_scale, col/2.0f); ++ ++ const float freq_factor = p.has_ff != 0 ? data_ff[col/2] : 1.0f; ++ ++ float cos_theta, sin_theta; ++ rope_yarn(theta_base / freq_factor, col, cos_theta, sin_theta); ++ ++ const float x0 = float(data_a[i + 0]); ++ const float x1 = float(data_a[i + p.n_dims/2]); ++ ++ data_d[i + 0] = D_TYPE(x0*cos_theta - x1*sin_theta); ++ data_d[i + p.n_dims/2] = D_TYPE(x0*sin_theta + x1*cos_theta); ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +new file mode 100644 +index 00000000..e416ad93 +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +@@ -0,0 +1,37 @@ ++#version 450 ++ ++#include "rope_head.comp" ++ ++void main() { ++ const uint col = gl_GlobalInvocationID.y * 2; ++ const uint row = gl_GlobalInvocationID.x; ++ ++ if (col >= p.ncols) { ++ return; ++ } ++ ++ if (col >= p.n_dims) { ++ const uint i = row*p.ncols + col; ++ ++ data_d[i + 0] = data_a[i + 0]; ++ data_d[i + 1] = data_a[i + 1]; ++ ++ return; ++ } ++ ++ const uint i = row*p.ncols + col; ++ const uint i2 = row/p.p_delta_rows; ++ ++ const float theta_base = data_pos[i2] * pow(p.theta_scale, col/2.0f); ++ ++ const float freq_factor = p.has_ff != 0 ? data_ff[col/2] : 1.0f; ++ ++ float cos_theta, sin_theta; ++ rope_yarn(theta_base / freq_factor, col, cos_theta, sin_theta); ++ ++ const float x0 = float(data_a[i + 0]); ++ const float x1 = float(data_a[i + 1]); ++ ++ data_d[i + 0] = D_TYPE(x0*cos_theta - x1*sin_theta); ++ data_d[i + 1] = D_TYPE(x0*sin_theta + x1*cos_theta); ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +new file mode 100644 +index 00000000..4663428d +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +@@ -0,0 +1,24 @@ ++#version 450 ++ ++#include "types.comp" ++#include "generic_unary_head.comp" ++ ++const uint num_threads = 128; ++ ++layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; ++ ++void main() { ++ uint idx = get_idx(); ++ ++ // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation ++ const uint num_iter = 4; ++ ++ [[unroll]] for (uint i = 0; i < num_iter; ++i) { ++ if (idx >= p.ne) { ++ continue; ++ } ++ ++ data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) * FLOAT_TYPE(p.param1)); ++ idx += num_threads; ++ } ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +new file mode 100644 +index 00000000..4d36f88e +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +@@ -0,0 +1,22 @@ ++#version 450 ++ ++#include "generic_head.comp" ++#include "types.comp" ++ ++#extension GL_EXT_control_flow_attributes : enable ++ ++layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; ++ ++layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; ++layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; ++ ++void main() { ++ const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; ++ ++ if (i >= p.KX) { ++ return; ++ } ++ ++ const float xi = float(data_a[i]); ++ data_d[i] = D_TYPE(xi / (1.0f + exp(-xi))); ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +new file mode 100644 +index 00000000..d7c15a16 +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +@@ -0,0 +1,17 @@ ++#version 450 ++ ++#include "types.comp" ++#include "generic_unary_head.comp" ++ ++layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; ++ ++void main() { ++ const uint idx = get_idx(); ++ ++ if (idx >= p.ne) { ++ return; ++ } ++ ++ const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]); ++ data_d[get_doffset() + dst_idx(idx)] = D_TYPE(sin(val)); ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +new file mode 100644 +index 00000000..a25808e1 +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +@@ -0,0 +1,174 @@ ++#version 450 ++ ++#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require ++#extension GL_EXT_control_flow_attributes : enable ++ ++layout (push_constant) uniform parameter ++{ ++ uint KX; ++ uint KY; ++ float scale; ++ float max_bias; ++ float m0; ++ float m1; ++ uint n_head_log2; ++ uint nrows_x; ++} p; ++ ++#include "types.comp" ++ ++layout(constant_id = 0) const uint BLOCK_SIZE = 32; ++layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; ++ ++layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; ++layout (binding = 1) readonly buffer Y {B_TYPE data_b[];}; ++layout (binding = 2) buffer D {D_TYPE data_d[];}; ++ ++shared FLOAT_TYPE vals[BLOCK_SIZE]; ++ ++// num_iters is the number of BLOCK_SIZE loop iterations we need to iterate ++// over all the columns. The main function tries to pass a constant here, ++// as if it were a template function, to allow unrolling. ++void soft_max(uint num_iters) { ++ const uint tid = gl_LocalInvocationID.x; ++ const uint rowx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; ++ const uint rowy = (p.KY > 0) ? (rowx % p.KY) : 0; ++ ++ if (rowx >= p.nrows_x) { ++ return; ++ } ++ ++ float slope = 1.0f; ++ ++ // ALiBi ++ if (p.max_bias > 0.0f) { ++ const uint h = rowx/p.KY; // head index ++ ++ const float base = h < p.n_head_log2 ? p.m0 : p.m1; ++ const uint exp = h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1; ++ ++ slope = pow(base, exp); ++ } ++ ++ // Find max ++ FLOAT_TYPE max_val = uintBitsToFloat(0xFF800000); ++ ++ // Cache values while we compute the max, so we don't need to read them ++ // again when we're ready to compute exp(x-max). ++ const uint DATA_CACHE_SIZE = 16; ++ FLOAT_TYPE data_cache[DATA_CACHE_SIZE]; ++ ++ [[unroll]] for (uint col0 = 0, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) { ++ const uint col = col0 + tid; ++ ++ FLOAT_TYPE a = FLOAT_TYPE(0); ++ if (col < p.KX) { ++ a = data_a[rowx * p.KX + col]; ++ } ++ ++ FLOAT_TYPE b = FLOAT_TYPE(0); ++ if (p.KY > 0 && col < p.KX) { ++ b = data_b[rowy * p.KX + col]; ++ } ++ ++ FLOAT_TYPE v = a * p.scale + slope * b; ++ ++ if (col < p.KX) { ++ max_val = max(max_val, v); ++ } ++ ++ if (idx < DATA_CACHE_SIZE) { ++ data_cache[idx] = v; ++ } ++ } ++ ++ // reduce across the workgroup ++ vals[tid] = max_val; ++ barrier(); ++ [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) { ++ if (tid < s) { ++ vals[tid] = max(vals[tid], vals[tid + s]); ++ } ++ barrier(); ++ } ++ ++ max_val = vals[0]; ++ barrier(); ++ ++ FLOAT_TYPE sum = FLOAT_TYPE(0.0f); ++ ++ // Compute sum{exp(x - max)} ++ [[unroll]] for (uint col0 = 0, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) { ++ const uint col = col0 + tid; ++ ++ if (col >= p.KX) { ++ break; ++ } ++ ++ // compute exp(a*scale+b*slope), add it to sum, and cache the new value ++ // in data_cache if possible. ++ const uint i = rowx * p.KX + col; ++ FLOAT_TYPE val; ++ if (idx < DATA_CACHE_SIZE) { ++ val = exp(data_cache[idx] - max_val); ++ } else { ++ val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)) - max_val); ++ } ++ sum += val; ++ if (idx < DATA_CACHE_SIZE) { ++ data_cache[idx] = val; ++ } else { ++ data_d[i] = D_TYPE(val); ++ } ++ } ++ ++ // reduce across the workgroup ++ vals[tid] = sum; ++ barrier(); ++ [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) { ++ if (tid < s) { ++ vals[tid] += vals[tid + s]; ++ } ++ barrier(); ++ } ++ sum = vals[0]; ++ ++ FLOAT_TYPE rcpdivisor = 1.0/sum; ++ ++ [[unroll]] for (uint col0 = 0, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) { ++ const uint col = col0 + tid; ++ ++ if (col >= p.KX) { ++ continue; ++ } ++ ++ if (idx < DATA_CACHE_SIZE) { ++ data_d[rowx*p.KX + col] = D_TYPE(data_cache[idx] * rcpdivisor); ++ } else { ++ data_d[rowx*p.KX + col] *= D_TYPE(rcpdivisor); ++ } ++ } ++} ++ ++void main() { ++ // instantiate the soft_max function for several different ++ // dimensions, to allow loop unrolling ++ uint num_blocks = (p.KX + BLOCK_SIZE - 1) / BLOCK_SIZE; ++ if (num_blocks > 32) { ++ soft_max(num_blocks); ++ } else if (num_blocks > 16) { ++ soft_max(32); ++ } else if (num_blocks > 8) { ++ soft_max(16); ++ } else if (num_blocks > 4) { ++ soft_max(8); ++ } else if (num_blocks == 4) { ++ soft_max(4); ++ } else if (num_blocks == 3) { ++ soft_max(3); ++ } else if (num_blocks == 2) { ++ soft_max(2); ++ } else if (num_blocks == 1) { ++ soft_max(1); ++ } ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/square.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +new file mode 100644 +index 00000000..ef43598b +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +@@ -0,0 +1,17 @@ ++#version 450 ++ ++#include "types.comp" ++#include "generic_unary_head.comp" ++ ++layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; ++ ++void main() { ++ const uint idx = get_idx(); ++ ++ if (idx >= p.ne) { ++ return; ++ } ++ ++ const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]); ++ data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val * val); ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +new file mode 100644 +index 00000000..961e5ffa +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +@@ -0,0 +1,37 @@ ++#version 450 ++ ++#include "generic_head.comp" ++#include "types.comp" ++ ++#extension GL_EXT_control_flow_attributes : enable ++layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; ++ ++layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; ++layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; ++ ++layout (constant_id = 0) const uint BLOCK_SIZE = 32; ++ ++shared FLOAT_TYPE tmp[BLOCK_SIZE]; ++ ++void main() { ++ const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; ++ const uint col = gl_LocalInvocationID.x; ++ ++ tmp[col] = FLOAT_TYPE(0.0f); ++ ++ for (uint i = col; i < p.KX; i += BLOCK_SIZE) { ++ tmp[col] += FLOAT_TYPE(data_a[row*p.KX + i]); ++ } ++ ++ barrier(); ++ [[unroll]] for (int s = int(BLOCK_SIZE) / 2; s > 0; s >>= 1) { ++ if (col < s) { ++ tmp[col] += tmp[col + s]; ++ } ++ barrier(); ++ } ++ ++ if (col == 0) { ++ data_d[row] = D_TYPE(tmp[0]); ++ } ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +new file mode 100644 +index 00000000..495f966b +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +@@ -0,0 +1,20 @@ ++#version 450 ++ ++#include "generic_head.comp" ++#include "types.comp" ++ ++#extension GL_EXT_control_flow_attributes : enable ++ ++layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; ++ ++layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; ++layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; ++ ++void main() { ++ const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; ++ ++ if (i >= p.KX) { ++ return; ++ } ++ data_d[i] = D_TYPE(1. - 2. / (exp(2.*data_a[i]) + 1.)); ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat2_support.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat2_support.comp +new file mode 100644 +index 00000000..28eb24e1 +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat2_support.comp +@@ -0,0 +1,7 @@ ++#version 460 ++ ++#extension GL_NV_cooperative_matrix2 : require ++ ++void main() ++{ ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +new file mode 100644 +index 00000000..79e065a9 +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +@@ -0,0 +1,41 @@ ++#version 450 ++ ++#extension GL_EXT_shader_16bit_storage : require ++ ++layout (push_constant) uniform parameter ++{ ++ uint nb1; ++ uint dim; ++ uint max_period; ++} p; ++ ++#include "types.comp" ++ ++#extension GL_EXT_control_flow_attributes : enable ++#define BLOCK_SIZE 256 ++ ++layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; ++ ++layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; ++layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; ++ ++void main() { ++ const uint i = gl_WorkGroupID.y; ++ const uint j = gl_GlobalInvocationID.x; ++ const uint d_offset = i * p.nb1; ++ ++ if (p.dim % 2 != 0 && j == ((p.dim + 1) / 2)) { ++ data_d[d_offset + p.dim] = 0.f; ++ } ++ ++ const uint half_dim = p.dim / 2; ++ if (j >= half_dim) { ++ return; ++ } ++ ++ const float timestep = float(data_a[i]); ++ const float freq = float(exp(-log(p.max_period) * j / half_dim)); ++ const float arg = timestep * freq; ++ data_d[d_offset + j] = D_TYPE(cos(arg)); ++ data_d[d_offset + j + half_dim] = D_TYPE(sin(arg)); ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/types.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +new file mode 100644 +index 00000000..eecc47f3 +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +@@ -0,0 +1,323 @@ ++ ++#if !defined(GGML_TYPES_COMP) ++#define GGML_TYPES_COMP ++ ++#extension GL_EXT_shader_explicit_arithmetic_types : require ++ ++#if defined(DATA_A_F32) ++#define QUANT_K 1 ++#define QUANT_R 1 ++ ++#if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1 ++#define A_TYPE float ++#elif LOAD_VEC_A == 4 ++#define A_TYPE vec4 ++#elif LOAD_VEC_A == 8 ++#define A_TYPE mat2x4 ++#endif ++#endif ++ ++#if defined(DATA_A_F16) ++#define QUANT_K 1 ++#define QUANT_R 1 ++ ++#if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1 ++#define A_TYPE float16_t ++#elif LOAD_VEC_A == 4 ++#define A_TYPE f16vec4 ++#elif LOAD_VEC_A == 8 ++#define A_TYPE f16mat2x4 ++#endif ++#endif ++ ++#define QUANT_K_Q4_0 32 ++#define QUANT_R_Q4_0 2 ++ ++struct block_q4_0 ++{ ++ float16_t d; ++ uint8_t qs[16]; ++}; ++struct block_q4_0_packed16 ++{ ++ float16_t d; ++ uint16_t qs[16/2]; ++}; ++ ++#if defined(DATA_A_Q4_0) ++#define QUANT_K QUANT_K_Q4_0 ++#define QUANT_R QUANT_R_Q4_0 ++#define A_TYPE block_q4_0 ++#define A_TYPE_PACKED16 block_q4_0_packed16 ++#endif ++ ++#define QUANT_K_Q4_1 32 ++#define QUANT_R_Q4_1 2 ++ ++struct block_q4_1 ++{ ++ float16_t d; ++ float16_t m; ++ uint8_t qs[16]; ++}; ++ ++struct block_q4_1_packed16 ++{ ++ float16_t d; ++ float16_t m; ++ uint16_t qs[16/2]; ++}; ++ ++#if defined(DATA_A_Q4_1) ++#define QUANT_K QUANT_K_Q4_1 ++#define QUANT_R QUANT_R_Q4_1 ++#define A_TYPE block_q4_1 ++#define A_TYPE_PACKED16 block_q4_1_packed16 ++#endif ++ ++#define QUANT_K_Q5_0 32 ++#define QUANT_R_Q5_0 2 ++ ++struct block_q5_0 ++{ ++ float16_t d; ++ uint16_t qh[2]; ++ uint8_t qs[16]; ++}; ++ ++struct block_q5_0_packed16 ++{ ++ float16_t d; ++ uint16_t qh[2]; ++ uint16_t qs[16/2]; ++}; ++ ++#if defined(DATA_A_Q5_0) ++#define QUANT_K QUANT_K_Q5_0 ++#define QUANT_R QUANT_R_Q5_0 ++#define A_TYPE block_q5_0 ++#define A_TYPE_PACKED16 block_q5_0_packed16 ++#endif ++ ++#define QUANT_K_Q5_1 32 ++#define QUANT_R_Q5_1 2 ++ ++struct block_q5_1 ++{ ++ float16_t d; ++ float16_t m; ++ uint qh; ++ uint8_t qs[16]; ++}; ++ ++struct block_q5_1_packed16 ++{ ++ float16_t d; ++ float16_t m; ++ uint qh; ++ uint16_t qs[16/2]; ++}; ++ ++#if defined(DATA_A_Q5_1) ++#define QUANT_K QUANT_K_Q5_1 ++#define QUANT_R QUANT_R_Q5_1 ++#define A_TYPE block_q5_1 ++#define A_TYPE_PACKED16 block_q5_1_packed16 ++#endif ++ ++#define QUANT_K_Q8_0 32 ++#define QUANT_R_Q8_0 1 ++ ++struct block_q8_0 ++{ ++ float16_t d; ++ int8_t qs[32]; ++}; ++struct block_q8_0_packed16 ++{ ++ float16_t d; ++ uint16_t qs[32/2]; ++}; ++ ++#if defined(DATA_A_Q8_0) ++#define QUANT_K QUANT_K_Q8_0 ++#define QUANT_R QUANT_R_Q8_0 ++#define A_TYPE block_q8_0 ++#define A_TYPE_PACKED16 block_q8_0_packed16 ++#endif ++ ++// K-quants ++#define QUANT_K_Q2_K 256 ++ ++struct block_q2_K ++{ ++ uint8_t scales[QUANT_K_Q2_K/16]; ++ uint8_t qs[QUANT_K_Q2_K/4]; ++ f16vec2 d; ++}; ++ ++struct block_q2_K_packed16 ++{ ++ uint16_t scales[QUANT_K_Q2_K/16/2]; ++ uint16_t qs[QUANT_K_Q2_K/4/2]; ++ f16vec2 d; ++}; ++ ++struct block_q2_K_packed32 ++{ ++ uint32_t scales[QUANT_K_Q2_K/16/4]; ++ uint32_t qs[QUANT_K_Q2_K/4/4]; ++ f16vec2 d; ++}; ++ ++#if defined(DATA_A_Q2_K) ++#define QUANT_K QUANT_K_Q2_K ++#define A_TYPE block_q2_K ++#define A_TYPE_PACKED16 block_q2_K_packed16 ++#define A_TYPE_PACKED32 block_q2_K_packed32 ++#endif ++ ++#define QUANT_K_Q3_K 256 ++ ++struct block_q3_K ++{ ++ uint8_t hmask[QUANT_K_Q3_K/8]; ++ uint8_t qs[QUANT_K_Q3_K/4]; ++ uint8_t scales[12]; ++ float16_t d; ++}; ++ ++struct block_q3_K_packed16 ++{ ++ uint16_t hmask[QUANT_K_Q3_K/8/2]; ++ uint16_t qs[QUANT_K_Q3_K/4/2]; ++ uint16_t scales[12/2]; ++ float16_t d; ++}; ++ ++#if defined(DATA_A_Q3_K) ++#define QUANT_K QUANT_K_Q3_K ++#define A_TYPE block_q3_K ++#define A_TYPE_PACKED16 block_q3_K_packed16 ++#endif ++ ++#define QUANT_K_Q4_K 256 ++ ++struct block_q4_K ++{ ++ f16vec2 d; ++ uint8_t scales[3*QUANT_K_Q4_K/64]; ++ uint8_t qs[QUANT_K_Q4_K/2]; ++}; ++ ++struct block_q4_K_packed16 ++{ ++ f16vec2 d; ++ uint16_t scales[3*QUANT_K_Q4_K/64/2]; ++ uint16_t qs[QUANT_K_Q4_K/2/2]; ++}; ++ ++struct block_q4_K_packed32 ++{ ++ f16vec2 d; ++ uint32_t scales[3*QUANT_K_Q4_K/64/4]; ++ uint32_t qs[QUANT_K_Q4_K/2/4]; ++}; ++ ++#if defined(DATA_A_Q4_K) ++#define QUANT_K QUANT_K_Q4_K ++#define A_TYPE block_q4_K ++#define A_TYPE_PACKED16 block_q4_K_packed16 ++#define A_TYPE_PACKED32 block_q4_K_packed32 ++#endif ++ ++#define QUANT_K_Q5_K 256 ++ ++struct block_q5_K ++{ ++ f16vec2 d; ++ uint8_t scales[12]; ++ uint8_t qh[QUANT_K_Q5_K/8]; ++ uint8_t qs[QUANT_K_Q5_K/2]; ++}; ++ ++struct block_q5_K_packed16 ++{ ++ f16vec2 d; ++ uint16_t scales[12/2]; ++ uint16_t qh[QUANT_K_Q5_K/8/2]; ++ uint16_t qs[QUANT_K_Q5_K/2/2]; ++}; ++ ++#if defined(DATA_A_Q5_K) ++#define QUANT_K QUANT_K_Q5_K ++#define A_TYPE block_q5_K ++#define A_TYPE_PACKED16 block_q5_K_packed16 ++#endif ++ ++#define QUANT_K_Q6_K 256 ++ ++struct block_q6_K ++{ ++ uint8_t ql[QUANT_K_Q6_K/2]; ++ uint8_t qh[QUANT_K_Q6_K/4]; ++ int8_t scales[QUANT_K_Q6_K/16]; ++ float16_t d; ++}; ++ ++struct block_q6_K_packed16 ++{ ++ uint16_t ql[QUANT_K_Q6_K/2/2]; ++ uint16_t qh[QUANT_K_Q6_K/4/2]; ++ int8_t scales[QUANT_K_Q6_K/16]; ++ float16_t d; ++}; ++ ++#if defined(DATA_A_Q6_K) ++#define QUANT_K QUANT_K_Q6_K ++#define A_TYPE block_q6_K ++#define A_TYPE_PACKED16 block_q6_K_packed16 ++#endif ++ ++// IQuants ++ ++#define QUANT_K_IQ4_NL 32 ++#define QUANT_R_IQ4_NL 2 ++ ++struct block_iq4_nl ++{ ++ float16_t d; ++ uint8_t qs[QUANT_K_IQ4_NL/2]; ++}; ++ ++struct block_iq4_nl_packed16 ++{ ++ float16_t d; ++ uint16_t qs[QUANT_K_IQ4_NL/2/2]; ++}; ++ ++#if defined(DATA_A_IQ4_NL) ++ ++const int8_t kvalues_iq4nl_const[16] = { ++ int8_t(-127), int8_t(-104), int8_t(-83), int8_t(-65), int8_t(-49), int8_t(-35), int8_t(-22), int8_t(-10), ++ int8_t(1), int8_t(13), int8_t(25), int8_t(38), int8_t(53), int8_t(69), int8_t(89), int8_t(113) ++}; ++ ++shared FLOAT_TYPE kvalues_iq4nl[16]; ++ ++void init_iq4nl_shmem() ++{ ++ // copy the table into shared memory and sync ++ if (gl_LocalInvocationIndex.x < 16) { ++ kvalues_iq4nl[gl_LocalInvocationIndex.x] = FLOAT_TYPE(kvalues_iq4nl_const[gl_LocalInvocationIndex.x]); ++ } ++ barrier(); ++} ++ ++#define QUANT_K QUANT_K_IQ4_NL ++#define QUANT_R QUANT_R_IQ4_NL ++#define A_TYPE block_iq4_nl ++#define A_TYPE_PACKED16 block_iq4_nl_packed16 ++#endif ++ ++#endif // !defined(GGML_TYPES_COMP) +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +new file mode 100644 +index 00000000..6f607380 +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +@@ -0,0 +1,36 @@ ++#version 450 ++ ++layout (push_constant) uniform parameter ++{ ++ uint ne; uint a_offset; uint d_offset; ++ uint nb00; uint nb01; uint nb02; uint nb03; ++ uint ne10; uint ne11; uint ne12; uint ne13; ++ float sf0; float sf1; float sf2; float sf3; ++} p; ++ ++#include "types.comp" ++ ++layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; ++ ++layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; ++layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; ++ ++void main() { ++ const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; ++ ++ if (idx >= p.ne) { ++ return; ++ } ++ ++ const uint i10 = idx % p.ne10; ++ const uint i11 = (idx / p.ne10) % p.ne11; ++ const uint i12 = (idx / (p.ne10 * p.ne11)) % p.ne12; ++ const uint i13 = (idx / (p.ne10 * p.ne11 * p.ne12)) % p.ne13; ++ ++ const uint i00 = uint(i10 / p.sf0); ++ const uint i01 = uint(i11 / p.sf1); ++ const uint i02 = uint(i12 / p.sf2); ++ const uint i03 = uint(i13 / p.sf3); ++ ++ data_d[p.d_offset + idx] = D_TYPE(data_a[p.a_offset + i03 * p.nb03 + i02 * p.nb02 + i01 * p.nb01 + i00 * p.nb00]); ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +new file mode 100644 +index 00000000..8111c063 +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +@@ -0,0 +1,594 @@ ++ ++ ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++ ++#ifdef _WIN32 ++ #include ++ #include // For _mkdir on Windows ++ #include // For std::replace on w64devkit ++#else ++ #include ++ #include ++ #include ++#endif ++ ++#include ++ ++#define ASYNCIO_CONCURRENCY 64 ++ ++std::mutex lock; ++std::vector> shader_fnames; ++ ++std::string GLSLC = "glslc"; ++std::string input_dir = "vulkan-shaders"; ++std::string output_dir = "/tmp"; ++std::string target_hpp = "ggml-vulkan-shaders.hpp"; ++std::string target_cpp = "ggml-vulkan-shaders.cpp"; ++bool no_clean = false; ++ ++const std::vector type_names = { ++ "f32", ++ "f16", ++ "q4_0", ++ "q4_1", ++ "q5_0", ++ "q5_1", ++ "q8_0", ++ "q2_k", ++ "q3_k", ++ "q4_k", ++ "q5_k", ++ "q6_k", ++ "iq4_nl" ++}; ++ ++namespace { ++void execute_command(const std::string& command, std::string& stdout_str, std::string& stderr_str) { ++#ifdef _WIN32 ++ HANDLE stdout_read, stdout_write; ++ HANDLE stderr_read, stderr_write; ++ SECURITY_ATTRIBUTES sa = { sizeof(SECURITY_ATTRIBUTES), NULL, TRUE }; ++ ++ if (!CreatePipe(&stdout_read, &stdout_write, &sa, 0) || ++ !SetHandleInformation(stdout_read, HANDLE_FLAG_INHERIT, 0)) { ++ throw std::runtime_error("Failed to create stdout pipe"); ++ } ++ ++ if (!CreatePipe(&stderr_read, &stderr_write, &sa, 0) || ++ !SetHandleInformation(stderr_read, HANDLE_FLAG_INHERIT, 0)) { ++ throw std::runtime_error("Failed to create stderr pipe"); ++ } ++ ++ PROCESS_INFORMATION pi; ++ STARTUPINFOA si = {}; ++ si.cb = sizeof(STARTUPINFOA); ++ si.dwFlags = STARTF_USESTDHANDLES; ++ si.hStdOutput = stdout_write; ++ si.hStdError = stderr_write; ++ ++ std::vector cmd(command.begin(), command.end()); ++ cmd.push_back('\0'); ++ ++ if (!CreateProcessA(NULL, cmd.data(), NULL, NULL, TRUE, 0, NULL, NULL, &si, &pi)) { ++ throw std::runtime_error("Failed to create process"); ++ } ++ ++ CloseHandle(stdout_write); ++ CloseHandle(stderr_write); ++ ++ std::array buffer; ++ DWORD bytes_read; ++ ++ while (ReadFile(stdout_read, buffer.data(), (DWORD)buffer.size(), &bytes_read, NULL) && bytes_read > 0) { ++ stdout_str.append(buffer.data(), bytes_read); ++ } ++ ++ while (ReadFile(stderr_read, buffer.data(), (DWORD)buffer.size(), &bytes_read, NULL) && bytes_read > 0) { ++ stderr_str.append(buffer.data(), bytes_read); ++ } ++ ++ CloseHandle(stdout_read); ++ CloseHandle(stderr_read); ++ WaitForSingleObject(pi.hProcess, INFINITE); ++ CloseHandle(pi.hProcess); ++ CloseHandle(pi.hThread); ++#else ++int stdout_pipe[2]; ++ int stderr_pipe[2]; ++ ++ if (pipe(stdout_pipe) != 0 || pipe(stderr_pipe) != 0) { ++ throw std::runtime_error("Failed to create pipes"); ++ } ++ ++ pid_t pid = fork(); ++ if (pid < 0) { ++ throw std::runtime_error("Failed to fork process"); ++ } ++ ++ if (pid == 0) { ++ close(stdout_pipe[0]); ++ close(stderr_pipe[0]); ++ dup2(stdout_pipe[1], STDOUT_FILENO); ++ dup2(stderr_pipe[1], STDERR_FILENO); ++ close(stdout_pipe[1]); ++ close(stderr_pipe[1]); ++ execl("/bin/sh", "sh", "-c", command.c_str(), (char*) nullptr); ++ _exit(EXIT_FAILURE); ++ } else { ++ close(stdout_pipe[1]); ++ close(stderr_pipe[1]); ++ ++ std::array buffer; ++ ssize_t bytes_read; ++ ++ while ((bytes_read = read(stdout_pipe[0], buffer.data(), buffer.size())) > 0) { ++ stdout_str.append(buffer.data(), bytes_read); ++ } ++ ++ while ((bytes_read = read(stderr_pipe[0], buffer.data(), buffer.size())) > 0) { ++ stderr_str.append(buffer.data(), bytes_read); ++ } ++ ++ close(stdout_pipe[0]); ++ close(stderr_pipe[0]); ++ waitpid(pid, nullptr, 0); ++ } ++#endif ++} ++ ++bool directory_exists(const std::string& path) { ++ struct stat info; ++ if (stat(path.c_str(), &info) != 0) { ++ return false; // Path doesn't exist or can't be accessed ++ } ++ return (info.st_mode & S_IFDIR) != 0; // Check if it is a directory ++} ++ ++bool create_directory(const std::string& path) { ++#ifdef _WIN32 ++ return _mkdir(path.c_str()) == 0 || errno == EEXIST; // EEXIST means the directory already exists ++#else ++ return mkdir(path.c_str(), 0755) == 0 || errno == EEXIST; // 0755 is the directory permissions ++#endif ++} ++ ++std::string to_uppercase(const std::string& input) { ++ std::string result = input; ++ for (char& c : result) { ++ c = std::toupper(c); ++ } ++ return result; ++} ++ ++bool string_ends_with(const std::string& str, const std::string& suffix) { ++ if (suffix.size() > str.size()) { ++ return false; ++ } ++ return std::equal(suffix.rbegin(), suffix.rend(), str.rbegin()); ++} ++ ++static const char path_separator = '/'; ++ ++std::string join_paths(const std::string& path1, const std::string& path2) { ++ return path1 + path_separator + path2; ++} ++ ++std::string basename(const std::string &path) { ++ return path.substr(path.find_last_of("/\\") + 1); ++} ++ ++// variables to track number of compiles in progress ++static uint32_t compile_count = 0; ++static std::mutex compile_count_mutex; ++static std::condition_variable compile_count_cond; ++ ++void string_to_spv_func(const std::string& _name, const std::string& in_fname, const std::map& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) { ++ std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat ? "_coopmat" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32")); ++ std::string out_fname = join_paths(output_dir, name + ".spv"); ++ std::string in_path = join_paths(input_dir, in_fname); ++ ++ std::string target_env = (name.find("_cm2") != std::string::npos) ? "--target-env=vulkan1.3" : "--target-env=vulkan1.2"; ++ ++ // disable spirv-opt for coopmat shaders for https://github.com/ggerganov/llama.cpp/issues/10734 ++ std::string opt_level = coopmat ? "" : "-O"; ++ ++ #ifdef _WIN32 ++ std::vector cmd = {GLSLC, "-fshader-stage=compute", target_env, opt_level, "\"" + in_path + "\"", "-o", "\"" + out_fname + "\""}; ++ #else ++ std::vector cmd = {GLSLC, "-fshader-stage=compute", target_env, opt_level, in_path, "-o", out_fname}; ++ #endif ++ ++ #ifdef GGML_VULKAN_SHADER_DEBUG_INFO ++ cmd.push_back("-g"); ++ #endif ++ ++ for (const auto& define : defines) { ++ cmd.push_back("-D" + define.first + "=" + define.second); ++ } ++ ++ std::string command; ++ for (const auto& part : cmd) { ++ command += part + " "; ++ } ++ ++ std::string stdout_str, stderr_str; ++ try { ++ // std::cout << "Executing command: "; ++ // for (const auto& part : cmd) { ++ // std::cout << part << " "; ++ // } ++ // std::cout << std::endl; ++ ++ execute_command(command, stdout_str, stderr_str); ++ if (!stderr_str.empty()) { ++ std::cerr << "cannot compile " << name << "\n\n" << command << "\n\n" << stderr_str << std::endl; ++ return; ++ } ++ ++ std::lock_guard guard(lock); ++ shader_fnames.push_back(std::make_pair(name, out_fname)); ++ } catch (const std::exception& e) { ++ std::cerr << "Error executing command for " << name << ": " << e.what() << std::endl; ++ } ++ { ++ std::lock_guard guard(compile_count_mutex); ++ assert(compile_count > 0); ++ compile_count--; ++ } ++ compile_count_cond.notify_all(); ++} ++ ++std::map merge_maps(const std::map& a, const std::map& b) { ++ std::map result = a; ++ result.insert(b.begin(), b.end()); ++ return result; ++} ++ ++static std::vector> compiles; ++void string_to_spv(const std::string& _name, const std::string& in_fname, const std::map& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) { ++ { ++ // wait until fewer than N compiles are in progress. ++ // 16 is an arbitrary limit, the goal is to avoid "failed to create pipe" errors. ++ uint32_t N = 16; ++ std::unique_lock guard(compile_count_mutex); ++ while (compile_count >= N) { ++ compile_count_cond.wait(guard); ++ } ++ compile_count++; ++ } ++ compiles.push_back(std::async(string_to_spv_func, _name, in_fname, defines, fp16, coopmat, coopmat2, f16acc)); ++} ++ ++void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool f16acc) { ++ std::string load_vec = coopmat2 ? "1" : fp16 ? "8" : "4"; ++ std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4"; ++ std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4"; ++ ++ std::map base_dict = {{"FLOAT_TYPE", (coopmat2 || fp16) ? "float16_t" : "float"}}; ++ std::string shader_name = "matmul"; ++ ++ if (matmul_id) { ++ base_dict["MUL_MAT_ID"] = "1"; ++ shader_name = "matmul_id"; ++ } ++ ++ if (fp16) { ++ base_dict["FLOAT16"] = "1"; ++ } ++ ++ base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float"; ++ ++ if (coopmat) { ++ base_dict["COOPMAT"] = "1"; ++ } ++ ++ base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float"; ++ ++ std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp"; ++ ++ // Shaders with f16 B_TYPE ++ string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc); ++ string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); ++ ++ string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); ++ string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); ++ ++ for (const auto& tname : type_names) { ++ std::string data_a_key = "DATA_A_" + to_uppercase(tname); ++ // For unaligned, load one at a time for f32/f16, or two at a time for quants ++ std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16") ? "1" : "2"; ++ // For aligned matmul loads ++ std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : "2"; ++ ++ string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc); ++ string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); ++ ++ if (tname != "f16" && tname != "f32") { ++ string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc); ++ string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); ++ } ++ } ++} ++ ++void process_shaders() { ++ std::cout << "ggml_vulkan: Generating and compiling shaders to SPIR-V" << std::endl; ++ std::map base_dict = {{"FLOAT_TYPE", "float"}}; ++ ++ // matmul ++ for (const auto& matmul_id : {false, true}) { ++ // No coopmats ++ // fp32 ++ matmul_shaders(false, matmul_id, false, false, false); ++ ++ // fp16, fp32acc and fp16acc ++ matmul_shaders(true, matmul_id, false, false, false); ++ matmul_shaders(true, matmul_id, false, false, true); ++ ++ // Coopmat, fp32acc and fp16acc ++ matmul_shaders(true, matmul_id, true, false, false); ++ matmul_shaders(true, matmul_id, true, false, true); ++ ++#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) ++ // Coopmat2, fp32acc and fp16acc ++ matmul_shaders(true, matmul_id, false, true, false); ++ matmul_shaders(true, matmul_id, false, true, true); ++#endif ++ } ++ ++#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) ++ // flash attention ++ for (const auto& f16acc : {false, true}) { ++ std::string acctype = f16acc ? "float16_t" : "float"; ++ ++ for (const auto& tname : type_names) { ++ if (tname == "f32") { ++ continue; ++ } ++ ++ if (tname == "f16") { ++ string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", ++ merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, true, f16acc); ++ } else { ++ std::string data_a_key = "DATA_A_" + to_uppercase(tname); ++ string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", ++ merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc); ++ } ++ } ++ } ++#endif ++ ++ for (const auto& tname : type_names) { ++ // mul mat vec ++ std::string data_a_key = "DATA_A_" + to_uppercase(tname); ++ std::string shader = (string_ends_with(tname, "_k")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp"; ++ ++ string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}})); ++ string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}})); ++ ++ string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}})); ++ ++ // Dequant shaders ++ if (tname != "f16") { ++ string_to_spv("dequant_" + tname, "dequant_" + tname + ".comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float16_t"}})); ++ } ++ ++ if (!string_ends_with(tname, "_k")) { ++ shader = (tname == "f32" || tname == "f16") ? "get_rows.comp" : "get_rows_quant.comp"; ++ ++ if (tname == "f16") { ++ string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}})); ++ } else { ++ string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}})); ++ } ++ string_to_spv("get_rows_" + tname + "_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}})); ++ } ++ } ++ ++ string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); ++ string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); ++ ++ // Norms ++ string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); ++ string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); ++ string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); ++ ++ string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); ++ string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}); ++ string_to_spv("cpy_f16_f16", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); ++ string_to_spv("contig_cpy_f32_f32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); ++ string_to_spv("contig_cpy_f32_f16", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}); ++ string_to_spv("contig_cpy_f16_f16", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); ++ ++ string_to_spv("add_f32", "add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); ++ string_to_spv("add_f16_f32_f16", "add.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}}); ++ ++ string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); ++ ++ string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {}); ++ ++ string_to_spv("mul_f32", "mul.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); ++ ++ string_to_spv("div_f32", "div.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); ++ ++ string_to_spv("repeat_f32", "repeat.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); ++ ++ string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); ++ ++ string_to_spv("sqr_f32", "square.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); ++ ++ string_to_spv("sin_f32", "sin.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); ++ ++ string_to_spv("cos_f32", "cos.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); ++ ++ string_to_spv("clamp_f32", "clamp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); ++ ++ string_to_spv("pad_f32", "pad.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); ++ ++ string_to_spv("concat_f32", "concat.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); ++ string_to_spv("concat_f16", "concat.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); ++ string_to_spv("concat_i32", "concat.comp", {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}}); ++ ++ string_to_spv("upscale_f32", "upscale.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); ++ ++ string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); ++ string_to_spv("gelu_quick_f32", "gelu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); ++ string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); ++ string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); ++ string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); ++ string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); ++ ++ string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); ++ ++ string_to_spv("soft_max_f32", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); ++ string_to_spv("soft_max_f32_f16", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}})); ++ ++ string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); ++ string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); ++ string_to_spv("rope_norm_f16_rte", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}); ++ ++ string_to_spv("rope_neox_f32", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); ++ string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); ++ string_to_spv("rope_neox_f16_rte", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}); ++ ++ string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}}); ++ ++ string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); ++ ++ string_to_spv("im2col_f32", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); ++ string_to_spv("im2col_f32_f16", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}})); ++ string_to_spv("im2col_f32_f16_rte", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}})); ++ ++ string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); ++ ++ string_to_spv("pool2d_f32", "pool2d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); ++ ++ string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); ++ ++ for (auto &c : compiles) { ++ c.wait(); ++ } ++} ++ ++void write_output_files() { ++ FILE* hdr = fopen(target_hpp.c_str(), "w"); ++ FILE* src = fopen(target_cpp.c_str(), "w"); ++ ++ fprintf(hdr, "#include \n\n"); ++ fprintf(src, "#include \"%s\"\n\n", basename(target_hpp).c_str()); ++ ++ for (const auto& pair : shader_fnames) { ++ const std::string& name = pair.first; ++ #ifdef _WIN32 ++ std::string path = pair.second; ++ std::replace(path.begin(), path.end(), '/', '\\' ); ++ #else ++ const std::string& path = pair.second; ++ #endif ++ ++ FILE* spv = fopen(path.c_str(), "rb"); ++ if (!spv) { ++ std::cerr << "Error opening SPIR-V file: " << path << " (" << strerror(errno) << ")\n"; ++ continue; ++ } ++ ++ fseek(spv, 0, SEEK_END); ++ size_t size = ftell(spv); ++ fseek(spv, 0, SEEK_SET); ++ ++ std::vector data(size); ++ size_t read_size = fread(data.data(), 1, size, spv); ++ fclose(spv); ++ if (read_size != size) { ++ std::cerr << "Error reading SPIR-V file: " << path << " (" << strerror(errno) << ")\n"; ++ continue; ++ } ++ ++ fprintf(hdr, "extern unsigned char %s_data[%zu];\n", name.c_str(), size); ++ fprintf(hdr, "const uint64_t %s_len = %zu;\n\n", name.c_str(), size); ++ ++ fprintf(src, "unsigned char %s_data[%zu] = {\n", name.c_str(), size); ++ for (size_t i = 0; i < size; ++i) { ++ fprintf(src, "0x%02x,", data[i]); ++ if ((i + 1) % 12 == 0) fprintf(src, "\n"); ++ } ++ fprintf(src, "\n};\n\n"); ++ ++ if (!no_clean) { ++ std::remove(path.c_str()); ++ } ++ } ++ ++ fclose(hdr); ++ fclose(src); ++} ++} ++ ++int main(int argc, char** argv) { ++ std::map args; ++ for (int i = 1; i < argc; ++i) { ++ std::string arg = argv[i]; ++ if (arg.rfind("--", 0) == 0) { ++ if (i + 1 < argc && argv[i + 1][0] != '-') { ++ args[arg] = argv[i + 1]; ++ ++i; ++ } else { ++ args[arg] = ""; ++ } ++ } ++ } ++ ++ if (args.find("--glslc") != args.end()) { ++ GLSLC = args["--glslc"]; // Path to glslc ++ } ++ if (args.find("--input-dir") != args.end()) { ++ input_dir = args["--input-dir"]; // Directory containing shader sources ++ } ++ if (args.find("--output-dir") != args.end()) { ++ output_dir = args["--output-dir"]; // Directory for containing SPIR-V output ++ } ++ if (args.find("--target-hpp") != args.end()) { ++ target_hpp = args["--target-hpp"]; // Path to generated header file ++ } ++ if (args.find("--target-cpp") != args.end()) { ++ target_cpp = args["--target-cpp"]; // Path to generated cpp file ++ } ++ if (args.find("--no-clean") != args.end()) { ++ no_clean = true; // Keep temporary SPIR-V files in output-dir after build ++ } ++ ++ if (!directory_exists(input_dir)) { ++ std::cerr << "\"" << input_dir << "\" must be a valid directory containing shader sources" << std::endl; ++ return EXIT_FAILURE; ++ } ++ ++ if (!directory_exists(output_dir)) { ++ if (!create_directory(output_dir)) { ++ std::cerr << "Error creating output directory: " << output_dir << "\n"; ++ return EXIT_FAILURE; ++ } ++ } ++ ++ process_shaders(); ++ ++ write_output_files(); ++ ++ return EXIT_SUCCESS; ++} +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp +new file mode 100644 +index 00000000..35cc6c45 +--- /dev/null ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp +@@ -0,0 +1,87 @@ ++#version 450 ++ ++#extension GL_EXT_control_flow_attributes : require ++ ++#define BLOCK_SIZE 64 ++layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; ++ ++layout(push_constant) uniform Parameters { ++ uint B; ++ uint T; ++ uint C; ++ uint H; ++}; ++ ++layout(binding = 0) readonly buffer KBuf { A_TYPE k[]; }; ++layout(binding = 1) readonly buffer VBuf { A_TYPE v[]; }; ++layout(binding = 2) readonly buffer RBuf { A_TYPE r[]; }; ++layout(binding = 3) readonly buffer TimeFBuf { A_TYPE tf[]; }; ++layout(binding = 4) readonly buffer TimeDBuf { A_TYPE td[]; }; ++layout(binding = 5) readonly buffer StateBuf { A_TYPE state_in[]; }; ++layout(binding = 6) buffer DstBuf { A_TYPE dst[]; }; ++ ++shared A_TYPE _k[BLOCK_SIZE], _r[BLOCK_SIZE], _tf[BLOCK_SIZE], _td[BLOCK_SIZE]; ++ ++void main() { ++ const uint head_size = BLOCK_SIZE; ++ const uint batch_id = gl_WorkGroupID.x / H; ++ const uint head_id = gl_WorkGroupID.x % H; ++ const uint tid = gl_LocalInvocationID.x; ++ ++ const uint state_size = C * head_size; ++ const uint n_seq_tokens = T / B; ++ ++ if (batch_id >= B || head_id >= H) { ++ return; ++ } ++ ++ A_TYPE state[BLOCK_SIZE]; ++ [[unroll]] for (uint i = 0; i < head_size; i++) { ++ state[i] = state_in[batch_id * state_size + head_id * head_size * head_size ++ + i * head_size + tid]; ++ } ++ ++ barrier(); ++ _tf[tid] = tf[head_id * head_size + tid]; ++ barrier(); ++ ++ const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid; ++ const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid; ++ ++ for (uint t = start_t; t < end_t; t += C) { ++ barrier(); ++ _k[tid] = k[t]; ++ _r[tid] = r[t]; ++ _td[tid] = td[t]; ++ barrier(); ++ ++ const A_TYPE v_val = v[t]; ++ A_TYPE y = 0.0; ++ ++ [[unroll]] for (uint j = 0; j < head_size; j += 4) { ++ vec4 k_vec = vec4(_k[j], _k[j+1], _k[j+2], _k[j+3]); ++ vec4 r_vec = vec4(_r[j], _r[j+1], _r[j+2], _r[j+3]); ++ vec4 tf_vec = vec4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]); ++ vec4 td_vec = vec4(_td[j], _td[j+1], _td[j+2], _td[j+3]); ++ vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]); ++ ++ vec4 kv = k_vec * v_val; ++ ++ vec4 temp = tf_vec * kv + s_vec; ++ y += dot(r_vec, temp); ++ ++ s_vec = s_vec * td_vec + kv; ++ state[j] = s_vec.x; ++ state[j+1] = s_vec.y; ++ state[j+2] = s_vec.z; ++ state[j+3] = s_vec.w; ++ } ++ ++ dst[t] = y; ++ } ++ ++ [[unroll]] for (uint i = 0; i < head_size; i++) { ++ dst[T * C + batch_id * state_size + head_id * head_size * head_size ++ + i * head_size + tid] = state[i]; ++ } ++} +-- +2.43.0 + From 98f699773aad00ba60c0cda4ca1a0ce7940d0a09 Mon Sep 17 00:00:00 2001 From: Vadim Grinco Date: Mon, 10 Mar 2025 12:34:37 +0100 Subject: [PATCH 025/172] Applied 00-fix-vulkan-building.patch Work done by McBane87 here: https://github.com/whyvl/ollama-vulkan/issues/7#issuecomment-2660836871 Signed-off-by: Vadim Grinco --- CMakePresets.json | 13 +- discover/gpu.go | 7 +- .../ggml/ggml/src/ggml-vulkan/CMakeLists.txt | 92 + .../ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp | 8745 +++++++++++++++++ .../ggml-vulkan/vulkan-shaders/CMakeLists.txt | 9 + .../src/ggml-vulkan/vulkan-shaders/acc.comp | 29 + .../src/ggml-vulkan/vulkan-shaders/add.comp | 29 + .../ggml-vulkan/vulkan-shaders/argsort.comp | 69 + .../src/ggml-vulkan/vulkan-shaders/clamp.comp | 17 + .../ggml-vulkan/vulkan-shaders/concat.comp | 41 + .../vulkan-shaders/contig_copy.comp | 42 + .../src/ggml-vulkan/vulkan-shaders/copy.comp | 20 + .../src/ggml-vulkan/vulkan-shaders/cos.comp | 17 + .../vulkan-shaders/dequant_f32.comp | 20 + .../vulkan-shaders/dequant_funcs.comp | 118 + .../vulkan-shaders/dequant_funcs_cm2.comp | 325 + .../vulkan-shaders/dequant_head.comp | 13 + .../vulkan-shaders/dequant_iq4_nl.comp | 32 + .../vulkan-shaders/dequant_q2_k.comp | 34 + .../vulkan-shaders/dequant_q3_k.comp | 42 + .../vulkan-shaders/dequant_q4_0.comp | 30 + .../vulkan-shaders/dequant_q4_1.comp | 32 + .../vulkan-shaders/dequant_q4_k.comp | 68 + .../vulkan-shaders/dequant_q5_0.comp | 34 + .../vulkan-shaders/dequant_q5_1.comp | 35 + .../vulkan-shaders/dequant_q5_k.comp | 70 + .../vulkan-shaders/dequant_q6_k.comp | 33 + .../vulkan-shaders/dequant_q8_0.comp | 31 + .../vulkan-shaders/diag_mask_inf.comp | 34 + .../src/ggml-vulkan/vulkan-shaders/div.comp | 27 + .../vulkan-shaders/flash_attn_cm2.comp | 289 + .../src/ggml-vulkan/vulkan-shaders/gelu.comp | 25 + .../vulkan-shaders/gelu_quick.comp | 23 + .../vulkan-shaders/generic_binary_head.comp | 64 + .../vulkan-shaders/generic_head.comp | 9 + .../vulkan-shaders/generic_unary_head.comp | 56 + .../ggml-vulkan/vulkan-shaders/get_rows.comp | 28 + .../vulkan-shaders/get_rows_quant.comp | 39 + .../vulkan-shaders/group_norm.comp | 66 + .../ggml-vulkan/vulkan-shaders/im2col.comp | 87 + .../vulkan-shaders/leaky_relu.comp | 22 + .../src/ggml-vulkan/vulkan-shaders/mul.comp | 27 + .../mul_mat_split_k_reduce.comp | 48 + .../vulkan-shaders/mul_mat_vec.comp | 152 + .../vulkan-shaders/mul_mat_vec_base.comp | 118 + .../vulkan-shaders/mul_mat_vec_nc.comp | 71 + .../vulkan-shaders/mul_mat_vec_p021.comp | 73 + .../vulkan-shaders/mul_mat_vec_q2_k.comp | 115 + .../vulkan-shaders/mul_mat_vec_q3_k.comp | 103 + .../vulkan-shaders/mul_mat_vec_q4_k.comp | 133 + .../vulkan-shaders/mul_mat_vec_q5_k.comp | 162 + .../vulkan-shaders/mul_mat_vec_q6_k.comp | 112 + .../ggml-vulkan/vulkan-shaders/mul_mm.comp | 631 ++ .../vulkan-shaders/mul_mm_cm2.comp | 328 + .../src/ggml-vulkan/vulkan-shaders/norm.comp | 44 + .../src/ggml-vulkan/vulkan-shaders/pad.comp | 28 + .../ggml-vulkan/vulkan-shaders/pool2d.comp | 74 + .../src/ggml-vulkan/vulkan-shaders/relu.comp | 21 + .../ggml-vulkan/vulkan-shaders/repeat.comp | 26 + .../ggml-vulkan/vulkan-shaders/rms_norm.comp | 42 + .../ggml-vulkan/vulkan-shaders/rope_head.comp | 49 + .../ggml-vulkan/vulkan-shaders/rope_neox.comp | 37 + .../ggml-vulkan/vulkan-shaders/rope_norm.comp | 37 + .../src/ggml-vulkan/vulkan-shaders/scale.comp | 24 + .../src/ggml-vulkan/vulkan-shaders/silu.comp | 22 + .../src/ggml-vulkan/vulkan-shaders/sin.comp | 17 + .../ggml-vulkan/vulkan-shaders/soft_max.comp | 174 + .../ggml-vulkan/vulkan-shaders/square.comp | 17 + .../ggml-vulkan/vulkan-shaders/sum_rows.comp | 37 + .../src/ggml-vulkan/vulkan-shaders/tanh.comp | 20 + .../vulkan-shaders/test_coopmat2_support.comp | 7 + .../vulkan-shaders/timestep_embedding.comp | 41 + .../src/ggml-vulkan/vulkan-shaders/types.comp | 323 + .../ggml-vulkan/vulkan-shaders/upscale.comp | 36 + .../vulkan-shaders/vulkan-shaders-gen.cpp | 594 ++ .../src/ggml-vulkan/vulkan-shaders/wkv6.comp | 87 + 76 files changed, 14642 insertions(+), 4 deletions(-) create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/CMakeLists.txt create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/add.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/div.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_split_k_reduce.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/square.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat2_support.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/types.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp diff --git a/CMakePresets.json b/CMakePresets.json index 442cb2a6d..09e924011 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -58,7 +58,11 @@ "cacheVariables": { "AMDGPU_TARGETS": "gfx900;gfx940;gfx941;gfx942;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx906:xnack-;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-" } - } + }, + { + "name": "Vulkan", + "inherits": [ "Default" ] + } ], "buildPresets": [ { @@ -105,6 +109,11 @@ "name": "ROCm 6", "inherits": [ "ROCm" ], "configurePreset": "ROCm 6" - } + }, + { + "name": "Vulkan", + "targets": [ "ggml-vulkan" ], + "configurePreset": "Vulkan" + } ] } diff --git a/discover/gpu.go b/discover/gpu.go index 791d6b199..2494469a7 100644 --- a/discover/gpu.go +++ b/discover/gpu.go @@ -196,7 +196,10 @@ func initVulkanHandles() *vulkanHandles { libcapPaths := FindLibCapLibs() if len(vulkanPaths) > 0 && len(libcapPaths) > 0 { + slog.Info("vulkan: load libvulkan and libcap ok") vHandles.deviceCount, vHandles.vulkan, vulkanLibPath, libcapLibPath = LoadVulkanMgmt(vulkanPaths, libcapPaths) + } else { + slog.Info("vulkan: failed to load libvulkan or libcap") } return vHandles @@ -425,7 +428,7 @@ func GetGPUInfo() GpuInfoList { gpuInfo.ID = C.GoString(&memInfo.gpu_id[0]) gpuInfo.Compute = fmt.Sprintf("%d.%d", memInfo.major, memInfo.minor) gpuInfo.MinimumMemory = 0 - gpuInfo.DependencyPath = depPaths + gpuInfo.DependencyPath = []string{LibOllamaPath} gpuInfo.Name = C.GoString(&memInfo.gpu_name[0]) gpuInfo.DriverMajor = int(memInfo.major) gpuInfo.DriverMinor = int(memInfo.minor) @@ -767,7 +770,7 @@ func LoadVulkanMgmt(vulkanLibPaths []string, capLibPaths []string) (int, *C.vk_h C.vk_init(vkLib, capLib, &resp) if resp.err != nil { - slog.Debug("Unable to load vulkan", "library", vkLibPath, capLibPath, "error", C.GoString(resp.err)) + slog.Error("Unable to load vulkan", "library", vkLibPath, capLibPath, "error", C.GoString(resp.err)) C.free(unsafe.Pointer(resp.err)) } else { return int(resp.num_devices), &resp.ch, vkLibPath, capLibPath diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/CMakeLists.txt b/ml/backend/ggml/ggml/src/ggml-vulkan/CMakeLists.txt new file mode 100644 index 000000000..9501de736 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/CMakeLists.txt @@ -0,0 +1,92 @@ +find_package(Vulkan COMPONENTS glslc REQUIRED) + +if (Vulkan_FOUND) + message(STATUS "Vulkan found") + + ggml_add_backend_library(ggml-vulkan + ggml-vulkan.cpp + ../../include/ggml-vulkan.h + ) + + # Compile a test shader to determine whether GL_NV_cooperative_matrix2 is supported. + # If it's not, there will be an error to stderr. + # If it's supported, set a define to indicate that we should compile those shaders + execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat2_support.comp" + OUTPUT_VARIABLE glslc_output + ERROR_VARIABLE glslc_error) + + if (${glslc_error} MATCHES ".*extension not supported: GL_NV_cooperative_matrix2.*") + message(STATUS "GL_NV_cooperative_matrix2 not supported by glslc") + else() + message(STATUS "GL_NV_cooperative_matrix2 supported by glslc") + add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + endif() + + target_link_libraries(ggml-vulkan PRIVATE Vulkan::Vulkan) + target_include_directories(ggml-vulkan PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) + + # Workaround to the "can't dereference invalidated vector iterator" bug in clang-cl debug build + # Posssibly relevant: https://stackoverflow.com/questions/74748276/visual-studio-no-displays-the-correct-length-of-stdvector + if (MSVC AND CMAKE_CXX_COMPILER_ID STREQUAL "Clang") + add_compile_definitions(_ITERATOR_DEBUG_LEVEL=0) + endif() + + if (GGML_VULKAN_CHECK_RESULTS) + add_compile_definitions(GGML_VULKAN_CHECK_RESULTS) + endif() + + if (GGML_VULKAN_DEBUG) + add_compile_definitions(GGML_VULKAN_DEBUG) + endif() + + if (GGML_VULKAN_MEMORY_DEBUG) + add_compile_definitions(GGML_VULKAN_MEMORY_DEBUG) + endif() + + if (GGML_VULKAN_SHADER_DEBUG_INFO) + add_compile_definitions(GGML_VULKAN_SHADER_DEBUG_INFO) + endif() + + if (GGML_VULKAN_PERF) + add_compile_definitions(GGML_VULKAN_PERF) + endif() + + if (GGML_VULKAN_VALIDATE) + add_compile_definitions(GGML_VULKAN_VALIDATE) + endif() + + if (GGML_VULKAN_RUN_TESTS) + add_compile_definitions(GGML_VULKAN_RUN_TESTS) + endif() + + add_subdirectory(vulkan-shaders) + + set (_ggml_vk_genshaders_cmd vulkan-shaders-gen) + set (_ggml_vk_header ${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.hpp) + set (_ggml_vk_source ${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.cpp) + set (_ggml_vk_input_dir ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders) + set (_ggml_vk_output_dir ${CMAKE_CURRENT_BINARY_DIR}/vulkan-shaders.spv) + + file(GLOB _ggml_vk_shader_deps "${_ggml_vk_input_dir}/*.comp") + + add_custom_command( + OUTPUT ${_ggml_vk_header} + ${_ggml_vk_source} + + COMMAND "$/${_ggml_vk_genshaders_cmd}" + --glslc ${Vulkan_GLSLC_EXECUTABLE} + --input-dir ${_ggml_vk_input_dir} + --output-dir ${_ggml_vk_output_dir} + --target-hpp ${_ggml_vk_header} + --target-cpp ${_ggml_vk_source} + --no-clean + + DEPENDS ${_ggml_vk_shader_deps} ${_ggml_vk_genshaders_cmd} + COMMENT "Generate vulkan shaders" + ) + + target_sources(ggml-vulkan PRIVATE ${_ggml_vk_source} ${_ggml_vk_header}) + +else() + message(WARNING "Vulkan not found") +endif() diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp new file mode 100644 index 000000000..d75cd6d61 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -0,0 +1,8745 @@ +#include "ggml-vulkan.h" +#include +#if defined(GGML_VULKAN_RUN_TESTS) || defined(GGML_VULKAN_PERF) || defined(GGML_VULKAN_CHECK_RESULTS) +#include +#include "ggml-cpu.h" +#endif + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ggml-impl.h" +#include "ggml-backend-impl.h" + +#include "ggml-vulkan-shaders.hpp" + +#define VK_API_VERSION VK_API_VERSION_1_2 + +#define CEIL_DIV(M, N) (((M) + (N)-1) / (N)) + +#define VK_VENDOR_ID_AMD 0x1002 +#define VK_VENDOR_ID_APPLE 0x106b +#define VK_VENDOR_ID_INTEL 0x8086 +#define VK_VENDOR_ID_NVIDIA 0x10de + +#define VK_DEVICE_DESCRIPTOR_POOL_SIZE 32 + +#define GGML_VK_MAX_NODES 8192 + +#define MAX_VK_BUFFERS 256 + +#define VK_CHECK(err, msg) \ + do { \ + vk::Result err_ = (err); \ + if (err_ != vk::Result::eSuccess) { \ + fprintf(stderr, "ggml_vulkan: %s error %s at %s:%d\n", \ + #err, to_string(err_).c_str(), __FILE__, __LINE__); \ + exit(1); \ + } \ + } while (0) + +#ifdef GGML_VULKAN_DEBUG +#define VK_LOG_DEBUG(msg) std::cerr << msg << std::endl +#else +#define VK_LOG_DEBUG(msg) ((void) 0) +#endif // GGML_VULKAN_DEBUG + +struct ggml_backend_vk_context; + +struct vk_queue { + uint32_t queue_family_index; + vk::Queue queue; + vk::CommandPool pool; + uint32_t cmd_buffer_idx; + std::vector cmd_buffers; + + vk::PipelineStageFlags stage_flags; + + bool transfer_only; +}; + +struct vk_pipeline_struct { + std::string name; + vk::ShaderModule shader_module; + vk::DescriptorSetLayout dsl; + std::vector descriptor_pools; + std::vector descriptor_sets; + uint32_t descriptor_set_idx; + vk::PipelineLayout layout; + vk::Pipeline pipeline; + uint32_t push_constant_size; + uint32_t parameter_count; + std::array wg_denoms; + uint32_t align; +}; + +typedef std::shared_ptr vk_pipeline; +typedef std::weak_ptr vk_pipeline_ref; + +static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline); + +struct vk_matmul_pipeline_struct { + vk_pipeline l, m, s; + vk_pipeline a_l, a_m, a_s; +}; + +typedef std::shared_ptr vk_matmul_pipeline; + +struct vk_matmul_pipeline2 { + vk_matmul_pipeline2() { + f16acc = std::make_shared(); + f32acc = std::make_shared(); + } + vk_matmul_pipeline f32acc; + vk_matmul_pipeline f16acc; +}; + +struct vk_device_struct; +typedef std::shared_ptr vk_device; +typedef std::weak_ptr vk_device_ref; + +struct vk_buffer_struct; +typedef std::shared_ptr vk_buffer; +typedef std::weak_ptr vk_buffer_ref; + +struct ggml_backend_vk_buffer_type_context { + std::string name; + vk_device device; +}; + +static const char * ggml_backend_vk_buffer_type_name(ggml_backend_buffer_type_t buft); +static ggml_backend_buffer_t ggml_backend_vk_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size); +static size_t ggml_backend_vk_buffer_type_get_alignment(ggml_backend_buffer_type_t buft); +static size_t ggml_backend_vk_buffer_type_get_max_size(ggml_backend_buffer_type_t buft); +static size_t ggml_backend_vk_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor); +static ggml_backend_buffer_type_i ggml_backend_vk_buffer_type_interface = { + /* .get_name = */ ggml_backend_vk_buffer_type_name, + /* .alloc_buffer = */ ggml_backend_vk_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_vk_buffer_type_get_alignment, + /* .get_max_size = */ ggml_backend_vk_buffer_type_get_max_size, + /* .get_alloc_size = */ ggml_backend_vk_buffer_type_get_alloc_size, + /* .is_host = */ NULL, +}; + +#ifdef GGML_VULKAN_MEMORY_DEBUG +class vk_memory_logger; +#endif +#ifdef GGML_VULKAN_PERF +class vk_perf_logger; +#endif +static void ggml_vk_destroy_buffer(vk_buffer& buf); + +static constexpr uint32_t mul_mat_vec_max_cols = 8; + +struct vk_device_struct { + std::mutex mutex; + + vk::PhysicalDevice physical_device; + vk::PhysicalDeviceProperties properties; + std::string name; + uint64_t max_memory_allocation_size; + bool fp16; + bool pipeline_robustness; + vk::Device device; + uint32_t vendor_id; + vk_queue compute_queue; + vk_queue transfer_queue; + bool single_queue; + uint32_t subgroup_size; + uint32_t shader_core_count; + bool uma; + bool float_controls_rte_fp16; + + bool subgroup_size_control; + uint32_t subgroup_min_size; + uint32_t subgroup_max_size; + bool subgroup_require_full_support; + + bool coopmat_support; + bool coopmat_acc_f32_support; + bool coopmat_acc_f16_support; + uint32_t coopmat_m; + uint32_t coopmat_n; + uint32_t coopmat_k; + bool coopmat2; + + size_t idx; + + bool mul_mat_l; + bool mul_mat_m; + bool mul_mat_s; + bool mul_mat_id_l; + bool mul_mat_id_m; + bool mul_mat_id_s; + + vk_matmul_pipeline pipeline_matmul_f32; + vk_matmul_pipeline pipeline_matmul_f32_f16; + vk_matmul_pipeline2 pipeline_matmul_f16; + vk_matmul_pipeline2 pipeline_matmul_f16_f32; + vk_pipeline pipeline_matmul_split_k_reduce; + + vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_COUNT]; + vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat[GGML_TYPE_COUNT]; + + vk_matmul_pipeline pipeline_matmul_id_f32; + vk_matmul_pipeline2 pipeline_matmul_id_f16; + vk_matmul_pipeline2 pipeline_matmul_id_f16_f32; + + vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT]; + + vk_pipeline pipeline_dequant[GGML_TYPE_COUNT]; + vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols]; + vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols]; + vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_COUNT]; + + vk_pipeline pipeline_mul_mat_vec_p021_f16_f32; + vk_pipeline pipeline_mul_mat_vec_nc_f16_f32; + vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT]; + vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT]; + vk_pipeline pipeline_acc_f32; + vk_pipeline pipeline_add_f32, pipeline_add_f32_norepeat; + vk_pipeline pipeline_add_f16_f32_f16, pipeline_add_f16_f32_f16_norepeat; + vk_pipeline pipeline_mul_f32, pipeline_mul_f32_norepeat; + vk_pipeline pipeline_div_f32, pipeline_div_f32_norepeat; + vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32; + vk_pipeline pipeline_upscale_f32; + vk_pipeline pipeline_scale_f32; + vk_pipeline pipeline_sqr_f32; + vk_pipeline pipeline_sin_f32; + vk_pipeline pipeline_cos_f32; + vk_pipeline pipeline_clamp_f32; + vk_pipeline pipeline_pad_f32; + vk_pipeline pipeline_repeat_f32; + vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16; + vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16; + vk_pipeline pipeline_norm_f32; + vk_pipeline pipeline_group_norm_f32; + vk_pipeline pipeline_rms_norm_f32; + vk_pipeline pipeline_gelu_f32; + vk_pipeline pipeline_gelu_quick_f32; + vk_pipeline pipeline_silu_f32; + vk_pipeline pipeline_relu_f32; + vk_pipeline pipeline_leaky_relu_f32; + vk_pipeline pipeline_tanh_f32; + vk_pipeline pipeline_diag_mask_inf_f32; + vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16; + vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512; + vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16; + vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16; + vk_pipeline pipeline_argsort_f32; + vk_pipeline pipeline_sum_rows_f32; + vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16; + vk_pipeline pipeline_timestep_embedding_f32; + vk_pipeline pipeline_pool2d_f32; + vk_pipeline pipeline_rwkv_wkv6_f32; + + // [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned} + vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D80[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D96[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D112[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D128[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D256[GGML_TYPE_COUNT][2][2][2]; + + std::unordered_map pipelines; + std::unordered_map pipeline_descriptor_set_requirements; + + std::vector> pinned_memory; + + vk::Fence fence; + vk_buffer sync_staging; + + ggml_backend_buffer_type buffer_type; + +#ifdef GGML_VULKAN_MEMORY_DEBUG + std::unique_ptr memory_logger; +#endif +#ifdef GGML_VULKAN_PERF + std::unique_ptr perf_logger; +#endif + + ~vk_device_struct() { + VK_LOG_DEBUG("destroy device " << name); + + device.destroyFence(fence); + + ggml_vk_destroy_buffer(sync_staging); + + device.destroyCommandPool(compute_queue.pool); + if (!single_queue) { + device.destroyCommandPool(transfer_queue.pool); + } + + for (auto& pipeline : pipelines) { + if (pipeline.second.expired()) { + continue; + } + + vk_pipeline pl = pipeline.second.lock(); + ggml_vk_destroy_pipeline(device, pl); + } + pipelines.clear(); + + device.destroy(); + } +}; + +struct vk_buffer_struct { + vk::Buffer buffer = VK_NULL_HANDLE; + vk::DeviceMemory device_memory = VK_NULL_HANDLE; + vk::MemoryPropertyFlags memory_property_flags; + void * ptr; + size_t size = 0; + + vk_device device; + + ~vk_buffer_struct() { + if (size == 0) { + return; + } + VK_LOG_DEBUG("~vk_buffer_struct(" << buffer << ", " << size << ")"); + + device->device.freeMemory(device_memory); + device->device.destroyBuffer(buffer); + } +}; + +struct vk_subbuffer { + vk_buffer buffer; + uint64_t offset; + uint64_t size; + + operator vk::DescriptorBufferInfo() const { + return { buffer->buffer, offset, size }; + } +}; + +struct vk_semaphore { + vk::Semaphore s; + uint64_t value; +}; + +struct vk_submission { + vk::CommandBuffer buffer; + std::vector wait_semaphores; + std::vector signal_semaphores; +}; + +typedef std::vector vk_sequence; + +struct vk_mat_mat_push_constants { + uint32_t M; uint32_t N; uint32_t K; + uint32_t stride_a; uint32_t stride_b; uint32_t stride_d; + uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d; + uint32_t k_split; + uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3; +}; +struct vk_mat_vec_push_constants { + uint32_t ncols; uint32_t stride_a; uint32_t stride_b; uint32_t stride_d; + uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d; + uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3; +}; + +struct vk_mat_mat_id_push_constants { + uint32_t M; uint32_t N; uint32_t K; + uint32_t stride_a; uint32_t stride_b; uint32_t stride_d; + uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d; + uint32_t nei0; uint32_t nei1; uint32_t nbi1; uint32_t ne11; +}; +struct vk_mat_vec_id_push_constants { + uint32_t ncols; uint32_t stride_a; uint32_t stride_b; uint32_t stride_d; + uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d; + uint32_t nei0; uint32_t ne11; +}; + +struct vk_flash_attn_push_constants { + uint32_t N; + uint32_t KV; + + uint32_t ne1; + uint32_t ne2; + uint32_t ne3; + + uint32_t neq2; + uint32_t neq3; + uint32_t nek2; + uint32_t nek3; + uint32_t nev2; + uint32_t nev3; + uint32_t nem1; + + uint32_t nb02; + uint32_t nb03; + uint32_t nb12; + uint32_t nb13; + uint32_t nb22; + uint32_t nb23; + uint32_t nb31; + + float scale; + float max_bias; + float logit_softcap; + + uint32_t mask; + uint32_t n_head_log2; + float m0; + float m1; +}; + +struct vk_op_push_constants { + uint32_t KX; + uint32_t KY; + float param1; + float param2; +}; + +struct vk_op_unary_push_constants { + uint32_t ne; + uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03; + uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13; + uint32_t misalign_offsets; + float param1; float param2; + uint32_t ne0_012mp; uint32_t ne0_012L; + uint32_t ne0_01mp; uint32_t ne0_01L; + uint32_t ne0_0mp; uint32_t ne0_0L; + uint32_t ne1_012mp; uint32_t ne1_012L; + uint32_t ne1_01mp; uint32_t ne1_01L; + uint32_t ne1_0mp; uint32_t ne1_0L; +}; +static_assert(sizeof(vk_op_unary_push_constants) <= 128, "sizeof(vk_op_unary_push_constants) must be <= 128"); + +// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1. +// Precompute mp (m' in the paper) and L such that division +// can be computed using a multiply (high 32b of 64b result) +// and a shift: +// +// n/d = (mulhi(n, mp) + n) >> L; +static void init_fastdiv_values(uint32_t d, uint32_t &mp, uint32_t &L) +{ + // compute L = ceil(log2(d)); + L = 0; + while (L < 32 && (uint32_t{1} << L) < d) { + L++; + } + + mp = (uint32_t)((uint64_t{1} << 32) * ((uint64_t{1} << L) - d) / d + 1); +} + +template void init_pushconst_fastdiv(T &p) { + GGML_UNUSED(p); + static_assert(!std::is_const::value, "unexpected type"); +} + +template <> void init_pushconst_fastdiv(vk_op_unary_push_constants &p) { + // Compute magic values to divide by these six numbers. + init_fastdiv_values(p.ne02*p.ne01*p.ne00, p.ne0_012mp, p.ne0_012L); + init_fastdiv_values(p.ne01*p.ne00, p.ne0_01mp, p.ne0_01L); + init_fastdiv_values(p.ne00, p.ne0_0mp, p.ne0_0L); + init_fastdiv_values(p.ne12*p.ne11*p.ne10, p.ne1_012mp, p.ne1_012L); + init_fastdiv_values(p.ne11*p.ne10, p.ne1_01mp, p.ne1_01L); + init_fastdiv_values(p.ne10, p.ne1_0mp, p.ne1_0L); +} + +struct vk_op_binary_push_constants { + uint32_t ne; + uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03; + uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13; + uint32_t ne20; uint32_t ne21; uint32_t ne22; uint32_t ne23; uint32_t nb20; uint32_t nb21; uint32_t nb22; uint32_t nb23; + uint32_t misalign_offsets; + float param1; float param2; int32_t param3; +}; + +struct vk_op_diag_mask_push_constants { + uint32_t ncols; + uint32_t rows_per_channel; + int32_t n_past; +}; + +struct vk_op_rope_push_constants { + uint32_t ncols; + uint32_t n_dims; + float freq_scale; + uint32_t p_delta_rows; + float freq_base; + float ext_factor; + float attn_factor; + float corr_dims[2]; + float theta_scale; + uint32_t has_ff; +}; + +struct vk_op_soft_max_push_constants { + uint32_t KX; + uint32_t KY; + float scale; + float max_bias; + float m0; + float m1; + uint32_t n_head_log2; + uint32_t nrows_x; +}; + +struct vk_op_argsort_push_constants { + uint32_t ncols; + uint32_t ncols_pad; + int32_t order; +}; + +struct vk_op_im2col_push_constants { + uint32_t batch_offset; uint32_t offset_delta; + uint32_t IC; + uint32_t IW; uint32_t IH; + uint32_t OW; uint32_t OH; + uint32_t KW; uint32_t KH; + uint32_t pelements; + uint32_t CHW; + int32_t s0; int32_t s1; + int32_t p0; int32_t p1; + int32_t d0; int32_t d1; +}; + +struct vk_op_timestep_embedding_push_constants { + uint32_t nb1; + uint32_t dim; + uint32_t max_period; +}; + +struct vk_op_pool2d_push_constants { + uint32_t IW; uint32_t IH; + uint32_t OW; uint32_t OH; + uint32_t OC; + uint32_t pelements; + uint32_t op; + int32_t k0; int32_t k1; + int32_t s0; int32_t s1; + int32_t p0; int32_t p1; +}; + +struct vk_op_rwkv_wkv6_push_constants { + uint32_t B; + uint32_t T; + uint32_t C; + uint32_t H; +}; + +// Allow pre-recording command buffers +struct vk_staging_memcpy { + vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {} + + void * dst; + const void * src; + size_t n; +}; + +struct vk_op_upscale_push_constants { + uint32_t ne; uint32_t a_offset; uint32_t d_offset; + uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03; + uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; + float sf0; float sf1; float sf2; float sf3; +}; + +struct vk_context_struct { + vk_submission * s; + std::vector seqs; + + int exit_tensor_idx; + + std::vector in_memcpys; + std::vector out_memcpys; + + vk_queue * q; +}; +typedef std::shared_ptr vk_context; +typedef std::weak_ptr vk_context_ref; + +struct ggml_vk_garbage_collector { + std::vector tl_semaphores; + std::vector semaphores; + std::vector events; + std::vector temp_buffers; + std::vector contexts; +}; + +#if defined(GGML_VULKAN_MEMORY_DEBUG) || defined(GGML_VULKAN_DEBUG) +#define VK_LOG_MEMORY(msg) std::cerr << "ggml_vulkan memory: " << msg << std::endl + +static std::string format_size(size_t size) { + const size_t kib = 1024; + const size_t mib = kib * 1024; + const size_t gib = mib * 1024; + + std::ostringstream oss; + oss << std::fixed << std::setprecision(2); + + if (size >= gib) { + oss << static_cast(size) / gib << " GiB"; + } else if (size >= mib) { + oss << static_cast(size) / mib << " MiB"; + } else if (size >= kib) { + oss << static_cast(size) / kib << " KiB"; + } else { + oss << size << " B"; + } + + return oss.str(); +} + +static std::mutex log_mutex; + +class vk_memory_logger { +public: + vk_memory_logger(): total_device(0), total_host(0) {} + void log_allocation(vk_buffer_ref buf_ref, size_t size); + void log_deallocation(vk_buffer_ref buf_ref); + +private: + std::map allocations; // Track allocations + size_t total_device; + size_t total_host; +}; +#else +#define VK_LOG_MEMORY(msg) ((void) 0) +#endif // GGML_VULKAN_MEMORY_DEBUG + +#if defined(GGML_VULKAN_PERF) + +class vk_perf_logger { +public: + void print_timings() { + std::cerr << "----------------\nVulkan Timings:" << std::endl; + for (const auto& t : timings) { + uint64_t total = 0; + for (const auto& time : t.second) { + total += time; + } + std::cerr << t.first << ": " << t.second.size() << " x " << (total / t.second.size() / 1000.0) << " ms" << std::endl; + } + + timings.clear(); + } + + void log_timing(const ggml_tensor * node, uint64_t time) { + if (node->op == GGML_OP_UNARY) { + timings[ggml_unary_op_name(ggml_get_unary_op(node))].push_back(time); + return; + } + if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) { + const uint64_t m = node->src[0]->ne[1]; + const uint64_t n = node->src[1]->ne[1]; + const uint64_t k = node->src[1]->ne[0]; + std::string name = ggml_op_name(node->op); + if (n == 1) { + name += "_VEC m=" + std::to_string(m) + " k=" + std::to_string(k); + } else { + name += " m=" + std::to_string(m) + " n=" + std::to_string(n) + " k=" + std::to_string(k); + } + timings[name].push_back(time); + return; + } + timings[ggml_op_name(node->op)].push_back(time); + } +private: + std::map> timings; +}; +#endif // GGML_VULKAN_PERF + +struct ggml_backend_vk_context { + std::string name; + + vk_device device; + + size_t semaphore_idx, event_idx; + ggml_vk_garbage_collector gc; + size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k; + vk_buffer prealloc_x, prealloc_y, prealloc_split_k; + vk::Fence fence; + + vk_buffer buffer_pool[MAX_VK_BUFFERS]; + + vk_context_ref compute_ctx; + vk_context_ref transfer_ctx; + + std::vector tensor_ctxs; +}; + +static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT + +static uint64_t vk_tensor_offset(const ggml_tensor * tensor) { + if (tensor->view_src) { + return (uint8_t *) tensor->view_src->data - (uint8_t *) vk_ptr_base; + } + return (uint8_t *) tensor->data - (uint8_t *) vk_ptr_base; +} + +struct ggml_backend_vk_buffer_context { + vk_device_ref device; + vk_buffer dev_buffer; + std::string name; + + ggml_backend_vk_buffer_context(vk_device_ref device, vk_buffer&& dev_buffer, std::string& name) : + device(device), + dev_buffer(dev_buffer), + name(name) { + } + + ~ggml_backend_vk_buffer_context() { + ggml_vk_destroy_buffer(dev_buffer); + } +}; + +#ifdef GGML_VULKAN_MEMORY_DEBUG +void vk_memory_logger::log_allocation(vk_buffer_ref buf_ref, size_t size) { + std::lock_guard guard(log_mutex); + vk_buffer buf = buf_ref.lock(); + const bool device = bool(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eDeviceLocal); + const std::string type = device ? "device" : "host"; + allocations[buf->buffer] = size; + total_device += device ? size : 0; + total_host += device ? 0 : size; + VK_LOG_MEMORY(buf->device->name << ": +" << format_size(size) << " " << type << " at " << buf->buffer << ". Total device: " << format_size(total_device) << ", total host: " << format_size(total_host)); +} + +void vk_memory_logger::log_deallocation(vk_buffer_ref buf_ref) { + if (buf_ref.expired() || buf_ref.lock()->size == 0) { + return; + } + + std::lock_guard guard(log_mutex); + vk_buffer buf = buf_ref.lock(); + const bool device = bool(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eDeviceLocal); + std::string type = device ? "device" : "host"; + auto it = allocations.find(buf->buffer); + total_device -= device ? it->second : 0; + total_host -= device ? 0 : it->second; + if (it != allocations.end()) { + VK_LOG_MEMORY(buf->device->name << ": -" << format_size(it->second) << " " << type << " at " << buf->buffer << ". Total device: " << format_size(total_device) << ", total host: " << format_size(total_host)); + allocations.erase(it); + } else { + VK_LOG_MEMORY("ERROR " << buf->device->name << ": Attempted to deallocate unknown " << type << " memory at " << buf->buffer); + } +} +#endif // GGML_VULKAN_MEMORY_DEBUG + +struct vk_instance_t { + vk::Instance instance; + + std::vector device_indices; + vk_device devices[GGML_VK_MAX_DEVICES]; +}; + +static bool vk_instance_initialized = false; +static vk_instance_t vk_instance; + +#ifdef GGML_VULKAN_CHECK_RESULTS +static size_t vk_skip_checks; +static size_t vk_output_tensor; + +static void ggml_vk_print_tensor(const ggml_tensor * tensor, const char * name); +static void ggml_vk_check_results_0(ggml_tensor * tensor); +static void ggml_vk_check_results_1(ggml_tensor * tensor); +#endif + +typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); + +static void ggml_backend_vk_free(ggml_backend_t backend); + +// variables to track number of compiles in progress +static uint32_t compile_count = 0; +static std::mutex compile_count_mutex; +static std::condition_variable compile_count_cond; + +static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, const std::string name, size_t spv_size, const void* spv_data, const std::string entrypoint, + uint32_t parameter_count, uint32_t push_constant_size, std::array wg_denoms, std::vector specialization_constants, + uint32_t align, bool disable_robustness, bool require_full_subgroups, uint32_t required_subgroup_size) { + VK_LOG_DEBUG("ggml_vk_create_pipeline(" << device->name << ", " << name << ", " << entrypoint << ", " << parameter_count << ", " << push_constant_size << + ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << align << + ", " << disable_robustness << ", " << require_full_subgroups << ", " << required_subgroup_size << ")"); + GGML_ASSERT(parameter_count > 0); + GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT + + pipeline = std::make_shared(); + pipeline->name = name; + pipeline->parameter_count = parameter_count; + pipeline->push_constant_size = push_constant_size; + pipeline->wg_denoms = wg_denoms; + pipeline->align = align; + + vk::ShaderModuleCreateInfo shader_module_create_info({}, spv_size, reinterpret_cast(spv_data)); + pipeline->shader_module = device->device.createShaderModule(shader_module_create_info); + + std::vector dsl_binding; + std::vector dsl_binding_flags; + for (uint32_t i = 0; i < parameter_count; i++) { + dsl_binding.push_back({i, vk::DescriptorType::eStorageBuffer, 1, vk::ShaderStageFlagBits::eCompute}); + dsl_binding_flags.push_back({}); + } + + vk::DescriptorSetLayoutBindingFlagsCreateInfo dslbfci = { dsl_binding_flags }; + + vk::PushConstantRange pcr( + vk::ShaderStageFlagBits::eCompute, + 0, + pipeline->push_constant_size + ); + + vk::DescriptorSetLayoutCreateInfo descriptor_set_layout_create_info( + {}, + dsl_binding); + descriptor_set_layout_create_info.setPNext(&dslbfci); + pipeline->dsl = device->device.createDescriptorSetLayout(descriptor_set_layout_create_info); + + vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline->parameter_count * VK_DEVICE_DESCRIPTOR_POOL_SIZE); + vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, VK_DEVICE_DESCRIPTOR_POOL_SIZE, descriptor_pool_size); + pipeline->descriptor_pools.push_back(device->device.createDescriptorPool(descriptor_pool_create_info)); + + pipeline->descriptor_set_idx = 0; + + vk::PipelineLayoutCreateInfo pipeline_layout_create_info(vk::PipelineLayoutCreateFlags(), pipeline->dsl, pcr); + pipeline->layout = device->device.createPipelineLayout(pipeline_layout_create_info); + + std::vector specialization_entries(specialization_constants.size()); + + for (size_t i = 0; i < specialization_constants.size(); i++) { + specialization_entries[i].constantID = i; + specialization_entries[i].offset = i * sizeof(uint32_t); + specialization_entries[i].size = sizeof(uint32_t); + } + + vk::SpecializationInfo specialization_info( + specialization_entries.size(), + specialization_entries.data(), + specialization_constants.size() * sizeof(uint32_t), + specialization_constants.data() + ); + + vk::PipelineShaderStageCreateFlags pipeline_shader_stage_create_flags{}; + + if (device->subgroup_require_full_support && require_full_subgroups) { + pipeline_shader_stage_create_flags |= vk::PipelineShaderStageCreateFlagBits::eRequireFullSubgroupsEXT; + } + + vk::PipelineShaderStageCreateInfo pipeline_shader_create_info( + pipeline_shader_stage_create_flags, + vk::ShaderStageFlagBits::eCompute, + pipeline->shader_module, + entrypoint.c_str(), + &specialization_info); + + vk::PipelineShaderStageRequiredSubgroupSizeCreateInfoEXT pipeline_shader_stage_required_subgroup_size_create_info; + pipeline_shader_stage_required_subgroup_size_create_info.requiredSubgroupSize = required_subgroup_size; + if (device->subgroup_size_control && required_subgroup_size > 0) { + GGML_ASSERT(device->subgroup_min_size <= required_subgroup_size && required_subgroup_size <= device->subgroup_max_size); + pipeline_shader_create_info.setPNext(&pipeline_shader_stage_required_subgroup_size_create_info); + } + + vk::ComputePipelineCreateInfo compute_pipeline_create_info( + vk::PipelineCreateFlags{}, + pipeline_shader_create_info, + pipeline->layout); + + vk::PipelineRobustnessCreateInfoEXT rci; + + if (device->pipeline_robustness && disable_robustness) { + rci.storageBuffers = vk::PipelineRobustnessBufferBehaviorEXT::eDisabled; + rci.uniformBuffers = vk::PipelineRobustnessBufferBehaviorEXT::eDisabled; + compute_pipeline_create_info.setPNext(&rci); + } + + pipeline->pipeline = device->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value; + + { + std::lock_guard guard(device->mutex); + device->pipelines.insert({ pipeline->name, pipeline }); + } + + { + std::lock_guard guard(compile_count_mutex); + assert(compile_count > 0); + compile_count--; + + // "Progress bar" for shader compiles + static uint32_t total_compile_count = 0; + if ((total_compile_count++ % 10) == 0) { + std::cerr << "."; + } + } + compile_count_cond.notify_all(); +} + +static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline) { + VK_LOG_DEBUG("ggml_pipeline_destroy_pipeline(" << pipeline->name << ")"); + for (auto& pool : pipeline->descriptor_pools) { + device.destroyDescriptorPool(pool); + } + pipeline->descriptor_pools.clear(); + pipeline->descriptor_sets.clear(); + pipeline->descriptor_set_idx = 0; + + device.destroyDescriptorSetLayout(pipeline->dsl); + + device.destroyPipelineLayout(pipeline->layout); + + device.destroyShaderModule(pipeline->shader_module); + + device.destroyPipeline(pipeline->pipeline); +} + +static void ggml_pipeline_request_descriptor_sets(vk_device& device, vk_pipeline& pipeline, uint32_t n) { + VK_LOG_DEBUG("ggml_pipeline_request_descriptor_sets(" << pipeline->name << ", " << n << ")"); + device->pipeline_descriptor_set_requirements[pipeline->name] += n; +} + +static void ggml_pipeline_allocate_descriptor_sets(vk_device& device) { + std::lock_guard guard(device->mutex); + + for (auto& pair : device->pipeline_descriptor_set_requirements) { + vk_pipeline pipeline = device->pipelines.at(pair.first).lock(); + const uint64_t n = pair.second; + + VK_LOG_DEBUG("ggml_pipeline_allocate_descriptor_sets(" << pipeline->name << ", " << n << ")"); + + if (pipeline->descriptor_sets.size() >= pipeline->descriptor_set_idx + n) { + // Enough descriptors are available + continue; + } + + uint32_t to_alloc = pipeline->descriptor_set_idx + n - pipeline->descriptor_sets.size(); + uint32_t pool_remaining = VK_DEVICE_DESCRIPTOR_POOL_SIZE - pipeline->descriptor_sets.size() % VK_DEVICE_DESCRIPTOR_POOL_SIZE; + uint32_t pool_idx = pipeline->descriptor_sets.size() / VK_DEVICE_DESCRIPTOR_POOL_SIZE; + + while (to_alloc > 0) { + const uint32_t alloc_count = std::min(pool_remaining, to_alloc); + to_alloc -= alloc_count; + pool_remaining = VK_DEVICE_DESCRIPTOR_POOL_SIZE; + + if (pool_idx >= pipeline->descriptor_pools.size()) { + vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline->parameter_count * VK_DEVICE_DESCRIPTOR_POOL_SIZE); + vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, VK_DEVICE_DESCRIPTOR_POOL_SIZE, descriptor_pool_size); + pipeline->descriptor_pools.push_back(device->device.createDescriptorPool(descriptor_pool_create_info)); + } + + std::vector layouts(alloc_count); + for (uint32_t i = 0; i < alloc_count; i++) { + layouts[i] = pipeline->dsl; + } + vk::DescriptorSetAllocateInfo descriptor_set_alloc_info(pipeline->descriptor_pools[pool_idx], alloc_count, layouts.data()); + std::vector sets = device->device.allocateDescriptorSets(descriptor_set_alloc_info); + pipeline->descriptor_sets.insert(pipeline->descriptor_sets.end(), sets.begin(), sets.end()); + + pool_idx++; + } + } +} + +static void ggml_pipeline_cleanup(vk_pipeline& pipeline) { + VK_LOG_DEBUG("ggml_pipeline_cleanup(" << pipeline->name << ")"); + pipeline->descriptor_set_idx = 0; +} + +static vk::CommandBuffer ggml_vk_create_cmd_buffer(vk_device& device, vk_queue& q) { + VK_LOG_DEBUG("ggml_vk_create_cmd_buffer()"); + std::lock_guard guard(device->mutex); + + if (q.cmd_buffers.size() > q.cmd_buffer_idx) { + // Reuse command buffer + return q.cmd_buffers[q.cmd_buffer_idx++]; + } + + vk::CommandBufferAllocateInfo command_buffer_alloc_info( + q.pool, + vk::CommandBufferLevel::ePrimary, + 1); + const std::vector cmd_buffers = device->device.allocateCommandBuffers(command_buffer_alloc_info); + auto buf = cmd_buffers.front(); + + q.cmd_buffers.push_back(buf); + q.cmd_buffer_idx++; + + return buf; +} + +static vk_submission ggml_vk_create_submission(vk_device& device, vk_queue& q, std::vector wait_semaphores, std::vector signal_semaphores) { + VK_LOG_DEBUG("ggml_vk_create_submission()"); + vk_submission s; + s.buffer = ggml_vk_create_cmd_buffer(device, q); + s.wait_semaphores = std::move(wait_semaphores); + s.signal_semaphores = std::move(signal_semaphores); + return s; +} + +static void ggml_vk_submit(vk_context& ctx, vk::Fence fence) { + if (ctx->seqs.empty()) { + if (fence) { + ctx->q->queue.submit({}, fence); + } + return; + } + VK_LOG_DEBUG("ggml_vk_submit(" << ctx << ", " << fence << ")"); + + std::vector> tl_wait_vals; + std::vector> tl_signal_vals; + std::vector> tl_wait_semaphores; + std::vector> tl_signal_semaphores; + std::vector tl_submit_infos; + std::vector submit_infos; + int idx = -1; + std::vector> stage_flags; + + size_t reserve = 0; + + for (const auto& sequence : ctx->seqs) { + reserve += sequence.size(); + } + + // Pre-reserve vectors to prevent reallocation, which invalidates pointers + tl_wait_semaphores.reserve(reserve); + tl_wait_vals.reserve(reserve); + tl_signal_semaphores.reserve(reserve); + tl_signal_vals.reserve(reserve); + tl_submit_infos.reserve(reserve); + submit_infos.reserve(reserve); + stage_flags.reserve(reserve); + + for (const auto& sequence : ctx->seqs) { + for (const auto& submission : sequence) { + stage_flags.push_back({}); + idx++; + tl_wait_vals.push_back({}); + tl_wait_semaphores.push_back({}); + tl_signal_vals.push_back({}); + tl_signal_semaphores.push_back({}); + for (size_t i = 0; i < submission.wait_semaphores.size(); i++) { + stage_flags[idx].push_back(ctx->q->stage_flags); + tl_wait_vals[idx].push_back(submission.wait_semaphores[i].value); + tl_wait_semaphores[idx].push_back(submission.wait_semaphores[i].s); + } + for (size_t i = 0; i < submission.signal_semaphores.size(); i++) { + tl_signal_vals[idx].push_back(submission.signal_semaphores[i].value); + tl_signal_semaphores[idx].push_back(submission.signal_semaphores[i].s); + } + tl_submit_infos.push_back({ + (uint32_t) submission.wait_semaphores.size(), + tl_wait_vals[idx].data(), + (uint32_t) submission.signal_semaphores.size(), + tl_signal_vals[idx].data(), + }); + tl_submit_infos[idx].sType = vk::StructureType::eTimelineSemaphoreSubmitInfo; + tl_submit_infos[idx].pNext = nullptr; + vk::SubmitInfo si{ + (uint32_t) submission.wait_semaphores.size(), + tl_wait_semaphores[idx].data(), + stage_flags[idx].data(), + 1, + &submission.buffer, + (uint32_t) submission.signal_semaphores.size(), + tl_signal_semaphores[idx].data(), + }; + si.setPNext(&tl_submit_infos[idx]); + submit_infos.push_back(si); + } + } + + ctx->q->queue.submit(submit_infos, fence); + + ctx->seqs.clear(); +} + +static uint32_t ggml_vk_find_queue_family_index(std::vector& queue_family_props, const vk::QueueFlags& required, const vk::QueueFlags& avoid, int32_t compute_index, uint32_t min_num_queues) { + VK_LOG_DEBUG("ggml_vk_find_queue_family_index()"); + const uint32_t qfsize = queue_family_props.size(); + + // Try with avoid preferences first + for (uint32_t i = 0; i < qfsize; i++) { + if (queue_family_props[i].queueCount >= min_num_queues && (compute_index < 0 || i != (uint32_t) compute_index) && queue_family_props[i].queueFlags & required && !(queue_family_props[i].queueFlags & avoid)) { + return i; + } + } + + // Fall back to only required + for (size_t i = 0; i < qfsize; i++) { + if (queue_family_props[i].queueCount >= min_num_queues && (compute_index < 0 || i != (uint32_t) compute_index) && queue_family_props[i].queueFlags & required) { + return i; + } + } + + // Fall back to reusing compute queue + for (size_t i = 0; i < qfsize; i++) { + if (queue_family_props[i].queueCount >= min_num_queues && queue_family_props[i].queueFlags & required) { + return i; + } + } + + // Fall back to ignoring min_num_queries + for (size_t i = 0; i < qfsize; i++) { + if (queue_family_props[i].queueFlags & required) { + return i; + } + } + + // All commands that are allowed on a queue that supports transfer operations are also allowed on a queue that supports either graphics or compute operations. + // Thus, if the capabilities of a queue family include VK_QUEUE_GRAPHICS_BIT or VK_QUEUE_COMPUTE_BIT, then reporting the VK_QUEUE_TRANSFER_BIT capability separately for that queue family is optional. + if (compute_index >= 0) { + return compute_index; + } + + std::cerr << "ggml_vulkan: No suitable queue family index found." << std::endl; + + for(auto &q_family : queue_family_props) { + std::cerr << "Queue number: " + std::to_string(q_family.queueCount) << " flags: " + to_string(q_family.queueFlags) << std::endl; + } + abort(); +} + +static void ggml_vk_create_queue(vk_device& device, vk_queue& q, uint32_t queue_family_index, uint32_t queue_index, vk::PipelineStageFlags&& stage_flags, bool transfer_only) { + VK_LOG_DEBUG("ggml_vk_create_queue()"); + std::lock_guard guard(device->mutex); + + q.queue_family_index = queue_family_index; + q.transfer_only = transfer_only; + + vk::CommandPoolCreateInfo command_pool_create_info_compute(vk::CommandPoolCreateFlags(VK_COMMAND_POOL_CREATE_TRANSIENT_BIT), queue_family_index); + q.pool = device->device.createCommandPool(command_pool_create_info_compute); + + q.cmd_buffer_idx = 0; + + q.queue = device->device.getQueue(queue_family_index, queue_index); + + q.stage_flags = stage_flags; +} + +static vk_context ggml_vk_create_context(ggml_backend_vk_context * ctx, vk_queue& q) { + vk_context result = std::make_shared(); + VK_LOG_DEBUG("ggml_vk_create_context(" << result << ")"); + ctx->gc.contexts.emplace_back(result); + result->q = &q; + return result; +} + +static vk_context ggml_vk_create_temporary_context(vk_queue& q) { + vk_context result = std::make_shared(); + VK_LOG_DEBUG("ggml_vk_create_temporary_context(" << result << ")"); + result->q = &q; + return result; +} + +static vk_semaphore * ggml_vk_create_binary_semaphore(ggml_backend_vk_context * ctx) { + VK_LOG_DEBUG("ggml_vk_create_timeline_semaphore()"); + vk::SemaphoreTypeCreateInfo tci{ vk::SemaphoreType::eBinary, 0 }; + vk::SemaphoreCreateInfo ci{}; + ci.setPNext(&tci); + vk::Semaphore semaphore = ctx->device->device.createSemaphore(ci); + ctx->gc.semaphores.push_back({ semaphore, 0 }); + return &ctx->gc.semaphores[ctx->gc.semaphores.size() - 1]; +} + +static vk_semaphore * ggml_vk_create_timeline_semaphore(ggml_backend_vk_context * ctx) { + VK_LOG_DEBUG("ggml_vk_create_timeline_semaphore()"); + if (ctx->semaphore_idx >= ctx->gc.tl_semaphores.size()) { + vk::SemaphoreTypeCreateInfo tci{ vk::SemaphoreType::eTimeline, 0 }; + vk::SemaphoreCreateInfo ci{}; + ci.setPNext(&tci); + vk::Semaphore semaphore = ctx->device->device.createSemaphore(ci); + ctx->gc.tl_semaphores.push_back({ semaphore, 0 }); + } + return &ctx->gc.tl_semaphores[ctx->semaphore_idx++]; +} + +static vk::Event ggml_vk_create_event(ggml_backend_vk_context * ctx) { + if (ctx->event_idx >= ctx->gc.events.size()) { + ctx->gc.events.push_back(ctx->device->device.createEvent({})); + } + return ctx->gc.events[ctx->event_idx++]; +} + +static void ggml_vk_queue_cleanup(vk_device& device, vk_queue& q) { + VK_LOG_DEBUG("ggml_vk_queue_cleanup()"); + std::lock_guard guard(device->mutex); + + // Requires command buffers to be done + device->device.resetCommandPool(q.pool); + q.cmd_buffer_idx = 0; +} + +static uint32_t find_properties(const vk::PhysicalDeviceMemoryProperties* mem_props, vk::MemoryRequirements* mem_req, vk::MemoryPropertyFlags flags) { + for (uint32_t i = 0; i < mem_props->memoryTypeCount; ++i) { + vk::MemoryType memory_type = mem_props->memoryTypes[i]; + if ((mem_req->memoryTypeBits & ((uint64_t)1 << i)) && + (flags & memory_type.propertyFlags) == flags && + mem_props->memoryHeaps[memory_type.heapIndex].size >= mem_req->size) { + return static_cast(i); + } + } + return UINT32_MAX; +} + +static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, vk::MemoryPropertyFlags req_flags, vk::MemoryPropertyFlags fallback_flags = vk::MemoryPropertyFlags(0)) { + VK_LOG_DEBUG("ggml_vk_create_buffer(" << device->name << ", " << size << ", " << to_string(req_flags) << ", " << to_string(fallback_flags) << ")"); + if (size > device->max_memory_allocation_size) { + throw vk::OutOfDeviceMemoryError("Requested buffer size exceeds device memory allocation limit"); + } + + std::lock_guard guard(device->mutex); + + vk_buffer buf = std::make_shared(); + + if (size == 0) { + buf->size = 0; + return buf; + } + + vk::BufferCreateInfo buffer_create_info{ + vk::BufferCreateFlags(), + size, + vk::BufferUsageFlagBits::eStorageBuffer | vk::BufferUsageFlagBits::eTransferSrc | vk::BufferUsageFlagBits::eTransferDst, + vk::SharingMode::eExclusive, + 0, + nullptr, + }; + + buf->buffer = device->device.createBuffer(buffer_create_info); + + vk::MemoryRequirements mem_req = device->device.getBufferMemoryRequirements(buf->buffer); + + vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties(); + + uint32_t memory_type_index = UINT32_MAX; + + memory_type_index = find_properties(&mem_props, &mem_req, req_flags); + buf->memory_property_flags = req_flags; + + if (memory_type_index == UINT32_MAX && fallback_flags) { + memory_type_index = find_properties(&mem_props, &mem_req, fallback_flags); + buf->memory_property_flags = fallback_flags; + } + + if (memory_type_index == UINT32_MAX) { + device->device.destroyBuffer(buf->buffer); + throw vk::OutOfDeviceMemoryError("No suitable memory type found"); + } + + try { + buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index }); + } catch (const vk::SystemError& e) { + if (buf->memory_property_flags != fallback_flags) { + // Try again with fallback flags + memory_type_index = find_properties(&mem_props, &mem_req, fallback_flags); + buf->memory_property_flags = fallback_flags; + + try { + buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index }); + } + catch (const vk::SystemError& e) { + device->device.destroyBuffer(buf->buffer); + throw e; + } + } else { + // Out of Host/Device memory, clean up buffer + device->device.destroyBuffer(buf->buffer); + throw e; + } + } + buf->ptr = nullptr; + + if (buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) { + buf->ptr = device->device.mapMemory(buf->device_memory, 0, VK_WHOLE_SIZE); + } + + device->device.bindBufferMemory(buf->buffer, buf->device_memory, 0); + + buf->device = device; + buf->size = size; + +#ifdef GGML_VULKAN_MEMORY_DEBUG + device->memory_logger->log_allocation(buf, size); +#endif + + return buf; +} + +static vk_buffer ggml_vk_create_buffer_check(vk_device& device, size_t size, vk::MemoryPropertyFlags req_flags, vk::MemoryPropertyFlags fallback_flags = vk::MemoryPropertyFlags(0)) { + try { + return ggml_vk_create_buffer(device, size, req_flags, fallback_flags); + } catch (const vk::SystemError& e) { + std::cerr << "ggml_vulkan: Memory allocation of size " << size << " failed." << std::endl; + std::cerr << "ggml_vulkan: " << e.what() << std::endl; + throw e; + } +} + +static vk_buffer ggml_vk_create_buffer_device(vk_device& device, size_t size) { + vk_buffer buf; + try { + if (device->uma) { + // Fall back to host memory type + buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent); + } else { + // use rebar if available, otherwise fallback to device only visible memory + buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal | vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent, vk::MemoryPropertyFlagBits::eDeviceLocal); + } + } catch (const vk::SystemError& e) { + std::cerr << "ggml_vulkan: Device memory allocation of size " << size << " failed." << std::endl; + std::cerr << "ggml_vulkan: " << e.what() << std::endl; + throw e; + } + + return buf; +} + +static void ggml_vk_destroy_buffer(vk_buffer& buf) { + if (buf == nullptr) { + return; + } + +#ifdef GGML_VULKAN_MEMORY_DEBUG + if (buf->device != nullptr) { + buf->device->memory_logger->log_deallocation(buf); + } +#endif + + buf.reset(); +} + +static vk_subbuffer ggml_vk_subbuffer(vk_buffer& buf) { + return { buf, 0, VK_WHOLE_SIZE }; +} + +static void ggml_vk_sync_buffers(vk_context& ctx) { + VK_LOG_DEBUG("ggml_vk_sync_buffers()"); + + const bool transfer_queue = ctx->q->transfer_only; + + ctx->s->buffer.pipelineBarrier( + ctx->q->stage_flags, + ctx->q->stage_flags, + {}, + { { + { !transfer_queue ? (vk::AccessFlagBits::eShaderRead | vk::AccessFlagBits::eShaderWrite | vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) : (vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) }, + { !transfer_queue ? (vk::AccessFlagBits::eShaderRead | vk::AccessFlagBits::eShaderWrite | vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) : (vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) } + } }, + {}, + {} + ); +} + +static void ggml_vk_wait_events(vk_context& ctx, std::vector&& events) { + VK_LOG_DEBUG("ggml_vk_wait_events()"); + if (events.empty()) { + return; + } + + ctx->s->buffer.waitEvents( + events, + ctx->q->stage_flags, + ctx->q->stage_flags, + {}, + {}, + {} + ); +} + +// number of rows/cols for flash attention shader +static constexpr uint32_t flash_attention_num_small_rows = 32; +static std::array fa_rows_cols(uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) { + GGML_UNUSED(clamp); + + // small rows, large cols + if (small_rows) { + return {flash_attention_num_small_rows, 128}; + } + // small cols to reduce register count + if (ggml_is_quantized(type) || D == 256) { + return {64, 32}; + } + return {64, 64}; +}; + +static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector& warptile, bool mul_mat_id) { + // Needs to be kept up to date on shader changes + const uint32_t bank_conflict_offset = device->coopmat_support ? 8 : 1; + const uint32_t type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float); + const uint32_t warps = warptile[0] / warptile[10]; + + const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size; + const uint32_t mmid_row_ids = mul_mat_id ? 3072 * sizeof(uint32_t) : 0; + const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0; + + return (load_bufs + mmid_row_ids + coopmat_stage) <= device->properties.limits.maxComputeSharedMemorySize; +} + +static void ggml_vk_load_shaders(vk_device& device) { + VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")"); + + std::cerr << "ggml_vulkan: Compiling shaders"; + + // some shaders have a minimum subgroup size + const uint32_t subgroup_size_16 = std::max(device->subgroup_size, 16u); + const uint32_t subgroup_size_32 = std::max(device->subgroup_size, 32u); + + // mulmat + std::vector l_warptile, m_warptile, s_warptile, + l_warptile_mmq, m_warptile_mmq, s_warptile_mmq, + l_warptile_mmq_k, m_warptile_mmq_k, s_warptile_mmq_k, + l_warptile_mmqid, m_warptile_mmqid, s_warptile_mmqid; + std::array l_wg_denoms, m_wg_denoms, s_wg_denoms, + l_mmq_wg_denoms, m_mmq_wg_denoms, s_mmq_wg_denoms, + l_mmq_wg_denoms_k, m_mmq_wg_denoms_k, s_mmq_wg_denoms_k, + l_mmqid_wg_denoms, m_mmqid_wg_denoms, s_mmqid_wg_denoms; + + uint32_t l_align, m_align, s_align; + if (device->coopmat2) { + // spec constants and tile sizes for non-quant matmul/matmul_id + l_warptile = { 256, 128, 256, 64 }; + m_warptile = { 256, 128, 128, 64 }; + s_warptile = { 128, 64, 64, 64 }; + l_wg_denoms = {128, 256, 1 }; + m_wg_denoms = {128, 128, 1 }; + s_wg_denoms = { 64, 64, 1 }; + + // spec constants and tile sizes for quant matmul (non-Qi_K) + l_warptile_mmq = { 256, 128, 256, 64 }; + m_warptile_mmq = { 256, 128, 128, 64 }; + s_warptile_mmq = { 256, 128, 128, 64 }; + l_mmq_wg_denoms = { 128, 256, 1 }; + m_mmq_wg_denoms = { 128, 128, 1 }; + s_mmq_wg_denoms = { 128, 128, 1 }; + + // spec constants and tile sizes for quant matmul (Qi_K) + l_warptile_mmq_k = { 256, 128, 512, 16 }; + m_warptile_mmq_k = { 256, 128, 256, 16 }; + s_warptile_mmq_k = { 256, 32, 128, 64 }; + l_mmq_wg_denoms_k = { 128, 512, 1 }; + m_mmq_wg_denoms_k = { 128, 256, 1 }; + s_mmq_wg_denoms_k = { 32, 128, 1 }; + + // spec constants and tile sizes for quant matmul_id + l_warptile_mmqid = { 256, 128, 128, 16 }; + m_warptile_mmqid = { 256, 128, 64, 16 }; + s_warptile_mmqid = { 256, 64, 64, 16 }; + l_mmqid_wg_denoms = { 128, 128, 1 }; + m_mmqid_wg_denoms = { 128, 64, 1 }; + s_mmqid_wg_denoms = { 64, 64, 1 }; + + l_align = 128; + m_align = 64; + s_align = 32; + } else { + // Matrix cores require different warp group sizes + const uint32_t tm_l = device->coopmat_support ? device->coopmat_m : 4; + const uint32_t tm_m = device->coopmat_support ? device->coopmat_m : 4; + const uint32_t tm_s = device->coopmat_support ? device->coopmat_m : 2; + const uint32_t tn_l = device->coopmat_support ? device->coopmat_n : 4; + const uint32_t tn_m = device->coopmat_support ? device->coopmat_n : 2; + const uint32_t tn_s = device->coopmat_support ? device->coopmat_n : 2; + const uint32_t tk_l = device->coopmat_support ? device->coopmat_k : 1; + const uint32_t tk_m = device->coopmat_support ? device->coopmat_k : 1; + const uint32_t tk_s = device->coopmat_support ? device->coopmat_k : 1; + + l_warptile = { 128, 128, 128, 16, device->subgroup_size * 2, 64, 2, tm_l, tn_l, tk_l, device->subgroup_size }; + m_warptile = { 128, 64, 64, 16, device->subgroup_size, 32, 2, tm_m, tn_m, tk_m, device->subgroup_size }; + s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, device->subgroup_size }; + + l_warptile_mmq = { 128, 128, 128, 32, device->subgroup_size * 2, 64, 2, tm_l, tn_l, tk_l, device->subgroup_size }; + m_warptile_mmq = { 128, 64, 64, 32, device->subgroup_size, 32, 2, tm_m, tn_m, tk_m, device->subgroup_size }; + s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, device->subgroup_size }; + + l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 }; + m_mmq_wg_denoms = m_wg_denoms = { 64, 64, 1 }; + s_mmq_wg_denoms = s_wg_denoms = { 32, 32, 1 }; + l_align = 128; + m_align = 64; + s_align = 32; + + // Fallback to smaller sizes if there's not enough shared memory. Given the current shaders + // and tile sizes, this should handle 16KB, 32KB, and 48KB+. + // This logic doesn't explicitly account for the 12KB row_ids in the mul_mat_mat_id shaders. + // But the numbers happen to work out for 32KB shared memory size that when using the medium + // size there's enough room for everything, and we assert for this. + uint32_t shmem_needed = (l_warptile[1] + l_warptile[2]) * (l_warptile[3] + 1) * sizeof(float); + if (shmem_needed > device->properties.limits.maxComputeSharedMemorySize) { + l_warptile = m_warptile; + l_wg_denoms = m_wg_denoms; + shmem_needed = (l_warptile[1] + l_warptile[2]) * (l_warptile[3] + 1) * sizeof(float); + GGML_ASSERT(shmem_needed <= device->properties.limits.maxComputeSharedMemorySize); + } + if (device->properties.limits.maxComputeSharedMemorySize >= 32768) { + // assert mul_mat_mat_id shaders will fit. + GGML_ASSERT(shmem_needed + 3072*4 <= device->properties.limits.maxComputeSharedMemorySize); + } + + shmem_needed = (l_warptile_mmq[1] + l_warptile_mmq[2]) * (l_warptile_mmq[3] + 1) * sizeof(float); + if (shmem_needed > device->properties.limits.maxComputeSharedMemorySize) { + if (device->properties.limits.maxComputeSharedMemorySize == 32768) { + l_warptile_mmq = m_warptile_mmq; + l_mmq_wg_denoms = m_mmq_wg_denoms; + } else { + l_warptile_mmq = s_warptile_mmq; + l_mmq_wg_denoms = s_mmq_wg_denoms; + } + shmem_needed = (l_warptile_mmq[1] + l_warptile_mmq[2]) * (l_warptile_mmq[3] + 1) * sizeof(float); + GGML_ASSERT(shmem_needed <= device->properties.limits.maxComputeSharedMemorySize); + } + if (device->properties.limits.maxComputeSharedMemorySize >= 32768) { + // assert mul_mat_mat_id shaders will fit. + GGML_ASSERT(shmem_needed + 3072*4 <= device->properties.limits.maxComputeSharedMemorySize); + } + // Disable medium and large matrix multiplication if not enough shared memory is available + // Check mmq warptiles as the largest configuration + // Throw an error if not enough for any matrix multiplication is available + if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, false)) { + std::cerr << "ggml_vulkan: Error: Shared memory size too small for matrix multiplication." << std::endl; + throw std::runtime_error("Shared memory size too small for matrix multiplication."); + } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, false)) { + device->mul_mat_m = false; + device->mul_mat_l = false; + } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, false)) { + device->mul_mat_l = false; + } + + // Disable mul_mat_id if not enough shared memory is available + if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, true)) { + device->mul_mat_id_s = false; + device->mul_mat_id_m = false; + device->mul_mat_id_l = false; + } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, true)) { + device->mul_mat_id_m = false; + device->mul_mat_id_l = false; + } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, true)) { + device->mul_mat_id_l = false; + } + } + + device->pipeline_matmul_f32 = std::make_shared(); + device->pipeline_matmul_f32_f16 = std::make_shared(); + + device->pipeline_matmul_id_f32 = std::make_shared(); + + std::vector> compiles; + auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint, + uint32_t parameter_count, uint32_t push_constant_size, std::array wg_denoms, const std::vector& specialization_constants, + uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) { + { + // wait until fewer than N compiles are in progress + uint32_t N = std::max(1u, std::thread::hardware_concurrency()); + std::unique_lock guard(compile_count_mutex); + while (compile_count >= N) { + compile_count_cond.wait(guard); + } + compile_count++; + } + compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), name, spv_size, spv_data, entrypoint, + parameter_count, push_constant_size, wg_denoms, specialization_constants, align, disable_robustness, require_full_subgroups, required_subgroup_size)); + }; + +#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + if (device->coopmat2) { + + auto const &fa_wg_denoms = [&](uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array { + return {fa_rows_cols(D, clamp, type, small_rows)[0], 1, 1}; + }; + + auto const &fa_spec_constants = [&](uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector { + // For large number of rows, 128 invocations seems to work best. + // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we + // can't use 256 for D==80. + uint32_t wg_size = (small_rows && (D % 32) == 0) ? 256 : 128; + auto rows_cols = fa_rows_cols(D, clamp, type, small_rows); + return {wg_size, rows_cols[0], rows_cols[1], (D), clamp}; + }; + +#define CREATE_FA2(TYPE, NAMELC, D) \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,false), fa_spec_constants(D,1,TYPE,false), 1); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,false), fa_spec_constants(D,0,TYPE,false), fa_rows_cols(D,0,TYPE,false)[1]); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,false), fa_spec_constants(D,1,TYPE,false), 1); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,false), fa_spec_constants(D,0,TYPE,false), fa_rows_cols(D,0,TYPE,false)[1]); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,true), fa_spec_constants(D,1,TYPE,true), 1); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,true), fa_spec_constants(D,0,TYPE,true), fa_rows_cols(D,0,TYPE,true)[1]); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,true), fa_spec_constants(D,1,TYPE,true), 1); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,true), fa_spec_constants(D,0,TYPE,true), fa_rows_cols(D,0,TYPE,true)[1]); \ + +#define CREATE_FA(TYPE, NAMELC) \ + CREATE_FA2(TYPE, NAMELC, 64) \ + CREATE_FA2(TYPE, NAMELC, 80) \ + CREATE_FA2(TYPE, NAMELC, 96) \ + CREATE_FA2(TYPE, NAMELC, 112) \ + CREATE_FA2(TYPE, NAMELC, 128) \ + CREATE_FA2(TYPE, NAMELC, 256) + + CREATE_FA(GGML_TYPE_F16, f16) + CREATE_FA(GGML_TYPE_Q4_0, q4_0) + CREATE_FA(GGML_TYPE_Q4_1, q4_1) + CREATE_FA(GGML_TYPE_Q5_0, q5_0) + CREATE_FA(GGML_TYPE_Q5_1, q5_1) + CREATE_FA(GGML_TYPE_Q8_0, q8_0) + // K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently + //CREATE_FA(GGML_TYPE_Q2_K, q2_k) + //CREATE_FA(GGML_TYPE_Q3_K, q3_k) + //CREATE_FA(GGML_TYPE_Q4_K, q4_k) + //CREATE_FA(GGML_TYPE_Q5_K, q5_k) + //CREATE_FA(GGML_TYPE_Q6_K, q6_k) + CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl) +#undef CREATE_FA + + // Create 6 variants, {s,m,l}x{unaligned,aligned} +#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \ + + // Create 2 variants, {f16,f32} accumulator +#define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ + CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ + CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ + + CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3) + + CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + + CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4) + + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) +#undef CREATE_MM +#undef CREATE_MM2 + } else +#endif // defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + if (device->coopmat_support) { + // Create 6 variants, {s,m,l}x{unaligned,aligned} +#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + if (device->mul_mat ## ID ## _l) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \ + if (device->mul_mat ## ID ## _m) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \ + if (device->mul_mat ## ID ## _s) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \ + if (device->mul_mat ## ID ## _l) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \ + if (device->mul_mat ## ID ## _m) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \ + if (device->mul_mat ## ID ## _s) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \ + + // Create 2 variants, {f16,f32} accumulator +#define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + if (device->coopmat_acc_f16_support) { \ + CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + } \ + if (device->coopmat_acc_f32_support) { \ + CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + } \ + + CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + + if (device->coopmat_acc_f16_support) { + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + } else { + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + } + + // If there's not enough shared memory for row_ids and the result tile, don't create these pipelines. + if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) { + CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + + if (device->coopmat_acc_f16_support) { + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + } else { + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + } + } +#undef CREATE_MM2 +#undef CREATE_MM + } else if (device->fp16) { + // Create 6 variants, {s,m,l}x{unaligned,aligned} +#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + if (device->mul_mat ## ID ## _l) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ + if (device->mul_mat ## ID ## _m) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ + if (device->mul_mat ## ID ## _s) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ + if (device->mul_mat ## ID ## _l) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \ + if (device->mul_mat ## ID ## _m) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \ + if (device->mul_mat ## ID ## _s) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \ + + // Create 2 variants, {f16,f32} accumulator +#define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + + CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + + // If there's not enough shared memory for row_ids and the result tile, don't create these pipelines. + if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) { + CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + } +#undef CREATE_MM2 +#undef CREATE_MM + } else { + // Create 6 variants, {s,m,l}x{unaligned,aligned} +#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + if (device->mul_mat ## ID ## _l) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ + if (device->mul_mat ## ID ## _m) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ + if (device->mul_mat ## ID ## _s) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ + if (device->mul_mat ## ID ## _l) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \ + if (device->mul_mat ## ID ## _m) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \ + if (device->mul_mat ## ID ## _s) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \ + + CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + + // If there's not enough shared memory for row_ids and the result tile, don't create these pipelines. + if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) { + CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + CREATE_MM(pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + CREATE_MM(pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + } +#undef CREATE_MM + } + + // mul mat vec + + // the number of rows computed per shader depends on GPU model and quant + uint32_t rm_stdq = 1; + uint32_t rm_kq = 2; + if (device->vendor_id == VK_VENDOR_ID_AMD) { + if (device->subgroup_min_size == 64 && device->subgroup_max_size == 64) { // GCN + rm_stdq = 2; + rm_kq = 4; + } + } else if (device->vendor_id == VK_VENDOR_ID_INTEL) + rm_stdq = 2; + + for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) { + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32_"+std::to_string(i+1), mul_mat_vec_f32_f32_f32_len, mul_mat_vec_f32_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f32_f32_"+std::to_string(i+1), mul_mat_vec_f16_f32_f32_len, mul_mat_vec_f16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_1_f32_f32_len, mul_mat_vec_q5_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q8_0_f32_f32_len, mul_mat_vec_q8_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q2_k_f32_f32_len, mul_mat_vec_q2_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q3_k_f32_f32_len, mul_mat_vec_q3_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq, i+1}, 1, true); + + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32_"+std::to_string(i+1), mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32_"+std::to_string(i+1), mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_1_f16_f32_len, mul_mat_vec_q5_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q8_0_f16_f32_len, mul_mat_vec_q8_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q2_k_f16_f32_len, mul_mat_vec_q2_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q3_k_f16_f32_len, mul_mat_vec_q3_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq, i+1}, 1, true); + } + + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32", mul_mat_vec_id_q5_1_f32_len, mul_mat_vec_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32", mul_mat_vec_id_q8_0_f32_len, mul_mat_vec_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", mul_mat_vec_id_q2_k_f32_len, mul_mat_vec_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", mul_mat_vec_id_q3_k_f32_len, mul_mat_vec_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq}, 1, true); + + // dequant shaders + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_0], "dequant_q4_0", dequant_q4_0_len, dequant_q4_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_1], "dequant_q4_1", dequant_q4_1_len, dequant_q4_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_0], "dequant_q5_0", dequant_q5_0_len, dequant_q5_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_1], "dequant_q5_1", dequant_q5_1_len, dequant_q5_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q8_0], "dequant_q8_0", dequant_q8_0_len, dequant_q8_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q2_K], "dequant_q2_k", dequant_q2_k_len, dequant_q2_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q3_K], "dequant_q3_k", dequant_q3_k_len, dequant_q3_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_K], "dequant_q4_k", dequant_q4_k_len, dequant_q4_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_K], "dequant_q5_k", dequant_q5_k_len, dequant_q5_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q6_K], "dequant_q6_k", dequant_q6_k_len, dequant_q6_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_NL], "dequant_iq4_nl", dequant_iq4_nl_len, dequant_iq4_nl_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); + + // get_rows + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F16 ], "get_rows_f16", get_rows_f16_len, get_rows_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_0], "get_rows_q4_0", get_rows_q4_0_len, get_rows_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_1], "get_rows_q4_1", get_rows_q4_1_len, get_rows_q4_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_0], "get_rows_q5_0", get_rows_q5_0_len, get_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_1], "get_rows_q5_1", get_rows_q5_1_len, get_rows_q5_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q8_0], "get_rows_q8_0", get_rows_q8_0_len, get_rows_q8_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl", get_rows_iq4_nl_len, get_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_0], "get_rows_q4_0_f32", get_rows_q4_0_f32_len, get_rows_q4_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_1], "get_rows_q4_1_f32", get_rows_q4_1_f32_len, get_rows_q4_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_0], "get_rows_q5_0_f32", get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_1], "get_rows_q5_1_f32", get_rows_q5_1_f32_len, get_rows_q5_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q8_0], "get_rows_q8_0_f32", get_rows_q8_0_f32_len, get_rows_q8_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32, "mul_mat_vec_p021_f16_f32", mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 7 * sizeof(uint32_t), {1, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f32, "contig_cpy_f32_f32", contig_cpy_f32_f32_len, contig_cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_add_f32, "add_f32", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1); + ggml_vk_create_pipeline(device, device->pipeline_add_f32_norepeat, "add_f32_norepeat", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1); + ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16, "add_f16_f32_f16", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1); + ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16_norepeat, "add_f16_f32_f16_norepeat", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1); + ggml_vk_create_pipeline(device, device->pipeline_mul_f32_norepeat, "mul_f32_norepeat", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1); + ggml_vk_create_pipeline(device, device->pipeline_div_f32, "div_f32", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1); + ggml_vk_create_pipeline(device, device->pipeline_div_f32_norepeat, "div_f32_norepeat", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_concat_i32, "concat_i32", concat_i32_len, concat_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_upscale_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_sqr_f32, "sqr_f32", sqr_f32_len, sqr_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_sin_f32, "sin_f32", sin_f32_len, sin_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cos_f32, "cos_f32", cos_f32_len, cos_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32", repeat_f32_len, repeat_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_gelu_f32, "gelu_f32", gelu_f32_len, gelu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_gelu_quick_f32, "gelu_quick_f32", gelu_quick_f32_len, gelu_quick_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_silu_f32, "silu_f32", silu_f32_len, silu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_relu_f32, "relu_f32", relu_f32_len, relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_tanh_f32, "tanh_f32", tanh_f32_len, tanh_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1); + + ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + + if (device->float_controls_rte_fp16) { + ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_rte_len, rope_norm_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_rte_len, rope_neox_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + } else { + ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + } + + ggml_vk_create_pipeline(device, device->pipeline_argsort_f32, "argsort_f32", argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1024, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + + ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); + if (device->float_controls_rte_fp16) { + ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); + } else { + ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); + } + + ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_pool2d_f32, "pool2d_f32", pool2d_f32_len, pool2d_f32_data, "main", 2, sizeof(vk_op_pool2d_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); + + for (auto &c : compiles) { + c.wait(); + } + std::cerr << "Done!" << std::endl; +} + +static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props); + +static vk_device ggml_vk_get_device(size_t idx) { + VK_LOG_DEBUG("ggml_vk_get_device(" << idx << ")"); + + if (vk_instance.devices[idx] == nullptr) { + VK_LOG_DEBUG("Initializing new vk_device"); + vk_device device = std::make_shared(); + vk_instance.devices[idx] = device; + +#ifdef GGML_VULKAN_MEMORY_DEBUG + device->memory_logger = std::unique_ptr(new vk_memory_logger()); +#endif +#ifdef GGML_VULKAN_PERF + device->perf_logger = std::unique_ptr(new vk_perf_logger()); +#endif + + size_t dev_num = vk_instance.device_indices[idx]; + + std::vector physical_devices = vk_instance.instance.enumeratePhysicalDevices(); + + if (dev_num >= physical_devices.size()) { + std::cerr << "ggml_vulkan: Device with index " << dev_num << " does not exist." << std::endl; + throw std::runtime_error("Device not found"); + } + + device->physical_device = physical_devices[dev_num]; + const std::vector ext_props = device->physical_device.enumerateDeviceExtensionProperties(); + + bool fp16_storage = false; + bool fp16_compute = false; + bool maintenance4_support = false; + bool sm_builtins = false; + bool amd_shader_core_properties2 = false; + bool pipeline_robustness = false; + bool coopmat2_support = false; + device->coopmat_support = false; + + // Check if maintenance4 is supported + for (const auto& properties : ext_props) { + if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) { + maintenance4_support = true; + } else if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) { + fp16_storage = true; + } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) { + fp16_compute = true; + } else if (strcmp("VK_NV_shader_sm_builtins", properties.extensionName) == 0) { + sm_builtins = true; + } else if (strcmp("VK_AMD_shader_core_properties2", properties.extensionName) == 0) { + amd_shader_core_properties2 = true; + } else if (strcmp("VK_EXT_pipeline_robustness", properties.extensionName) == 0) { + pipeline_robustness = true; + } else if (strcmp("VK_EXT_subgroup_size_control", properties.extensionName) == 0) { + device->subgroup_size_control = true; + } else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0 && + !getenv("GGML_VK_DISABLE_COOPMAT")) { + device->coopmat_support = true; + device->coopmat_m = 0; + device->coopmat_n = 0; + device->coopmat_k = 0; + } else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 && + !getenv("GGML_VK_DISABLE_COOPMAT2")) { + coopmat2_support = true; + } + } + + vk::PhysicalDeviceProperties2 props2; + vk::PhysicalDeviceMaintenance3Properties props3; + vk::PhysicalDeviceMaintenance4Properties props4; + vk::PhysicalDeviceSubgroupProperties subgroup_props; + vk::PhysicalDeviceDriverProperties driver_props; + vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props; + vk::PhysicalDeviceShaderCoreProperties2AMD amd_shader_core_properties2_props; + vk::PhysicalDeviceVulkan12Properties vk12_props; + vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props; + + props2.pNext = &props3; + props3.pNext = &subgroup_props; + subgroup_props.pNext = &driver_props; + driver_props.pNext = &vk12_props; + + VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_props; + + if (maintenance4_support) { + last_struct->pNext = (VkBaseOutStructure *)&props4; + last_struct = (VkBaseOutStructure *)&props4; + } + if (sm_builtins) { + last_struct->pNext = (VkBaseOutStructure *)&sm_props; + last_struct = (VkBaseOutStructure *)&sm_props; + } + if (amd_shader_core_properties2) { + last_struct->pNext = (VkBaseOutStructure *)&amd_shader_core_properties2_props; + last_struct = (VkBaseOutStructure *)&amd_shader_core_properties2_props; + } + if (device->subgroup_size_control) { + last_struct->pNext = (VkBaseOutStructure *)&subgroup_size_control_props; + last_struct = (VkBaseOutStructure *)&subgroup_size_control_props; + } + +#if defined(VK_NV_cooperative_matrix2) + vk::PhysicalDeviceCooperativeMatrix2PropertiesNV coopmat2_props; + if (coopmat2_support) { + last_struct->pNext = (VkBaseOutStructure *)&coopmat2_props; + last_struct = (VkBaseOutStructure *)&coopmat2_props; + } +#endif + + device->physical_device.getProperties2(&props2); + device->properties = props2.properties; + + const char* GGML_VK_FORCE_MAX_ALLOCATION_SIZE = getenv("GGML_VK_FORCE_MAX_ALLOCATION_SIZE"); + + if (GGML_VK_FORCE_MAX_ALLOCATION_SIZE != nullptr) { + device->max_memory_allocation_size = std::stoul(GGML_VK_FORCE_MAX_ALLOCATION_SIZE); + } else if (maintenance4_support) { + device->max_memory_allocation_size = std::min(props3.maxMemoryAllocationSize, props4.maxBufferSize); + } else { + device->max_memory_allocation_size = props3.maxMemoryAllocationSize; + } + + device->vendor_id = device->properties.vendorID; + device->subgroup_size = subgroup_props.subgroupSize; + device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu; + if (sm_builtins) { + device->shader_core_count = sm_props.shaderSMCount; + } else if (amd_shader_core_properties2) { + device->shader_core_count = amd_shader_core_properties2_props.activeComputeUnitCount; + } else { + device->shader_core_count = 0; + } + device->float_controls_rte_fp16 = vk12_props.shaderRoundingModeRTEFloat16; + + const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr; + + device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute; + + if (!ggml_vk_khr_cooperative_matrix_support(device->properties, driver_props)) { + device->coopmat_support = false; + } + + std::vector queue_family_props = device->physical_device.getQueueFamilyProperties(); + + // Try to find a non-graphics compute queue and transfer-focused queues + const uint32_t compute_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eCompute, vk::QueueFlagBits::eGraphics, -1, 1); + const uint32_t transfer_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eTransfer, vk::QueueFlagBits::eCompute | vk::QueueFlagBits::eGraphics, compute_queue_family_index, 1); + + const float priorities[] = { 1.0f, 1.0f }; + device->single_queue = compute_queue_family_index == transfer_queue_family_index && queue_family_props[compute_queue_family_index].queueCount == 1; + + std::vector device_queue_create_infos; + if (compute_queue_family_index != transfer_queue_family_index) { + device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities}); + device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), transfer_queue_family_index, 1, priorities + 1}); + } else if(!device->single_queue) { + device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 2, priorities}); + } else { + device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities}); + } + vk::DeviceCreateInfo device_create_info; + std::vector device_extensions; + vk::PhysicalDeviceFeatures device_features = device->physical_device.getFeatures(); + + VkPhysicalDeviceFeatures2 device_features2; + device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2; + device_features2.pNext = nullptr; + device_features2.features = (VkPhysicalDeviceFeatures)device_features; + + VkPhysicalDeviceVulkan11Features vk11_features; + vk11_features.pNext = nullptr; + vk11_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_1_FEATURES; + device_features2.pNext = &vk11_features; + + VkPhysicalDeviceVulkan12Features vk12_features; + vk12_features.pNext = nullptr; + vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES; + vk11_features.pNext = &vk12_features; + + last_struct = (VkBaseOutStructure *)&vk12_features; + + VkPhysicalDevicePipelineRobustnessFeaturesEXT pl_robustness_features; + pl_robustness_features.pNext = nullptr; + pl_robustness_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PIPELINE_ROBUSTNESS_FEATURES_EXT; + pl_robustness_features.pipelineRobustness = VK_FALSE; + + if (pipeline_robustness) { + last_struct->pNext = (VkBaseOutStructure *)&pl_robustness_features; + last_struct = (VkBaseOutStructure *)&pl_robustness_features; + device_extensions.push_back("VK_EXT_pipeline_robustness"); + } + + VkPhysicalDeviceSubgroupSizeControlFeaturesEXT subgroup_size_control_features; + subgroup_size_control_features.pNext = nullptr; + subgroup_size_control_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_SIZE_CONTROL_FEATURES_EXT; + subgroup_size_control_features.computeFullSubgroups = false; + subgroup_size_control_features.subgroupSizeControl = false; + + if (device->subgroup_size_control) { + last_struct->pNext = (VkBaseOutStructure *)&subgroup_size_control_features; + last_struct = (VkBaseOutStructure *)&subgroup_size_control_features; + } + + VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features; + coopmat_features.pNext = nullptr; + coopmat_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR; + coopmat_features.cooperativeMatrix = VK_FALSE; + + if (device->coopmat_support) { + last_struct->pNext = (VkBaseOutStructure *)&coopmat_features; + last_struct = (VkBaseOutStructure *)&coopmat_features; + } + +#if defined(VK_NV_cooperative_matrix2) + VkPhysicalDeviceCooperativeMatrix2FeaturesNV coopmat2_features {}; + coopmat2_features.pNext = nullptr; + coopmat2_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_2_FEATURES_NV; + if (coopmat2_support) { + last_struct->pNext = (VkBaseOutStructure *)&coopmat2_features; + last_struct = (VkBaseOutStructure *)&coopmat2_features; + device_extensions.push_back("VK_NV_cooperative_matrix2"); + } +#endif + + vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2); + + device->fp16 = device->fp16 && vk12_features.shaderFloat16; + + device->pipeline_robustness = pl_robustness_features.pipelineRobustness; + + if (device->subgroup_size_control) { + device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize; + device->subgroup_max_size = subgroup_size_control_props.maxSubgroupSize; + } + + device->subgroup_size_control = device->subgroup_size_control && + (subgroup_size_control_props.requiredSubgroupSizeStages & vk::ShaderStageFlagBits::eCompute) && + subgroup_size_control_features.subgroupSizeControl; + + if (device->subgroup_size_control) { + device->subgroup_require_full_support = subgroup_size_control_features.computeFullSubgroups; + device_extensions.push_back("VK_EXT_subgroup_size_control"); + } + + device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix; + + if (coopmat2_support) { +#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + if (coopmat2_features.cooperativeMatrixWorkgroupScope && + coopmat2_features.cooperativeMatrixFlexibleDimensions && + coopmat2_features.cooperativeMatrixReductions && + coopmat2_features.cooperativeMatrixConversions && + coopmat2_features.cooperativeMatrixPerElementOperations && + coopmat2_features.cooperativeMatrixTensorAddressing && + coopmat2_features.cooperativeMatrixBlockLoads && + vk12_features.bufferDeviceAddress) { + + std::vector flexible_dimensions; + uint32_t count = 0; + + PFN_vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV + _vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV = + (PFN_vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV) + vk_instance.instance.getProcAddr("vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV"); + + _vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV(device->physical_device, &count, nullptr); + + VkCooperativeMatrixFlexibleDimensionsPropertiesNV empty_prop {}; + empty_prop.sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_FLEXIBLE_DIMENSIONS_PROPERTIES_NV; + flexible_dimensions.resize(count, empty_prop); + + _vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV(device->physical_device, &count, flexible_dimensions.data()); + + bool found_fp16_128 = false, + found_fp16_256 = false, + found_fp32_128 = false, + found_fp32_256 = false; + // need to support fp16*fp16 with fp16/fp32 accumulator, for workgroupsize 128 + // with 32x16x16 and 256 with 32x32x16. + for (auto &prop : flexible_dimensions) { + if (prop.saturatingAccumulation == VK_FALSE && + prop.scope == VK_SCOPE_WORKGROUP_KHR && + prop.AType == VK_COMPONENT_TYPE_FLOAT16_KHR && + prop.BType == VK_COMPONENT_TYPE_FLOAT16_KHR) { + + if (prop.workgroupInvocations == 128 && + prop.MGranularity <= 32 && + prop.NGranularity <= 16 && + prop.KGranularity <= 16) { + if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR && + prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) { + found_fp16_128 = true; + } + if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR && + prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) { + found_fp32_128 = true; + } + } + if (prop.workgroupInvocations == 256 && + prop.MGranularity <= 32 && + prop.NGranularity <= 32 && + prop.KGranularity <= 16) { + if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR && + prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) { + found_fp16_256 = true; + } + if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR && + prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) { + found_fp32_256 = true; + } + } + } + } + if (found_fp16_128 && found_fp16_256 && + found_fp32_128 && found_fp32_256 && + coopmat2_props.cooperativeMatrixFlexibleDimensionsMaxDimension >= 512) { + device->coopmat2 = true; + } + } +#endif + } + + if (!vk11_features.storageBuffer16BitAccess) { + std::cerr << "ggml_vulkan: device " << GGML_VK_NAME << idx << " does not support 16-bit storage." << std::endl; + throw std::runtime_error("Unsupported device"); + } + + device_extensions.push_back("VK_KHR_16bit_storage"); + +#ifdef GGML_VULKAN_VALIDATE + device_extensions.push_back("VK_KHR_shader_non_semantic_info"); +#endif + + if (device->fp16) { + device_extensions.push_back("VK_KHR_shader_float16_int8"); + } + + if (device->coopmat_support) { + // Query supported shapes + std::vector cm_props; + + PFN_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR pfn_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR = + (PFN_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR)vkGetInstanceProcAddr(vk_instance.instance, "vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR"); + + uint32_t cm_props_num; + + pfn_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR(device->physical_device, &cm_props_num, nullptr); + + cm_props.resize(cm_props_num); + + for (auto& prop : cm_props) { + prop.sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_PROPERTIES_KHR; + } + + pfn_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR(device->physical_device, &cm_props_num, cm_props.data()); + + VK_LOG_DEBUG("ggml_vulkan: Cooperative Matrix Shapes: " << cm_props.size()); + + for (auto& prop : cm_props) { + VK_LOG_DEBUG("ggml_vulkan: M: " << prop.MSize << " N: " << prop.NSize << " K: " << prop.KSize << " A: " << vk::to_string((vk::ComponentTypeKHR)prop.AType) << " B: " << vk::to_string((vk::ComponentTypeKHR)prop.BType) << " C: " << vk::to_string((vk::ComponentTypeKHR)prop.CType) << " Result: " << vk::to_string((vk::ComponentTypeKHR)prop.ResultType) << " saturatingAccumulation: " << prop.saturatingAccumulation << " scope: " << vk::to_string((vk::ScopeKHR)prop.scope)); + + if ((vk::ComponentTypeKHR)prop.AType == vk::ComponentTypeKHR::eFloat16 && + (vk::ComponentTypeKHR)prop.BType == vk::ComponentTypeKHR::eFloat16 && + (vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup + ) { + if ((vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat32 && + (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat32) { + // coopmat sizes not set yet + if (device->coopmat_m == 0) { + device->coopmat_acc_f32_support = true; + device->coopmat_m = prop.MSize; + device->coopmat_n = prop.NSize; + device->coopmat_k = prop.KSize; + } else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) { + // Only enable if shape is identical + device->coopmat_acc_f32_support = true; + } + } else if ((vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat16 && + (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat16) { + // coopmat sizes not set yet + if (device->coopmat_m == 0) { + device->coopmat_acc_f16_support = true; + device->coopmat_m = prop.MSize; + device->coopmat_n = prop.NSize; + device->coopmat_k = prop.KSize; + } else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) { + // Only enable if shape is identical + device->coopmat_acc_f16_support = true; + } + } + } + } + + if (device->coopmat_m == 0 || !device->coopmat_acc_f32_support) { + // No suitable matmul mode found + GGML_LOG_DEBUG("ggml_vulkan: WARNING: No suitable matrix core mode found. Disabling matrix cores.\n"); + device->coopmat_support = false; + } + } + + if (device->coopmat_support) { + device_extensions.push_back("VK_KHR_cooperative_matrix"); + } + + device->name = GGML_VK_NAME + std::to_string(idx); + + device_create_info = { + vk::DeviceCreateFlags(), + device_queue_create_infos, + {}, + device_extensions + }; + device_create_info.setPNext(&device_features2); + device->device = device->physical_device.createDevice(device_create_info); + + // Queues + ggml_vk_create_queue(device, device->compute_queue, compute_queue_family_index, 0, { vk::PipelineStageFlagBits::eComputeShader | vk::PipelineStageFlagBits::eTransfer }, false); + + // Shaders + // Disable matmul tile sizes early if performance low or not supported + switch (device->vendor_id) { +#ifndef GGML_VULKAN_RUN_TESTS + case VK_VENDOR_ID_AMD: + case VK_VENDOR_ID_INTEL: + device->mul_mat_l = false; + device->mul_mat_m = true; + device->mul_mat_s = true; + device->mul_mat_id_l = false; + device->mul_mat_id_m = true; + device->mul_mat_id_s = true; + break; + case VK_VENDOR_ID_APPLE: + device->mul_mat_l = false; + device->mul_mat_m = true; + device->mul_mat_s = false; + device->mul_mat_id_l = false; + device->mul_mat_id_m = true; + device->mul_mat_id_s = false; + break; +#endif + default: + device->mul_mat_l = true; + device->mul_mat_m = true; + device->mul_mat_s = true; + device->mul_mat_id_l = true; + device->mul_mat_id_m = true; + device->mul_mat_id_s = true; + break; + } + + ggml_vk_load_shaders(device); + + if (!device->single_queue) { + const uint32_t transfer_queue_index = compute_queue_family_index == transfer_queue_family_index ? 1 : 0; + ggml_vk_create_queue(device, device->transfer_queue, transfer_queue_family_index, transfer_queue_index, { vk::PipelineStageFlagBits::eTransfer }, true); + } else { + // TODO: Use pointer or reference to avoid copy + device->transfer_queue = device->compute_queue; + } + + device->buffer_type = { + /* .iface = */ ggml_backend_vk_buffer_type_interface, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), idx), + /* .context = */ new ggml_backend_vk_buffer_type_context{ device->name, device }, + }; + + device->fence = device->device.createFence({}); + + device->idx = idx; + + return device; + } + + return vk_instance.devices[idx]; +} + +static void ggml_vk_print_gpu_info(size_t idx) { + GGML_ASSERT(idx < vk_instance.device_indices.size()); + size_t dev_num = vk_instance.device_indices[idx]; + VK_LOG_DEBUG("ggml_vk_print_gpu_info(" << dev_num << ")"); + GGML_ASSERT(vk_instance_initialized); + + std::vector devices = vk_instance.instance.enumeratePhysicalDevices(); + + if (dev_num >= devices.size()) { + std::cerr << "ggml_vulkan: Device with index " << dev_num << " does not exist." << std::endl; + throw std::runtime_error("Device not found"); + } + + vk::PhysicalDevice physical_device = devices[dev_num]; + std::vector ext_props = physical_device.enumerateDeviceExtensionProperties(); + + vk::PhysicalDeviceProperties2 props2; + vk::PhysicalDeviceMaintenance3Properties props3; + vk::PhysicalDeviceSubgroupProperties subgroup_props; + vk::PhysicalDeviceDriverProperties driver_props; + props2.pNext = &props3; + props3.pNext = &subgroup_props; + subgroup_props.pNext = &driver_props; + physical_device.getProperties2(&props2); + + const size_t subgroup_size = subgroup_props.subgroupSize; + const bool uma = props2.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu; + + bool fp16_storage = false; + bool fp16_compute = false; + bool coopmat_support = false; + bool coopmat2_support = false; + + for (auto properties : ext_props) { + if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) { + fp16_storage = true; + } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) { + fp16_compute = true; + } else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0 && + !getenv("GGML_VK_DISABLE_COOPMAT")) { + coopmat_support = true; +#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + } else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 && + !getenv("GGML_VK_DISABLE_COOPMAT2")) { + coopmat2_support = true; +#endif + } + } + + if (!ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props)) { + coopmat_support = false; + } + + const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16"); + bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr; + + bool fp16 = !force_disable_f16 && fp16_storage && fp16_compute; + + vk::PhysicalDeviceFeatures device_features = physical_device.getFeatures(); + + VkPhysicalDeviceFeatures2 device_features2; + device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2; + device_features2.pNext = nullptr; + device_features2.features = (VkPhysicalDeviceFeatures)device_features; + + VkPhysicalDeviceVulkan11Features vk11_features; + vk11_features.pNext = nullptr; + vk11_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_1_FEATURES; + device_features2.pNext = &vk11_features; + + VkPhysicalDeviceVulkan12Features vk12_features; + vk12_features.pNext = nullptr; + vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES; + vk11_features.pNext = &vk12_features; + + // Pointer to the last chain element + VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_features; + + VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features; + coopmat_features.pNext = nullptr; + coopmat_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR; + coopmat_features.cooperativeMatrix = VK_FALSE; + + if (coopmat_support) { + last_struct->pNext = (VkBaseOutStructure *)&coopmat_features; + last_struct = (VkBaseOutStructure *)&coopmat_features; + } + + vkGetPhysicalDeviceFeatures2(physical_device, &device_features2); + + fp16 = fp16 && vk12_features.shaderFloat16; + + coopmat_support = coopmat_support && coopmat_features.cooperativeMatrix; + + std::string matrix_cores = coopmat2_support ? "NV_coopmat2" : coopmat_support ? "KHR_coopmat" : "none"; + + std::string device_name = props2.properties.deviceName.data(); + GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | matrix cores: %s\n", + idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, subgroup_size, matrix_cores.c_str()); + + if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) { + GGML_LOG_DEBUG("ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want.\n"); + } +} + +static bool ggml_vk_instance_validation_ext_available(const std::vector& instance_extensions); +static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector& instance_extensions); + +void ggml_vk_instance_init() { + if (vk_instance_initialized) { + return; + } + VK_LOG_DEBUG("ggml_vk_instance_init()"); + + vk_instance_initialized = true; + + vk::ApplicationInfo app_info{ "ggml-vulkan", 1, nullptr, 0, VK_API_VERSION }; + + const std::vector instance_extensions = vk::enumerateInstanceExtensionProperties(); + const bool validation_ext = ggml_vk_instance_validation_ext_available(instance_extensions); +#ifdef __APPLE__ + const bool portability_enumeration_ext = ggml_vk_instance_portability_enumeration_ext_available(instance_extensions); +#endif + + std::vector layers; + + if (validation_ext) { + layers.push_back("VK_LAYER_KHRONOS_validation"); + } + std::vector extensions; + if (validation_ext) { + extensions.push_back("VK_EXT_validation_features"); + } +#ifdef __APPLE__ + if (portability_enumeration_ext) { + extensions.push_back("VK_KHR_portability_enumeration"); + } +#endif + vk::InstanceCreateInfo instance_create_info(vk::InstanceCreateFlags{}, &app_info, layers, extensions); +#ifdef __APPLE__ + if (portability_enumeration_ext) { + instance_create_info.flags |= vk::InstanceCreateFlagBits::eEnumeratePortabilityKHR; + } +#endif + + std::vector features_enable; + vk::ValidationFeaturesEXT validation_features; + + if (validation_ext) { + features_enable = { vk::ValidationFeatureEnableEXT::eBestPractices }; + validation_features = { + features_enable, + {}, + }; + validation_features.setPNext(nullptr); + instance_create_info.setPNext(&validation_features); + GGML_LOG_DEBUG("ggml_vulkan: Validation layers enabled\n"); + } + vk_instance.instance = vk::createInstance(instance_create_info); + + size_t num_available_devices = vk_instance.instance.enumeratePhysicalDevices().size(); + + // Emulate behavior of CUDA_VISIBLE_DEVICES for Vulkan + char * devices_env = getenv("GGML_VK_VISIBLE_DEVICES"); + if (devices_env != nullptr) { + std::string devices(devices_env); + std::replace(devices.begin(), devices.end(), ',', ' '); + + std::stringstream ss(devices); + size_t tmp; + while (ss >> tmp) { + if(tmp >= num_available_devices) { + std::cerr << "ggml_vulkan: Invalid device index " << tmp << " in GGML_VK_VISIBLE_DEVICES." << std::endl; + throw std::runtime_error("Invalid Vulkan device index"); + } + vk_instance.device_indices.push_back(tmp); + } + } else { + std::vector devices = vk_instance.instance.enumeratePhysicalDevices(); + + // Make sure at least one device exists + if (devices.empty()) { + std::cerr << "ggml_vulkan: Error: No devices found." << std::endl; + GGML_ABORT("fatal error"); + } + + // Default to using all dedicated GPUs + for (size_t i = 0; i < devices.size(); i++) { + vk::PhysicalDeviceProperties2 new_props; + vk::PhysicalDeviceDriverProperties new_driver; + vk::PhysicalDeviceIDProperties new_id; + new_props.pNext = &new_driver; + new_driver.pNext = &new_id; + devices[i].getProperties2(&new_props); + + if (new_props.properties.deviceType == vk::PhysicalDeviceType::eDiscreteGpu) { + // Check if there are two physical devices corresponding to the same GPU + auto old_device = std::find_if( + vk_instance.device_indices.begin(), + vk_instance.device_indices.end(), + [&devices, &new_id](const size_t k){ + vk::PhysicalDeviceProperties2 old_props; + vk::PhysicalDeviceIDProperties old_id; + old_props.pNext = &old_id; + devices[k].getProperties2(&old_props); + return std::equal(std::begin(old_id.deviceUUID), std::end(old_id.deviceUUID), std::begin(new_id.deviceUUID)); + } + ); + if (old_device == vk_instance.device_indices.end()) { + vk_instance.device_indices.push_back(i); + } else { + // There can be two physical devices corresponding to the same GPU if there are 2 different drivers + // This can cause error when splitting layers aross the devices, need to keep only 1 + VK_LOG_DEBUG("Device " << i << " and device " << *old_device << " have the same deviceUUID"); + + vk::PhysicalDeviceProperties2 old_props; + vk::PhysicalDeviceDriverProperties old_driver; + old_props.pNext = &old_driver; + devices[*old_device].getProperties2(&old_props); + + std::map driver_priorities {}; + int old_priority = std::numeric_limits::max(); + int new_priority = std::numeric_limits::max(); + + // Check https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VkDriverId.html for the list of driver id + // Smaller number -> higher priority + switch (old_props.properties.vendorID) { + case VK_VENDOR_ID_AMD: + driver_priorities[vk::DriverId::eMesaRadv] = 1; + driver_priorities[vk::DriverId::eAmdOpenSource] = 2; + driver_priorities[vk::DriverId::eAmdProprietary] = 3; + break; + case VK_VENDOR_ID_INTEL: + driver_priorities[vk::DriverId::eIntelOpenSourceMESA] = 1; + driver_priorities[vk::DriverId::eIntelProprietaryWindows] = 2; + break; + case VK_VENDOR_ID_NVIDIA: + driver_priorities[vk::DriverId::eNvidiaProprietary] = 1; +#if defined(VK_API_VERSION_1_3) && VK_HEADER_VERSION >= 235 + driver_priorities[vk::DriverId::eMesaNvk] = 2; +#endif + break; + } + + if (driver_priorities.count(old_driver.driverID)) { + old_priority = driver_priorities[old_driver.driverID]; + } + if (driver_priorities.count(new_driver.driverID)) { + new_priority = driver_priorities[new_driver.driverID]; + } + + if (new_priority < old_priority) { + auto r = std::remove(vk_instance.device_indices.begin(), vk_instance.device_indices.end(), *old_device); + vk_instance.device_indices.erase(r, vk_instance.device_indices.end()); + vk_instance.device_indices.push_back(i); + + VK_LOG_DEBUG("Prioritize device " << i << " driver " << new_driver.driverName << " over device " << *old_device << " driver " << old_driver.driverName); + } + else { + VK_LOG_DEBUG("Prioritize device " << *old_device << " driver " << old_driver.driverName << " over device " << i << " driver " << new_driver.driverName << std::endl); + } + } + } + } + + // If no dedicated GPUs found, fall back to GPU 0 + if (vk_instance.device_indices.empty()) { + vk_instance.device_indices.push_back(0); + } + } + GGML_LOG_DEBUG("ggml_vulkan: Found %zu Vulkan devices:\n", vk_instance.device_indices.size()); + + for (size_t i = 0; i < vk_instance.device_indices.size(); i++) { + ggml_vk_print_gpu_info(i); + } +} + +static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) { + VK_LOG_DEBUG("ggml_vk_init(" << ctx->name << ", " << idx << ")"); + ggml_vk_instance_init(); + GGML_ASSERT(idx < vk_instance.device_indices.size()); + + ctx->name = GGML_VK_NAME + std::to_string(idx); + + ctx->device = ggml_vk_get_device(idx); + + ctx->semaphore_idx = 0; + ctx->event_idx = 0; + + ctx->prealloc_size_x = 0; + ctx->prealloc_size_y = 0; + ctx->prealloc_size_split_k = 0; + + ctx->fence = ctx->device->device.createFence({}); + +#ifdef GGML_VULKAN_CHECK_RESULTS + const char* skip_checks = getenv("GGML_VULKAN_SKIP_CHECKS"); + vk_skip_checks = (skip_checks == NULL ? 0 : atoi(skip_checks)); + const char* output_tensor = getenv("GGML_VULKAN_OUTPUT_TENSOR"); + vk_output_tensor = (output_tensor == NULL ? 0 : atoi(output_tensor)); +#endif +} + +static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type type) { + VK_LOG_DEBUG("ggml_vk_get_to_fp16()"); + switch (type) { + case GGML_TYPE_F32: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ4_NL: + break; + default: + return nullptr; + } + + return ctx->device->pipeline_dequant[type]; +} + +static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type, ggml_prec prec) { + VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_pipeline(" << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")"); + if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) { + return ctx->device->pipeline_matmul_f32; + } + if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) { + return ctx->device->pipeline_matmul_f32_f16; + } + if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) { + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) { + return ctx->device->pipeline_matmul_f16_f32.f16acc; + } + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { + return ctx->device->pipeline_matmul_f16.f16acc; + } + } else { + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) { + return ctx->device->pipeline_matmul_f16_f32.f32acc; + } + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { + return ctx->device->pipeline_matmul_f16.f32acc; + } + } + + if (src1_type != GGML_TYPE_F32 && !ctx->device->coopmat2) { + return nullptr; + } + + switch (src0_type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ4_NL: + break; + default: + return nullptr; + } + + if (ctx->device->coopmat2) { + assert(src1_type == GGML_TYPE_F16); + return ctx->device->pipeline_dequant_mul_mat_mat_f16[src0_type].f16acc; + } + return ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc; +} + +static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type, uint32_t num_cols) { + VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec()"); + GGML_ASSERT(b_type == GGML_TYPE_F32 || b_type == GGML_TYPE_F16); + GGML_ASSERT(num_cols >= 1 && num_cols <= mul_mat_vec_max_cols); + + switch (a_type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ4_NL: + break; + default: + return nullptr; + } + + return b_type == GGML_TYPE_F32 ? ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[a_type][num_cols-1] : ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[a_type][num_cols-1]; +} + +static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type, ggml_prec prec) { + VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_id_pipeline()"); + if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) { + return ctx->device->pipeline_matmul_id_f32; + } + if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) { + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) { + return ctx->device->pipeline_matmul_id_f16_f32.f16acc; + } + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { + return ctx->device->pipeline_matmul_id_f16.f16acc; + } + } else { + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) { + return ctx->device->pipeline_matmul_id_f16_f32.f32acc; + } + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { + return ctx->device->pipeline_matmul_id_f16.f32acc; + } + } + + GGML_ASSERT(src1_type == GGML_TYPE_F32); + + switch (src0_type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ4_NL: + break; + default: + return nullptr; + } + + return ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f32acc; +} + +static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type) { + VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec()"); + GGML_ASSERT(b_type == GGML_TYPE_F32); + + switch (a_type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ4_NL: + break; + default: + return nullptr; + } + + return ctx->device->pipeline_dequant_mul_mat_vec_id_f32[a_type]; +} + +static vk_buffer ggml_vk_pool_malloc(ggml_backend_vk_context * ctx, size_t size) { + VK_LOG_DEBUG("ggml_vk_pool_malloc(" << size << ")"); + VK_LOG_MEMORY("ggml_vk_pool_malloc"); + + int best_i = -1; + size_t best_size = std::numeric_limits::max(); //smallest unused buffer that fits our needs + int worst_i = -1; + size_t worst_size = 0; //largest unused buffer seen so far + for (int i = 0; i < MAX_VK_BUFFERS; ++i) { + vk_buffer &b = ctx->buffer_pool[i]; + if (b != nullptr && b->size >= size && b->size < best_size) { + best_i = i; + best_size = b->size; + } + if (b != nullptr && b->size > worst_size) { + worst_i = i; + worst_size = b->size; + } + } + if(best_i != -1) { + //found the smallest buffer that fits our needs + vk_buffer b = ctx->buffer_pool[best_i]; + ctx->buffer_pool[best_i].reset(); + return b; + } + if(worst_i != -1) { + //no buffer that fits our needs, resize largest one to save memory + vk_buffer& b = ctx->buffer_pool[worst_i]; + ggml_vk_destroy_buffer(b); + } + + return ggml_vk_create_buffer_device(ctx->device, size); +} + +static void ggml_vk_pool_free(ggml_backend_vk_context * ctx, vk_buffer& buffer) { + VK_LOG_DEBUG("ggml_vk_pool_free(" << buffer->size << ")"); + for (int i = 0; i < MAX_VK_BUFFERS; ++i) { + vk_buffer& b = ctx->buffer_pool[i]; + if (b == nullptr) { + b = buffer; + return; + } + } + std::cerr << "ggml_vulkan: WARNING: vk buffer pool full, increase MAX_VK_BUFFERS" << std::endl; + ggml_vk_destroy_buffer(buffer); +} + +// Returns an available temporary buffer that may only be used temporarily, it will be reused +static vk_buffer ggml_vk_create_buffer_temp(ggml_backend_vk_context * ctx, size_t size) { + // Try to find existing temp buffer with enough capacity + for (auto& buffer : ctx->gc.temp_buffers) { + if (buffer->size >= size) { + return buffer; + } + } + + VK_LOG_MEMORY("ggml_vk_create_buffer_temp(" << size << ")"); + + // Otherwise create new buffer + vk_buffer buf = ggml_vk_pool_malloc(ctx, size); + ctx->gc.temp_buffers.push_back(buf); + + return buf; +} + +static void * ggml_vk_host_malloc(vk_device& device, size_t size) { + VK_LOG_MEMORY("ggml_vk_host_malloc(" << size << ")"); + vk_buffer buf = ggml_vk_create_buffer(device, size, + vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached, + vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent); + + if(!(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible)) { + fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory\n", + size/1024.0/1024.0); + device->device.freeMemory(buf->device_memory); + device->device.destroyBuffer(buf->buffer); + return nullptr; + } + + device->pinned_memory.push_back(std::make_tuple(buf->ptr, size, buf)); + + return buf->ptr; +} + +static void ggml_vk_host_free(vk_device& device, void* ptr) { + if (ptr == nullptr) { + return; + } + VK_LOG_MEMORY("ggml_vk_host_free(" << ptr << ")"); + vk_buffer buf; + size_t index; + for (size_t i = 0; i < device->pinned_memory.size(); i++) { + const uint8_t* addr = (const uint8_t*) std::get<0>(device->pinned_memory[i]); + const uint8_t* endr = addr + std::get<1>(device->pinned_memory[i]); + if (ptr >= addr && ptr < endr) { + buf = std::get<2>(device->pinned_memory[i]); + index = i; + break; + } + } + if (buf == nullptr) { + fprintf(stderr, "WARNING: failed to free pinned memory: memory not in map\n"); + return; + } + + ggml_vk_destroy_buffer(buf); + + device->pinned_memory.erase(device->pinned_memory.begin() + index); +} + +static void ggml_vk_host_get(vk_device& device, const void * ptr, vk_buffer& buf, size_t& buf_offset) { + buf = nullptr; + buf_offset = 0; + for (size_t i = 0; i < device->pinned_memory.size(); i++) { + const uint8_t* addr = (const uint8_t*) std::get<0>(device->pinned_memory[i]); + const uint8_t* endr = addr + std::get<1>(device->pinned_memory[i]); + if (ptr >= addr && ptr < endr) { + buf = std::get<2>(device->pinned_memory[i]); + buf_offset = ((const uint8_t *)ptr) - addr; + break; + } + } +} + +static vk_submission ggml_vk_begin_submission(vk_device& device, vk_queue& q, bool one_time = true) { + vk_submission s; + s.buffer = ggml_vk_create_cmd_buffer(device, q); + if (one_time) { + s.buffer.begin({ vk::CommandBufferUsageFlagBits::eOneTimeSubmit }); + } else { + s.buffer.begin({ vk::CommandBufferUsageFlags{} }); + } + + return s; +} + + + +static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& subctx, vk_pipeline& pipeline, std::initializer_list const& descriptor_buffer_infos, size_t push_constant_size, const void* push_constants, std::array elements) { + const uint32_t wg0 = CEIL_DIV(elements[0], pipeline->wg_denoms[0]); + const uint32_t wg1 = CEIL_DIV(elements[1], pipeline->wg_denoms[1]); + const uint32_t wg2 = CEIL_DIV(elements[2], pipeline->wg_denoms[2]); + VK_LOG_DEBUG("ggml_vk_dispatch_pipeline(" << pipeline->name << ", {"; + for (auto& buffer : descriptor_buffer_infos) { + std::cerr << "(" << buffer.buffer << ", " << buffer.offset << ", " << buffer.range << "), "; + } + std::cerr << "}, (" << wg0 << "," << wg1 << "," << wg2 << "))"); + GGML_ASSERT(pipeline->descriptor_set_idx < pipeline->descriptor_sets.size()); + GGML_ASSERT(descriptor_buffer_infos.size() == pipeline->parameter_count); + + vk::DescriptorSet& descriptor_set = pipeline->descriptor_sets[pipeline->descriptor_set_idx++]; + vk::WriteDescriptorSet write_descriptor_set{ descriptor_set, 0, 0, pipeline->parameter_count, vk::DescriptorType::eStorageBuffer, nullptr, descriptor_buffer_infos.begin() }; + ctx->device->device.updateDescriptorSets({ write_descriptor_set }, {}); + + subctx->s->buffer.pushConstants(pipeline->layout, vk::ShaderStageFlagBits::eCompute, 0, push_constant_size, push_constants); + subctx->s->buffer.bindPipeline(vk::PipelineBindPoint::eCompute, pipeline->pipeline); + subctx->s->buffer.bindDescriptorSets(vk::PipelineBindPoint::eCompute, + pipeline->layout, + 0, + { descriptor_set }, + {}); + subctx->s->buffer.dispatch(wg0, wg1, wg2); +} + +static void ggml_vk_end_submission(vk_submission& s, std::vector wait_semaphores, std::vector signal_semaphores) { + s.buffer.end(); + + s.wait_semaphores = std::move(wait_semaphores); + s.signal_semaphores = std::move(signal_semaphores); +} + +static void ggml_vk_ctx_end(vk_context& ctx) { + VK_LOG_DEBUG("ggml_vk_ctx_end(" << ctx << ", " << ctx->seqs.size() << ")"); + if (ctx->s == nullptr) { + return; + } + + ctx->s->buffer.end(); + ctx->s = nullptr; +} + +static void ggml_vk_ctx_begin(vk_device& device, vk_context& subctx) { + VK_LOG_DEBUG("ggml_vk_ctx_begin(" << device->name << ")"); + if (subctx->s != nullptr) { + ggml_vk_ctx_end(subctx); + } + + subctx->seqs.push_back({ ggml_vk_begin_submission(device, *subctx->q) }); + subctx->s = subctx->seqs[subctx->seqs.size() - 1].data(); +} + +static size_t ggml_vk_align_size(size_t width, size_t align) { + VK_LOG_DEBUG("ggml_vk_align_size(" << width << ", " << align << ")"); + return CEIL_DIV(width, align) * align; +} + +static void deferred_memcpy(void * dst, const void * src, size_t size, std::vector* memcpys = nullptr) { + if (memcpys == nullptr) { + memcpy(dst, src, size); + } else { + memcpys->emplace_back(dst, src, size); + } +} + +static void ggml_vk_ensure_sync_staging_buffer(vk_device& device, size_t size) { + if (device->sync_staging == nullptr || device->sync_staging->size < size) { + VK_LOG_MEMORY("ggml_vk_ensure_sync_staging_buffer(" << size << ")"); + ggml_vk_destroy_buffer(device->sync_staging); + device->sync_staging = ggml_vk_create_buffer_check(device, size, + vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached, + vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent); + } +} + +static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_context& subctx, vk_buffer& dst, size_t offset, const ggml_tensor * tensor, bool sync_staging = false) { + VK_LOG_DEBUG("ggml_vk_buffer_write_nc_async(" << tensor << ")"); + GGML_ASSERT(!ggml_is_contiguous(tensor)); + // Buffer is already mapped + if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) { + std::cerr << "ggml_vulkan: buffer_write_nc_async dst buffer is host_visible. Use synchronous write." << std::endl; + GGML_ABORT("fatal error"); + } + // Check if src is pinned memory + vk_buffer buf = nullptr; + size_t buf_offset = 0; + ggml_vk_host_get(ctx->device, tensor->data, buf, buf_offset); + + const uint64_t ne0 = tensor->ne[0]; + const uint64_t ne1 = tensor->ne[1]; + const uint64_t ne2 = tensor->ne[2]; + const uint64_t ne3 = tensor->ne[3]; + const uint64_t nb0 = tensor->nb[0]; + const uint64_t nb1 = tensor->nb[1]; + const uint64_t nb2 = tensor->nb[2]; + const uint64_t nb3 = tensor->nb[3]; + const ggml_type type = tensor->type; + const uint64_t ts = ggml_type_size(type); + const uint64_t bs = ggml_blck_size(type); + + const uint64_t dstnb0 = ts; + const uint64_t dstnb1 = dstnb0*(ne0/bs); + const uint64_t dstnb2 = dstnb1*ne1; + const uint64_t dstnb3 = dstnb2*ne2; + + const uint64_t ne = ggml_nelements(tensor); + + if (buf != nullptr) { + // Memory is pinned, use as staging buffer + std::vector slices; + + for (uint64_t i3 = 0; i3 < ne3; i3++) { + for (uint64_t i2 = 0; i2 < ne2; i2++) { + // Find longest contiguous slice + if (ne1*nb1 == dstnb2) { + slices.push_back({ buf_offset + i3*nb3 + i2*nb2, offset + i3*dstnb3 + i2*dstnb2, dstnb2 }); + } else { + for (uint64_t i1 = 0; i1 < ne1; i1++) { + if (ne0*nb0/bs == dstnb1) { + slices.push_back({ buf_offset + i3*nb3 + i2*nb2 + i1*nb1, offset + i3*dstnb3 + i2*dstnb2 + i1*dstnb1, dstnb1 }); + } else { + const uint64_t s_off = buf_offset + i3*nb3 + i2*nb2 + i1*nb1; + const uint64_t d_off = offset + i3*dstnb3 + i2*dstnb2 + i1*dstnb1; + for (uint64_t i0 = 0; i0 < ne0; i0++) { + slices.push_back({ s_off + i1*nb0, d_off + i0*dstnb0, dstnb0 }); + } + } + } + } + } + } + + ggml_vk_sync_buffers(subctx); + subctx->s->buffer.copyBuffer(buf->buffer, dst->buffer, slices); + return; + } + + if (!sync_staging) { + GGML_ABORT("Asynchronous write to non-pinned memory not supported"); + } + + // Staging buffer required + vk_buffer& staging = ctx->device->sync_staging; + const uint64_t copy_size = ts*ne/bs; + ggml_vk_ensure_sync_staging_buffer(ctx->device, copy_size); + VkBufferCopy buf_copy{ 0, offset, copy_size }; + + ggml_vk_sync_buffers(subctx); + vkCmdCopyBuffer(subctx->s->buffer, (VkBuffer)staging->buffer, (VkBuffer)dst->buffer, 1, &buf_copy); + + for (uint64_t i3 = 0; i3 < ne3; i3++) { + for (uint64_t i2 = 0; i2 < ne2; i2++) { + // Find longest contiguous slice + if (ne1*nb1 == dstnb2) { + deferred_memcpy((uint8_t *)staging->ptr + i3*dstnb3 + i2*dstnb2, (const uint8_t *) tensor->data + buf_offset + i3*nb3 + i2*nb2, dstnb2, &subctx->in_memcpys); + } else { + for (uint64_t i1 = 0; i1 < ne1; i1++) { + if (ne0*nb0/bs == dstnb1) { + deferred_memcpy((uint8_t *)staging->ptr + i3*dstnb3 + i2*dstnb2 + i1*dstnb1, (const uint8_t *) tensor->data + buf_offset + i3*nb3 + i2*nb2 + i1*nb1, dstnb1, &subctx->in_memcpys); + } else { + const uint64_t s_off = buf_offset + i3*nb3 + i2*nb2 + i1*nb1; + const uint64_t d_off = i3*dstnb3 + i2*dstnb2 + i1*dstnb1; + for (uint64_t i0 = 0; i0 < ne0; i0++) { + deferred_memcpy((uint8_t *)staging->ptr + d_off + i0*dstnb0, (const uint8_t *) tensor->data + s_off + i0*nb0, dstnb0, &subctx->in_memcpys); + } + } + } + } + } + } +} + +static void ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height, bool sync_staging = false) { + VK_LOG_DEBUG("ggml_vk_buffer_write_2d_async(" << width << ", " << height << ")"); + // Buffer is already mapped + if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) { + std::cerr << "ggml_vulkan: buffer_write_async dst buffer is host_visible. Use synchronous write." << std::endl; + GGML_ABORT("fatal error"); + } + // Check if src is pinned memory + vk_buffer buf = nullptr; + size_t buf_offset = 0; + ggml_vk_host_get(dst->device, src, buf, buf_offset); + + if (buf != nullptr) { + // Memory is pinned, use as staging buffer + std::vector slices(1); + if (width == spitch) { + // Only do single write if stride is equal + slices[0].srcOffset = buf_offset; + slices[0].dstOffset = offset; + slices[0].size = width * height; + } else { + slices.resize(height); + for (size_t i = 0; i < height; i++) { + slices[i].srcOffset = buf_offset + i * spitch; + slices[i].dstOffset = offset + i * width; + slices[i].size = width; + } + } + + ggml_vk_sync_buffers(subctx); + subctx->s->buffer.copyBuffer(buf->buffer, dst->buffer, slices); + return; + } + VK_LOG_DEBUG("STAGING"); + + if (!sync_staging) { + GGML_ABORT("Asynchronous write to non-pinned memory not supported"); + } + + // Staging buffer required + const size_t copy_size = width*height; + ggml_vk_ensure_sync_staging_buffer(dst->device, copy_size); + + vk_buffer& staging_buffer = dst->device->sync_staging; + + VkBufferCopy buf_copy = { + 0, + offset, + copy_size}; + + ggml_vk_sync_buffers(subctx); + vkCmdCopyBuffer(subctx->s->buffer, (VkBuffer)staging_buffer->buffer, (VkBuffer)dst->buffer, 1, &buf_copy); + + if (width == spitch) { + deferred_memcpy((uint8_t *)staging_buffer->ptr, src, width * height, &subctx->in_memcpys); + } else { + for (size_t i = 0; i < height; i++) { + deferred_memcpy((uint8_t *)staging_buffer->ptr + i * width, (const uint8_t *) src + i * spitch, width, &subctx->in_memcpys); + } + } +} + +static void ggml_vk_buffer_write_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t size, bool sync_staging = false) { + VK_LOG_DEBUG("ggml_vk_buffer_write_async(" << size << ")"); + return ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, size, size, 1, sync_staging); +} + +static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height) { + VK_LOG_DEBUG("ggml_vk_buffer_write_2d(" << width << ", " << height << ")"); + // Buffer is already mapped + if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) { + GGML_ASSERT(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostCoherent); + + for (size_t i = 0; i < height; i++) { + memcpy((uint8_t *)dst->ptr + offset + i * width, (const uint8_t *) src + i * spitch, width); + } + } else { + vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue); + ggml_vk_ctx_begin(dst->device, subctx); + ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, spitch, width, height, true); + ggml_vk_ctx_end(subctx); + + for (auto& cpy : subctx->in_memcpys) { + memcpy(cpy.dst, cpy.src, cpy.n); + } + + ggml_vk_submit(subctx, dst->device->fence); + VK_CHECK(dst->device->device.waitForFences({ dst->device->fence }, true, UINT64_MAX), "vk_buffer_write_2d waitForFences"); + dst->device->device.resetFences({ dst->device->fence }); + } +} + +static void ggml_vk_buffer_write(vk_buffer& dst, size_t offset, const void * src, size_t size) { + VK_LOG_DEBUG("ggml_vk_buffer_write(" << size << ")"); + ggml_vk_buffer_write_2d(dst, offset, src, 0, size, 1); +} + +static void ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size_t offset, void * dst, size_t spitch, size_t dpitch, size_t width, size_t height, bool sync_staging = false) { + VK_LOG_DEBUG("ggml_vk_buffer_read_2d_async(offset=" << offset << ", width=" << width << ", height=" << height << ")"); + GGML_ASSERT(width > 0); + GGML_ASSERT(height > 0); + GGML_ASSERT(src != nullptr); + + // TODO: staging_offset is not used + + // Check if dst is pinned memory + vk_buffer buf = nullptr; + size_t buf_offset = 0; + ggml_vk_host_get(src->device, dst, buf, buf_offset); + + std::vector slices(1); + if (width == spitch && width == dpitch) { + // Only do single write if stride is equal + slices[0].srcOffset = offset; + slices[0].dstOffset = buf_offset; + slices[0].size = width * height; + } else { + slices.resize(height); + for (size_t i = 0; i < height; i++) { + slices[i].srcOffset = offset + i * spitch; + slices[i].dstOffset = buf_offset + i * dpitch; + slices[i].size = width; + } + } + + if (buf != nullptr) { + // Memory is pinned, use as staging buffer + ggml_vk_sync_buffers(subctx); + subctx->s->buffer.copyBuffer(src->buffer, buf->buffer, slices); + + return; + } + VK_LOG_DEBUG("STAGING"); + + if (!sync_staging) { + GGML_ABORT("Asynchronous read from non-pinned memory not supported"); + } + + // Fall back to staging buffer + const size_t copy_size = dpitch * height; + ggml_vk_ensure_sync_staging_buffer(src->device, copy_size); + + vk_buffer& staging_buffer = src->device->sync_staging; + + ggml_vk_sync_buffers(subctx); + subctx->s->buffer.copyBuffer(src->buffer, staging_buffer->buffer, slices); + + deferred_memcpy(dst, staging_buffer->ptr, copy_size, &subctx->out_memcpys); +} + +static void ggml_vk_buffer_read_async(vk_context subctx, vk_buffer& src, size_t offset, void * dst, size_t size, bool sync_staging = false) { + return ggml_vk_buffer_read_2d_async(subctx, src, offset, dst, size, size, size, 1, sync_staging); +} + +static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_t size) { + VK_LOG_DEBUG("ggml_vk_buffer_read(" << src->buffer << ", " << offset << ", " << size << ")"); + + // If the device is not an UMA device the memory is host-accessible through rebar. While writing + // through PCIe is sufficient fast reading back data from PCIe is slower than going through + // the HW device to host copy path. + if(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible && src->device->uma) { + GGML_ASSERT(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostCoherent); + + memcpy(dst, (uint8_t *) src->ptr + offset, size); + } else { + vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue); + ggml_vk_ctx_begin(src->device, subctx); + ggml_vk_buffer_read_async(subctx, src, offset, dst, size, true); + ggml_vk_ctx_end(subctx); + + ggml_vk_submit(subctx, src->device->fence); + VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX), "vk_buffer_read waitForFences"); + src->device->device.resetFences({ src->device->fence }); + + for (auto& cpy : subctx->out_memcpys) { + memcpy(cpy.dst, cpy.src, cpy.n); + } + } +} + +static void ggml_vk_buffer_copy_async(vk_context& ctx, vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) { + VK_LOG_DEBUG("ggml_vk_buffer_copy_async(" << size << ")"); + // Make sure both buffers are on same device + GGML_ASSERT(src->device == dst->device); + + VkBufferCopy bc{ src_offset, dst_offset, size }; + + vkCmdCopyBuffer(ctx->s->buffer, (VkBuffer)src->buffer, (VkBuffer)dst->buffer, 1, &bc); +} + +static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) { + if (src->device == dst->device) { + VK_LOG_DEBUG("ggml_vk_buffer_copy(SINGLE_DEVICE, " << size << ")"); + // Copy within the device + vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue); + ggml_vk_ctx_begin(src->device, subctx); + ggml_vk_buffer_copy_async(subctx, dst, dst_offset, src, src_offset, size); + ggml_vk_ctx_end(subctx); + ggml_vk_submit(subctx, src->device->fence); + VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX), "vk_buffer_copy waitForFences"); + src->device->device.resetFences({ src->device->fence }); + } else { + VK_LOG_DEBUG("ggml_vk_buffer_copy(MULTI_DEVICE, " << size << ")"); + // Copy device to device + ggml_vk_ensure_sync_staging_buffer(src->device, size); + ggml_vk_ensure_sync_staging_buffer(dst->device, size); + + // Copy to src staging buffer + ggml_vk_buffer_copy(src->device->sync_staging, 0, src, src_offset, size); + // memcpy to dst staging buffer + memcpy(dst->device->sync_staging->ptr, src->device->sync_staging->ptr, size); + // Copy to dst buffer + ggml_vk_buffer_copy(dst, dst_offset, dst->device->sync_staging, 0, size); + } +} + +static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, size_t size) { + VK_LOG_DEBUG("ggml_vk_buffer_memset(" << offset << ", " << c << ", " << size << ")"); + + vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue); + ggml_vk_ctx_begin(dst->device, subctx); + subctx->s->buffer.fillBuffer(dst->buffer, offset, size, c); + ggml_vk_ctx_end(subctx); + + ggml_vk_submit(subctx, dst->device->fence); + VK_CHECK(dst->device->device.waitForFences({ dst->device->fence }, true, UINT64_MAX), "vk_memset waitForFences"); + dst->device->device.resetFences({ dst->device->fence }); +} + +static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, int m, int n, int k, const vk_pipeline& pipeline) { + VK_LOG_DEBUG("ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ")"); + + uint32_t split_k = 1; + if (ctx->device->shader_core_count != 0 && m >= (int)pipeline->wg_denoms[0] && n >= (int)pipeline->wg_denoms[1]) { + // If k is 'large' and the SMs will fill less than halfway, use split_k. + uint32_t m_tiles = CEIL_DIV(m, pipeline->wg_denoms[0]); + uint32_t n_tiles = CEIL_DIV(n, pipeline->wg_denoms[1]); + if (k >= 2048 && m_tiles * n_tiles < ctx->device->shader_core_count / 2) { + split_k = ctx->device->shader_core_count / (m_tiles * n_tiles); + // Clamp to 2 or 4 + split_k = std::min(split_k, 4u); + if (split_k == 3) { + split_k = 2; + } + } + } + + return split_k; +} + +static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) { + VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ")"); + + if (ctx->device->coopmat2) { + if ((ctx->device->mul_mat_l && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_m && !ctx->device->mul_mat_s)) { + return aligned ? mmp->a_l : mmp->l; + } + if ((ctx->device->mul_mat_m && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_s) { + return aligned ? mmp->a_m : mmp->m; + } + return aligned ? mmp->a_s : mmp->s; + } + + if ((ctx->device->mul_mat_s && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_m && !ctx->device->mul_mat_l)) { + return aligned ? mmp->a_s : mmp->s; + } + if ((ctx->device->mul_mat_m && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_l) { + return aligned ? mmp->a_m : mmp->m; + } + return aligned ? mmp->a_l : mmp->l; +} + +static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n) { + VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ")"); + return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true)->align; +} + +static void ggml_vk_matmul( + ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline& pipeline, + vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& split_k_buffer, + uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d, + uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d, + uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3) { + VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ")"); + ggml_vk_sync_buffers(subctx); + if (split_k == 1) { + const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3 }; + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, sizeof(vk_mat_mat_push_constants), &pc, { m, n, batch }); + return; + } + + GGML_ASSERT(batch_stride_d == m * n); + + const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, CEIL_DIV(k, split_k), ne02, ne12, broadcast2, broadcast3 }; + // Make sure enough workgroups get assigned for split k to work + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, sizeof(vk_mat_mat_push_constants), &pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch }); + ggml_vk_sync_buffers(subctx); + const std::array pc2 = { (uint32_t)(m * n * batch), split_k }; + ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 }); +} + +static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) { + VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ")"); + + if (ctx->device->coopmat2) { + if ((ctx->device->mul_mat_id_l && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_id_m && !ctx->device->mul_mat_id_s)) { + return aligned ? mmp->a_l : mmp->l; + } + if ((ctx->device->mul_mat_id_m && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_id_s) { + return aligned ? mmp->a_m : mmp->m; + } + return aligned ? mmp->a_s : mmp->s; + } + + if ((ctx->device->mul_mat_id_s && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_id_m && !ctx->device->mul_mat_id_l)) { + return aligned ? mmp->a_s : mmp->s; + } + if ((ctx->device->mul_mat_id_m && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_id_l) { + return aligned ? mmp->a_m : mmp->m; + } + return aligned ? mmp->a_l : mmp->l; +} + +static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n) { + VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ")"); + return ggml_vk_guess_matmul_id_pipeline(ctx, mmp, m, n, true)->align; +} + +static void ggml_vk_matmul_id( + ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline& pipeline, + vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& ids, + uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d, + uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d, + uint32_t n_as, uint32_t nei0, uint32_t nei1, uint32_t nbi1, uint32_t ne11) { + VK_LOG_DEBUG("ggml_vk_matmul_id(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), ids: (" << ids.buffer->buffer << ", " << ids.offset << ", " << ids.size << "), " << + "m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", " << + "batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", " << + "n_as: " << n_as << ", nei0: " << nei0 << ", nei1: " << nei1 << ", nbi1: " << nbi1 << ", ne11: " << ne11 << ")"); + ggml_vk_sync_buffers(subctx); + const vk_mat_mat_id_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, + nei0, nei1, nbi1, ne11 }; + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids }, sizeof(vk_mat_mat_id_push_constants), &pc, { m, nei1, n_as }); +} + +static bool ggml_vk_dim01_contiguous(const ggml_tensor * tensor) { + return + tensor->nb[0] == ggml_type_size(tensor->type) && + tensor->nb[1] == (tensor->nb[0]*tensor->ne[0])/ggml_blck_size(tensor->type) && + tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; +} + +static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src, const ggml_tensor * dst, ggml_type to) { + + // Choose "contiguous copy" shader if src/dst are contiguous + bool contig = ggml_is_contiguous(src) && (!dst || ggml_is_contiguous(dst)); + + if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_F32) { + if (contig) { + return ctx->device->pipeline_contig_cpy_f32_f32; + } else { + return ctx->device->pipeline_cpy_f32_f32; + } + } + if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_F16) { + if (contig) { + return ctx->device->pipeline_contig_cpy_f32_f16; + } else { + return ctx->device->pipeline_cpy_f32_f16; + } + } + if (src->type == GGML_TYPE_F16 && to == GGML_TYPE_F16) { + if (contig) { + return ctx->device->pipeline_contig_cpy_f16_f16; + } else { + return ctx->device->pipeline_cpy_f16_f16; + } + } + + std::cerr << "Missing CPY op for types: " << ggml_type_name(src->type) << " " << ggml_type_name(to) << std::endl; + GGML_ABORT("fatal error"); +} + +static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline pipeline, const ggml_tensor * tensor, vk_subbuffer&& in, vk_subbuffer&& out) { + VK_LOG_DEBUG("ggml_vk_cpy_to_contiguous((" << tensor << ", type=" << tensor->type << ", ne0=" << tensor->ne[0] << ", ne1=" << tensor->ne[1] << ", ne2=" << tensor->ne[2] << ", ne3=" << tensor->ne[3] << ", nb0=" << tensor->nb[0] << ", nb1=" << tensor->nb[1] << ", nb2=" << tensor->nb[2] << ", nb3=" << tensor->nb[3] << "), "; + std::cerr << "buffer in size=" << in.buffer->size << ", buffer out size=" << out.buffer->size << ")"); + const int tensor_type_size = ggml_type_size(tensor->type); + + const uint32_t ne = ggml_nelements(tensor); + std::array elements; + + if (ne > 262144) { + elements = { 512, 512, CEIL_DIV(ne, 262144) }; + } else if (ne > 512) { + elements = { 512, CEIL_DIV(ne, 512), 1 }; + } else { + elements = { ne, 1, 1 }; + } + + vk_op_unary_push_constants pc = { + (uint32_t)ne, + (uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3], (uint32_t)tensor->nb[0] / tensor_type_size, (uint32_t)tensor->nb[1] / tensor_type_size, (uint32_t)tensor->nb[2] / tensor_type_size, (uint32_t)tensor->nb[3] / tensor_type_size, + (uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3], 1 , (uint32_t)tensor->ne[0] , (uint32_t)(tensor->ne[0] * tensor->ne[1]) , (uint32_t)(tensor->ne[0] * tensor->ne[1] * tensor->ne[2]), + 0, + 0.0f, 0.0f, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }; + init_pushconst_fastdiv(pc); + ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(vk_op_unary_push_constants), &pc, elements); +} + +static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + VK_LOG_DEBUG("ggml_vk_mul_mat_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; + std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; + std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; + std::cerr << "), " << (dryrun ? "dryrun" : "") << ")"); + GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); // NOLINT + GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT + + const uint64_t ne00 = src0->ne[0]; + const uint64_t ne01 = src0->ne[1]; + const uint64_t ne02 = src0->ne[2]; + const uint64_t ne03 = src0->ne[3]; + + const uint64_t ne10 = src1->ne[0]; + const uint64_t ne11 = src1->ne[1]; + const uint64_t ne12 = src1->ne[2]; + const uint64_t ne13 = src1->ne[3]; + + const uint64_t ne20 = dst->ne[0]; + const uint64_t ne21 = dst->ne[1]; + + const uint64_t r2 = ne12 / ne02; + const uint64_t r3 = ne13 / ne03; + + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; + ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; + + vk_buffer d_Qx = nullptr; + size_t qx_buf_offset = 0; + vk_buffer d_Qy = nullptr; + size_t qy_buf_offset = 0; + + bool src0_uma = false; + bool src1_uma = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, src0->data, d_Qx, qx_buf_offset); + ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset); + src0_uma = d_Qx != nullptr; + src1_uma = d_Qy != nullptr; + } + + const bool x_non_contig = !ggml_vk_dim01_contiguous(src0); + // Reformat and convert to fp16 if src1 is non-contiguous, or for coopmat2 for better perf + const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) || + !ggml_vk_dim01_contiguous(src1); + + const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig; + + vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type, (ggml_prec)dst->op_params[0]); + + const bool qx_needs_dequant = mmp == nullptr || x_non_contig; + const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig; + + if (qx_needs_dequant) { + // Fall back to dequant + f16 mulmat + mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16, (ggml_prec)dst->op_params[0]); + } + + // Not implemented + GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT + + const int x_ne = ne01 * ne00; + const int y_ne = ne11 * ne10; + const int d_ne = ne11 * ne01; + + const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11)); + const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8; + + vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned); + + const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, pipeline); + + const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type); + const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); + const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne; + const uint64_t y_sz = y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne; + const uint64_t d_sz = sizeof(float) * d_ne; + + vk_pipeline to_fp16_vk_0 = nullptr; + vk_pipeline to_fp16_vk_1 = nullptr; + + if (x_non_contig) { + to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, GGML_TYPE_F16); + } else { + to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type); + } + if (y_non_contig) { + to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, GGML_TYPE_F16); + } else { + to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type); + } + GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT + GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT + + if (dryrun) { + const uint64_t x_sz_upd = x_sz * ne02 * ne03; + const uint64_t y_sz_upd = y_sz * ne12 * ne13; + const uint64_t split_k_size = split_k > 1 ? d_sz * ne12 * ne13 * split_k : 0; + if ( + (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) || + (qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size) || + (split_k > 1 && split_k_size > ctx->device->max_memory_allocation_size)) { + GGML_ABORT("Requested preallocation size is too large"); + } + if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) { + ctx->prealloc_size_x = x_sz_upd; + } + if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) { + ctx->prealloc_size_y = y_sz_upd; + } + if (split_k > 1 && ctx->prealloc_size_split_k < split_k_size) { + ctx->prealloc_size_split_k = split_k_size; + } + + // Request descriptor sets + ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); + if (qx_needs_dequant) { + ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_0, 1); + } + if (qy_needs_dequant) { + ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1); + } + if (split_k > 1) { + ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, 1); + } + return; + } + + vk_buffer d_D = dst_buf_ctx->dev_buffer; + const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; + GGML_ASSERT(d_D != nullptr); + GGML_ASSERT(d_D->size >= d_buf_offset + d_sz * ne02 * ne03); + vk_buffer d_X; + uint64_t x_buf_offset = 0; + vk_buffer d_Y; + uint64_t y_buf_offset = 0; + if (!src0_uma) { + d_Qx = src0_buf_ctx->dev_buffer; + qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs; + GGML_ASSERT(d_Qx != nullptr); + } + if (!src1_uma) { + d_Qy = src1_buf_ctx->dev_buffer; + qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs; + GGML_ASSERT(d_Qy != nullptr); + } + if (qx_needs_dequant) { + d_X = ctx->prealloc_x; + GGML_ASSERT(d_X->size >= x_sz * ne02 * ne03); + } else { + d_X = d_Qx; + x_buf_offset = qx_buf_offset; + GGML_ASSERT(qx_sz == x_sz); + } + if (qy_needs_dequant) { + d_Y = ctx->prealloc_y; + GGML_ASSERT(d_Y->size >= y_sz * ne02 * ne03); + } else { + d_Y = d_Qy; + y_buf_offset = qy_buf_offset; + GGML_ASSERT(qy_sz == y_sz); + } + + if (x_non_contig) { + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE }); + } else if (qx_needs_dequant) { + const std::vector pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) }; + ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, vk_subbuffer{ d_X, 0, x_sz * ne02 * ne03 } }, pc.size() * sizeof(uint32_t), pc.data(), { (uint32_t)(x_ne * ne02 * ne03), 1, 1}); + } + if (y_non_contig) { + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }); + } + + uint32_t stride_batch_x = ne00*ne01; + uint32_t stride_batch_y = ne10*ne11; + + if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) { + stride_batch_x = src0->nb[0] / ggml_type_size(src0->type); + } + + if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) { + stride_batch_y = src1->nb[0] / ggml_type_size(src1->type); + } + + // compute + ggml_vk_matmul( + ctx, subctx, pipeline, + { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 }, + { d_D, d_buf_offset, d_sz * ne12 * ne13 }, { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k }, + ne01, ne11, ne10, + ne10, ne10, ne01, stride_batch_x, stride_batch_y, ne20*ne21, + split_k, ne12*ne13, ne02, ne12, r2, r3 + ); // NOLINT +} + +static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + VK_LOG_DEBUG("ggml_vk_mul_mat_vec_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; + std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; + std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; + std::cerr << "), " << (dryrun ? "dryrun" : "") << "),)"); + GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); // NOLINT + GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT + + const uint64_t ne00 = src0->ne[0]; + const uint64_t ne01 = src0->ne[1]; + const uint64_t ne02 = src0->ne[2]; + const uint64_t ne03 = src0->ne[3]; + + const uint64_t ne10 = src1->ne[0]; + const uint64_t ne11 = src1->ne[1]; + const uint64_t ne12 = src1->ne[2]; + const uint64_t ne13 = src1->ne[3]; + + const uint64_t ne20 = dst->ne[0]; + const uint64_t ne21 = dst->ne[1]; + const uint64_t ne22 = dst->ne[2]; + const uint64_t ne23 = dst->ne[3]; + + const uint64_t r2 = ne12 / ne02; + const uint64_t r3 = ne13 / ne03; + + // batch_n indicates that we need to compute a few vector results, and this assumes + // ne12 and ne13 are 1. It overloads the batch_strides to hold the row strides. + GGML_ASSERT(ne11 == 1 || ne12 * ne13 == 1); + bool batch_n = ne11 > 1; + + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; + ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; + + vk_buffer d_Qx = nullptr; + size_t qx_buf_offset = 0; + vk_buffer d_Qy = nullptr; + size_t qy_buf_offset = 0; + + bool src0_uma = false; + bool src1_uma = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, src0->data, d_Qx, qx_buf_offset); + ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset); + src0_uma = d_Qx != nullptr; + src1_uma = d_Qy != nullptr; + } + + const bool x_non_contig = !ggml_vk_dim01_contiguous(src0); + const bool y_non_contig = !ggml_vk_dim01_contiguous(src1); + + const bool f16_f32_kernel = src1->type == GGML_TYPE_F32; + + const bool qx_needs_dequant = x_non_contig; + const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig; + + // Not implemented + GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT + + const uint64_t x_ne = ne01 * ne00; + const uint64_t y_ne = ne11 * ne10; + const uint64_t d_ne = ne11 * ne01; + + const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment); + const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); + const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz; + const uint64_t y_sz = f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne; + const uint64_t d_sz = sizeof(float) * d_ne; + + vk_pipeline to_fp16_vk_0 = nullptr; + vk_pipeline to_fp16_vk_1 = nullptr; + if (x_non_contig) { + to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, src0->type); + } + if (y_non_contig) { + to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, src1->type); + } else { + to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type); + } + vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, src1->type, ne11); + GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT + GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT + GGML_ASSERT(dmmv != nullptr); + + if (dryrun) { + const uint64_t x_sz_upd = x_sz * ne02 * ne03; + const uint64_t y_sz_upd = y_sz * ne12 * ne13; + if ( + (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) || + (qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size)) { + GGML_ABORT("Requested preallocation size is too large"); + } + if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) { + ctx->prealloc_size_x = x_sz_upd; + } + if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) { + ctx->prealloc_size_y = y_sz_upd; + } + + // Request descriptor sets + if (qx_needs_dequant) { + ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_0, 1); + } + if (qy_needs_dequant) { + ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1); + } + ggml_pipeline_request_descriptor_sets(ctx->device, dmmv, 1); + return; + } + + vk_buffer d_D = dst_buf_ctx->dev_buffer; + const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; + GGML_ASSERT(d_D != nullptr); + vk_buffer d_X; + uint64_t x_buf_offset = 0; + vk_buffer d_Y; + uint64_t y_buf_offset = 0; + if(!src0_uma) { + d_Qx = src0_buf_ctx->dev_buffer; + qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs; + GGML_ASSERT(d_Qx != nullptr); + } + if(!src1_uma) { + d_Qy = src1_buf_ctx->dev_buffer; + qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs; + GGML_ASSERT(d_Qy != nullptr); + } + if (qx_needs_dequant) { + d_X = ctx->prealloc_x; + } else { + d_X = d_Qx; + x_buf_offset = qx_buf_offset; + GGML_ASSERT(qx_sz == x_sz); + } + if (qy_needs_dequant) { + d_Y = ctx->prealloc_y; + } else { + d_Y = d_Qy; + y_buf_offset = qy_buf_offset; + GGML_ASSERT(qy_sz == y_sz); + } + + if (x_non_contig) { + GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment)); + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE }); + } + if (y_non_contig) { + GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne); + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }); + } + + // For batch_n, the A matrix is the same for each batch, and B/D use the row stride as the batch stride + uint32_t stride_batch_x = batch_n ? 0 : ne00*ne01; + uint32_t stride_batch_y = batch_n ? ne10 : (ne10*ne11); + uint32_t stride_batch_d = batch_n ? ne20 : (ne20*ne21); + + if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) { + stride_batch_x = src0->nb[0] / ggml_type_size(src0->type); + } + + if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) { + stride_batch_y = src1->nb[0] / ggml_type_size(src1->type); + } + + const uint32_t max_groups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0]; + + uint32_t groups_x = ne01; + uint32_t groups_z = 1; + + if (ne01 > max_groups_x) { + groups_z = 64; + groups_x = CEIL_DIV(groups_x, groups_z); + } + + // compute + const vk_mat_vec_push_constants pc = { + (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01, + stride_batch_x, stride_batch_y, stride_batch_d, + (uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3, + }; + ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, + { vk_subbuffer{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, vk_subbuffer{ d_Y, y_buf_offset, y_sz * ne12 * ne13 }, vk_subbuffer{ d_D, d_buf_offset, d_sz * ne22 * ne23} }, + sizeof(vk_mat_vec_push_constants), &pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z }); +} + +static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + VK_LOG_DEBUG("ggml_vk_mul_mat_p021_f16_f32(" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; + std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; + std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; + std::cerr << "), " << (dryrun ? "dryrun" : "") << ")"); + GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1)); + GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // NOLINT + GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); // NOLINT + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + const uint64_t ne00 = src0->ne[0]; + const uint64_t ne01 = src0->ne[1]; + const uint64_t ne02 = src0->ne[2]; + // const uint64_t ne03 = src0->ne[3]; + + const uint64_t ne10 = src1->ne[0]; + const uint64_t ne11 = src1->ne[1]; + const uint64_t ne12 = src1->ne[2]; + // const uint64_t ne13 = src1->ne[3]; + + GGML_ASSERT(ne11 == 1); + + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; + ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; + + vk_buffer d_Qy = nullptr; + size_t qy_buf_offset = 0; + + bool src1_uma = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset); + src1_uma = d_Qy != nullptr; + } + + const uint64_t x_ne = ne00 * ne01 * ne02; + const uint64_t y_ne = ne10 * ne11 * ne12; + const uint64_t d_ne = ne01 * ne11 * ne12; + + const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment); + const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); + const uint64_t d_sz = sizeof(float) * d_ne; + + if (dryrun) { + // Request descriptor sets + ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_mul_mat_vec_p021_f16_f32, 1); + return; + } + + vk_buffer d_D = dst_buf_ctx->dev_buffer; + const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; + GGML_ASSERT(d_D != nullptr); + vk_buffer d_Qx = src0_buf_ctx->dev_buffer; + const uint64_t qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs; + GGML_ASSERT(d_Qx != nullptr); + if (!src1_uma) { + d_Qy = src1_buf_ctx->dev_buffer; + qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs; + GGML_ASSERT(d_Qx != nullptr); + } + + const uint64_t qy_buffer_offset = (qy_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment; + const uint64_t qy_shader_offset = qy_buf_offset - qy_buffer_offset; + + const uint64_t d_buffer_offset = (d_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment; + const uint64_t d_shader_offset = d_buf_offset - d_buffer_offset; + + // compute + const std::array pc = { (uint32_t)ne00, (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) }; + ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 6 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 }); +} + +static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + VK_LOG_DEBUG("ggml_vk_mul_mat_nc_f16_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; + std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; + std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; + std::cerr << "), " << (dryrun ? "dryrun" : "") << ")"); + GGML_ASSERT(!ggml_is_transposed(src0)); + GGML_ASSERT(!ggml_is_transposed(src1)); + GGML_ASSERT(!ggml_is_permuted(src0)); + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + const uint64_t ne00 = src0->ne[0]; + const uint64_t ne01 = src0->ne[1]; + const uint64_t ne02 = src0->ne[2]; + // const uint64_t ne03 = src0->ne[3]; + + const uint64_t nb01 = src0->nb[1]; + const uint64_t nb02 = src0->nb[2]; + + // const uint64_t ne10 = src1->ne[0]; + const uint64_t ne11 = src1->ne[1]; + const uint64_t ne12 = src1->ne[2]; + // const uint64_t ne13 = src1->ne[3]; + + GGML_ASSERT(ne11 == 1); + + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; + ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; + + vk_buffer d_Qy = nullptr; + size_t qy_buf_offset = 0; + + bool src1_uma = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset); + src1_uma = d_Qy != nullptr; + } + + const uint64_t d_ne = ne01 * ne11 * ne12; + + const uint32_t row_stride_x = nb01 / sizeof(ggml_fp16_t); + const uint32_t channel_stride_x = nb02 / sizeof(ggml_fp16_t); + + const uint64_t qx_sz = ggml_nbytes(src0); + const uint64_t qy_sz = ggml_nbytes(src1); + const uint64_t d_sz = sizeof(float) * d_ne; + + if (dryrun) { + // Request descriptor sets + ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, 1); + return; + } + + vk_buffer d_D = dst_buf_ctx->dev_buffer; + const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; + GGML_ASSERT(d_D != nullptr); + vk_buffer d_Qx = src0_buf_ctx->dev_buffer; + const uint64_t qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs; + GGML_ASSERT(d_Qx != nullptr); + if (!src1_uma) { + d_Qy = src1_buf_ctx->dev_buffer; + qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs; + GGML_ASSERT(d_Qx != nullptr); + } + + const uint64_t qy_buffer_offset = (qy_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment; + const uint64_t qy_shader_offset = qy_buf_offset - qy_buffer_offset; + + const uint64_t d_buffer_offset = (d_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment; + const uint64_t d_shader_offset = d_buf_offset - d_buffer_offset; + + // compute + const std::array pc = { (uint32_t)ne00, (uint32_t)ne01, row_stride_x, channel_stride_x, (uint32_t)(ne12 / ne02), (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) }; + ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, + { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 7 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 }); +} + +static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + VK_LOG_DEBUG("ggml_vk_mul_mat(" << src0 << ", " << src1 << ", " << dst << ")"); + if (src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && dst->ne[1] == 1 && + // detect 0213 permutation, and batch size of 1 + src0->nb[0] <= src0->nb[2] && + src0->nb[2] <= src0->nb[1] && + src0->nb[1] <= src0->nb[3] && + src1->nb[0] <= src1->nb[2] && + src1->nb[2] <= src1->nb[1] && + src1->nb[1] <= src1->nb[3] && + src0->ne[3] == 1 && + src1->ne[3] == 1) { + ggml_vk_mul_mat_vec_p021_f16_f32(ctx, subctx, src0, src1, dst, dryrun); + } else if (src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && dst->ne[1] == 1 && + !ggml_is_permuted(src0) && !ggml_is_permuted(src1)) { + ggml_vk_mul_mat_vec_nc_f16_f32(ctx, subctx, src0, src1, dst, dryrun); + // mul_mat_vec supports batching ne12*ne13 when ne11==1, or treating ne11 as the batch size (up to four) + // when ne12 and ne13 are one. + } else if ((dst->ne[1] == 1 || (dst->ne[1] <= mul_mat_vec_max_cols && src1->ne[2] * src1->ne[3] == 1)) && + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) { + ggml_vk_mul_mat_vec_q_f16(ctx, subctx, src0, src1, dst, dryrun); + } else { + ggml_vk_mul_mat_q_f16(ctx, subctx, src0, src1, dst, dryrun); + } +} + +static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, bool dryrun = false) { + VK_LOG_DEBUG("ggml_vk_mul_mat_id_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; + std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; + std::cerr << "), (" << ids << ", name=" << ids->name << ", type=" << ids->type << ", ne0=" << ids->ne[0] << ", ne1=" << ids->ne[1] << ", ne2=" << ids->ne[2] << ", ne3=" << ids->ne[3] << ", nb0=" << ids->nb[0] << ", nb1=" << ids->nb[1] << ", nb2=" << ids->nb[2] << ", nb3=" << ids->nb[3]; + std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)"); + GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT + GGML_ASSERT(ids->type == GGML_TYPE_I32); + + const uint64_t ne00 = src0->ne[0]; + const uint64_t ne01 = src0->ne[1]; + const uint64_t ne02 = src0->ne[2]; + const uint64_t ne03 = src0->ne[3]; + + const uint64_t ne10 = src1->ne[0]; + const uint64_t ne11 = src1->ne[1]; + const uint64_t ne12 = src1->ne[2]; + const uint64_t ne13 = src1->ne[3]; + + const uint64_t nei0 = ids->ne[0]; + const uint64_t nei1 = ids->ne[1]; + GGML_ASSERT(nei0 * nei1 <= 3072); + + const uint32_t nbi1 = ids->nb[1]; + const uint32_t nbi2 = ids->nb[2]; + + const uint64_t ne20 = dst->ne[0]; + const uint64_t ne21 = dst->ne[1]; + const uint64_t ne22 = dst->ne[2]; + const uint64_t ne23 = dst->ne[3]; + + const uint64_t n_as = ne02; + + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; + ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; + ggml_backend_vk_buffer_context * ids_buf_ctx = (ggml_backend_vk_buffer_context *)ids->buffer->context; + + vk_buffer d_Qx = nullptr; + size_t qx_buf_offset = 0; + vk_buffer d_Qy = nullptr; + size_t qy_buf_offset = 0; + vk_buffer d_ids = nullptr; + size_t ids_buf_offset = 0; + + bool src0_uma = false; + bool src1_uma = false; + bool ids_uma = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, src0->data, d_Qx, qx_buf_offset); + ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset); + ggml_vk_host_get(ctx->device, ids->data, d_ids, ids_buf_offset); + src0_uma = d_Qx != nullptr; + src1_uma = d_Qy != nullptr; + ids_uma = d_ids != nullptr; + } + + const bool x_non_contig = !ggml_vk_dim01_contiguous(src0); + const bool y_non_contig = !ggml_vk_dim01_contiguous(src1); + + const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig; + + vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type, (ggml_prec)dst->op_params[0]); + + const bool qx_needs_dequant = mmp == nullptr || x_non_contig; + const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig; + + if (qx_needs_dequant) { + GGML_ABORT("fatal error"); + } + + // Not implemented + GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT + + const uint64_t x_ne = ne01 * ne00; + const uint64_t y_ne = ne11 * ne10; + const uint64_t d_ne = ne21 * ne20; + + const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1)); + const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8; + + vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned); + + const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type); + const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); + const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne; + const uint64_t y_sz = y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne; + const uint64_t ids_sz = nbi2; + const uint64_t d_sz = sizeof(float) * d_ne; + + vk_pipeline to_fp16_vk_0 = nullptr; + vk_pipeline to_fp16_vk_1 = nullptr; + + if (x_non_contig) { + to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, GGML_TYPE_F16); + } else { + to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type); + } + if (y_non_contig) { + to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, GGML_TYPE_F16); + } else { + to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type); + } + GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT + GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT + + if (dryrun) { + const uint64_t x_sz_upd = x_sz * ne02 * ne03; + const uint64_t y_sz_upd = y_sz * ne12 * ne13; + if ( + (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) || + (qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size)) { + GGML_ABORT("Requested preallocation size is too large"); + } + if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) { + ctx->prealloc_size_x = x_sz_upd; + } + if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) { + ctx->prealloc_size_y = y_sz_upd; + } + + // Request descriptor sets + ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); + if (qx_needs_dequant) { + ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_0, 1); + } + if (qy_needs_dequant) { + ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1); + } + return; + } + + vk_buffer d_D = dst_buf_ctx->dev_buffer; + const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; + GGML_ASSERT(d_D != nullptr); + vk_buffer d_X; + uint64_t x_buf_offset = 0; + vk_buffer d_Y; + uint64_t y_buf_offset = 0; + if (!src0_uma) { + d_Qx = src0_buf_ctx->dev_buffer; + qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs; + GGML_ASSERT(d_Qx != nullptr); + } + if (!src1_uma) { + d_Qy = src1_buf_ctx->dev_buffer; + qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs; + GGML_ASSERT(d_Qy != nullptr); + } + if (!ids_uma) { + d_ids = ids_buf_ctx->dev_buffer; + ids_buf_offset = vk_tensor_offset(ids) + ids->view_offs; + GGML_ASSERT(d_ids != nullptr); + } + if (qx_needs_dequant) { + d_X = ctx->prealloc_x; + GGML_ASSERT(d_X->size >= x_sz * ne02 * ne03); + } else { + d_X = d_Qx; + x_buf_offset = qx_buf_offset; + GGML_ASSERT(qx_sz == x_sz); + } + if (qy_needs_dequant) { + d_Y = ctx->prealloc_y; + GGML_ASSERT(d_Y->size >= y_sz * ne02 * ne03); + } else { + d_Y = d_Qy; + y_buf_offset = qy_buf_offset; + GGML_ASSERT(qy_sz == y_sz); + } + + if (x_non_contig) { + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE }); + } else if (qx_needs_dequant) { + const std::vector pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) }; + ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0, + { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, vk_subbuffer{ d_X, 0, x_sz * ne02 * ne03 } }, pc.size() * sizeof(uint32_t), pc.data(), { (uint32_t)(x_ne * ne02 * ne03), 1, 1}); + } + if (y_non_contig) { + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }); + } + + uint32_t stride_batch_x = ne00*ne01; + uint32_t stride_batch_y = ne10*ne11; + + if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) { + stride_batch_x = src0->nb[0] / ggml_type_size(src0->type); + } + + if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) { + stride_batch_y = src1->nb[0] / ggml_type_size(src1->type); + } + + // compute + ggml_vk_matmul_id( + ctx, subctx, pipeline, + { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 }, + { d_D, d_buf_offset, d_sz * ne22 * ne23 }, { d_ids, ids_buf_offset, ids_sz }, + ne01, ne21, ne10, ne10, ne10, ne01, + stride_batch_x, stride_batch_y, ne20*ne21, + n_as, nei0, nei1, nbi1 / ggml_type_size(ids->type), ne11 + ); // NOLINT +} + +static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, bool dryrun = false) { + VK_LOG_DEBUG("ggml_vk_mul_mat_vec_id_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; + std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; + std::cerr << "), (" << ids << ", name=" << ids->name << ", type=" << ids->type << ", ne0=" << ids->ne[0] << ", ne1=" << ids->ne[1] << ", ne2=" << ids->ne[2] << ", ne3=" << ids->ne[3] << ", nb0=" << ids->nb[0] << ", nb1=" << ids->nb[1] << ", nb2=" << ids->nb[2] << ", nb3=" << ids->nb[3]; + std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; + std::cerr << "), " << (dryrun ? "dryrun" : "") << ")"); + GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); // NOLINT + GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT + GGML_ASSERT(ids->type == GGML_TYPE_I32); + + const uint64_t ne00 = src0->ne[0]; + const uint64_t ne01 = src0->ne[1]; + const uint64_t ne02 = src0->ne[2]; + const uint64_t ne03 = src0->ne[3]; + + const uint64_t ne10 = src1->ne[0]; + const uint64_t ne11 = src1->ne[1]; + const uint64_t ne12 = src1->ne[2]; + const uint64_t ne13 = src1->ne[3]; + + const uint64_t nei0 = ids->ne[0]; + const uint64_t nei1 = ids->ne[1]; + + const uint64_t nbi2 = ids->nb[2]; + + GGML_ASSERT(nei1 == 1); + + const uint64_t ne20 = dst->ne[0]; + const uint64_t ne21 = dst->ne[1]; + const uint64_t ne22 = dst->ne[2]; + const uint64_t ne23 = dst->ne[3]; + + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; + ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; + ggml_backend_vk_buffer_context * ids_buf_ctx = (ggml_backend_vk_buffer_context *)ids->buffer->context; + + vk_buffer d_Qx = nullptr; + size_t qx_buf_offset = 0; + vk_buffer d_Qy = nullptr; + size_t qy_buf_offset = 0; + vk_buffer d_ids = nullptr; + size_t ids_buf_offset = 0; + + bool src0_uma = false; + bool src1_uma = false; + bool ids_uma = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, src0->data, d_Qx, qx_buf_offset); + ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset); + ggml_vk_host_get(ctx->device, ids->data, d_ids, ids_buf_offset); + src0_uma = d_Qx != nullptr; + src1_uma = d_Qy != nullptr; + ids_uma = d_ids != nullptr; + } + + const bool x_non_contig = !ggml_vk_dim01_contiguous(src0); + const bool y_non_contig = !ggml_vk_dim01_contiguous(src1); + + const bool f16_f32_kernel = src1->type == GGML_TYPE_F32; + + const bool qx_needs_dequant = x_non_contig; + const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig; + + // Not implemented + GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT + + const uint64_t x_ne = ne01 * ne00; + const uint64_t y_ne = ne11 * ne10; + const uint64_t d_ne = ne21 * ne20; + + const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment); + const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); + const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz; + const uint64_t y_sz = f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne; + const uint64_t ids_sz = nbi2; + const uint64_t d_sz = sizeof(float) * d_ne; + + vk_pipeline to_fp16_vk_0 = nullptr; + vk_pipeline to_fp16_vk_1 = nullptr; + if (x_non_contig) { + to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, src0->type); + } + if (y_non_contig) { + to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, src1->type); + } else { + to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type); + } + vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec_id(ctx, src0->type, src1->type); + GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT + GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT + GGML_ASSERT(dmmv != nullptr); + + if (dryrun) { + const uint64_t x_sz_upd = x_sz * ne02 * ne03; + const uint64_t y_sz_upd = y_sz * ne12 * ne13; + if ( + (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) || + (qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size)) { + GGML_ABORT("Requested preallocation size is too large"); + } + if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) { + ctx->prealloc_size_x = x_sz_upd; + } + if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) { + ctx->prealloc_size_y = y_sz_upd; + } + + // Request descriptor sets + if (qx_needs_dequant) { + ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_0, 1); + } + if (qy_needs_dequant) { + ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1); + } + ggml_pipeline_request_descriptor_sets(ctx->device, dmmv, 1); + return; + } + + vk_buffer d_D = dst_buf_ctx->dev_buffer; + const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; + GGML_ASSERT(d_D != nullptr); + vk_buffer d_X; + uint64_t x_buf_offset = 0; + vk_buffer d_Y; + uint64_t y_buf_offset = 0; + if(!src0_uma) { + d_Qx = src0_buf_ctx->dev_buffer; + qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs; + GGML_ASSERT(d_Qx != nullptr); + } + if(!src1_uma) { + d_Qy = src1_buf_ctx->dev_buffer; + qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs; + GGML_ASSERT(d_Qy != nullptr); + } + if(!ids_uma) { + d_ids = ids_buf_ctx->dev_buffer; + ids_buf_offset = vk_tensor_offset(ids) + ids->view_offs; + GGML_ASSERT(d_ids != nullptr); + } + if (qx_needs_dequant) { + d_X = ctx->prealloc_x; + } else { + d_X = d_Qx; + x_buf_offset = qx_buf_offset; + GGML_ASSERT(qx_sz == x_sz); + } + if (qy_needs_dequant) { + d_Y = ctx->prealloc_y; + } else { + d_Y = d_Qy; + y_buf_offset = qy_buf_offset; + GGML_ASSERT(qy_sz == y_sz); + } + + if (x_non_contig) { + GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment)); + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE }); + } + if (y_non_contig) { + GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne); + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }); + } + + uint32_t stride_batch_y = ne10*ne11; + + if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) { + stride_batch_y = src1->nb[0] / ggml_type_size(src1->type); + } + + const uint32_t max_groups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0]; + + uint32_t groups_x = ne01; + uint32_t groups_z = 1; + + if (ne01 > max_groups_x) { + groups_z = 64; + groups_x = CEIL_DIV(groups_x, groups_z); + } + + // compute + const vk_mat_vec_id_push_constants pc = { + (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01, + (uint32_t)x_ne, stride_batch_y, (uint32_t)(ne20*ne21), + (uint32_t)nei0, (uint32_t)ne11, + }; + ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, + { vk_subbuffer{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, + vk_subbuffer{ d_Y, y_buf_offset, y_sz * ne12 * ne13 }, vk_subbuffer{ d_D, d_buf_offset, d_sz * ne22 * ne23}, vk_subbuffer{ d_ids, ids_buf_offset, ids_sz } }, + sizeof(vk_mat_vec_id_push_constants), &pc, { groups_x, (uint32_t)nei0, groups_z }); +} + +static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) { + VK_LOG_DEBUG("ggml_vk_mul_mat_id(" << src0 << ", " << src1 << ", " << src2 << ", " << dst << ")"); + if (src2->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) { + ggml_vk_mul_mat_vec_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun); + } else { + ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun); + } +} + +static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * q, const ggml_tensor * k, const ggml_tensor * v, const ggml_tensor * mask, ggml_tensor * dst, bool dryrun = false) { + VK_LOG_DEBUG("ggml_vk_flash_attn((" << q << ", name=" << q->name << ", type=" << q->type << ", ne0=" << q->ne[0] << ", ne1=" << q->ne[1] << ", ne2=" << q->ne[2] << ", ne3=" << q->ne[3] << ", nb0=" << q->nb[0] << ", nb1=" << q->nb[1] << ", nb2=" << q->nb[2] << ", nb3=" << q->nb[3]; + std::cerr << "), (" << k << ", name=" << k->name << ", type=" << k->type << ", ne0=" << k->ne[0] << ", ne1=" << k->ne[1] << ", ne2=" << k->ne[2] << ", ne3=" << k->ne[3] << ", nb0=" << k->nb[0] << ", nb1=" << k->nb[1] << ", nb2=" << k->nb[2] << ", nb3=" << k->nb[3]; + std::cerr << "), (" << v << ", name=" << v->name << ", type=" << v->type << ", ne0=" << v->ne[0] << ", ne1=" << v->ne[1] << ", ne2=" << v->ne[2] << ", ne3=" << v->ne[3] << ", nb0=" << v->nb[0] << ", nb1=" << v->nb[1] << ", nb2=" << v->nb[2] << ", nb3=" << v->nb[3]; + std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; + std::cerr << "), " << (dryrun ? "dryrun" : "") << ")"); + + GGML_TENSOR_LOCALS(int64_t, neq, q, ne) + GGML_TENSOR_LOCALS(size_t, nbq, q, nb) + GGML_TENSOR_LOCALS(int64_t, nek, k, ne) + GGML_TENSOR_LOCALS(size_t, nbk, k, nb) + GGML_TENSOR_LOCALS(int64_t, nev, v, ne) + GGML_TENSOR_LOCALS(size_t, nbv, v, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + + const uint32_t nem1 = mask ? mask->ne[1] : 0; + const uint32_t nbm1 = mask ? mask->nb[1] : 0; + + const uint32_t D = neq0; + const uint32_t N = neq1; + const uint32_t KV = nek1; + + GGML_ASSERT(ne0 == D); + GGML_ASSERT(ne2 == N); + + // input tensor rows must be contiguous + GGML_ASSERT(nbq0 == ggml_type_size(q->type)); + GGML_ASSERT(nbk0 == ggml_type_size(k->type)); + GGML_ASSERT(nbv0 == ggml_type_size(v->type)); + + GGML_ASSERT(neq0 == D); + GGML_ASSERT(nek0 == D); + GGML_ASSERT(nev0 == D); + + GGML_ASSERT(neq1 == N); + GGML_ASSERT(nev0 == D); + + GGML_ASSERT(nev1 == nek1); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + assert(dst->type == GGML_TYPE_F32); + assert(q->type == GGML_TYPE_F32); + assert(k->type == v->type); + + vk_pipeline *pipelines; + // XXX TODO other backends may be changing accumulator precision to default to f32 soon + bool f32acc = dst->op_params[3] == GGML_PREC_F32; + bool small_rows = N <= flash_attention_num_small_rows; + switch (D) { + case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break; + case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break; + case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96[k->type][f32acc][small_rows][0]; break; + case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112[k->type][f32acc][small_rows][0]; break; + case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128[k->type][f32acc][small_rows][0]; break; + case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256[k->type][f32acc][small_rows][0]; break; + default: + assert(!"unsupported D value"); + return; + } + assert(pipelines); + + bool aligned = (KV % pipelines[1]->align) == 0; + vk_pipeline pipeline = pipelines[aligned]; + assert(pipeline); + + if (dryrun) { + // Request descriptor sets + ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); + return; + } + + float scale = 1.0f; + float max_bias = 0.0f; + float logit_softcap = 0.0f; + + memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float)); + memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float)); + + if (logit_softcap != 0) { + scale /= logit_softcap; + } + + const uint32_t n_head_kv = neq2; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv)); + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + ggml_vk_sync_buffers(subctx); + + vk_buffer d_Q = nullptr, d_K = nullptr, d_V = nullptr, d_D = nullptr, d_M = nullptr; + size_t q_buf_offset = 0, k_buf_offset = 0, v_buf_offset = 0, d_buf_offset = 0, m_buf_offset = 0; + + bool Q_uma = false, K_uma = false, V_uma = false, D_uma = false, M_uma = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, q->data, d_Q, q_buf_offset); + ggml_vk_host_get(ctx->device, k->data, d_K, q_buf_offset); + ggml_vk_host_get(ctx->device, v->data, d_V, q_buf_offset); + ggml_vk_host_get(ctx->device, dst->data, d_D, q_buf_offset); + Q_uma = d_Q != nullptr; + K_uma = d_K != nullptr; + V_uma = d_V != nullptr; + D_uma = d_D != nullptr; + if (mask) { + ggml_vk_host_get(ctx->device, mask->data, d_M, q_buf_offset); + M_uma = d_M != nullptr; + } + } + + + ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_buffer_context * q_buf_ctx = (ggml_backend_vk_buffer_context *)q->buffer->context; + ggml_backend_vk_buffer_context * k_buf_ctx = (ggml_backend_vk_buffer_context *)k->buffer->context; + ggml_backend_vk_buffer_context * v_buf_ctx = (ggml_backend_vk_buffer_context *)v->buffer->context; + + if (!Q_uma) { + d_Q = q_buf_ctx->dev_buffer; + q_buf_offset = vk_tensor_offset(q) + q->view_offs; + } + if (!K_uma) { + d_K = k_buf_ctx->dev_buffer; + k_buf_offset = vk_tensor_offset(k) + k->view_offs; + } + if (!V_uma) { + d_V = v_buf_ctx->dev_buffer; + v_buf_offset = vk_tensor_offset(v) + v->view_offs; + } + if (!D_uma) { + d_D = d_buf_ctx->dev_buffer; + d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; + } + + if (!M_uma) { + d_M = d_Q; + m_buf_offset = q_buf_offset; + if (mask) { + ggml_backend_vk_buffer_context * m_buf_ctx = (ggml_backend_vk_buffer_context*)mask->buffer->context; + d_M = m_buf_ctx->dev_buffer; + m_buf_offset = vk_tensor_offset(mask) + mask->view_offs; + } + } + + const vk_flash_attn_push_constants pc = { N, KV, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, (uint32_t)neq2, (uint32_t)neq3, (uint32_t)nek2, (uint32_t)nek3, (uint32_t)nev2, (uint32_t)nev3, nem1, (uint32_t)nbq2, (uint32_t)nbq3, (uint32_t)nbk2, (uint32_t)nbk3, (uint32_t)nbv2, (uint32_t)nbv3, nbm1, scale, max_bias, logit_softcap, mask != nullptr, n_head_log2, m0, m1 }; + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, + { + vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE}, + vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE}, + vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE}, + vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE}, + vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE}, + }, + sizeof(vk_flash_attn_push_constants), &pc, { (uint32_t)neq1, (uint32_t)neq2, (uint32_t)neq3 }); +} + +static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) { + switch (op) { + case GGML_OP_GET_ROWS: + GGML_ASSERT(src1->type == GGML_TYPE_I32); + if (dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_get_rows[src0->type]; + } + if (dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_get_rows_f32[src0->type]; + } + return nullptr; + case GGML_OP_ACC: + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_acc_f32; + } + return nullptr; + case GGML_OP_ADD: + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_f32_norepeat : ctx->device->pipeline_add_f32; + } + if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) { + return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_f16_f32_f16_norepeat : ctx->device->pipeline_add_f16_f32_f16; + } + return nullptr; + case GGML_OP_MUL: + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_mul_f32_norepeat : ctx->device->pipeline_mul_f32; + } + return nullptr; + case GGML_OP_DIV: + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_div_f32_norepeat : ctx->device->pipeline_div_f32; + } + return nullptr; + case GGML_OP_CONCAT: + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_concat_f32; + } + if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_concat_f16; + } + if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) { + return ctx->device->pipeline_concat_i32; + } + return nullptr; + case GGML_OP_UPSCALE: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_upscale_f32; + } + return nullptr; + case GGML_OP_SCALE: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_scale_f32; + } + return nullptr; + case GGML_OP_SQR: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_sqr_f32; + } + return nullptr; + case GGML_OP_SIN: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_sin_f32; + } + return nullptr; + case GGML_OP_COS: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_cos_f32; + } + return nullptr; + case GGML_OP_CLAMP: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_clamp_f32; + } + return nullptr; + case GGML_OP_PAD: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_pad_f32; + } + return nullptr; + case GGML_OP_REPEAT: + if (ggml_type_size(src0->type) == sizeof(float) && ggml_type_size(dst->type) == sizeof(float)) { + return ctx->device->pipeline_repeat_f32; + } + return nullptr; + case GGML_OP_CPY: + case GGML_OP_CONT: + case GGML_OP_DUP: + return ggml_vk_get_cpy_pipeline(ctx, src0, dst, dst->type); + case GGML_OP_NORM: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_norm_f32; + } + return nullptr; + case GGML_OP_GROUP_NORM: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_group_norm_f32; + } + return nullptr; + case GGML_OP_RMS_NORM: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_rms_norm_f32; + } + return nullptr; + case GGML_OP_UNARY: + switch (ggml_get_unary_op(dst)) { + case GGML_UNARY_OP_SILU: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_silu_f32; + } + break; + case GGML_UNARY_OP_GELU: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_gelu_f32; + } + break; + case GGML_UNARY_OP_GELU_QUICK: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_gelu_quick_f32; + } + break; + case GGML_UNARY_OP_RELU: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_relu_f32; + } + break; + case GGML_UNARY_OP_TANH: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_tanh_f32; + } + break; + default: + break; + } + return nullptr; + case GGML_OP_DIAG_MASK_INF: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_diag_mask_inf_f32; + } + return nullptr; + case GGML_OP_SOFT_MAX: + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); + + if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) { + return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_wg512 : ctx->device->pipeline_soft_max_f32; + } + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { + return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_f16_wg512 : ctx->device->pipeline_soft_max_f32_f16; + } + return nullptr; + case GGML_OP_ROPE: + { + const int mode = ((const int32_t *) dst->op_params)[2]; + const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; + + if (is_neox) { + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_rope_neox_f32; + } + if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_rope_neox_f16; + } + } else { + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_rope_norm_f32; + } + if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_rope_norm_f16; + } + } + return nullptr; + } + case GGML_OP_ARGSORT: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) { + return ctx->device->pipeline_argsort_f32; + } + return nullptr; + case GGML_OP_SUM_ROWS: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_sum_rows_f32; + } + return nullptr; + case GGML_OP_IM2COL: + if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_im2col_f32; + } + if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_im2col_f32_f16; + } + return nullptr; + case GGML_OP_TIMESTEP_EMBEDDING: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_timestep_embedding_f32; + } + return nullptr; + case GGML_OP_POOL_2D: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_pool2d_f32; + } + return nullptr; + case GGML_OP_RWKV_WKV6: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_rwkv_wkv6_f32; + } + return nullptr; + case GGML_OP_LEAKY_RELU: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_leaky_relu_f32; + } + return nullptr; + default: + return nullptr; + } + + GGML_UNUSED(src2); +} + +static bool ggml_vk_op_supports_incontiguous(ggml_op op) { + switch (op) { + case GGML_OP_CPY: + case GGML_OP_GET_ROWS: + case GGML_OP_ADD: + case GGML_OP_MUL: + case GGML_OP_DIV: + case GGML_OP_CONCAT: + case GGML_OP_UPSCALE: + case GGML_OP_SQR: + case GGML_OP_SIN: + case GGML_OP_COS: + case GGML_OP_CLAMP: + case GGML_OP_PAD: + case GGML_OP_REPEAT: + return true; + default: + return false; + } +} + +static uint32_t get_misalign_bytes(ggml_backend_vk_context * ctx, const ggml_tensor * t) +{ + return ((vk_tensor_offset(t) + t->view_offs) & (ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1));; +} + +template void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, T &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { + GGML_UNUSED(p); + GGML_UNUSED(src0); + GGML_UNUSED(src1); + GGML_UNUSED(src2); + GGML_UNUSED(dst); + static_assert(!std::is_const::value, "unexpected type"); + GGML_ASSERT(!src0 || get_misalign_bytes(ctx, src0) == 0); + GGML_ASSERT(!src1 || get_misalign_bytes(ctx, src1) == 0); + GGML_ASSERT(!src2 || get_misalign_bytes(ctx, src2) == 0); + GGML_ASSERT(!dst || get_misalign_bytes(ctx, dst) == 0); +} + +template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_unary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { + const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); + const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); + + p.misalign_offsets = (a_offset << 16) | d_offset; + + GGML_UNUSED(src1); + GGML_UNUSED(src2); +} + +template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_binary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { + const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); + const uint32_t b_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type); + const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); + + GGML_ASSERT(dst->op != GGML_OP_GET_ROWS || (a_offset == 0 && b_offset == 0 && d_offset == 0)); + + p.misalign_offsets = (a_offset << 16) | (b_offset << 8) | d_offset; + + GGML_UNUSED(src2); +} + +template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_upscale_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { + const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); + const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); + + p.a_offset = a_offset; + p.d_offset = d_offset; + + GGML_UNUSED(src1); + GGML_UNUSED(src2); +} + +template +static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op, PC&& pc, bool dryrun = false) { + VK_LOG_DEBUG("ggml_vk_op_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; + if (src1 != nullptr) { + std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; + } + if (src2 != nullptr) { + std::cerr << "), (" << src2 << ", name=" << src2->name << ", type=" << src2->type << ", ne0=" << src2->ne[0] << ", ne1=" << src2->ne[1] << ", ne2=" << src2->ne[2] << ", ne3=" << src2->ne[3] << ", nb0=" << src2->nb[0] << ", nb1=" << src2->nb[1] << ", nb2=" << src2->nb[2] << ", nb3=" << src2->nb[3]; + } + std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; + std::cerr << "), " << ggml_op_name(op) << ", " << (dryrun ? "dryrun" : "") << ")"); + GGML_ASSERT(op == GGML_OP_GET_ROWS || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type)))); // NOLINT + GGML_ASSERT(ggml_vk_op_supports_incontiguous(op) || ggml_vk_dim01_contiguous(src0)); // NOLINT + GGML_ASSERT(dst->buffer != nullptr); + const uint64_t ne00 = src0->ne[0]; + const uint64_t ne01 = src0->ne[1]; + const uint64_t ne02 = src0->ne[2]; + const uint64_t ne03 = src0->ne[3]; + const uint64_t ne0 = ne00 * ne01; + + const bool use_src1 = src1 != nullptr; + const uint64_t ne10 = use_src1 ? src1->ne[0] : 0; + const uint64_t ne11 = use_src1 ? src1->ne[1] : 0; + const uint64_t ne12 = use_src1 ? src1->ne[2] : 0; + const uint64_t ne13 = use_src1 ? src1->ne[3] : 0; + const uint64_t ne1 = ne10 * ne11; + // const uint64_t nb10 = use_src1 ? src1->nb[0] : 0; + + const bool use_src2 = src2 != nullptr; + const uint64_t ne20 = use_src2 ? src2->ne[0] : 0; + const uint64_t ne21 = use_src2 ? src2->ne[1] : 0; + const uint64_t ne22 = use_src2 ? src2->ne[2] : 0; + const uint64_t ne23 = use_src2 ? src2->ne[3] : 0; + const uint64_t ne2 = ne20 * ne21; + + const uint64_t ned0 = dst->ne[0]; + const uint64_t ned1 = dst->ne[1]; + const uint64_t ned2 = dst->ne[2]; + const uint64_t ned3 = dst->ne[3]; + const uint64_t ned = ned0 * ned1; + + init_pushconst_fastdiv(pc); + + vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, dst, op); + + if (pipeline == nullptr) { + std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(op) << " for " << ggml_type_name(src0->type); + if (src1 != nullptr) { + std::cerr << " and " << ggml_type_name(src1->type); + } + std::cerr << " to " << ggml_type_name(dst->type) << std::endl; + GGML_ABORT("fatal error"); + } + + if (dryrun) { + ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); + return; + } + + const bool op_supports_incontiguous = ggml_vk_op_supports_incontiguous(op); + + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; + ggml_backend_vk_buffer_context * src1_buf_ctx = use_src1 ? (ggml_backend_vk_buffer_context *)src1->buffer->context : nullptr; + ggml_backend_vk_buffer_context * src2_buf_ctx = use_src2 ? (ggml_backend_vk_buffer_context *)src2->buffer->context : nullptr; + + vk_buffer d_X = nullptr; + size_t x_buf_offset = 0; + vk_buffer d_Y = nullptr; + size_t y_buf_offset = 0; + vk_buffer d_Z = nullptr; + size_t z_buf_offset = 0; + + bool src0_uma = false; + bool src1_uma = false; + bool src2_uma = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, src0->data, d_X, x_buf_offset); + src0_uma = d_X != nullptr; + if (use_src1) { + ggml_vk_host_get(ctx->device, src1->data, d_Y, y_buf_offset); + src1_uma = d_Y != nullptr; + } + if (use_src2) { + ggml_vk_host_get(ctx->device, src2->data, d_Z, z_buf_offset); + src2_uma = d_Z != nullptr; + } + } + + uint64_t x_sz = ggml_type_size(src0->type)/ggml_blck_size(src0->type) * ne0; + uint64_t y_sz = use_src1 ? ggml_type_size(src1->type) * ne1 : 0; + uint64_t z_sz = use_src2 ? ggml_type_size(src2->type) * ne2 : 0; + uint64_t d_sz = ggml_type_size(dst->type) * ned; + + vk_buffer d_D = dst_buf_ctx->dev_buffer; + + // Workaround for tiny tensor inputs on ROPE + if (op == GGML_OP_ROPE && use_src1 && y_sz > d_D->size) { + y_sz = VK_WHOLE_SIZE; + } + + GGML_ASSERT(d_D != nullptr); + uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; + if(!src0_uma) { + d_X = src0_buf_ctx->dev_buffer; + x_buf_offset = vk_tensor_offset(src0) + src0->view_offs; + GGML_ASSERT(d_X != nullptr); + } + if (use_src1 && !src1_uma) { + d_Y = src1_buf_ctx->dev_buffer; + y_buf_offset = vk_tensor_offset(src1) + src1->view_offs; + GGML_ASSERT(d_Y != nullptr); + } + if (use_src2 && !src2_uma) { + d_Z = src2_buf_ctx->dev_buffer; + z_buf_offset = vk_tensor_offset(src2) + src2->view_offs; + GGML_ASSERT(d_Z != nullptr); + } + // Compute misalignment offset for descriptors and store it in in push constants, then align the descriptor offsets. + init_pushconst_tensor_offsets(ctx, pc, src0, src1, src2, dst); + x_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); + y_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); + z_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); + d_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); + + if (op_supports_incontiguous) { + x_sz = ggml_nbytes(src0); + y_sz = use_src1 ? ggml_nbytes(src1) : 0; + z_sz = use_src2 ? ggml_nbytes(src2) : 0; + d_sz = ggml_nbytes(dst); + + if (x_buf_offset + x_sz >= d_X->size) { + x_sz = VK_WHOLE_SIZE; + } + if (use_src1 && y_buf_offset + y_sz >= d_Y->size) { + y_sz = VK_WHOLE_SIZE; + } + if (use_src2 && z_buf_offset + z_sz >= d_Z->size) { + z_sz = VK_WHOLE_SIZE; + } + if (d_buf_offset + d_sz >= d_D->size) { + d_sz = VK_WHOLE_SIZE; + } + } + + std::array elements; + + // Single call if dimension 2 is contiguous + GGML_ASSERT(op_supports_incontiguous || (ggml_is_contiguous(src0) && (src1 == nullptr || ggml_is_contiguous(src1)))); + + switch (op) { + case GGML_OP_NORM: + case GGML_OP_RMS_NORM: + case GGML_OP_SOFT_MAX: + case GGML_OP_SUM_ROWS: + { + const uint32_t nr = ggml_nrows(src0); + if (nr > 262144) { + elements = { 512, 512, CEIL_DIV(nr, 262144) }; + } else if (nr > 512) { + elements = { 512, CEIL_DIV(nr, 512), 1 }; + } else { + elements = { nr, 1, 1 }; + } + } break; + case GGML_OP_GROUP_NORM: + { + const uint32_t num_groups = dst->op_params[0]; + elements = { num_groups * (uint32_t)src0->ne[3], 1, 1 }; + } break; + case GGML_OP_DIAG_MASK_INF: + case GGML_OP_ROPE: + elements = { (uint32_t)ggml_nrows(src0), (uint32_t)ne00, 1 }; + break; + case GGML_OP_GET_ROWS: + elements = { (uint32_t)ne00, (uint32_t)ne10, (uint32_t)(ne11 * ne12) }; + break; + case GGML_OP_ARGSORT: + elements = { (uint32_t)ne00, (uint32_t)ggml_nrows(src0), 1 }; + break; + case GGML_OP_IM2COL: + { + const bool is_2D = dst->op_params[6] == 1; + + const uint32_t IC = src1->ne[is_2D ? 2 : 1]; + + const uint32_t KH = is_2D ? src0->ne[1] : 1; + const uint32_t KW = src0->ne[0]; + + const uint32_t OH = is_2D ? dst->ne[2] : 1; + const uint32_t OW = dst->ne[1]; + + const uint32_t batch = src1->ne[is_2D ? 3 : 2]; + + elements = { OW * KW * KH, OH, batch * IC }; + } break; + case GGML_OP_TIMESTEP_EMBEDDING: + { + const uint32_t dim = dst->op_params[0]; + uint32_t half_ceil = (dim + 1) / 2; + elements = { half_ceil, (uint32_t)src0->ne[0], 1 }; + } break; + case GGML_OP_POOL_2D: + { + const uint32_t N = dst->ne[3]; + const uint32_t OC = dst->ne[2]; + const uint32_t OH = dst->ne[1]; + const uint32_t OW = dst->ne[0]; + elements = { N * OC * OH * OW, 1, 1}; + } break; + case GGML_OP_ADD: + case GGML_OP_DIV: + case GGML_OP_MUL: + case GGML_OP_SCALE: + case GGML_OP_SQR: + case GGML_OP_SIN: + case GGML_OP_COS: + case GGML_OP_CLAMP: + case GGML_OP_PAD: + case GGML_OP_REPEAT: + case GGML_OP_CPY: + case GGML_OP_CONCAT: + case GGML_OP_UPSCALE: + case GGML_OP_UNARY: + { + const uint32_t ne = ggml_nelements(dst); + if (ne > 262144) { + elements = { 512, 512, CEIL_DIV(ne, 262144) }; + } else if (ne > 512) { + elements = { 512, CEIL_DIV(ne, 512), 1 }; + } else { + elements = { ne, 1, 1 }; + } + } break; + default: + elements = { (uint32_t)ggml_nelements(src0), 1, 1 }; + break; + } + + if (!op_supports_incontiguous) { + if (x_sz != VK_WHOLE_SIZE) { + x_sz *= ne02 * ne03; + } + if (use_src1 && y_sz != VK_WHOLE_SIZE) { + y_sz *= ne12 * ne13; + } + if (use_src2 && z_sz != VK_WHOLE_SIZE) { + z_sz *= ne22 * ne23; + } + if (d_sz != VK_WHOLE_SIZE) { + d_sz *= ned2 * ned3; + } + } + + if (op == GGML_OP_SOFT_MAX) { + // Empty src1 is possible in soft_max, but the shader needs a buffer + vk_subbuffer subbuf_y; + if (use_src1) { + subbuf_y = { d_Y, y_buf_offset, y_sz }; + } else { + subbuf_y = { d_X, 0, x_sz }; + } + + ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); + } else if (op == GGML_OP_ROPE) { + // Empty src2 is possible in rope, but the shader needs a buffer + vk_subbuffer subbuf_z; + if (use_src2) { + subbuf_z = { d_Z, z_buf_offset, z_sz }; + } else { + subbuf_z = { d_X, 0, x_sz }; + } + + ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); + } else if (op == GGML_OP_IM2COL) { + // im2col uses only src1 and dst buffers + ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); + } else if (use_src2) { + ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); + } else if (use_src1) { + ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); + } else { + ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); + } +} + +static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GET_ROWS, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, 0, + }, dryrun); +} + +static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + int nb1 = dst->op_params[0] / 4; // 4 bytes of float32 + int nb2 = dst->op_params[1] / 4; // 4 bytes of float32 + // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused + int offset = dst->op_params[3] / 4; // offset in bytes + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_ACC, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, offset, + }, dryrun); +} + +static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_ADD, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, 0, + }, dryrun); +} + +static void ggml_vk_mul(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_MUL, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, 0, + }, dryrun); +} + +static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_DIV, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, 0, + }, dryrun); +} + +static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, bool dryrun = false) { + const ggml_tensor * k = dst->src[0]; + const ggml_tensor * v = dst->src[1]; + const ggml_tensor * r = dst->src[2]; + const ggml_tensor * tf = dst->src[3]; + const ggml_tensor * td = dst->src[4]; + const ggml_tensor * state = dst->src[5]; + + GGML_ASSERT(!ggml_is_quantized(k->type)); + GGML_ASSERT(!ggml_is_quantized(v->type)); + GGML_ASSERT(!ggml_is_quantized(r->type)); + GGML_ASSERT(!ggml_is_quantized(tf->type)); + GGML_ASSERT(!ggml_is_quantized(td->type)); + GGML_ASSERT(!ggml_is_quantized(state->type)); + GGML_ASSERT(dst->buffer != nullptr); + + vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, k, v, r, dst, GGML_OP_RWKV_WKV6); + GGML_ASSERT(pipeline != nullptr); + + if (dryrun) { + ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); + return; + } + + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_buffer_context * k_buf_ctx = (ggml_backend_vk_buffer_context *)k->buffer->context; + ggml_backend_vk_buffer_context * v_buf_ctx = (ggml_backend_vk_buffer_context *)v->buffer->context; + ggml_backend_vk_buffer_context * r_buf_ctx = (ggml_backend_vk_buffer_context *)r->buffer->context; + ggml_backend_vk_buffer_context * tf_buf_ctx = (ggml_backend_vk_buffer_context *)tf->buffer->context; + ggml_backend_vk_buffer_context * td_buf_ctx = (ggml_backend_vk_buffer_context *)td->buffer->context; + ggml_backend_vk_buffer_context * state_buf_ctx = (ggml_backend_vk_buffer_context *)state->buffer->context; + + ggml_vk_sync_buffers(subctx); + + vk_buffer d_D = nullptr, d_K = nullptr, d_V = nullptr, d_R = nullptr, d_TF = nullptr, d_TD = nullptr, d_State = nullptr; + size_t k_offset = 0, v_offset = 0, r_offset = 0, tf_offset = 0, td_offset = 0, state_offset = 0, dst_offset = 0; + bool K_uma = false, V_uma = false, R_uma = false, TF_uma = false, TD_uma = false, STATE_uma = false, DST_uma = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, k->data, d_K, k_offset); + ggml_vk_host_get(ctx->device, v->data, d_V, v_offset); + ggml_vk_host_get(ctx->device, r->data, d_R, r_offset); + ggml_vk_host_get(ctx->device, tf->data, d_TF, tf_offset); + ggml_vk_host_get(ctx->device, td->data, d_TD, td_offset); + ggml_vk_host_get(ctx->device, state->data, d_State, state_offset); + ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset); + + K_uma = d_K != nullptr; + V_uma = d_V != nullptr; + R_uma = d_R != nullptr; + TF_uma = d_TF != nullptr; + TD_uma = d_TD != nullptr; + STATE_uma = d_State != nullptr; + DST_uma = d_D != nullptr; + } + + if (!K_uma) { + d_K = k_buf_ctx->dev_buffer; + k_offset = vk_tensor_offset(k) + k->view_offs; + } + if (!V_uma) { + d_V = v_buf_ctx->dev_buffer; + v_offset = vk_tensor_offset(v) + v->view_offs; + } + if (!R_uma) { + d_R = r_buf_ctx->dev_buffer; + r_offset = vk_tensor_offset(r) + r->view_offs; + } + if (!TF_uma) { + d_TF = tf_buf_ctx->dev_buffer; + tf_offset = vk_tensor_offset(tf) + tf->view_offs; + } + if (!TD_uma) { + d_TD = td_buf_ctx->dev_buffer; + td_offset = vk_tensor_offset(td) + td->view_offs; + } + if (!STATE_uma) { + d_State = state_buf_ctx->dev_buffer; + state_offset = vk_tensor_offset(state) + state->view_offs; + } + if (!DST_uma) { + d_D = dst_buf_ctx->dev_buffer; + dst_offset = vk_tensor_offset(dst) + dst->view_offs; + } + + const uint64_t k_size = ggml_nbytes(k); + const uint64_t v_size = ggml_nbytes(v); + const uint64_t r_size = ggml_nbytes(r); + const uint64_t tf_size = ggml_nbytes(tf); + const uint64_t td_size = ggml_nbytes(td); + const uint64_t state_size = ggml_nbytes(state); + const uint64_t dst_size = ggml_nbytes(dst); + + std::array elements = { + (uint32_t)(pc.B * pc.H), + 1, + 1 + }; + + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { + vk_subbuffer{ d_K, k_offset, k_size }, + vk_subbuffer{ d_V, v_offset, v_size }, + vk_subbuffer{ d_R, r_offset, r_size }, + vk_subbuffer{ d_TF, tf_offset, tf_size }, + vk_subbuffer{ d_TD, td_offset, td_size }, + vk_subbuffer{ d_State, state_offset, state_size }, + vk_subbuffer{ d_D, dst_offset, dst_size } + }, sizeof(vk_op_rwkv_wkv6_push_constants), &pc, elements); +} + +static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) { + const size_t seq_length = dst->src[0]->ne[3]; + const size_t n_embed = dst->ne[0]; + const size_t n_heads = dst->src[0]->ne[2]; + const size_t n_seqs = dst->src[5]->ne[1]; + + ggml_vk_op_f32_rwkv6( + ctx, subctx, dst, + { + (uint32_t)n_seqs, + (uint32_t)seq_length, + (uint32_t)n_embed, + (uint32_t)n_heads, + }, + dryrun + ); +} + +static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + int * op_params = (int *)dst->op_params; + + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONCAT, { + (uint32_t)ggml_nelements(dst), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, op_params[0], + }, dryrun); +} + +static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + + const float sf0 = (float)dst->ne[0] / src0->ne[0]; + const float sf1 = (float)dst->ne[1] / src0->ne[1]; + const float sf2 = (float)dst->ne[2] / src0->ne[2]; + const float sf3 = (float)dst->ne[3] / src0->ne[3]; + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UPSCALE, { + (uint32_t)ggml_nelements(dst), 0, 0, + (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)dst->ne[0], (uint32_t)dst->ne[1], (uint32_t)dst->ne[2],(uint32_t)dst->ne[3], + sf0, sf1, sf2, sf3, + }, dryrun); +} + +static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + float * op_params = (float *)dst->op_params; + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SCALE, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + op_params[0], 0.0f, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }, dryrun); +} + +static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQR, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }, dryrun); +} + +static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SIN, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }, dryrun); +} + +static void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_COS, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }, dryrun); +} + +static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + float * op_params = (float *)dst->op_params; + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CLAMP, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + op_params[0], op_params[1], + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }, dryrun); +} + +static void ggml_vk_pad(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_PAD, { + (uint32_t)ggml_nelements(dst), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }, dryrun); +} + +static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT, { + (uint32_t)ggml_nelements(dst), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }, dryrun); +} + +static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }, dryrun); +} + +static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + float * op_params = (float *)dst->op_params; + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun); +} + +static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + const int * int_op_params = (const int *)dst->op_params; + const float * float_op_params = (const float *)dst->op_params; + + const uint32_t num_groups = int_op_params[0]; + const float eps = float_op_params[1]; + const uint32_t group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun); +} + +static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + float * op_params = (float *)dst->op_params; + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun); +} + +static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun); +} + +static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + int32_t * op_params = (int32_t *)dst->op_params; + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }, dryrun); +} + +static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + float * op_params = (float *)dst->op_params; + + float scale = op_params[0]; + float max_bias = op_params[1]; + + const uint32_t ncols = (uint32_t)src0->ne[0]; + const uint32_t nrows_x = (uint32_t)ggml_nrows(src0); + const uint32_t nrows_y = (uint32_t)src0->ne[1]; + + const uint32_t n_head_kv = nrows_x/nrows_y; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, { + ncols, + src1 != nullptr ? nrows_y : (uint32_t)0, + scale, max_bias, + m0, m1, + n_head_log2, + nrows_x, + }, dryrun); +} + +static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) { + const int n_dims = ((int32_t *) dst->op_params)[1]; + // const int mode = ((int32_t *) dst->op_params)[2]; + // const int n_ctx = ((int32_t *) dst->op_params)[3]; + const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; + const float freq_base = ((float *) dst->op_params)[5]; + const float freq_scale = ((float *) dst->op_params)[6]; + const float ext_factor = ((float *) dst->op_params)[7]; + const float attn_factor = ((float *) dst->op_params)[8]; + const float beta_fast = ((float *) dst->op_params)[9]; + const float beta_slow = ((float *) dst->op_params)[10]; + + float corr_dims[2]; + ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); + + const float theta_scale = powf(freq_base, -2.0f/n_dims); + + ggml_vk_op_f32(ctx, subctx, src0, src1, src2, dst, GGML_OP_ROPE, { + (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1], + freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale, + src2 != nullptr, + }, dryrun); +} + +static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + int32_t * op_params = (int32_t *)dst->op_params; + + uint32_t ncols = src0->ne[0]; + + uint32_t ncols_pad = 1; + while (ncols_pad < ncols) { + ncols_pad *= 2; + } + + GGML_ASSERT(ncols_pad <= 1024); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ARGSORT, { + ncols, + ncols_pad, + op_params[0], + }, dryrun); +} + +static void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, { (uint32_t)src0->ne[0], 0, 0.0f, 0.0f }, dryrun); +} + +static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + const int32_t s0 = dst->op_params[0]; + const int32_t s1 = dst->op_params[1]; + const int32_t p0 = dst->op_params[2]; + const int32_t p1 = dst->op_params[3]; + const int32_t d0 = dst->op_params[4]; + const int32_t d1 = dst->op_params[5]; + + const bool is_2D = dst->op_params[6] == 1; + + const uint32_t IC = src1->ne[is_2D ? 2 : 1]; + const uint32_t IH = is_2D ? src1->ne[1] : 1; + const uint32_t IW = src1->ne[0]; + + const uint32_t KH = is_2D ? src0->ne[1] : 1; + const uint32_t KW = src0->ne[0]; + + const uint32_t OH = is_2D ? dst->ne[2] : 1; + const uint32_t OW = dst->ne[1]; + + const uint32_t offset_delta = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32 + const uint32_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32 + + const uint32_t pelements = OW * KW * KH; + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_IM2COL, { + batch_offset, offset_delta, + IC, IW, IH, OW, OH, KW, KH, + pelements, + IC * KH * KW, + s0, s1, p0, p1, d0, d1, + }, dryrun); +} + +static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + const uint32_t dim = dst->op_params[0]; + const uint32_t max_period = dst->op_params[1]; + const uint32_t nb1 = dst->nb[1] / ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_TIMESTEP_EMBEDDING, { + nb1, dim, max_period, + }, dryrun); +} + +static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + uint32_t op = static_cast(dst->op_params[0]); + const int32_t k1 = dst->op_params[1]; + const int32_t k0 = dst->op_params[2]; + const int32_t s1 = dst->op_params[3]; + const int32_t s0 = dst->op_params[4]; + const int32_t p1 = dst->op_params[5]; + const int32_t p0 = dst->op_params[6]; + + const uint32_t IH = src0->ne[1]; + const uint32_t IW = src0->ne[0]; + + const uint32_t N = dst->ne[3]; + + const uint32_t OC = dst->ne[2]; + const uint32_t OH = dst->ne[1]; + const uint32_t OW = dst->ne[0]; + + const uint32_t parallel_elements = N * OC * OH * OW; + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_POOL_2D, { + IW, IH, OW, OH, OC, + parallel_elements, + op, + k0, k1, s0, s1, p0, p1, + }, dryrun); +} + +static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + const float * op_params = (const float *)dst->op_params; + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f }, dryrun); +} + +#ifdef GGML_VULKAN_RUN_TESTS +static void ggml_vk_print_matrix_area(const void * data, ggml_type type, int ne0, int ne1, int i0, int i1, int i2) { + if (type != GGML_TYPE_F32 && type != GGML_TYPE_F16) { + return; + } + i0 = std::max(i0, 5); + i1 = std::max(i1, 5); + i2 = std::max(i2, 0); + fprintf(stderr, " "); + for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) { + fprintf(stderr, "%7d ", idx1); + } + fprintf(stderr, "\n"); + for (int idx0 = i0 - 5; idx0 < i0 + 5; idx0++) { + fprintf(stderr, "%7d: ", idx0); + for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) { + if (idx0 >= 0 && idx0 < ne0 && idx1 >= 0 && idx1 < ne1) { + float val; + if (type == GGML_TYPE_F32) { + val = *((const float *) data + i2*ne1*ne0 + idx1*ne0 + idx0); + } else if (type == GGML_TYPE_F16) { + val = ggml_fp16_to_fp32(*((const ggml_fp16_t *) data + i2*ne1*ne0 + idx1*ne0 + idx0)); + } else { + GGML_ABORT("fatal error"); + } + fprintf(stderr, "% 7.2f ", val); + } else { + fprintf(stderr, " "); + } + } + fprintf(stderr, "\n"); + } +} + +template +static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, int split_k, int shader_size) { + VK_LOG_DEBUG("ggml_vk_test_matmul(" << m << ", " << n << ", " << k << ", " << batch << ", " << num_it << ", " << split_k << ", " << shader_size << ")"); + const size_t x_ne = m * k * batch; + const size_t y_ne = k * n * batch; + const size_t d_ne = m * n * batch; + + vk_pipeline p; + std::string shname; + if (shader_size == 0) { + if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f32->a_s; + shname = "F32_ALIGNED_S"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f32_f16->a_s; + shname = "F32_F16_ALIGNED_S"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f16_f32.f32acc->a_s; + shname = "F16_F32_ALIGNED_S"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f16.f32acc->a_s; + shname = "F16_ALIGNED_S"; + } else { + GGML_ABORT("fatal error"); + } + } else if (shader_size == 1) { + if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f32->a_m; + shname = "F32_ALIGNED_M"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f32_f16->a_m; + shname = "F32_F16_ALIGNED_M"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f16_f32.f32acc->a_m; + shname = "F16_F32_ALIGNED_M"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f16.f32acc->a_m; + shname = "F16_ALIGNED_M"; + } else { + GGML_ABORT("fatal error"); + } + } else if (shader_size == 2) { + if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f32->a_l; + shname = "F32_ALIGNED_L"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f32_f16->a_l; + shname = "F32_F16_ALIGNED_L"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f16_f32.f32acc->a_l; + shname = "F16_F32_ALIGNED_L"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f16.f32acc->a_l; + shname = "F16_ALIGNED_L"; + } else { + GGML_ABORT("fatal error"); + } + } else { + GGML_ASSERT(0); + } + + const size_t kpad = ggml_vk_align_size(k, p->align); + + if (k != kpad) { + if (shader_size == 0) { + if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f32->s; + shname = "F32_S"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f32_f16->s; + shname = "F32_F16_S"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f16_f32.f32acc->s; + shname = "F16_F32_S"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f16.f32acc->s; + shname = "F16_S"; + } + } else if (shader_size == 1) { + if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f32->m; + shname = "F32_M"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f32_f16->m; + shname = "F32_F16_M"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f16_f32.f32acc->m; + shname = "F16_F32_M"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f16.f32acc->m; + shname = "F16_M"; + } + } else if (shader_size == 2) { + if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f32->l; + shname = "F32_L"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f32_f16->l; + shname = "F32_F16_L"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f16_f32.f32acc->l; + shname = "F16_F32_L"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f16.f32acc->l; + shname = "F16_L"; + } + } + } + + ggml_pipeline_request_descriptor_sets(ctx->device, p, num_it); + if (split_k > 1) { + ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, num_it); + + if (ctx->prealloc_split_k == nullptr || ctx->prealloc_split_k->size < sizeof(float) * d_ne * split_k) { + // Resize buffer + if (ctx->prealloc_split_k != nullptr) { + ggml_vk_destroy_buffer(ctx->prealloc_split_k); + } + ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, vk::MemoryPropertyFlagBits::eDeviceLocal); + } + } + + ggml_pipeline_allocate_descriptor_sets(ctx->device); + + vk_buffer d_X = ggml_vk_create_buffer_check(ctx->device, sizeof(X_TYPE) * x_ne, vk::MemoryPropertyFlagBits::eDeviceLocal); + vk_buffer d_Y = ggml_vk_create_buffer_check(ctx->device, sizeof(Y_TYPE) * y_ne, vk::MemoryPropertyFlagBits::eDeviceLocal); + vk_buffer d_D = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne, vk::MemoryPropertyFlagBits::eDeviceLocal); + + X_TYPE* x = (X_TYPE *) malloc(sizeof(X_TYPE) * x_ne); + Y_TYPE* y = (Y_TYPE *) malloc(sizeof(Y_TYPE) * y_ne); + float* d = (float *) malloc(sizeof(float) * d_ne); + + for (size_t i = 0; i < x_ne; i++) { + if (std::is_same()) { + x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f; + // x[i] = 1.0f; + // x[i] = i + 1; + // x[i] = (i % k == i / k) ? 1.0f : 0.0f; + } else if (std::is_same()) { + x[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f); + // x[i] = ggml_fp32_to_fp16(1.0f); + // x[i] = ggml_fp32_to_fp16(i + 1); + // x[i] = ggml_fp32_to_fp16((i % k == i / k) ? 1.0f : 0.0f); + } else { + GGML_ABORT("fatal error"); + } + } + for (size_t i = 0; i < y_ne; i++) { + if (std::is_same()) { + y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f; + // y[i] = (i % k == i / k) ? 1.0f : 0.0f; + // y[i] = i + 1; + } else if (std::is_same()) { + y[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f); + // y[i] = ggml_fp32_to_fp16((i % k == i / k) ? 1.0f : 0.0f); + // y[i] = ggml_fp32_to_fp16(i + 1); + } else { + GGML_ABORT("fatal error"); + } + } + + ggml_vk_buffer_write(d_X, 0, x, sizeof(X_TYPE) * k * m * batch); + ggml_vk_buffer_write(d_Y, 0, y, sizeof(Y_TYPE) * k * n * batch); + + vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue); + ggml_vk_ctx_begin(ctx->device, subctx); + for (size_t i = 0; i < num_it; i++) { + ggml_vk_matmul( + ctx, subctx, p, ggml_vk_subbuffer(d_X), ggml_vk_subbuffer(d_Y), ggml_vk_subbuffer(d_D), ggml_vk_subbuffer(ctx->prealloc_split_k), + m, n, k, + k, k, m, k*m, k*n, m*n, + split_k, batch, batch, batch, 1, 1 + ); + } + ggml_vk_ctx_end(subctx); + + auto begin = std::chrono::high_resolution_clock::now(); + ggml_vk_submit(subctx, ctx->fence); + VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_matmul waitForFences"); + ctx->device->device.resetFences({ ctx->fence }); + + auto end = std::chrono::high_resolution_clock::now(); + double time = std::chrono::duration_cast(end-begin).count() / 1000.0; + + // copy dst to host + ggml_vk_buffer_read(d_D, 0, d, sizeof(float) * d_ne); + + float * d_chk = (float *) malloc(sizeof(float) * d_ne); + + ggml_init_params iparams = { + /*.mem_size =*/ 1024*1024*1024, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + ggml_context * ggml_ctx = ggml_init(iparams); + + ggml_type src0_type; + ggml_type src1_type; + + if (std::is_same()) { + src0_type = GGML_TYPE_F32; + } else if (std::is_same()) { + src0_type = GGML_TYPE_F16; + } else { + GGML_ABORT("fatal error"); + } + if (std::is_same()) { + src1_type = GGML_TYPE_F32; + } else if (std::is_same()) { + src1_type = GGML_TYPE_F16; + } else { + GGML_ABORT("fatal error"); + } + + ggml_tensor * src0_ggml = ggml_new_tensor_3d(ggml_ctx, src0_type, k, m, batch); + ggml_tensor * src1_ggml = ggml_new_tensor_3d(ggml_ctx, src1_type, k, n, batch); + ggml_tensor * tensor_ggml = ggml_mul_mat(ggml_ctx, src0_ggml, src1_ggml); + + src0_ggml->data = x; + src1_ggml->data = y; + tensor_ggml->data = d_chk; + + ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx); + ggml_build_forward_expand(cgraph, tensor_ggml); + + ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 1); + + ggml_free(ggml_ctx); + + double avg_err = 0.0; + int first_err_n = -1; + int first_err_m = -1; + int first_err_b = -1; + + for (size_t i = 0; i < m*n*batch; i++) { + double err = std::fabs(d[i] - d_chk[i]); + avg_err += err; + + if ((err > 0.05f || std::isnan(err)) && first_err_n == -1) { + first_err_b = i / (m * n); + first_err_n = (i % (m * n)) / m; + first_err_m = (i % (m * n)) % m; + } + } + + avg_err /= m * n; + + double tflops = 2.0*m*n*k*batch*num_it / (time / 1000.0) / (1000.0*1000.0*1000.0*1000.0); + + std::cerr << "TEST " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl; + + if (avg_err > 0.1 || std::isnan(avg_err)) { + std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl; + std::cerr << "Actual result: " << std::endl << std::endl; + ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); + std::cerr << "Expected result: " << std::endl << std::endl; + ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); + + if (split_k > 1) { + float * split_k_buf = (float *) malloc(sizeof(float) * d_ne * split_k); + ggml_vk_buffer_read(ctx->prealloc_split_k, 0, split_k_buf, sizeof(float) * d_ne * split_k); + + std::cerr << "d_buf0: " << std::endl << std::endl; + ggml_vk_print_matrix_area(split_k_buf, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); + + std::cerr << "d_buf1: " << std::endl << std::endl; + ggml_vk_print_matrix_area(split_k_buf + d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); + + std::cerr << "d_buf2: " << std::endl << std::endl; + ggml_vk_print_matrix_area(split_k_buf + 2 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); + + std::cerr << "d_buf3: " << std::endl << std::endl; + ggml_vk_print_matrix_area(split_k_buf + 3 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); + + free(split_k_buf); + } + } + + free(d_chk); + + ggml_vk_queue_cleanup(ctx->device, ctx->device->transfer_queue); + ggml_vk_queue_cleanup(ctx->device, ctx->device->compute_queue); + + ggml_vk_destroy_buffer(d_X); + ggml_vk_destroy_buffer(d_Y); + ggml_vk_destroy_buffer(d_D); + + ggml_pipeline_cleanup(p); + ggml_pipeline_cleanup(ctx->device->pipeline_matmul_split_k_reduce); + + free(x); + free(y); + free(d); +} + +static void ggml_vk_print_tensor_area(const ggml_tensor * tensor, int i0, int i1, int i2, int i3) { + if (tensor->type != GGML_TYPE_F32 && tensor->type != GGML_TYPE_F16) { + return; + } + i0 = std::max(i0, 5); + i1 = std::max(i1, 5); + i2 = std::max(i2, 0); + i3 = std::max(i3, 0); + fprintf(stderr, " "); + for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) { + fprintf(stderr, "%7d ", idx1); + } + fprintf(stderr, "\n"); + for (int idx0 = i0 - 5; idx0 < i0 + 5; idx0++) { + fprintf(stderr, "%7d: ", idx0); + for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) { + if (idx0 >= 0 && idx0 < tensor->ne[0] && idx1 >= 0 && idx1 < tensor->ne[1] && i2 >= 0 && i2 < tensor->ne[2] && i3 >= 0 && i3 < tensor->ne[3]) { + float val; + if (tensor->type == GGML_TYPE_F32) { + val = *(float *) ((char *) tensor->data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]); + } else if (tensor->type == GGML_TYPE_F16) { + val = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) tensor->data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0])); + } else { + GGML_ABORT("fatal error"); + } + fprintf(stderr, "% 7.2f ", val); + } else { + fprintf(stderr, " "); + } + } + fprintf(stderr, "\n"); + } +} + +static void ggml_vk_quantize_data(const float * from, void * to, size_t ne, ggml_type quant) { + ggml_quantize_chunk(quant, from, to, 0, 1, ne, nullptr); +} + +static void ggml_vk_dequantize_data(const void * from, float * to, size_t ne, ggml_type quant) { + if (quant == GGML_TYPE_F32) { + memcpy(to, from, sizeof(float) * ne); + return; + } + + const auto * tt = ggml_get_type_traits(quant); + + ggml_to_float_t dequant_fn = tt->to_float; + + dequant_fn(from, to, ne); +} + +static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_type quant) { + VK_LOG_DEBUG("ggml_vk_test_dequant(" << ne << ")"); + const size_t x_sz = sizeof(float) * ne; + const size_t x_sz_f16 = sizeof(ggml_fp16_t) * ne; + const size_t qx_sz = ne * ggml_type_size(quant)/ggml_blck_size(quant); + float * x = (float *) malloc(x_sz); + void * qx = malloc(qx_sz); + vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal); + vk_buffer x_buf = ggml_vk_create_buffer_check(ctx->device, x_sz_f16, vk::MemoryPropertyFlagBits::eDeviceLocal); + float * x_ref = (float *) malloc(x_sz); + ggml_fp16_t * x_chk = (ggml_fp16_t *) malloc(x_sz_f16); + + for (size_t i = 0; i < ne; i++) { + x[i] = rand() / (float)RAND_MAX; + } + + vk_pipeline p = ggml_vk_get_to_fp16(ctx, quant); + + ggml_vk_quantize_data(x, qx, ne, quant); + ggml_vk_dequantize_data(qx, x_ref, ne, quant); + + ggml_pipeline_request_descriptor_sets(ctx->device, p, 1); + + ggml_pipeline_allocate_descriptor_sets(ctx->device); + + ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz); + + vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue); + ggml_vk_ctx_begin(ctx->device, subctx); + const std::vector pc = { 1, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne }; + ggml_vk_dispatch_pipeline(ctx, subctx, p, { vk_subbuffer{ qx_buf, 0, qx_sz }, vk_subbuffer{ x_buf, 0, x_sz_f16 } }, pc.size() * sizeof(int), pc.data(), { (uint32_t)ne, 1, 1}); + ggml_vk_ctx_end(subctx); + + auto begin = std::chrono::high_resolution_clock::now(); + + ggml_vk_submit(subctx, ctx->fence); + VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_dequant waitForFences"); + ctx->device->device.resetFences({ ctx->fence }); + + auto end = std::chrono::high_resolution_clock::now(); + + double ms_dequant = std::chrono::duration_cast(end-begin).count() / 1000.0; + ggml_vk_buffer_read(x_buf, 0, x_chk, x_sz_f16); + + int first_err = -1; + + double avg_err = 0.0; + for (size_t i = 0; i < ne; i++) { + double error = std::fabs(x_ref[i] - ggml_fp16_to_fp32(x_chk[i])); + avg_err += error; + + if (first_err < 0 && error > 0.05) { + first_err = i; + } + } + + avg_err /= ne; + + std::cerr << "TEST DEQUANT " << ggml_type_name(quant) << " time=" << ms_dequant << "ms avg_err=" << avg_err << std::endl; + + if (avg_err > 0.1) { + std::cerr << "first_error = " << first_err << std::endl; + std::cerr << "Actual result: " << std::endl << std::endl; + for (int i = std::max(0, first_err - 5); i < std::min((int)ne, first_err + 5); i++) { + std::cerr << ggml_fp16_to_fp32(x_chk[i]) << ", "; + } + std::cerr << std::endl << "Expected result: " << std::endl << std::endl; + for (int i = std::max(0, first_err - 5); i < std::min((int)ne, first_err + 5); i++) { + std::cerr << x_ref[i] << ", "; + } + std::cerr << std::endl; + } + + ggml_vk_destroy_buffer(x_buf); + ggml_vk_destroy_buffer(qx_buf); + + free(x); + free(qx); + free(x_ref); + free(x_chk); +} + +static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, size_t split_k, size_t shader_size, ggml_type quant) { + VK_LOG_DEBUG("ggml_vk_test_dequant_matmul(" << m << ", " << n << ", " << k << ", " << batch << ", " << num_it << ", " << split_k << ", " << ggml_type_name(quant) << ")"); + const size_t x_ne = m * k * batch; + const size_t y_ne = k * n * batch; + const size_t d_ne = m * n * batch; + + vk_pipeline p; + std::string shname; + if (shader_size == 0) { + p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_s : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_s; + shname = std::string(ggml_type_name(quant)) + "_ALIGNED_S"; + } else if (shader_size == 1) { + p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_m : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_m; + shname = std::string(ggml_type_name(quant)) + "_ALIGNED_M"; + } else if (shader_size == 2) { + p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_l : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_l; + shname = std::string(ggml_type_name(quant)) + "_ALIGNED_L"; + } else { + GGML_ASSERT(0); + } + + const size_t kpad = ggml_vk_align_size(k, p->align); + + if (k != kpad) { + if (shader_size == 0) { + p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->s : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->s; + shname = std::string(ggml_type_name(quant)) + "_S"; + } else if (shader_size == 1) { + p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->m : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->m; + shname = std::string(ggml_type_name(quant)) + "_M"; + } else if (shader_size == 2) { + p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->l : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->l; + shname = std::string(ggml_type_name(quant)) + "_L"; + } else { + GGML_ASSERT(0); + } + } + + const size_t x_sz = sizeof(float) * x_ne; + const size_t y_sz = sizeof(float) * y_ne; + const size_t qx_sz = x_ne * ggml_type_size(quant)/ggml_blck_size(quant); + const size_t d_sz = sizeof(float) * d_ne; + float * x = (float *) malloc(x_sz); + float * y = (float *) malloc(y_sz); + void * qx = malloc(qx_sz); + vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal); + vk_buffer y_buf = ggml_vk_create_buffer_check(ctx->device, y_sz, vk::MemoryPropertyFlagBits::eDeviceLocal); + vk_buffer d_buf = ggml_vk_create_buffer_check(ctx->device, d_sz, vk::MemoryPropertyFlagBits::eDeviceLocal); + float * d = (float *) malloc(d_sz); + float * d_chk = (float *) malloc(d_sz); + + for (size_t i = 0; i < x_ne; i++) { + x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f; + } + + ggml_vk_quantize_data(x, qx, x_ne, quant); + + for (size_t i = 0; i < y_ne; i++) { + // y[i] = rand() / (float)RAND_MAX; + y[i] = (i % k == i / k) ? 1.0f : 0.0f; + } + + ggml_pipeline_request_descriptor_sets(ctx->device, p, num_it); + if (split_k > 1) { + ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, num_it); + + if (ctx->prealloc_split_k == nullptr || ctx->prealloc_split_k->size < sizeof(float) * d_ne * split_k) { + // Resize buffer + if (ctx->prealloc_split_k != nullptr) { + ggml_vk_destroy_buffer(ctx->prealloc_split_k); + } + ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, vk::MemoryPropertyFlagBits::eDeviceLocal); + } + } + + ggml_pipeline_allocate_descriptor_sets(ctx->device); + + ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz); + ggml_vk_buffer_write(y_buf, 0, y, y_sz); + + vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue); + ggml_vk_ctx_begin(ctx->device, subctx); + for (size_t i = 0; i < num_it; i++) { + ggml_vk_matmul( + ctx, subctx, p, ggml_vk_subbuffer(qx_buf), ggml_vk_subbuffer(y_buf), ggml_vk_subbuffer(d_buf), ggml_vk_subbuffer(ctx->prealloc_split_k), + m, n, k, + k, k, m, k*m, k*n, m*n, + split_k, batch, batch, batch, 1, 1 + ); + } + ggml_vk_ctx_end(subctx); + + auto begin = std::chrono::high_resolution_clock::now(); + + ggml_vk_submit(subctx, ctx->fence); + VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_dequant waitForFences"); + ctx->device->device.resetFences({ ctx->fence }); + + auto end = std::chrono::high_resolution_clock::now(); + + double time_ms = std::chrono::duration_cast(end-begin).count() / 1000.0; + ggml_vk_buffer_read(d_buf, 0, d, d_sz); + + ggml_init_params iparams = { + /*.mem_size =*/ 1024*1024*1024, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + ggml_context * ggml_ctx = ggml_init(iparams); + + ggml_tensor * src0_ggml = ggml_new_tensor_3d(ggml_ctx, quant, k, m, batch); + ggml_tensor * src1_ggml = ggml_new_tensor_3d(ggml_ctx, GGML_TYPE_F32, k, n, batch); + ggml_tensor * tensor_ggml = ggml_mul_mat(ggml_ctx, src0_ggml, src1_ggml); + + src0_ggml->data = qx; + src1_ggml->data = y; + tensor_ggml->data = d_chk; + + ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx); + ggml_build_forward_expand(cgraph, tensor_ggml); + + ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 1); + + ggml_free(ggml_ctx); + + double avg_err = 0.0; + int first_err_n = -1; + int first_err_m = -1; + int first_err_b = -1; + + for (size_t i = 0; i < m*n*batch; i++) { + double err = std::fabs(d[i] - d_chk[i]); + avg_err += err; + + if ((err > 0.05f || std::isnan(err)) && first_err_n == -1) { + first_err_b = i / (m * n); + first_err_n = (i % (m * n)) / m; + first_err_m = (i % (m * n)) % m; + } + } + + avg_err /= m * n; + + double tflops = 2.0*m*n*k*batch*num_it / (time_ms / 1000.0) / (1000.0*1000.0*1000.0*1000.0); + + std::cerr << "TEST MMQ " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time_ms / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl; + + if (avg_err > 0.01 || std::isnan(avg_err)) { + std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl; + std::cerr << "Actual result: " << std::endl << std::endl; + ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); + std::cerr << std::endl; + std::cerr << "Expected result: " << std::endl << std::endl; + ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); + + if (split_k > 1) { + float * split_k_buf = (float *) malloc(sizeof(float) * d_ne * split_k); + ggml_vk_buffer_read(ctx->prealloc_split_k, 0, split_k_buf, sizeof(float) * d_ne * split_k); + + std::cerr << "d_buf0: " << std::endl << std::endl; + ggml_vk_print_matrix_area(split_k_buf, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); + + std::cerr << "d_buf1: " << std::endl << std::endl; + ggml_vk_print_matrix_area(split_k_buf + d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); + + std::cerr << "d_buf2: " << std::endl << std::endl; + ggml_vk_print_matrix_area(split_k_buf + 2 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); + + std::cerr << "d_buf3: " << std::endl << std::endl; + ggml_vk_print_matrix_area(split_k_buf + 3 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); + + free(split_k_buf); + } + } + + ggml_vk_destroy_buffer(qx_buf); + ggml_vk_destroy_buffer(y_buf); + ggml_vk_destroy_buffer(d_buf); + + free(x); + free(qx); + free(y); + free(d); + free(d_chk); +} +#endif + +static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) { +#if defined(GGML_VULKAN_RUN_TESTS) + const std::vector vals { + 512, 512, 128, + 128, 512, 512, + 4096, 512, 4096, + 11008, 512, 4096, + 4096, 512, 11008, + 32000, 512, 4096, + 8, 8, 8, + 100, 46, 576, + 623, 111, 128, + 100, 46, 558, + 512, 1, 256, + 128, 110, 622, + 511, 511, 127, + 511, 511, 7, + 511, 511, 17, + 49, 49, 128, + 128, 49, 49, + 4096, 49, 4096, + }; + const size_t num_it = 100; + + for (size_t i = 0; i < vals.size(); i += 3) { + ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0); + ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1); + ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2); + std::cerr << '\n'; + ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 0); + ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 1); + ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 2); + std::cerr << '\n'; + ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0); + ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1); + ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2); + std::cerr << '\n' << std::endl; + + if (vals[i + 2] % 32 == 0) { + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0, GGML_TYPE_Q4_0); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1, GGML_TYPE_Q4_0); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2, GGML_TYPE_Q4_0); + std::cerr << '\n'; + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 0, GGML_TYPE_Q4_0); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 1, GGML_TYPE_Q4_0); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 2, GGML_TYPE_Q4_0); + std::cerr << '\n'; + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0, GGML_TYPE_Q4_0); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1, GGML_TYPE_Q4_0); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2, GGML_TYPE_Q4_0); + std::cerr << '\n' << std::endl; + } + + if (vals[i + 2] % 256 == 0) { + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0, GGML_TYPE_Q4_K); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1, GGML_TYPE_Q4_K); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2, GGML_TYPE_Q4_K); + std::cerr << '\n'; + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 0, GGML_TYPE_Q4_K); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 1, GGML_TYPE_Q4_K); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 2, GGML_TYPE_Q4_K); + std::cerr << '\n'; + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0, GGML_TYPE_Q4_K); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1, GGML_TYPE_Q4_K); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2, GGML_TYPE_Q4_K); + std::cerr << '\n' << std::endl; + } + } + + GGML_ABORT("fatal error"); +#endif + + if (ctx->prealloc_x == nullptr || (ctx->prealloc_size_x > 0 && ctx->prealloc_x->size < ctx->prealloc_size_x)) { + VK_LOG_MEMORY("ggml_vk_preallocate_buffers(x_size: " << ctx->prealloc_size_x << ")"); + // Resize buffer + if (ctx->prealloc_x != nullptr) { + ggml_vk_destroy_buffer(ctx->prealloc_x); + } + ctx->prealloc_x = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_x); + } + if (ctx->prealloc_y == nullptr || (ctx->prealloc_size_y > 0 && ctx->prealloc_y->size < ctx->prealloc_size_y)) { + VK_LOG_MEMORY("ggml_vk_preallocate_buffers(y_size: " << ctx->prealloc_size_y << ")"); + // Resize buffer + if (ctx->prealloc_y != nullptr) { + ggml_vk_destroy_buffer(ctx->prealloc_y); + } + ctx->prealloc_y = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_y); + } + if (ctx->prealloc_split_k == nullptr || (ctx->prealloc_size_split_k > 0 && ctx->prealloc_split_k->size < ctx->prealloc_size_split_k)) { + VK_LOG_MEMORY("ggml_vk_preallocate_buffers(split_k_size: " << ctx->prealloc_size_split_k << ")"); + // Resize buffer + if (ctx->prealloc_split_k != nullptr) { + ggml_vk_destroy_buffer(ctx->prealloc_split_k); + } + ctx->prealloc_split_k = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_split_k); + } +} + +static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_tensor* tensor, int tensor_idx, bool use_fence); + +// Returns true if node has enqueued work into the queue, false otherwise +// If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution. +static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * node, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool submit){ + if (ggml_is_empty(node) || !node->buffer) { + return false; + } + + VK_LOG_DEBUG("ggml_vk_build_graph(" << node << ", " << ggml_op_name(node->op) << ")"); + ctx->semaphore_idx = 0; + + const ggml_tensor * src0 = node->src[0]; + const ggml_tensor * src1 = node->src[1]; + const ggml_tensor * src2 = node->src[2]; + const ggml_tensor * src3 = node->src[3]; + + switch (node->op) { + // Return on empty ops to avoid generating a compute_ctx and setting exit_tensor + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_PERMUTE: + case GGML_OP_TRANSPOSE: + case GGML_OP_NONE: + return false; + case GGML_OP_UNARY: + switch (ggml_get_unary_op(node)) { + case GGML_UNARY_OP_SILU: + case GGML_UNARY_OP_GELU: + case GGML_UNARY_OP_GELU_QUICK: + case GGML_UNARY_OP_RELU: + case GGML_UNARY_OP_TANH: + break; + default: + return false; + } + break; + case GGML_OP_REPEAT: + case GGML_OP_GET_ROWS: + case GGML_OP_ADD: + case GGML_OP_ACC: + case GGML_OP_MUL: + case GGML_OP_DIV: + case GGML_OP_CONCAT: + case GGML_OP_UPSCALE: + case GGML_OP_SCALE: + case GGML_OP_SQR: + case GGML_OP_SIN: + case GGML_OP_COS: + case GGML_OP_CLAMP: + case GGML_OP_PAD: + case GGML_OP_CPY: + case GGML_OP_CONT: + case GGML_OP_DUP: + case GGML_OP_NORM: + case GGML_OP_GROUP_NORM: + case GGML_OP_RMS_NORM: + case GGML_OP_DIAG_MASK_INF: + case GGML_OP_SOFT_MAX: + case GGML_OP_ROPE: + case GGML_OP_MUL_MAT: + case GGML_OP_MUL_MAT_ID: + case GGML_OP_ARGSORT: + case GGML_OP_SUM_ROWS: + case GGML_OP_IM2COL: + case GGML_OP_TIMESTEP_EMBEDDING: + case GGML_OP_POOL_2D: + case GGML_OP_RWKV_WKV6: + case GGML_OP_LEAKY_RELU: + case GGML_OP_FLASH_ATTN_EXT: + break; + default: + std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl; + GGML_ABORT("fatal error"); + return false; + } + + vk_context compute_ctx; + + if (!dryrun) { + if (ctx->compute_ctx.expired()) { + compute_ctx = ggml_vk_create_context(ctx, ctx->device->compute_queue); + ctx->compute_ctx = compute_ctx; + ggml_vk_ctx_begin(ctx->device, compute_ctx); + } else { + compute_ctx = ctx->compute_ctx.lock(); + } + } else { + switch (node->op) { + case GGML_OP_REPEAT: + case GGML_OP_ACC: + case GGML_OP_GET_ROWS: + case GGML_OP_ADD: + case GGML_OP_MUL: + case GGML_OP_DIV: + case GGML_OP_CONCAT: + case GGML_OP_UPSCALE: + case GGML_OP_SCALE: + case GGML_OP_SQR: + case GGML_OP_SIN: + case GGML_OP_COS: + case GGML_OP_CLAMP: + case GGML_OP_PAD: + case GGML_OP_CPY: + case GGML_OP_CONT: + case GGML_OP_DUP: + case GGML_OP_NORM: + case GGML_OP_GROUP_NORM: + case GGML_OP_RMS_NORM: + case GGML_OP_UNARY: + case GGML_OP_DIAG_MASK_INF: + case GGML_OP_SOFT_MAX: + case GGML_OP_ROPE: + case GGML_OP_ARGSORT: + case GGML_OP_SUM_ROWS: + case GGML_OP_IM2COL: + case GGML_OP_TIMESTEP_EMBEDDING: + case GGML_OP_POOL_2D: + case GGML_OP_LEAKY_RELU: + { + // These operations all go through ggml_vk_op_f32, so short-circuit and + // do the only thing needed for the dryrun. + vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, node, node->op); + ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); + return false; + } + default: + break; + } + } + + switch (node->op) { + case GGML_OP_REPEAT: + ggml_vk_repeat(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_ACC: + ggml_vk_acc(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_GET_ROWS: + ggml_vk_get_rows(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_ADD: + ggml_vk_add(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_MUL: + ggml_vk_mul(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_DIV: + ggml_vk_div(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_CONCAT: + ggml_vk_concat(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_UPSCALE: + ggml_vk_upscale(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_SCALE: + ggml_vk_scale(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_SQR: + ggml_vk_sqr(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_SIN: + ggml_vk_sin(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_COS: + ggml_vk_cos(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_CLAMP: + ggml_vk_clamp(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_PAD: + ggml_vk_pad(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_CPY: + case GGML_OP_CONT: + case GGML_OP_DUP: + ggml_vk_cpy(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_NORM: + ggml_vk_norm(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_GROUP_NORM: + ggml_vk_group_norm(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_RMS_NORM: + ggml_vk_rms_norm(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_UNARY: + switch (ggml_get_unary_op(node)) { + case GGML_UNARY_OP_SILU: + case GGML_UNARY_OP_GELU: + case GGML_UNARY_OP_GELU_QUICK: + case GGML_UNARY_OP_RELU: + case GGML_UNARY_OP_TANH: + ggml_vk_unary(ctx, compute_ctx, src0, node, dryrun); + break; + default: + return false; + } + break; + case GGML_OP_DIAG_MASK_INF: + ggml_vk_diag_mask_inf(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_SOFT_MAX: + ggml_vk_soft_max(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_ROPE: + ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, dryrun); + + break; + case GGML_OP_ARGSORT: + ggml_vk_argsort(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_SUM_ROWS: + ggml_vk_sum_rows(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_IM2COL: + ggml_vk_im2col(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_TIMESTEP_EMBEDDING: + ggml_vk_timestep_embedding(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_POOL_2D: + ggml_vk_pool_2d(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_LEAKY_RELU: + ggml_vk_leaky_relu(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_MUL_MAT: + ggml_vk_mul_mat(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_MUL_MAT_ID: + ggml_vk_mul_mat_id(ctx, compute_ctx, src0, src1, src2, node, dryrun); + + break; + + case GGML_OP_FLASH_ATTN_EXT: + ggml_vk_flash_attn(ctx, compute_ctx, src0, src1, src2, src3, node, dryrun); + + break; + + case GGML_OP_RWKV_WKV6: + ggml_vk_rwkv_wkv6(ctx, compute_ctx, node, dryrun); + + break; + default: + return false; + } + + if (dryrun) { + return false; + } + + ctx->tensor_ctxs[node_idx] = compute_ctx; + +#if defined(GGML_VULKAN_CHECK_RESULTS) || defined(GGML_VULKAN_PERF) + // Force context reset on each node so that each tensor ends up in its own context + // and can be run and compared to its CPU equivalent separately + last_node = true; +#endif + + if (submit || last_node) { + ggml_vk_ctx_end(compute_ctx); + + // TODO probably it'd be better to pass a exit_node flag to ggml_vk_compute_forward + if (last_node) { + compute_ctx->exit_tensor_idx = node_idx_begin; + } + else { + compute_ctx->exit_tensor_idx = -1; + } + + ctx->compute_ctx.reset(); + + bool ok = ggml_vk_compute_forward(ctx, node_begin, node_idx_begin, false); + if (!ok) { + if (node->op == GGML_OP_UNARY) { + std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast(node->op_params[0])) << ")" << std::endl; + } + else { + std::cerr << __func__ << ": error: op not supported " << node->name << " (" << ggml_op_name(node->op) << ")" << std::endl; + } + } + + } + return true; +} + +static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * tensor, int tensor_idx, bool use_fence = true){ + ggml_backend_buffer * buf = nullptr; + + switch (tensor->op) { + case GGML_OP_ADD: + case GGML_OP_ACC: + case GGML_OP_GET_ROWS: + case GGML_OP_MUL: + case GGML_OP_DIV: + case GGML_OP_CONCAT: + case GGML_OP_UPSCALE: + case GGML_OP_SCALE: + case GGML_OP_SQR: + case GGML_OP_SIN: + case GGML_OP_COS: + case GGML_OP_CLAMP: + case GGML_OP_PAD: + case GGML_OP_CPY: + case GGML_OP_CONT: + case GGML_OP_DUP: + case GGML_OP_NORM: + case GGML_OP_GROUP_NORM: + case GGML_OP_RMS_NORM: + case GGML_OP_DIAG_MASK_INF: + case GGML_OP_SOFT_MAX: + case GGML_OP_ROPE: + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_PERMUTE: + case GGML_OP_TRANSPOSE: + case GGML_OP_NONE: + case GGML_OP_ARGSORT: + case GGML_OP_SUM_ROWS: + case GGML_OP_IM2COL: + case GGML_OP_TIMESTEP_EMBEDDING: + case GGML_OP_POOL_2D: + case GGML_OP_RWKV_WKV6: + case GGML_OP_LEAKY_RELU: + case GGML_OP_REPEAT: + buf = tensor->buffer; + + break; + case GGML_OP_UNARY: + switch (ggml_get_unary_op(tensor)) { + case GGML_UNARY_OP_SILU: + case GGML_UNARY_OP_GELU: + case GGML_UNARY_OP_GELU_QUICK: + case GGML_UNARY_OP_RELU: + case GGML_UNARY_OP_TANH: + buf = tensor->buffer; + break; + default: + return false; + } + break; + case GGML_OP_MUL_MAT: + case GGML_OP_MUL_MAT_ID: + case GGML_OP_FLASH_ATTN_EXT: + buf = tensor->buffer; + + break; + default: + return false; + } + + if (buf == nullptr) { + return false; + } + + VK_LOG_DEBUG("ggml_vk_compute_forward(" << tensor << ", name=" << tensor->name << ", op=" << ggml_op_name(tensor->op) << ", type=" << tensor->type << ", ne0=" << tensor->ne[0] << ", ne1=" << tensor->ne[1] << ", ne2=" << tensor->ne[2] << ", ne3=" << tensor->ne[3] << ", nb0=" << tensor->nb[0] << ", nb1=" << tensor->nb[1] << ", nb2=" << tensor->nb[2] << ", nb3=" << tensor->nb[3] << ", view_src=" << tensor->view_src << ", view_offs=" << tensor->view_offs << ")"); + + vk_context subctx = ctx->tensor_ctxs[tensor_idx].lock(); + + // always wait for the GPU work to be done for the last submit + if (tensor_idx == subctx->exit_tensor_idx) { + use_fence = true; + } + + // Only run if ctx hasn't been submitted yet + if (!subctx->seqs.empty()) { +#ifdef GGML_VULKAN_CHECK_RESULTS + ggml_vk_check_results_0(tensor); + use_fence = true; +#endif + + // Do staging buffer copies + for (auto& cpy : subctx->in_memcpys) { + memcpy(cpy.dst, cpy.src, cpy.n); + } + + ggml_vk_submit(subctx, use_fence ? ctx->fence : vk::Fence{}); + + if (use_fence) { + VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_compute_forward waitForFences"); + + ctx->device->device.resetFences({ ctx->fence }); + } +#ifdef GGML_VULKAN_CHECK_RESULTS + ggml_vk_check_results_1(tensor); +#endif + } + + if (tensor_idx == subctx->exit_tensor_idx) { + // Do staging buffer copies + for (auto& cpy : subctx->out_memcpys) { + memcpy(cpy.dst, cpy.src, cpy.n); + } + subctx->in_memcpys.clear(); + subctx->out_memcpys.clear(); + } + + return true; +} + +// Clean up after graph processing is done +static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) { + VK_LOG_DEBUG("ggml_vk_graph_cleanup()"); + for (auto& buffer : ctx->gc.temp_buffers) { + ggml_vk_pool_free(ctx, buffer); + } + ctx->gc.temp_buffers.clear(); + + for (auto& dsr : ctx->device->pipeline_descriptor_set_requirements) { + vk_pipeline_ref plr = ctx->device->pipelines[dsr.first]; + + if (plr.expired()) { + continue; + } + + vk_pipeline pl = plr.lock(); + ggml_pipeline_cleanup(pl); + } + + ggml_vk_queue_cleanup(ctx->device, ctx->device->compute_queue); + ggml_vk_queue_cleanup(ctx->device, ctx->device->transfer_queue); + + for (size_t i = 0; i < ctx->gc.semaphores.size(); i++) { + ctx->device->device.destroySemaphore({ ctx->gc.semaphores[i].s }); + } + ctx->gc.semaphores.clear(); + + for (size_t i = 0; i < ctx->gc.tl_semaphores.size(); i++) { + ctx->device->device.destroySemaphore({ ctx->gc.tl_semaphores[i].s }); + } + ctx->gc.tl_semaphores.clear(); + ctx->semaphore_idx = 0; + + ctx->event_idx = 0; + + for (auto& event : ctx->gc.events) { + ctx->device->device.resetEvent(event); + } + + ctx->tensor_ctxs.clear(); + ctx->gc.contexts.clear(); + ctx->device->pipeline_descriptor_set_requirements.clear(); +} + +// Clean up on backend free +static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) { + VK_LOG_DEBUG("ggml_vk_cleanup(" << ctx->name << ")"); + ggml_vk_graph_cleanup(ctx); + + ggml_vk_destroy_buffer(ctx->prealloc_x); + ggml_vk_destroy_buffer(ctx->prealloc_y); + ggml_vk_destroy_buffer(ctx->prealloc_split_k); + + for (auto& buffer : ctx->buffer_pool) { + ggml_vk_destroy_buffer(buffer); + } + + ctx->prealloc_size_x = 0; + ctx->prealloc_size_y = 0; + ctx->prealloc_size_split_k = 0; + + for (auto& event : ctx->gc.events) { + ctx->device->device.destroyEvent(event); + } + ctx->gc.events.clear(); + + ctx->device->device.destroyFence(ctx->fence); +} + +static int ggml_vk_get_device_count() { + ggml_vk_instance_init(); + + return vk_instance.device_indices.size(); +} + +static void ggml_vk_get_device_description(int device, char * description, size_t description_size) { + ggml_vk_instance_init(); + + std::vector devices = vk_instance.instance.enumeratePhysicalDevices(); + + vk::PhysicalDeviceProperties props; + devices[device].getProperties(&props); + + snprintf(description, description_size, "%s", props.deviceName.data()); +} + +// backend interface + +#define UNUSED GGML_UNUSED + +// device backend + +static bool ggml_backend_buffer_is_vk(ggml_backend_buffer_t buffer) { + return buffer->buft->iface.get_name == ggml_backend_vk_buffer_type_name; +} + +static void ggml_backend_vk_buffer_free_buffer(ggml_backend_buffer_t buffer) { + VK_LOG_MEMORY("ggml_backend_vk_buffer_free_buffer()"); + ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context; + ggml_vk_destroy_buffer(ctx->dev_buffer); + delete ctx; +} + +static void * ggml_backend_vk_buffer_get_base(ggml_backend_buffer_t buffer) { + return vk_ptr_base; + + UNUSED(buffer); +} + +static void ggml_backend_vk_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { + VK_LOG_DEBUG("ggml_backend_vk_buffer_init_tensor(" << buffer << " (" << buffer->context << "), " << tensor << ")"); + if (tensor->view_src != nullptr) { + GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft); + } +} + +static void ggml_backend_vk_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + VK_LOG_DEBUG("ggml_backend_vk_buffer_set_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")"); + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context; + vk_buffer buf = buf_ctx->dev_buffer; + + ggml_vk_buffer_write(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size); +} + +static void ggml_backend_vk_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { + VK_LOG_DEBUG("ggml_backend_vk_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")"); + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context; + + vk_buffer buf = buf_ctx->dev_buffer; + + ggml_vk_buffer_read(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size); +} + +static bool ggml_backend_vk_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) { + if (ggml_backend_buffer_is_vk(src->buffer)) { + ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context; + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + + vk_buffer src_buf = src_buf_ctx->dev_buffer; + vk_buffer dst_buf = dst_buf_ctx->dev_buffer; + + ggml_vk_buffer_copy(dst_buf, vk_tensor_offset(dst) + dst->view_offs, src_buf, vk_tensor_offset(src) + src->view_offs, ggml_nbytes(src)); + + return true; + } + return false; + + UNUSED(buffer); +} + +static void ggml_backend_vk_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context; + + ggml_vk_buffer_memset(ctx->dev_buffer, 0, value, buffer->size); +} + +static ggml_backend_buffer_i ggml_backend_vk_buffer_interface = { + /* .free_buffer = */ ggml_backend_vk_buffer_free_buffer, + /* .get_base = */ ggml_backend_vk_buffer_get_base, + /* .init_tensor = */ ggml_backend_vk_buffer_init_tensor, + /* .memset_tensor = */ NULL, + /* .set_tensor = */ ggml_backend_vk_buffer_set_tensor, + /* .get_tensor = */ ggml_backend_vk_buffer_get_tensor, + /* .cpy_tensor = */ ggml_backend_vk_buffer_cpy_tensor, + /* .clear = */ ggml_backend_vk_buffer_clear, + /* .reset = */ NULL, +}; + +// vk buffer type +static const char * ggml_backend_vk_buffer_type_name(ggml_backend_buffer_type_t buft) { + ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *)buft->context; + + return ctx->name.c_str(); +} + +static ggml_backend_buffer_t ggml_backend_vk_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + VK_LOG_MEMORY("ggml_backend_vk_buffer_type_alloc_buffer(" << size << ")"); + ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context; + + vk_buffer dev_buffer = nullptr; + try { + dev_buffer = ggml_vk_create_buffer_device(ctx->device, size); + } catch (const vk::SystemError& e) { + return nullptr; + } + + ggml_backend_vk_buffer_context * bufctx = new ggml_backend_vk_buffer_context(ctx->device, std::move(dev_buffer), ctx->name); + + return ggml_backend_buffer_init(buft, ggml_backend_vk_buffer_interface, bufctx, size); +} + +static size_t ggml_backend_vk_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { + ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context; + return ctx->device->properties.limits.minStorageBufferOffsetAlignment; +} + +static size_t ggml_backend_vk_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) { + ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context; + return ctx->device->max_memory_allocation_size; +} + +static size_t ggml_backend_vk_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { + return ggml_nbytes(tensor); + + UNUSED(buft); +} + +ggml_backend_buffer_type_t ggml_backend_vk_buffer_type(size_t dev_num) { + ggml_vk_instance_init(); + + VK_LOG_DEBUG("ggml_backend_vk_buffer_type(" << dev_num << ")"); + + vk_device dev = ggml_vk_get_device(dev_num); + + return &dev->buffer_type; +} + +// host buffer type + +static const char * ggml_backend_vk_host_buffer_type_name(ggml_backend_buffer_type_t buft) { + return GGML_VK_NAME "_Host"; + + UNUSED(buft); +} + +static const char * ggml_backend_vk_host_buffer_name(ggml_backend_buffer_t buffer) { + return GGML_VK_NAME "_Host"; + + UNUSED(buffer); +} + +static void ggml_backend_vk_host_buffer_free_buffer(ggml_backend_buffer_t buffer) { + VK_LOG_MEMORY("ggml_backend_vk_host_buffer_free_buffer()"); + ggml_vk_host_free(vk_instance.devices[0], buffer->context); +} + +static ggml_backend_buffer_t ggml_backend_vk_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + VK_LOG_MEMORY("ggml_backend_vk_host_buffer_type_alloc_buffer(" << size << ")"); + + size += 32; // Behave like the CPU buffer type + void * ptr = nullptr; + try { + ptr = ggml_vk_host_malloc(vk_instance.devices[0], size); + } catch (vk::SystemError& e) { + std::cerr << "ggml_vulkan: Failed to allocate pinned memory." << std::endl; + std::cerr << "ggml_vulkan: " << e.what() << std::endl; + // fallback to cpu buffer + return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size); + } + + ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size); + buffer->buft = buft; + buffer->iface.free_buffer = ggml_backend_vk_host_buffer_free_buffer; + + return buffer; + + UNUSED(buft); +} + +static size_t ggml_backend_vk_host_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { + return vk_instance.devices[0]->properties.limits.minMemoryMapAlignment; + + UNUSED(buft); +} + +// Should be changed to return device-specific host buffer type +// but that probably requires changes in llama.cpp +ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type() { + static struct ggml_backend_buffer_type ggml_backend_vk_buffer_type_host = { + /* .iface = */ { + /* .get_name = */ ggml_backend_vk_host_buffer_type_name, + /* .alloc_buffer = */ ggml_backend_vk_host_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_vk_host_buffer_type_get_alignment, + /* .get_max_size = */ NULL, // defaults to SIZE_MAX + /* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size, + /* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host, + }, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), 0), + /* .context = */ nullptr, + }; + + // Make sure device 0 is initialized + ggml_vk_instance_init(); + ggml_vk_get_device(0); + + return &ggml_backend_vk_buffer_type_host; +} + + +// backend + +static const char * ggml_backend_vk_name(ggml_backend_t backend) { + ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; + + return ctx->name.c_str(); +} + +static void ggml_backend_vk_free(ggml_backend_t backend) { + ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; + VK_LOG_DEBUG("ggml_backend_vk_free(" << ctx->name << ")"); + + ggml_vk_cleanup(ctx); + + delete ctx; + delete backend; +} + +static ggml_backend_buffer_type_t ggml_backend_vk_get_default_buffer_type(ggml_backend_t backend) { + ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; + + return &ctx->device->buffer_type; +} + +static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + VK_LOG_DEBUG("ggml_backend_vk_set_tensor_async(" << size << ")"); + ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; + GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type"); + + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context; + + vk_context transfer_ctx; + + if (ctx->transfer_ctx.expired()) { + // Initialize new transfer context + transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue); + ctx->transfer_ctx = transfer_ctx; + ggml_vk_ctx_begin(ctx->device, transfer_ctx); + } else { + transfer_ctx = ctx->transfer_ctx.lock(); + } + + vk_buffer buf = buf_ctx->dev_buffer; + + ggml_vk_buffer_write_async(transfer_ctx, buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size); +} + +static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { + VK_LOG_DEBUG("ggml_backend_vk_get_tensor_async(" << size << ")"); + ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; + GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type"); + + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context; + + vk_context transfer_ctx; + + if (ctx->transfer_ctx.expired()) { + // Initialize new transfer context + transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue); + ctx->transfer_ctx = transfer_ctx; + ggml_vk_ctx_begin(ctx->device, transfer_ctx); + } else { + transfer_ctx = ctx->transfer_ctx.lock(); + } + + vk_buffer buf = buf_ctx->dev_buffer; + + ggml_vk_buffer_read_async(transfer_ctx, buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size); +} + +static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend, const ggml_tensor * src, ggml_tensor * dst) { + VK_LOG_DEBUG("ggml_backend_vk_cpy_tensor_async()"); + ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; + if ((dst->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || dst->buffer->buft == ggml_backend_vk_host_buffer_type()) && ggml_backend_buffer_is_vk(src->buffer)) { + ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context; + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + + vk_context transfer_ctx; + + if (ctx->transfer_ctx.expired()) { + // Initialize new transfer context + transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue); + ctx->transfer_ctx = transfer_ctx; + ggml_vk_ctx_begin(ctx->device, transfer_ctx); + } else { + transfer_ctx = ctx->transfer_ctx.lock(); + } + + vk_buffer src_buf = src_buf_ctx->dev_buffer; + vk_buffer dst_buf = dst_buf_ctx->dev_buffer; + + ggml_vk_buffer_copy_async(transfer_ctx, dst_buf, vk_tensor_offset(dst) + dst->view_offs, src_buf, vk_tensor_offset(src) + src->view_offs, ggml_nbytes(src)); + return true; + } + + return false; +} + +static void ggml_backend_vk_synchronize(ggml_backend_t backend) { + VK_LOG_DEBUG("ggml_backend_vk_synchronize()"); + ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; + if(ctx->transfer_ctx.expired()) { + return; + } + + vk_context transfer_ctx = ctx->transfer_ctx.lock(); + + ggml_vk_ctx_end(transfer_ctx); + + for (auto& cpy : transfer_ctx->in_memcpys) { + memcpy(cpy.dst, cpy.src, cpy.n); + } + + ggml_vk_submit(transfer_ctx, ctx->fence); + VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_backend_vk_synchronize waitForFences"); + ctx->device->device.resetFences({ ctx->fence }); + + for (auto& cpy : transfer_ctx->out_memcpys) { + memcpy(cpy.dst, cpy.src, cpy.n); + } + + ctx->transfer_ctx.reset(); +} + +static bool ggml_vk_is_empty(ggml_tensor * node) { + return ggml_is_empty(node) || node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE; +} + +static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { + VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)"); + ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; + + for (int i = 0; i < cgraph->n_nodes; i++) { + ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false); + } + ggml_vk_preallocate_buffers(ctx); + ggml_pipeline_allocate_descriptor_sets(ctx->device); + + int last_node = cgraph->n_nodes - 1; + + // If the last op in the cgraph isn't backend GPU, the command buffer doesn't get closed properly + while (last_node > 0 && ggml_vk_is_empty(cgraph->nodes[last_node])) { + last_node -= 1; + } + + // Reserve tensor context space for all nodes + ctx->tensor_ctxs.resize(cgraph->n_nodes); + + bool first_node_in_batch = true; // true if next node will be first node in a batch + int submit_node_idx = 0; // index to first node in a batch + + // Submit work every nodes_per_submit nodes to overlap CPU cmdbuffer generation with GPU execution. + // Start with a smaller count to get work submitted right away, and increase it after each submit. + int nodes_per_submit = 20; + int submitted_nodes = 0; + int submit_count = 0; + for (int i = 0; i < cgraph->n_nodes; i++) { + if (first_node_in_batch) { + submit_node_idx = i; + } + + bool submit = (submitted_nodes >= nodes_per_submit) || (i == last_node); + + bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, submit); + + if (enqueued) { + ++submitted_nodes; + +#ifndef GGML_VULKAN_CHECK_RESULTS + if (first_node_in_batch) { + first_node_in_batch = false; + } +#endif + } + + if (submit) { + first_node_in_batch = true; + submitted_nodes = 0; + switch (submit_count) { + case 0: + nodes_per_submit = 50; + break; + default: + nodes_per_submit = 100; + break; + } + submit_count++; + } + } + +#ifdef GGML_VULKAN_PERF + ctx->device->perf_logger->print_timings(); +#endif + + ggml_vk_graph_cleanup(ctx); + + return GGML_STATUS_SUCCESS; + + UNUSED(backend); +} + +// TODO: enable async and synchronize +static ggml_backend_i ggml_backend_vk_interface = { + /* .get_name = */ ggml_backend_vk_name, + /* .free = */ ggml_backend_vk_free, + /* .set_tensor_async = */ NULL, // ggml_backend_vk_set_tensor_async, + /* .get_tensor_async = */ NULL, // ggml_backend_vk_get_tensor_async, + /* .cpy_tensor_async = */ NULL, // ggml_backend_vk_cpy_tensor_async, + /* .synchronize = */ NULL, // ggml_backend_vk_synchronize, + /* .graph_plan_create = */ NULL, + /* .graph_plan_free = */ NULL, + /* .graph_plan_update = */ NULL, + /* .graph_plan_compute = */ NULL, + /* .graph_compute = */ ggml_backend_vk_graph_compute, + /* .event_record = */ NULL, + /* .event_wait = */ NULL, +}; + +static ggml_guid_t ggml_backend_vk_guid() { + static ggml_guid guid = { 0xb8, 0xf7, 0x4f, 0x86, 0x40, 0x3c, 0xe1, 0x02, 0x91, 0xc8, 0xdd, 0xe9, 0x02, 0x3f, 0xc0, 0x2b }; + return &guid; +} + +ggml_backend_t ggml_backend_vk_init(size_t dev_num) { + VK_LOG_DEBUG("ggml_backend_vk_init(" << dev_num << ")"); + + ggml_backend_vk_context * ctx = new ggml_backend_vk_context; + ggml_vk_init(ctx, dev_num); + + ggml_backend_t vk_backend = new ggml_backend { + /* .guid = */ ggml_backend_vk_guid(), + /* .interface = */ ggml_backend_vk_interface, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), dev_num), + /* .context = */ ctx, + }; + + return vk_backend; +} + +bool ggml_backend_is_vk(ggml_backend_t backend) { + return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_vk_guid()); +} + +int ggml_backend_vk_get_device_count() { + return ggml_vk_get_device_count(); +} + +void ggml_backend_vk_get_device_description(int device, char * description, size_t description_size) { + GGML_ASSERT(device < (int) vk_instance.device_indices.size()); + int dev_idx = vk_instance.device_indices[device]; + ggml_vk_get_device_description(dev_idx, description, description_size); +} + +void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total) { + GGML_ASSERT(device < (int) vk_instance.device_indices.size()); + + vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device]]; + + vk::PhysicalDeviceMemoryProperties memprops = vkdev.getMemoryProperties(); + + for (const vk::MemoryHeap& heap : memprops.memoryHeaps) { + if (heap.flags & vk::MemoryHeapFlagBits::eDeviceLocal) { + *total = heap.size; + *free = heap.size; + break; + } + } +} + +////////////////////////// + +struct ggml_backend_vk_device_context { + size_t device; + std::string name; + std::string description; +}; + +static const char * ggml_backend_vk_device_get_name(ggml_backend_dev_t dev) { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + return ctx->name.c_str(); +} + +static const char * ggml_backend_vk_device_get_description(ggml_backend_dev_t dev) { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + return ctx->description.c_str(); +} + +static void ggml_backend_vk_device_get_memory(ggml_backend_dev_t device, size_t * free, size_t * total) { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)device->context; + ggml_backend_vk_get_device_memory(ctx->device, free, total); +} + +static ggml_backend_buffer_type_t ggml_backend_vk_device_get_buffer_type(ggml_backend_dev_t dev) { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + return ggml_backend_vk_buffer_type(ctx->device); +} + +static ggml_backend_buffer_type_t ggml_backend_vk_device_get_host_buffer_type(ggml_backend_dev_t dev) { + UNUSED(dev); + return ggml_backend_vk_host_buffer_type(); +} + +static enum ggml_backend_dev_type ggml_backend_vk_device_get_type(ggml_backend_dev_t dev) { + UNUSED(dev); + return GGML_BACKEND_DEVICE_TYPE_GPU; +} + +static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) { + props->name = ggml_backend_vk_device_get_name(dev); + props->description = ggml_backend_vk_device_get_description(dev); + props->type = ggml_backend_vk_device_get_type(dev); + ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total); + props->caps = { + /* .async = */ false, + /* .host_buffer = */ true, + /* .buffer_from_host_ptr = */ false, + /* .events = */ false, + }; +} + +static ggml_backend_t ggml_backend_vk_device_init(ggml_backend_dev_t dev, const char * params) { + UNUSED(params); + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + return ggml_backend_vk_init(ctx->device); +} + +static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { + switch (op->op) { + case GGML_OP_UNARY: + switch (ggml_get_unary_op(op)) { + case GGML_UNARY_OP_GELU: + case GGML_UNARY_OP_GELU_QUICK: + case GGML_UNARY_OP_SILU: + case GGML_UNARY_OP_RELU: + case GGML_UNARY_OP_TANH: + return ggml_is_contiguous(op->src[0]); + default: + return false; + } + break; + case GGML_OP_MUL_MAT: + case GGML_OP_MUL_MAT_ID: + { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + const vk_device& device = ggml_vk_get_device(ctx->device); + if (op->op == GGML_OP_MUL_MAT_ID && !device->mul_mat_id_s && !device->mul_mat_id_m && !device->mul_mat_id_l) { + // If there's not enough shared memory for row_ids and the result tile, fallback to CPU + return false; + } + switch (op->src[0]->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ4_NL: + break; + default: + return false; + } + struct ggml_tensor * a; + struct ggml_tensor * b; + if (op->op == GGML_OP_MUL_MAT) { + a = op->src[0]; + b = op->src[1]; + } else { + a = op->src[2]; + b = op->src[1]; + } + if (a->ne[3] != b->ne[3]) { + return false; + } + if (!(ggml_vk_dim01_contiguous(op->src[0]) || op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) || + !(ggml_vk_dim01_contiguous(op->src[1]) || op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16)) { + return false; + } + + return true; + } break; + case GGML_OP_FLASH_ATTN_EXT: + { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + if (!ggml_vk_get_device(ctx->device)->coopmat2) { + return false; + } + switch (op->src[0]->ne[0]) { + case 64: + case 80: + case 96: + case 112: + case 128: + case 256: + break; + default: + return false; + } + if (op->src[0]->type != GGML_TYPE_F32) { + return false; + } + if (op->type != GGML_TYPE_F32) { + return false; + } + if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) { + return false; + } + // It's straightforward to support different K/V dequant, but would + // significantly increase the number of pipelines + if (op->src[1]->type != op->src[2]->type) { + return false; + } + switch (op->src[1]->type) { + case GGML_TYPE_F16: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + // K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently + //case GGML_TYPE_Q2_K: + //case GGML_TYPE_Q3_K: + //case GGML_TYPE_Q4_K: + //case GGML_TYPE_Q5_K: + //case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ4_NL: + break; + default: + return false; + } + return true; + } + case GGML_OP_GET_ROWS: + { + switch (op->src[0]->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_IQ4_NL: + return true; + default: + return false; + } + } break; + case GGML_OP_CONT: + case GGML_OP_CPY: + case GGML_OP_DUP: + { + ggml_type src0_type = op->src[0]->type; + ggml_type src1_type = op->src[1] != nullptr ? op->src[1]->type : src0_type; + if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) { + return true; + } + if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) { + return true; + } + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { + return true; + } + return false; + } break; + case GGML_OP_REPEAT: + return ggml_type_size(op->type) == sizeof(float) && ggml_type_size(op->src[0]->type) == sizeof(float); + case GGML_OP_ROPE: + { + const int mode = ((const int32_t *) op->op_params)[2]; + if (mode & GGML_ROPE_TYPE_MROPE) { + return false; + } + if (mode & GGML_ROPE_TYPE_VISION) { + return false; + } + return ggml_is_contiguous(op->src[0]); + } + case GGML_OP_NONE: + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_PERMUTE: + case GGML_OP_TRANSPOSE: + case GGML_OP_NORM: + case GGML_OP_GROUP_NORM: + case GGML_OP_RMS_NORM: + case GGML_OP_ADD: + case GGML_OP_ACC: + case GGML_OP_MUL: + case GGML_OP_DIV: + case GGML_OP_CONCAT: + case GGML_OP_UPSCALE: + case GGML_OP_SCALE: + case GGML_OP_SQR: + case GGML_OP_SIN: + case GGML_OP_COS: + case GGML_OP_CLAMP: + case GGML_OP_PAD: + case GGML_OP_DIAG_MASK_INF: + case GGML_OP_SOFT_MAX: + case GGML_OP_ARGSORT: + case GGML_OP_SUM_ROWS: + case GGML_OP_IM2COL: + case GGML_OP_TIMESTEP_EMBEDDING: + case GGML_OP_POOL_2D: + case GGML_OP_RWKV_WKV6: + case GGML_OP_LEAKY_RELU: + return true; + default: + return false; + } + + UNUSED(dev); +} + +static bool ggml_backend_vk_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { + if (buft->iface.get_name != ggml_backend_vk_buffer_type_name) { + return false; + } + + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + ggml_backend_vk_buffer_type_context * buft_ctx = (ggml_backend_vk_buffer_type_context *)buft->context; + + return buft_ctx->device->idx == ctx->device; +} + +static bool ggml_backend_vk_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) { + const int min_batch_size = 32; + + return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) || + (op->ne[2] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID); + + UNUSED(dev); +} + +static const struct ggml_backend_device_i ggml_backend_vk_device_i = { + /* .get_name = */ ggml_backend_vk_device_get_name, + /* .get_description = */ ggml_backend_vk_device_get_description, + /* .get_memory = */ ggml_backend_vk_device_get_memory, + /* .get_type = */ ggml_backend_vk_device_get_type, + /* .get_props = */ ggml_backend_vk_device_get_props, + /* .init_backend = */ ggml_backend_vk_device_init, + /* .get_buffer_type = */ ggml_backend_vk_device_get_buffer_type, + /* .get_host_buffer_type = */ ggml_backend_vk_device_get_host_buffer_type, + /* .buffer_from_host_ptr = */ NULL, + /* .supports_op = */ ggml_backend_vk_device_supports_op, + /* .supports_buft = */ ggml_backend_vk_device_supports_buft, + /* .offload_op = */ ggml_backend_vk_device_offload_op, + /* .event_new = */ NULL, + /* .event_free = */ NULL, + /* .event_synchronize = */ NULL, +}; + +static const char * ggml_backend_vk_reg_get_name(ggml_backend_reg_t reg) { + UNUSED(reg); + return GGML_VK_NAME; +} + +static size_t ggml_backend_vk_reg_get_device_count(ggml_backend_reg_t reg) { + UNUSED(reg); + return ggml_backend_vk_get_device_count(); +} + +static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, size_t device) { + static std::vector devices; + + static bool initialized = false; + + { + static std::mutex mutex; + std::lock_guard lock(mutex); + if (!initialized) { + for (int i = 0; i < ggml_backend_vk_get_device_count(); i++) { + ggml_backend_vk_device_context * ctx = new ggml_backend_vk_device_context; + char desc[256]; + ggml_backend_vk_get_device_description(i, desc, sizeof(desc)); + ctx->device = i; + ctx->name = GGML_VK_NAME + std::to_string(i); + ctx->description = desc; + devices.push_back(new ggml_backend_device { + /* .iface = */ ggml_backend_vk_device_i, + /* .reg = */ reg, + /* .context = */ ctx, + }); + } + initialized = true; + } + } + + GGML_ASSERT(device < devices.size()); + return devices[device]; +} + +static const struct ggml_backend_reg_i ggml_backend_vk_reg_i = { + /* .get_name = */ ggml_backend_vk_reg_get_name, + /* .get_device_count = */ ggml_backend_vk_reg_get_device_count, + /* .get_device = */ ggml_backend_vk_reg_get_device, + /* .get_proc_address = */ NULL, +}; + +ggml_backend_reg_t ggml_backend_vk_reg() { + static ggml_backend_reg reg = { + /* .api_version = */ GGML_BACKEND_API_VERSION, + /* .iface = */ ggml_backend_vk_reg_i, + /* .context = */ nullptr, + }; + + return ® +} + +// Extension availability +static bool ggml_vk_instance_validation_ext_available(const std::vector& instance_extensions) { +#ifdef GGML_VULKAN_VALIDATE + bool portability_enumeration_ext = false; + // Check for portability enumeration extension for MoltenVK support + for (const auto& properties : instance_extensions) { + if (strcmp("VK_KHR_portability_enumeration", properties.extensionName) == 0) { + return true; + } + } + if (!portability_enumeration_ext) { + std::cerr << "ggml_vulkan: WARNING: Instance extension VK_KHR_portability_enumeration not found." << std::endl; + } +#endif + return false; + + UNUSED(instance_extensions); +} +static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector& instance_extensions) { +#ifdef __APPLE__ + bool portability_enumeration_ext = false; + // Check for portability enumeration extension for MoltenVK support + for (const auto& properties : instance_extensions) { + if (strcmp("VK_KHR_portability_enumeration", properties.extensionName) == 0) { + return true; + } + } + if (!portability_enumeration_ext) { + std::cerr << "ggml_vulkan: WARNING: Instance extension VK_KHR_portability_enumeration not found." << std::endl; + } +#endif + return false; + + UNUSED(instance_extensions); +} + +static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props) { + switch (props.vendorID) { + case VK_VENDOR_ID_INTEL: + // Intel drivers don't support coopmat properly yet + return false; + case VK_VENDOR_ID_AMD: + if (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource) { + // Workaround for AMD proprietary driver reporting support on all GPUs + const std::string name = props.deviceName; + return name.rfind("AMD Radeon RX 7", 0) == 0 || name.rfind("AMD Radeon(TM) RX 7", 0) == 0 || // RDNA 3 consumer GPUs + name.rfind("AMD Radeon PRO W7", 0) == 0 || name.rfind("AMD Radeon(TM) PRO W7", 0) == 0 || // RDNA 3 workstation GPUs + name.rfind("AMD Radeon 7", 0) == 0 || name.rfind("AMD Radeon(TM) 7", 0) == 0; // RDNA 3 APUs + } + return true; + default: + return true; + } +} + +// checks + +#ifdef GGML_VULKAN_CHECK_RESULTS +static void ggml_vk_print_graph_origin(const ggml_tensor * tensor, std::vector& done, int level = 0) { + if (std::find(done.begin(), done.end(), tensor) != done.end() || level > 10) { + return; + } + for (int j = 0; j < level; j++) { + std::cerr << " "; + } + std::cerr << ggml_op_name(tensor->op) << " gpu=" << (tensor->extra != nullptr) << std::endl; + + done.push_back(tensor); + + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (tensor->src[i] != nullptr) { + ggml_vk_print_graph_origin(tensor->src[i], done, level + 1); + } + } +} + +static void ggml_vk_print_tensor_area(const ggml_tensor * tensor, const void * data, int i0, int i1, int i2, int i3) { + if (tensor->type != GGML_TYPE_F32 && tensor->type != GGML_TYPE_F16 && tensor->type != GGML_TYPE_I32) { + return; + } + i0 = std::max(i0, 5); + i1 = std::max(i1, 5); + i2 = std::max(i2, 0); + i3 = std::max(i3, 0); + fprintf(stderr, " "); + for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) { + fprintf(stderr, "%7d ", idx1); + } + fprintf(stderr, "\n"); + for (int idx0 = i0 - 5; idx0 < i0 + 5; idx0++) { + fprintf(stderr, "%7d: ", idx0); + for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) { + if (idx0 >= 0 && idx0 < tensor->ne[0] && idx1 >= 0 && idx1 < tensor->ne[1] && i2 >= 0 && i2 < tensor->ne[2] && i3 >= 0 && i3 < tensor->ne[3]) { + float val; + if (tensor->type == GGML_TYPE_F32) { + val = *(const float *) ((const char *) data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]); + } else if (tensor->type == GGML_TYPE_F16) { + val = ggml_fp16_to_fp32(*(const ggml_fp16_t *) ((const char *) data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0])); + } else if (tensor->type == GGML_TYPE_I32) { + val = *(const int32_t *) ((const char *) data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]); + } else { + GGML_ABORT("fatal error"); + } + fprintf(stderr, "% 7.2f ", val); + } else { + fprintf(stderr, " "); + } + } + fprintf(stderr, "\n"); + } +} + +static void ggml_vk_print_tensor(const ggml_tensor * tensor, const char * name) { + void * tensor_data = tensor->data; + + const bool is_gpu = tensor->buffer != nullptr && ggml_backend_buffer_is_vk(tensor->buffer); + + if (is_gpu) { + const size_t tensor_size = ggml_nbytes(tensor); + tensor_data = malloc(tensor_size); + + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context; + + vk_buffer buffer_gpu = buf_ctx->dev_buffer; + ggml_vk_buffer_read(buffer_gpu, vk_tensor_offset(tensor) + tensor->view_offs, tensor_data, tensor_size); + } + + std::cerr << "TENSOR CHECK " << name << " (" << tensor->name << "): " << ggml_op_name(tensor->op) << std::endl; + std::cerr << "tensor=" << tensor << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << std::endl; + if (tensor->src[0] != nullptr) { + std::cerr << "tensor->src[0]=" << tensor->src[0] << " name=" << tensor->src[0]->name << " op=" << ggml_op_name(tensor->src[0]->op) << " type=" << ggml_type_name(tensor->src[0]->type) << " ne0=" << tensor->src[0]->ne[0] << " nb0=" << tensor->src[0]->nb[0] << " ne1=" << tensor->src[0]->ne[1] << " nb1=" << tensor->src[0]->nb[1] << " ne2=" << tensor->src[0]->ne[2] << " nb2=" << tensor->src[0]->nb[2] << " ne3=" << tensor->src[0]->ne[3] << " nb3=" << tensor->src[0]->nb[3] << std::endl; + } + if (tensor->src[1] != nullptr) { + std::cerr << "tensor->src[1]=" << tensor->src[1] << " name=" << tensor->src[1]->name << " op=" << ggml_op_name(tensor->src[1]->op) << " type=" << ggml_type_name(tensor->src[1]->type) << " ne0=" << tensor->src[1]->ne[0] << " nb0=" << tensor->src[1]->nb[0] << " ne1=" << tensor->src[1]->ne[1] << " nb1=" << tensor->src[1]->nb[1] << " ne2=" << tensor->src[1]->ne[2] << " nb2=" << tensor->src[1]->nb[2] << " ne3=" << tensor->src[1]->ne[3] << " nb3=" << tensor->src[1]->nb[3] << std::endl; + } + std::cerr << std::endl << "Result:" << std::endl; + ggml_vk_print_tensor_area(tensor, tensor_data, 5, 5, 0, 0); + std::cerr << std::endl; + std::vector done; + ggml_vk_print_graph_origin(tensor, done); + + if (is_gpu) { + free(tensor_data); + } +} + +void * comp_result; +size_t comp_size; +size_t comp_nb[GGML_MAX_DIMS]; +size_t check_counter = 0; +static void ggml_vk_check_results_0(ggml_tensor * tensor) { + if (tensor->op == GGML_OP_TRANSPOSE) { + return; + } + + check_counter++; + if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) { + return; + } + + VK_LOG_DEBUG("ggml_vk_check_results_0(" << tensor->name << ")"); + + ggml_tensor * src0 = tensor->src[0]; + ggml_tensor * src1 = tensor->src[1]; + ggml_tensor * src2 = tensor->src[2]; + ggml_tensor * src3 = tensor->src[3]; + + struct ggml_init_params iparams = { + /*.mem_size =*/ 2ul*1024ul*1024ul*1024ul, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ false, + }; + + struct ggml_context * ggml_ctx = ggml_init(iparams); + + struct ggml_tensor * src0_clone = nullptr; + struct ggml_tensor * src1_clone = nullptr; + struct ggml_tensor * src2_clone = nullptr; + struct ggml_tensor * src3_clone = nullptr; + struct ggml_tensor * tensor_clone = nullptr; + + size_t src0_size; + size_t src1_size; + size_t src2_size; + size_t src3_size; + + void * src0_buffer = nullptr; + void * src1_buffer = nullptr; + void * src2_buffer = nullptr; + void * src3_buffer = nullptr; + + if (src0 != nullptr) { + src0_clone = ggml_dup_tensor(ggml_ctx, src0); + + src0_size = ggml_nbytes(src0); + + src0_buffer = malloc(src0_size); + src0_clone->data = src0_buffer; + if (ggml_backend_buffer_is_host(src0->buffer)) { + memcpy(src0_clone->data, src0->data, src0_size); + memcpy(src0_clone->nb, src0->nb, sizeof(size_t) * GGML_MAX_DIMS); + } else if (ggml_backend_buffer_is_vk(src0->buffer)) { + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; + vk_buffer& buffer_gpu = buf_ctx->dev_buffer; + uint64_t offset = vk_tensor_offset(src0) + src0->view_offs; + if (!ggml_is_contiguous(src0) && ggml_vk_dim01_contiguous(src0)) { + for (int i3 = 0; i3 < src0->ne[3]; i3++) { + for (int i2 = 0; i2 < src0->ne[2]; i2++) { + const int idx = i3*src0->ne[2] + i2; + ggml_vk_buffer_read(buffer_gpu, offset + idx * src0->nb[2], ((char *)src0_clone->data + idx * src0_clone->nb[2]), src0->ne[1] * src0->nb[1]); + } + } + + src0_clone->nb[0] = src0->nb[0]; + src0_clone->nb[1] = src0->nb[1]; + for (int i = 2; i < GGML_MAX_DIMS; i++) { + src0_clone->nb[i] = src0_clone->nb[i - 1]*src0_clone->ne[i - 1]; + } + } else { + if (offset + src0_size >= buffer_gpu->size) { + src0_size = buffer_gpu->size - offset; + } + ggml_vk_buffer_read(buffer_gpu, offset, src0_clone->data, src0_size); + memcpy(src0_clone->nb, src0->nb, sizeof(size_t) * GGML_MAX_DIMS); + } + } else { + GGML_ABORT("fatal error"); + } + + if (vk_output_tensor > 0 && vk_output_tensor == check_counter) { + ggml_vk_print_tensor(src0, "src0"); + } + } + if (src1 != nullptr) { + src1_clone = ggml_dup_tensor(ggml_ctx, src1); + + src1_size = ggml_nbytes(src1); + + src1_buffer = malloc(src1_size); + src1_clone->data = src1_buffer; + if (ggml_backend_buffer_is_host(src1->buffer)) { + memcpy(src1_clone->data, src1->data, src1_size); + memcpy(src1_clone->nb, src1->nb, sizeof(size_t) * GGML_MAX_DIMS); + } else if (ggml_backend_buffer_is_vk(src1->buffer)) { + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; + vk_buffer& buffer_gpu = buf_ctx->dev_buffer; + uint64_t offset = vk_tensor_offset(src1) + src1->view_offs; + if (!ggml_is_contiguous(src1) && ggml_vk_dim01_contiguous(src1)) { + for (int i3 = 0; i3 < src1->ne[3]; i3++) { + for (int i2 = 0; i2 < src1->ne[2]; i2++) { + const int idx = i3*src1->ne[2] + i2; + ggml_vk_buffer_read(buffer_gpu, offset + idx * src1->nb[2], ((char *)src1_clone->data + idx * src1_clone->nb[2]), src1->ne[1] * src1->nb[1]); + } + } + + src1_clone->nb[0] = src1->nb[0]; + src1_clone->nb[1] = src1->nb[1]; + for (int i = 2; i < GGML_MAX_DIMS; i++) { + src1_clone->nb[i] = src1_clone->nb[i - 1]*src1_clone->ne[i - 1]; + } + } else { + if (offset + src1_size >= buffer_gpu->size) { + src1_size = buffer_gpu->size - offset; + } + ggml_vk_buffer_read(buffer_gpu, offset, src1_clone->data, src1_size); + memcpy(src1_clone->nb, src1->nb, sizeof(size_t) * GGML_MAX_DIMS); + } + } else { + GGML_ABORT("fatal error"); + } + + if (vk_output_tensor > 0 && vk_output_tensor == check_counter) { + ggml_vk_print_tensor(src1, "src1"); + } + } + if (src2 != nullptr) { + src2_clone = ggml_dup_tensor(ggml_ctx, src2); + + src2_size = ggml_nbytes(src2); + + src2_buffer = malloc(src2_size); + src2_clone->data = src2_buffer; + if (ggml_backend_buffer_is_host(src2->buffer)) { + memcpy(src2_clone->data, src2->data, src2_size); + memcpy(src2_clone->nb, src2->nb, sizeof(size_t) * GGML_MAX_DIMS); + } else if (ggml_backend_buffer_is_vk(src2->buffer)) { + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src2->buffer->context; + vk_buffer& buffer_gpu = buf_ctx->dev_buffer; + uint64_t offset = vk_tensor_offset(src2) + src2->view_offs; + if (!ggml_is_contiguous(src2) && ggml_vk_dim01_contiguous(src2)) { + for (int i3 = 0; i3 < src2->ne[3]; i3++) { + for (int i2 = 0; i2 < src2->ne[2]; i2++) { + const int idx = i3*src2->ne[2] + i2; + ggml_vk_buffer_read(buffer_gpu, offset + idx * src2->nb[2], ((char *)src2_clone->data + idx * src2_clone->nb[2]), src2->ne[1] * src2->nb[1]); + } + } + + src2_clone->nb[0] = src2->nb[0]; + src2_clone->nb[1] = src2->nb[1]; + for (int i = 2; i < GGML_MAX_DIMS; i++) { + src2_clone->nb[i] = src2_clone->nb[i - 1]*src2_clone->ne[i - 1]; + } + } else { + if (offset + src2_size >= buffer_gpu->size) { + src2_size = buffer_gpu->size - offset; + } + ggml_vk_buffer_read(buffer_gpu, offset, src2_clone->data, src2_size); + memcpy(src2_clone->nb, src2->nb, sizeof(size_t) * GGML_MAX_DIMS); + } + } else { + GGML_ABORT("fatal error"); + } + + if (vk_output_tensor > 0 && vk_output_tensor == check_counter) { + ggml_vk_print_tensor(src2, "src2"); + } + } + if (src3 != nullptr) { + src3_clone = ggml_dup_tensor(ggml_ctx, src3); + + src3_size = ggml_nbytes(src3); + + src3_buffer = malloc(src3_size); + src3_clone->data = src3_buffer; + if (ggml_backend_buffer_is_host(src3->buffer)) { + memcpy(src3_clone->data, src3->data, src3_size); + memcpy(src3_clone->nb, src3->nb, sizeof(size_t) * GGML_MAX_DIMS); + } else if (ggml_backend_buffer_is_vk(src3->buffer)) { + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src3->buffer->context; + vk_buffer& buffer_gpu = buf_ctx->dev_buffer; + uint64_t offset = vk_tensor_offset(src3) + src3->view_offs; + if (!ggml_is_contiguous(src3) && ggml_vk_dim01_contiguous(src3)) { + for (int i3 = 0; i3 < src3->ne[3]; i3++) { + for (int i2 = 0; i2 < src3->ne[2]; i2++) { + const int idx = i3*src3->ne[2] + i2; + ggml_vk_buffer_read(buffer_gpu, offset + idx * src3->nb[2], ((char *)src3_clone->data + idx * src3_clone->nb[2]), src3->ne[1] * src3->nb[1]); + } + } + + src3_clone->nb[0] = src3->nb[0]; + src3_clone->nb[1] = src3->nb[1]; + for (int i = 2; i < GGML_MAX_DIMS; i++) { + src3_clone->nb[i] = src3_clone->nb[i - 1]*src3_clone->ne[i - 1]; + } + } else { + if (offset + src3_size >= buffer_gpu->size) { + src3_size = buffer_gpu->size - offset; + } + ggml_vk_buffer_read(buffer_gpu, offset, src3_clone->data, src3_size); + memcpy(src3_clone->nb, src3->nb, sizeof(size_t) * GGML_MAX_DIMS); + } + } else { + GGML_ABORT("fatal error"); + } + + if (vk_output_tensor > 0 && vk_output_tensor == check_counter) { + ggml_vk_print_tensor(src3, "src3"); + } + } + + if (tensor->op == GGML_OP_FLASH_ATTN_EXT) { + const float *params = (const float *)tensor->op_params; + tensor_clone = ggml_flash_attn_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, src3_clone, params[0], params[1], params[2]); + } else if (tensor->op == GGML_OP_MUL_MAT) { + tensor_clone = ggml_mul_mat(ggml_ctx, src0_clone, src1_clone); + } else if (tensor->op == GGML_OP_MUL_MAT_ID) { + tensor_clone = ggml_mul_mat_id(ggml_ctx, src0_clone, src1_clone, src2_clone); + } else if (tensor->op == GGML_OP_MUL) { + tensor_clone = ggml_mul(ggml_ctx, src0_clone, src1_clone); + } else if (tensor->op == GGML_OP_DIV) { + tensor_clone = ggml_div(ggml_ctx, src0_clone, src1_clone); + } else if (tensor->op == GGML_OP_CONCAT) { + tensor_clone = ggml_concat(ggml_ctx, src0_clone, src1_clone, *(int *)tensor->op_params); + } else if (tensor->op == GGML_OP_UPSCALE) { + tensor_clone = ggml_upscale_ext(ggml_ctx, src0_clone, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); + } else if (tensor->op == GGML_OP_SCALE) { + tensor_clone = ggml_scale(ggml_ctx, src0_clone, ((float *)tensor->op_params)[0]); + } else if (tensor->op == GGML_OP_SQR) { + tensor_clone = ggml_sqr(ggml_ctx, src0_clone); + } else if (tensor->op == GGML_OP_SIN) { + tensor_clone = ggml_sin(ggml_ctx, src0_clone); + } else if (tensor->op == GGML_OP_COS) { + tensor_clone = ggml_cos(ggml_ctx, src0_clone); + } else if (tensor->op == GGML_OP_CLAMP) { + tensor_clone = ggml_clamp(ggml_ctx, src0_clone, ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]); + } else if (tensor->op == GGML_OP_PAD) { + tensor_clone = ggml_pad(ggml_ctx, src0_clone, tensor->ne[0] - src0_clone->ne[0], tensor->ne[1] - src0_clone->ne[1], tensor->ne[2] - src0_clone->ne[2], tensor->ne[3] - src0_clone->ne[3]); + } else if (tensor->op == GGML_OP_REPEAT) { + tensor_clone = ggml_repeat(ggml_ctx, src0_clone, tensor); + } else if (tensor->op == GGML_OP_ADD) { + tensor_clone = ggml_add(ggml_ctx, src0_clone, src1_clone); + } else if (tensor->op == GGML_OP_ACC) { + tensor_clone = ggml_acc(ggml_ctx, src0_clone, src1_clone, tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]); + } else if (tensor->op == GGML_OP_NORM) { + tensor_clone = ggml_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params); + } else if (tensor->op == GGML_OP_GROUP_NORM) { + tensor_clone = ggml_group_norm(ggml_ctx, src0_clone, *(int *)tensor->op_params, ((float *)tensor->op_params)[1]); + } else if (tensor->op == GGML_OP_RMS_NORM) { + tensor_clone = ggml_rms_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params); + } else if (tensor->op == GGML_OP_SOFT_MAX) { + if (src1 != nullptr) { + tensor_clone = ggml_soft_max_ext(ggml_ctx, src0_clone, src1_clone, ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]); + } else { + tensor_clone = ggml_soft_max(ggml_ctx, src0_clone); + } + } else if (tensor->op == GGML_OP_DIAG_MASK_INF) { + tensor_clone = ggml_diag_mask_inf(ggml_ctx, src0_clone, *(int *)tensor->op_params); + } else if (tensor->op == GGML_OP_ROPE) { + const int n_dims = ((int32_t *) tensor->op_params)[1]; + const int mode = ((int32_t *) tensor->op_params)[2]; + //const int n_ctx_ggml = ((int32_t *) tensor->op_params)[3]; + const int n_ctx_orig_ggml = ((int32_t *) tensor->op_params)[4]; + const float freq_base = ((float *) tensor->op_params)[5]; + const float freq_scale = ((float *) tensor->op_params)[6]; + const float ext_factor = ((float *) tensor->op_params)[7]; + const float attn_factor = ((float *) tensor->op_params)[8]; + const float beta_fast = ((float *) tensor->op_params)[9]; + const float beta_slow = ((float *) tensor->op_params)[10]; + tensor_clone = ggml_rope_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + } else if (tensor->op == GGML_OP_UNARY) { + switch (ggml_get_unary_op(tensor)) { + case GGML_UNARY_OP_SILU: + tensor_clone = ggml_silu(ggml_ctx, src0_clone); + break; + case GGML_UNARY_OP_GELU: + tensor_clone = ggml_gelu(ggml_ctx, src0_clone); + break; + case GGML_UNARY_OP_GELU_QUICK: + tensor_clone = ggml_gelu_quick(ggml_ctx, src0_clone); + break; + case GGML_UNARY_OP_RELU: + tensor_clone = ggml_relu(ggml_ctx, src0_clone); + break; + case GGML_UNARY_OP_TANH: + tensor_clone = ggml_tanh(ggml_ctx, src0_clone); + break; + default: + std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl; + GGML_ABORT("fatal error"); + } + } else if (tensor->op == GGML_OP_CPY || tensor->op == GGML_OP_DUP) { + if (src1 == nullptr) { + tensor_clone = ggml_dup(ggml_ctx, src0_clone); + tensor_clone->type = tensor->type; + } else { + tensor_clone = ggml_cpy(ggml_ctx, src0_clone, src1_clone); + } + } else if (tensor->op == GGML_OP_CONT) { + tensor_clone = ggml_cont_4d(ggml_ctx, src0_clone, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); + } else if (tensor->op == GGML_OP_RESHAPE) { + tensor_clone = ggml_reshape_4d(ggml_ctx, src0_clone, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); + } else if (tensor->op == GGML_OP_VIEW) { + tensor_clone = ggml_view_4d(ggml_ctx, src0_clone, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], tensor->nb[1], tensor->nb[2], tensor->nb[3], ((int32_t *) tensor->op_params)[0]); + } else if (tensor->op == GGML_OP_PERMUTE) { + int32_t * params = (int32_t *)tensor->op_params; + tensor_clone = ggml_permute(ggml_ctx, src0_clone, params[0], params[1], params[2], params[3]); + } else if (tensor->op == GGML_OP_TRANSPOSE) { + tensor_clone = ggml_transpose(ggml_ctx, src0_clone); + } else if (tensor->op == GGML_OP_GET_ROWS) { + tensor_clone = ggml_get_rows(ggml_ctx, src0_clone, src1_clone); + } else if (tensor->op == GGML_OP_ARGSORT) { + tensor_clone = ggml_argsort(ggml_ctx, src0_clone, (ggml_sort_order) *(int *)tensor->op_params); + } else if (tensor->op == GGML_OP_SUM_ROWS) { + tensor_clone = ggml_sum_rows(ggml_ctx, src0_clone); + } else if (tensor->op == GGML_OP_IM2COL) { + const int32_t s0 = tensor->op_params[0]; + const int32_t s1 = tensor->op_params[1]; + const int32_t p0 = tensor->op_params[2]; + const int32_t p1 = tensor->op_params[3]; + const int32_t d0 = tensor->op_params[4]; + const int32_t d1 = tensor->op_params[5]; + + const bool is_2D = tensor->op_params[6] == 1; + tensor_clone = ggml_im2col(ggml_ctx, src0_clone, src1_clone, s0, s1, p0, p1, d0, d1, is_2D, tensor->type); + } else if (tensor->op == GGML_OP_TIMESTEP_EMBEDDING) { + const int32_t dim = tensor->op_params[0]; + const int32_t max_period = tensor->op_params[1]; + tensor_clone = ggml_timestep_embedding(ggml_ctx, src0_clone, dim, max_period); + } else if (tensor->op == GGML_OP_POOL_2D) { + enum ggml_op_pool op = static_cast(tensor->op_params[0]); + const int32_t k0 = tensor->op_params[1]; + const int32_t k1 = tensor->op_params[2]; + const int32_t s0 = tensor->op_params[3]; + const int32_t s1 = tensor->op_params[4]; + const int32_t p0 = tensor->op_params[5]; + const int32_t p1 = tensor->op_params[6]; + + tensor_clone = ggml_pool_2d(ggml_ctx, src0_clone, op, k0, k1, s0, s1, p0, p1); + } else if (tensor->op == GGML_OP_LEAKY_RELU) { + const float * op_params = (const float *)tensor->op_params; + tensor_clone = ggml_leaky_relu(ggml_ctx, src0_clone, op_params[0], false); + } else if (tensor->op == GGML_OP_RWKV_WKV6) { + tensor_clone = ggml_rwkv_wkv6(ggml_ctx, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], + tensor->src[4], tensor->src[5]); + } + else { + std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl; + GGML_ABORT("fatal error"); + } + + ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx); + ggml_build_forward_expand(cgraph, tensor_clone); + + ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 8); + + if (vk_output_tensor > 0 && vk_output_tensor == check_counter) { + ggml_vk_print_tensor(tensor_clone, "tensor_clone"); + } + + comp_size = ggml_nbytes(tensor_clone); + + comp_result = malloc(comp_size); + memcpy(comp_result, tensor_clone->data, comp_size); + memcpy(comp_nb, tensor_clone->nb, sizeof(size_t) * GGML_MAX_DIMS); + + if (src0 != nullptr) { + free(src0_buffer); + } + if (src1 != nullptr) { + free(src1_buffer); + } + + ggml_free(ggml_ctx); + + VK_LOG_DEBUG("END ggml_vk_check_results_0(" << tensor->name << ")"); +} + +static void ggml_vk_check_results_1(ggml_tensor * tensor) { + if (tensor->op == GGML_OP_TRANSPOSE) { + return; + } + if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) { + return; + } + + VK_LOG_DEBUG("ggml_vk_check_results_1(" << tensor->name << ")"); + + ggml_tensor * src0 = tensor->src[0]; + ggml_tensor * src1 = tensor->src[1]; + ggml_tensor * src2 = tensor->src[2]; + + void * tensor_data = tensor->data; + + if (ggml_backend_buffer_is_vk(tensor->buffer)) { + size_t tensor_size = ggml_nbytes(tensor); + tensor_data = malloc(tensor_size); + + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context; + + vk_buffer& buffer_gpu = buf_ctx->dev_buffer; + uint64_t offset = vk_tensor_offset(tensor) + tensor->view_offs; + if (offset + tensor_size >= buffer_gpu->size) { + tensor_size = buffer_gpu->size - offset; + } + + ggml_vk_buffer_read(buffer_gpu, offset, tensor_data, tensor_size); + } + + float first_error_result = -1.0f; + float first_error_correct = -1.0f; + std::array first_error = { -1, -1, -1, -1 }; + double avg_err = 0.0; + size_t counter = 0; + + for (int i3 = 0; i3 < tensor->ne[3]; i3++) { + for (int i2 = 0; i2 < tensor->ne[2]; i2++) { + for (int i1 = 0; i1 < tensor->ne[1]; i1++) { + for (int i0 = 0; i0 < tensor->ne[0]; i0++) { + const bool buffer_size_fit = i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0] < comp_size; + float correct = 0.0f; + float result = 0.0f; + + if (buffer_size_fit) { + if (tensor->type == GGML_TYPE_F32) { + correct = *(float *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]); + result = *(float *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]); + } else if (tensor->type == GGML_TYPE_F16) { + correct = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0])); + result = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0])); + } else if (tensor->type == GGML_TYPE_I32) { + correct = *(int32_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]); + result = *(int32_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]); + } else { + std::cerr << "Results check not implemented for type " << ggml_type_name(tensor->type) << std::endl; + } + } else { + std::cerr << "Missing debug code for type " << ggml_type_name(tensor->type) << std::endl; + GGML_ABORT("fatal error"); + } + + if ((std::isnan(correct) != std::isnan(result)) || (std::isinf(correct) != std::isinf(result)) || !buffer_size_fit) { + std::cerr << "ERROR: Invalid value in " << ggml_op_name(tensor->op) << " i3=" << i3 << " i2=" << i2 << " i1=" << i1 << " i0=" << i0 << " result=" << result << " correct=" << correct << " avg_err=" << (avg_err / counter) << std::endl; + std::cerr << "tensor=" << tensor << " tensor->name=" << tensor->name << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << " offset=" << tensor->view_offs << std::endl; + if (src0 != nullptr) { + std::cerr << "src0=" << src0 << " src0->name=" << src0->name << " op=" << ggml_op_name(src0->op) << " type=" << ggml_type_name(src0->type) << " ne0=" << src0->ne[0] << " nb0=" << src0->nb[0] << " ne1=" << src0->ne[1] << " nb1=" << src0->nb[1] << " ne2=" << src0->ne[2] << " nb2=" << src0->nb[2] << " ne3=" << src0->ne[3] << " nb3=" << src0->nb[3] << " offset=" << src0->view_offs << std::endl; + } + if (src1 != nullptr) { + std::cerr << "src1=" << src1 << " src1->name=" << src1->name << " op=" << ggml_op_name(src1->op) << " type=" << ggml_type_name(src1->type) << " ne0=" << src1->ne[0] << " nb0=" << src1->nb[0] << " ne1=" << src1->ne[1] << " nb1=" << src1->nb[1] << " ne2=" << src1->ne[2] << " nb2=" << src1->nb[2] << " ne3=" << src1->ne[3] << " nb3=" << src1->nb[3] << " offset=" << src1->view_offs << std::endl; + } + if (src2 != nullptr) { + std::cerr << "src2=" << src2 << " src2->name=" << src2->name << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl; + } + std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl; + std::cerr << std::endl << "Result:" << std::endl; + ggml_vk_print_tensor_area(tensor, tensor_data, i0, i1, i2, i3); + std::cerr << std::endl << "Correct:" << std::endl; + ggml_vk_print_tensor_area(tensor, comp_result, i0, i1, i2, i3); + std::cerr << std::endl; + std::vector done; + ggml_vk_print_graph_origin(tensor, done); + GGML_ABORT("fatal error"); + } + if (first_error[0] == -1 && std::fabs(correct - result) > 0.1f) { + first_error[0] = i0; + first_error[1] = i1; + first_error[2] = i2; + first_error[3] = i3; + first_error_result = result; + first_error_correct = correct; + } + + // Special case, value is infinite, avoid NaN result in avg_err + // NaN also appears in results, if both are nan error is 0 + if (!std::isinf(correct) && !std::isinf(result) && !std::isnan(correct) && !std::isnan(result)) { + avg_err += std::fabs(correct - result); + } + counter++; + } + } + } + } + + avg_err /= counter; + + if (vk_output_tensor > 0 && vk_output_tensor == check_counter) { + std::cerr << "TENSOR CHECK: avg_err=" << avg_err << " in " << ggml_op_name(tensor->op) << " (check " << check_counter << ")" << std::endl; + std::cerr << "tensor=" << tensor << " tensor->name=" << tensor->name << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << " offset=" << tensor->view_offs << std::endl; + if (src0 != nullptr) { + std::cerr << "src0=" << src0 << " op=" << ggml_op_name(src0->op) << " type=" << ggml_type_name(src0->type) << " ne0=" << src0->ne[0] << " nb0=" << src0->nb[0] << " ne1=" << src0->ne[1] << " nb1=" << src0->nb[1] << " ne2=" << src0->ne[2] << " nb2=" << src0->nb[2] << " ne3=" << src0->ne[3] << " nb3=" << src0->nb[3] << " offset=" << src0->view_offs << std::endl; + } + if (src1 != nullptr) { + std::cerr << "src1=" << src1 << " op=" << ggml_op_name(src1->op) << " type=" << ggml_type_name(src1->type) << " ne0=" << src1->ne[0] << " nb0=" << src1->nb[0] << " ne1=" << src1->ne[1] << " nb1=" << src1->nb[1] << " ne2=" << src1->ne[2] << " nb2=" << src1->nb[2] << " ne3=" << src1->ne[3] << " nb3=" << src1->nb[3] << " offset=" << src1->view_offs << std::endl; + } + if (src2 != nullptr) { + std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl; + } + std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl; + std::cerr << std::endl << "Result:" << std::endl; + ggml_vk_print_tensor_area(tensor, tensor_data, 5, 5, 0, 0); + std::cerr << std::endl << "Correct:" << std::endl; + ggml_vk_print_tensor_area(tensor, comp_result, 5, 5, 0, 0); + std::cerr << std::endl; + std::vector done; + ggml_vk_print_graph_origin(tensor, done); + } + + if (avg_err > 0.05 || std::isnan(avg_err)) { + std::cerr << "ERROR: avg_err=" << avg_err << " in " << ggml_op_name(tensor->op) << " (check " << check_counter << ")" << std::endl; + std::cerr << "tensor=" << tensor << " tensor->name=" << tensor->name << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << " offset=" << tensor->view_offs << std::endl; + if (src0 != nullptr) { + std::cerr << "src0=" << src0 << " op=" << ggml_op_name(src0->op) << " type=" << ggml_type_name(src0->type) << " ne0=" << src0->ne[0] << " nb0=" << src0->nb[0] << " ne1=" << src0->ne[1] << " nb1=" << src0->nb[1] << " ne2=" << src0->ne[2] << " nb2=" << src0->nb[2] << " ne3=" << src0->ne[3] << " nb3=" << src0->nb[3] << " offset=" << src0->view_offs << std::endl; + } + if (src1 != nullptr) { + std::cerr << "src1=" << src1 << " op=" << ggml_op_name(src1->op) << " type=" << ggml_type_name(src1->type) << " ne0=" << src1->ne[0] << " nb0=" << src1->nb[0] << " ne1=" << src1->ne[1] << " nb1=" << src1->nb[1] << " ne2=" << src1->ne[2] << " nb2=" << src1->nb[2] << " ne3=" << src1->ne[3] << " nb3=" << src1->nb[3] << " offset=" << src1->view_offs << std::endl; + } + if (src2 != nullptr) { + std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl; + } + std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl; + std::cerr << std::endl << "Result:" << std::endl; + ggml_vk_print_tensor_area(tensor, tensor_data, first_error[0], first_error[1], first_error[2], first_error[3]); + std::cerr << std::endl << "Correct:" << std::endl; + ggml_vk_print_tensor_area(tensor, comp_result, first_error[0], first_error[1], first_error[2], first_error[3]); + std::cerr << std::endl; + std::vector done; + ggml_vk_print_graph_origin(tensor, done); + GGML_ABORT("fatal error"); + } else { + std::cerr << check_counter << " " << tensor->name << " op=" << ggml_op_name(tensor->op) << " avg_err=" << avg_err << std::endl; + } + + free(comp_result); + comp_result = nullptr; + comp_size = 0; + + if (ggml_backend_buffer_is_vk(tensor->buffer)) { + free(tensor_data); + } + + VK_LOG_DEBUG("END ggml_vk_check_results_1(" << tensor->name << ")"); +} +#endif + +GGML_BACKEND_DL_IMPL(ggml_backend_vk_reg) diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt new file mode 100644 index 000000000..bd0c74cb1 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt @@ -0,0 +1,9 @@ +find_package (Threads REQUIRED) +find_package(Vulkan COMPONENTS glslc REQUIRED) + +set(TARGET vulkan-shaders-gen) +add_executable(${TARGET} vulkan-shaders-gen.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_compile_features(${TARGET} PRIVATE cxx_std_17) +target_link_libraries(vulkan-shaders-gen PUBLIC Threads::Threads) +target_link_libraries(vulkan-shaders-gen PRIVATE Vulkan::Vulkan) diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp new file mode 100644 index 000000000..d896f1ef0 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp @@ -0,0 +1,29 @@ +#version 450 + +#include "types.comp" +#include "generic_binary_head.comp" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint idx = gl_GlobalInvocationID.x; + if (idx >= p.ne) { + return; + } + + const uint offset = p.param3; + const uint src1_i = idx - offset; + const uint oz = src1_i / p.nb02; + const uint oy = (src1_i - (oz * p.nb02)) / p.nb01; + const uint ox = src1_i % p.nb01; + + uint i00, i01, i02, i03; + get_indices(idx, i00, i01, i02, i03); + + if (ox < p.ne10 && oy < p.ne11 && oz < p.ne12) { + data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + ox + oy * p.ne10 + oz * p.ne10 * p.ne11])); + } else { + data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)])); + } +} + diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/add.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/add.comp new file mode 100644 index 000000000..2b4085c4f --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/add.comp @@ -0,0 +1,29 @@ +#version 450 + +#extension GL_EXT_shader_16bit_storage : require + +#include "types.comp" +#include "generic_binary_head.comp" + +const uint num_threads = 256; + +layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; + +void main() { + uint idx = get_idx(); + + // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation + const uint num_iter = 2; + + [[unroll]] for (uint i = 0; i < num_iter; ++i) { + if (idx >= p.ne) { + continue; + } + uint i00, i01, i02, i03; + get_indices(idx, i00, i01, i02, i03); + + data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)])); + + idx += num_threads; + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp new file mode 100644 index 000000000..d4fa45b1e --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp @@ -0,0 +1,69 @@ +#version 450 + +#include "types.comp" + +#define BLOCK_SIZE 1024 +#define ASC 0 + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) buffer D {int data_d[];}; + +layout (push_constant) uniform parameter { + uint ncols; + uint ncols_pad; + uint order; +} p; + +shared int dst_row[BLOCK_SIZE]; + +void swap(uint idx0, uint idx1) { + int tmp = dst_row[idx0]; + dst_row[idx0] = dst_row[idx1]; + dst_row[idx1] = tmp; +} + +void main() { + // bitonic sort + const int col = int(gl_LocalInvocationID.x); + const uint row = gl_WorkGroupID.y; + + const uint row_offset = row * p.ncols; + + // initialize indices + if (col < p.ncols_pad) { + dst_row[col] = col; + } + barrier(); + + for (uint k = 2; k <= p.ncols_pad; k *= 2) { + for (uint j = k / 2; j > 0; j /= 2) { + const uint ixj = col ^ j; + if (col < p.ncols_pad && ixj > col) { + if ((col & k) == 0) { + if (dst_row[col] >= p.ncols || + (dst_row[ixj] < p.ncols && (p.order == ASC ? + data_a[row_offset + dst_row[col]] > data_a[row_offset + dst_row[ixj]] : + data_a[row_offset + dst_row[col]] < data_a[row_offset + dst_row[ixj]])) + ) { + swap(col, ixj); + } + } else { + if (dst_row[ixj] >= p.ncols || + (dst_row[col] < p.ncols && (p.order == ASC ? + data_a[row_offset + dst_row[col]] < data_a[row_offset + dst_row[ixj]] : + data_a[row_offset + dst_row[col]] > data_a[row_offset + dst_row[ixj]])) + ) { + swap(col, ixj); + } + } + } + barrier(); + } + } + + if (col < p.ncols) { + data_d[row_offset + col] = dst_row[col]; + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp new file mode 100644 index 000000000..1e5cb8dae --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp @@ -0,0 +1,17 @@ +#version 450 + +#include "types.comp" +#include "generic_unary_head.comp" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint idx = get_idx(); + + if (idx >= p.ne) { + return; + } + + const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]); + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val < p.param1 ? p.param1 : (val > p.param2 ? p.param2 : val)); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp new file mode 100644 index 000000000..9ee2f1fae --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp @@ -0,0 +1,41 @@ +#version 450 + +#include "types.comp" +#include "generic_binary_head.comp" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + const int dim = p.param3; + + if (idx >= p.ne) { + return; + } + + const uint i3 = idx / (p.ne22*p.ne21*p.ne20); + const uint i3_offset = i3 * p.ne22*p.ne21*p.ne20; + const uint i2 = (idx - i3_offset) / (p.ne21*p.ne20); + const uint i2_offset = i2*p.ne21*p.ne20; + const uint i1 = (idx - i3_offset - i2_offset) / p.ne20; + const uint i0 = idx - i3_offset - i2_offset - i1*p.ne20; + + uint o[4] = {0, 0, 0, 0}; + o[dim] = dim == 0 ? p.ne00 : (dim == 1 ? p.ne01 : (dim == 2 ? p.ne02 : p.ne03)); + + const uint src0_idx = i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0*p.nb00; + const uint src1_idx = (i3 - o[3])*p.nb13 + (i2 - o[2])*p.nb12 + (i1 - o[1])*p.nb11 + (i0 - o[0])*p.nb10; + const uint dst_idx = i3*p.nb23 + i2*p.nb22 + i1*p.nb21 + i0*p.nb20; + + const bool is_src0 = i0 < p.ne00 && i1 < p.ne01 && i2 < p.ne02 && i3 < p.ne03; + +#ifndef OPTIMIZATION_ERROR_WORKAROUND + data_d[get_doffset() + dst_idx] = D_TYPE(is_src0 ? data_a[get_aoffset() + src0_idx] : data_b[get_boffset() + src1_idx]); +#else + if (is_src0) { + data_d[get_doffset() + dst_idx] = data_a[get_aoffset() + src0_idx]; + } else { + data_d[get_doffset() + dst_idx] = data_b[get_boffset() + src1_idx]; + } +#endif +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp new file mode 100644 index 000000000..dd828c232 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp @@ -0,0 +1,42 @@ +#version 450 + +#include "types.comp" +#include "generic_unary_head.comp" + +#extension GL_EXT_control_flow_attributes : require + +const uint num_threads = 128; + +layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; + +void main() { + uint idx = get_idx(); + + // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation + const uint num_iter = 4; + + // fast path for when all four iterations are in-bounds + if (idx + (num_iter-1)*num_threads < p.ne) { + [[unroll]] for (uint i = 0; i < num_iter; ++i) { +#ifndef OPTIMIZATION_ERROR_WORKAROUND + data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]); +#else + data_d[get_doffset() + idx] = data_a[get_aoffset() + idx]; +#endif + idx += num_threads; + } + } else { + [[unroll]] for (uint i = 0; i < num_iter; ++i) { + if (idx >= p.ne) { + continue; + } + +#ifndef OPTIMIZATION_ERROR_WORKAROUND + data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]); +#else + data_d[get_doffset() + idx] = data_a[get_aoffset() + idx]; +#endif + idx += num_threads; + } + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp new file mode 100644 index 000000000..29c906494 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp @@ -0,0 +1,20 @@ +#version 450 + +#include "types.comp" +#include "generic_unary_head.comp" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint idx = get_idx(); + + if (idx >= p.ne) { + return; + } + +#ifndef OPTIMIZATION_ERROR_WORKAROUND + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(data_a[get_aoffset() + src0_idx(idx)]); +#else + data_d[get_doffset() + dst_idx(idx)] = data_a[get_aoffset() + src0_idx(idx)]; +#endif +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp new file mode 100644 index 000000000..0b8d02f58 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp @@ -0,0 +1,17 @@ +#version 450 + +#include "types.comp" +#include "generic_unary_head.comp" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint idx = get_idx(); + + if (idx >= p.ne) { + return; + } + + const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]); + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(cos(val)); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp new file mode 100644 index 000000000..a4d3fca55 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp @@ -0,0 +1,20 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {float data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + const uint i = gl_GlobalInvocationID.x * 16; + + if (i >= p.nel) { + return; + } + + [[unroll]] for (uint l = 0; l < 16; l++) { + data_b[i + l] = D_TYPE(data_a[i + l]); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp new file mode 100644 index 000000000..91bb8f8db --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp @@ -0,0 +1,118 @@ +#if !defined(DATA_A_F32) && !defined(DATA_A_F16) +#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require +#endif + +#include "types.comp" + +#if defined(A_TYPE_PACKED16) +layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];}; +#endif +#if defined(A_TYPE_PACKED32) +layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];}; +#endif + +#if defined(DATA_A_F32) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]); +} +#endif + +#if defined(DATA_A_F16) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]); +} +#endif + +#if defined(DATA_A_Q4_0) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint vui = uint(data_a[a_offset + ib].qs[iqs]); + return (vec2(vui & 0xF, vui >> 4) - 8.0f); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]); + return (vec4(vui & 0xF, (vui >> 4) & 0xF, (vui >> 8) & 0xF, vui >> 12) - 8.0f); +} +#endif + +#if defined(DATA_A_Q4_1) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint vui = uint(data_a[a_offset + ib].qs[iqs]); + return vec2(vui & 0xF, vui >> 4); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]); + return vec4(vui & 0xF, (vui >> 4) & 0xF, (vui >> 8) & 0xF, vui >> 12); +} +#endif + +#if defined(DATA_A_Q5_0) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint uint_qh = uint(data_a[a_offset + ib].qh[1]) << 16 | data_a[a_offset + ib].qh[0]; + const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); + const uint vui = uint(data_a[a_offset + ib].qs[iqs]); + return (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint uint_qh = uint(data_a_packed16[a_offset + ib].qh[1]) << 16 | data_a_packed16[a_offset + ib].qh[0]; + const ivec2 qh0 = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); + const ivec2 qh1 = ivec2(((uint_qh >> (iqs + 1)) << 4) & 0x10, (uint_qh >> (iqs + 13)) & 0x10); + const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]); + return (vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) - 16.0f); +} +#endif + +#if defined(DATA_A_Q5_1) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint uint_qh = data_a[a_offset + ib].qh; + const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); + const uint vui = uint(data_a[a_offset + ib].qs[iqs]); + return vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint uint_qh = data_a_packed16[a_offset + ib].qh; + const ivec2 qh0 = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); + const ivec2 qh1 = ivec2(((uint_qh >> (iqs + 1)) << 4) & 0x10, (uint_qh >> (iqs + 13)) & 0x10); + const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]); + return vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y); +} +#endif + +#if defined(DATA_A_Q8_0) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + return vec2(int(data_a[a_offset + ib].qs[iqs]), int(data_a[a_offset + ib].qs[iqs + 1])); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + uint32_t v0 = data_a_packed16[a_offset + ib].qs[iqs/2]; + uint32_t v1 = data_a_packed16[a_offset + ib].qs[iqs/2 + 1]; + return vec4(int8_t(v0 & 0xFF), int8_t(v0 >> 8), int8_t(v1 & 0xFF), int8_t(v1 >> 8)); +} +#endif + +#if defined(DATA_A_IQ4_NL) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint vui = uint(data_a[a_offset + ib].qs[iqs]); + return vec2(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[vui >> 4]); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]); + return vec4(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[(vui >> 4) & 0xF], kvalues_iq4nl[(vui >> 8) & 0xF], kvalues_iq4nl[vui >> 12]); +} +#endif + +#if defined(DATA_A_F32) || defined(DATA_A_F16) +vec2 get_dm(uint ib, uint a_offset) { + return vec2(0, 0); +} +#endif + +#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ4_NL) +vec2 get_dm(uint ib, uint a_offset) { + return vec2(float(data_a[a_offset + ib].d), 0); +} +#endif + +#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1) +vec2 get_dm(uint ib, uint a_offset) { + return vec2(float(data_a[a_offset + ib].d), float(data_a[a_offset + ib].m)); +} +#endif diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp new file mode 100644 index 000000000..94b78598e --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp @@ -0,0 +1,325 @@ + +#include "types.comp" + +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ4_0 { + block_q4_0_packed16 block; +}; + +float16_t dequantFuncQ4_0(const in decodeBufQ4_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + const uint shift = (idx & 0x10) >> 2; + uint32_t qs = uint32_t(bl.block.qs[(idx & 0xE) >> 1]); + qs >>= shift; + qs &= 0x0F0F; + qs = unpack8(qs)[idx & 1]; + float16_t ret = (float16_t(qs) - float16_t(8)) * d; + return ret; +} + +layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ4_1 { + block_q4_1 block; +}; + +float16_t dequantFuncQ4_1(const in decodeBufQ4_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const float16_t m = bl.block.m; + const uint idx = coordInBlock[1]; + const uint iqs = idx & 0xF; + const uint shift = (idx & 0x10) >> 2; + uint32_t qs = bl.block.qs[iqs]; + qs >>= shift; + qs &= 0xF; + float16_t ret = float16_t(qs) * d + m; + return ret; +} + +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ5_0 { + block_q5_0 block; +}; + +float16_t dequantFuncQ5_0(const in decodeBufQ5_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + const uint iqs = idx & 0xF; + + const uint uint_qh = uint(bl.block.qh[1]) << 16 | bl.block.qh[0]; + const uint qh = ((uint_qh >> idx) << 4) & 0x10; + + const uint shift = (idx & 0x10) >> 2; + uint32_t qs = bl.block.qs[iqs]; + qs >>= shift; + qs &= 0xF; + + float16_t ret = (float16_t(qs | qh) - float16_t(16)) * d; + return ret; +} + +layout(buffer_reference, std430, buffer_reference_align = 8) buffer decodeBufQ5_1 { + block_q5_1 block; +}; + +float16_t dequantFuncQ5_1(const in decodeBufQ5_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const float16_t m = bl.block.m; + const uint idx = coordInBlock[1]; + const uint iqs = idx & 0xF; + + const uint uint_qh = bl.block.qh; + const uint qh = ((uint_qh >> idx) << 4) & 0x10; + + const uint shift = (idx & 0x10) >> 2; + uint32_t qs = bl.block.qs[iqs]; + qs >>= shift; + qs &= 0xF; + + float16_t ret = float16_t(qs | qh) * d + m; + return ret; +} + +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ8_0 { + block_q8_0_packed16 block; +}; + +float16_t dequantFuncQ8_0(const in decodeBufQ8_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + const uint iqs = idx; + + // Load 16b and select the byte for this element + int32_t qs = unpack8(int32_t(bl.block.qs[(iqs & 0x1E) >> 1]))[iqs & 1]; + float16_t ret = float16_t(qs) * d; + return ret; +} + +layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ2_K { + block_q2_K block; +}; + +float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const f16vec2 d = bl.block.d; + const uint idx = coordInBlock[1]; + const uint iqs = idx; + + const uint qsi = (iqs / 128) * 32 + (iqs % 32); // 0..31 + const uint scalesi = iqs / 16; // 0..15 + const uint qsshift = ((iqs % 128) / 32) * 2; // 0,2,4,6 + + uint32_t qs = bl.block.qs[qsi]; + const uint scales = bl.block.scales[scalesi]; + float16_t ret = d.x * float16_t(scales & 0xF) * float16_t((qs >> qsshift) & 3) - d.y * float16_t(scales >> 4); + return ret; +} + +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ3_K { + block_q3_K block; +}; + +float16_t dequantFuncQ3_K(const in decodeBufQ3_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const uint idx = coordInBlock[1]; + const uint iqs = idx; + + const uint n = iqs / 128; // 0,1 + const uint qsi = n * 32 + (iqs % 32); // 0..63 + const uint hmi = (iqs % 32); // 0..31 + const uint j = (iqs % 128) / 8; // 0..15 + const uint is = iqs / 16; // 0..15 + const uint halfsplit = ((iqs % 128) / 32); // 0,1,2,3 + const uint qsshift = halfsplit * 2; // 0,2,4,6 + const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128 + + uint32_t scaleidx0 = (is < 8) ? is : (is-8); + uint32_t scaleidx0shift = (is < 8) ? 0 : 4; + uint32_t scaleidx1 = is + 8 - (is/4)*4; + uint32_t scaleidx1shift = (is/4)*2; + + const int8_t us = int8_t(((bl.block.scales[scaleidx0] >> scaleidx0shift) & 0xF) | (((bl.block.scales[scaleidx1] >> scaleidx1shift) & 3) << 4)); + + const float16_t dl = bl.block.d * float16_t(us - 32); + + float16_t ret = dl * float16_t(int8_t((bl.block.qs[qsi ] >> qsshift) & 3) - (((bl.block.hmask[hmi ] & m) != 0) ? 0 : 4)); + + return ret; +} + +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K { + block_q4_K block; +}; + +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K_packed16 { + block_q4_K_packed16 block; +}; + +float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufQ4_K_packed16 bl16 = decodeBufQ4_K_packed16(bl); + const uint idx = coordInBlock[1]; + + const uint b = (idx & 0x20) >> 5; // 0,1 + const uint is = (idx & 0xE0) >> 5; // 0..7 + + const f16vec2 loadd = bl.block.d; + + uint32_t sc; + uint32_t mbyte; + + uint32_t scidx0 = (is < 4) ? is : (is + 4); + uint32_t scidx1 = (is < 4) ? is : (is - 4); + uint32_t scidxmask1 = (is < 4) ? 0x30 : 0xC0; + uint32_t scidxshift1 = (is < 4) ? 0 : 2; + uint32_t mbidx0 = is + 4; + uint32_t mbidx1 = (is < 4) ? is + 4 : is; + uint32_t mbidxmask0 = (is < 4) ? 0xF : 0xF0; + uint32_t mbidxshift0 = (is < 4) ? 0 : 4; + uint32_t mbidxmask1 = (is < 4) ? 0x30 : 0xC0; + uint32_t mbidxshift1 = (is < 4) ? 0 : 2; + + sc = uint8_t((bl.block.scales[scidx0] & 0xF) | ((bl.block.scales[scidx1] & scidxmask1) >> scidxshift1)); + mbyte = uint8_t(((bl.block.scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((bl.block.scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + + const float16_t d = loadd.x * float16_t(sc); + const float16_t m = loadd.y * float16_t(mbyte); + + uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]); + qs = (qs >> (b * 4)) & 0x0F0F; + qs = unpack8(qs)[idx & 1]; + + float16_t ret = d * float16_t(qs) - m; + + return ret; +} + +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K { + block_q5_K block; +}; + +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K_packed16 { + block_q5_K_packed16 block; +}; + +float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufQ5_K_packed16 bl16 = decodeBufQ5_K_packed16(bl); + const uint idx = coordInBlock[1]; + + const uint b = (idx & 0x20) >> 5; // 0,1 + const uint is = (idx & 0xE0) >> 5; // 0..7 + + const uint32_t hm = 0x0101 << is; + + const f16vec2 loadd = bl.block.d; + + uint32_t sc; + uint32_t mbyte; + + uint32_t scidx0 = (is < 4) ? is : (is + 4); + uint32_t scidx1 = (is < 4) ? is : (is - 4); + uint32_t scidxmask1 = (is < 4) ? 0x30 : 0xC0; + uint32_t scidxshift1 = (is < 4) ? 0 : 2; + uint32_t mbidx0 = is + 4; + uint32_t mbidx1 = (is < 4) ? is + 4 : is; + uint32_t mbidxmask0 = (is < 4) ? 0xF : 0xF0; + uint32_t mbidxshift0 = (is < 4) ? 0 : 4; + uint32_t mbidxmask1 = (is < 4) ? 0x30 : 0xC0; + uint32_t mbidxshift1 = (is < 4) ? 0 : 2; + + sc = uint8_t((bl.block.scales[scidx0] & 0xF) | ((bl.block.scales[scidx1] & scidxmask1) >> scidxshift1)); + mbyte = uint8_t(((bl.block.scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((bl.block.scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + + const float16_t d = loadd.x * float16_t(sc); + const float16_t m = loadd.y * float16_t(mbyte); + + uint qh = uint32_t(bl16.block.qh[(idx & 0x1E) >> 1]); + qh = qh & hm; + qh = unpack8(qh)[idx & 1]; + + uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]); + qs = (qs >> (b * 4)) & 0x0F0F; + qs = unpack8(qs)[idx & 1]; + + float16_t ret = d * (float16_t(qs) + (qh != 0 ? float16_t(16) : float16_t(0))) - m; + + return ret; +} + +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ6_K { + block_q6_K block; +}; + +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ6_K_packed16 { + block_q6_K_packed16 block; +}; + +float16_t dequantFuncQ6_K(const in decodeBufQ6_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufQ6_K_packed16 bl16 = decodeBufQ6_K_packed16(bl); + const uint idx = coordInBlock[1]; + + const uint b = (idx & 0x40) >> 6; // 0,1 + const uint qhshift = (idx & 0x60) >> 4; // 0,2,4,6 + const uint is = (idx & 0xF0) >> 4; // 0..15 + + const float16_t dscale = bl.block.d * float16_t(bl.block.scales[is]); + + uint ql = uint32_t(bl16.block.ql[((idx & 0x80) >> 2) + ((idx & 0x3E) >> 1)]); + ql = (ql >> (b * 4)) & 0x0F0F; + + uint qh = uint32_t(bl16.block.qh[((idx & 0x80) >> 3) + ((idx & 0x1E) >> 1)]); + qh = ((qh >> qhshift) & 0x0303) << 4; + + int q = unpack8(ql | qh)[idx & 1]; + + float16_t ret = dscale * float16_t(q - 32); + + return ret; +} + +#if defined(DATA_A_IQ4_NL) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4_NL { + block_iq4_nl block; +}; + +float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + const uint iqs = idx & 0xF; + const uint shift = (idx & 0x10) >> 2; + uint32_t qs = bl.block.qs[iqs]; + qs >>= shift; + qs &= 0xF; + float16_t ret = float16_t(kvalues_iq4nl[qs]) * d; + return ret; +} +#endif + +#if defined(DATA_A_Q4_0) +#define dequantFuncA dequantFuncQ4_0 +#elif defined(DATA_A_Q4_1) +#define dequantFuncA dequantFuncQ4_1 +#elif defined(DATA_A_Q5_0) +#define dequantFuncA dequantFuncQ5_0 +#elif defined(DATA_A_Q5_1) +#define dequantFuncA dequantFuncQ5_1 +#elif defined(DATA_A_Q8_0) +#define dequantFuncA dequantFuncQ8_0 +#elif defined(DATA_A_Q2_K) +#define dequantFuncA dequantFuncQ2_K +#elif defined(DATA_A_Q3_K) +#define dequantFuncA dequantFuncQ3_K +#elif defined(DATA_A_Q4_K) +#define dequantFuncA dequantFuncQ4_K +#elif defined(DATA_A_Q5_K) +#define dequantFuncA dequantFuncQ5_K +#elif defined(DATA_A_Q6_K) +#define dequantFuncA dequantFuncQ6_K +#elif defined(DATA_A_IQ4_NL) +#define dequantFuncA dequantFuncIQ4_NL +#endif diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.comp new file mode 100644 index 000000000..8d806435b --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.comp @@ -0,0 +1,13 @@ +#extension GL_EXT_control_flow_attributes : require +#extension GL_EXT_shader_16bit_storage : require + +layout (push_constant) uniform parameter +{ + uint M; + uint K; + uint stride_a; + uint stride_b; + uint nel; +} p; + +#include "types.comp" diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp new file mode 100644 index 000000000..8de14fc03 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp @@ -0,0 +1,32 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_iq4_nl data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; + + init_iq4nl_shmem(); + + const uint tid = gl_LocalInvocationID.x % 64; + const uint il = tid/32; + const uint ir = tid%32; + const uint ib = 32*i + ir; + if (ib >= p.nel / 32) { + return; + } + + const uint q_idx = 8*il; + const uint b_idx = 1024*i + 32*ir + q_idx; + + const float d = float(data_a[ib].d); + + [[unroll]] for (uint l = 0; l < 8; ++l) { + data_b[b_idx + l + 0] = D_TYPE(d * kvalues_iq4nl[data_a[ib].qs[q_idx + l] & 0xF]); + data_b[b_idx + l + 16] = D_TYPE(d * kvalues_iq4nl[data_a[ib].qs[q_idx + l] >> 4]); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp new file mode 100644 index 000000000..157154af3 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp @@ -0,0 +1,34 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { + const uint i = gl_WorkGroupID.x * 256 + wgy; + if (i >= p.M * p.K / QUANT_K) { + return; + } + + const uint tid = gl_LocalInvocationID.x; + const uint ip = tid / 32; + const uint il = tid - 32 * ip; + const uint is = 8 * ip + il / 16; + + const uint y_idx = i * QUANT_K + 128 * ip + il; + + const uint ql_idx = 32 * ip + il; + const uint8_t qs = data_a[i].qs[32 * ip + il]; + + FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x); + FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y); + data_b[y_idx + 0] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+0] & 0xF) * ((qs >> 0) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+0] >> 4)); + data_b[y_idx + 32] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+2] & 0xF) * ((qs >> 2) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+2] >> 4)); + data_b[y_idx + 64] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+4] & 0xF) * ((qs >> 4) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+4] >> 4)); + data_b[y_idx + 96] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+6] & 0xF) * ((qs >> 6) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+6] >> 4)); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp new file mode 100644 index 000000000..c17dd0d99 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp @@ -0,0 +1,42 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { + const uint i = uint(gl_WorkGroupID.x * 256 + wgy); + if (i >= p.M * p.K / QUANT_K) { + return; + } + + const uint r = gl_LocalInvocationID.x / 4; + const uint tid = r / 2; + const uint is0 = r % 2; + const uint l0 = 16 * is0 + 4 * (gl_LocalInvocationID.x % 4); + const uint n = tid / 4; + const uint j = tid - 4*n; + + const uint8_t m = uint8_t(1 << (4*n + j)); + const uint is = 8*n + 2*j + is0; + const uint shift = 2*j; + + const int8_t us = int8_t(is < 4 ? (data_a[i].scales[is-0] & 0xF) | (((data_a[i].scales[is+8] >> 0) & 3) << 4) : + is < 8 ? (data_a[i].scales[is-0] & 0xF) | (((data_a[i].scales[is+4] >> 2) & 3) << 4) : + is < 12 ? (data_a[i].scales[is-8] >> 4) | (((data_a[i].scales[is+0] >> 4) & 3) << 4) : + (data_a[i].scales[is-8] >> 4) | (((data_a[i].scales[is-4] >> 6) & 3) << 4)); + const FLOAT_TYPE d_all = FLOAT_TYPE(data_a[i].d); + const FLOAT_TYPE dl = d_all * FLOAT_TYPE(us - 32); + + const uint y_idx = i * QUANT_K + 128 * n + 32 * j; + const uint qs_idx = 32*n; + + for (uint l = l0; l < l0 + 4; ++l) { + data_b[y_idx + l] = D_TYPE(dl * FLOAT_TYPE(int8_t((data_a[i].qs[qs_idx + l] >> shift) & 3) - (((data_a[i].hmask[l] & m) != 0) ? 0 : 4))); + } + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp new file mode 100644 index 000000000..408185327 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp @@ -0,0 +1,30 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_q4_0 data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; + + const uint tid = gl_LocalInvocationID.x % 64; + const uint il = tid/32; + const uint ir = tid%32; + const uint ib = 32*i + ir; + if (ib >= p.nel / 32) { + return; + } + + const uint q_idx = 8*il; + const uint b_idx = 1024*i + 32*ir + q_idx; + + const float d = float(data_a[ib].d); + + [[unroll]] for (uint l = 0; l < 8; ++l) { + data_b[b_idx + l + 0] = D_TYPE(d * ((data_a[ib].qs[q_idx + l] & 0xF) - 8.0f)); + data_b[b_idx + l + 16] = D_TYPE(d * ((data_a[ib].qs[q_idx + l] >> 4) - 8.0f)); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp new file mode 100644 index 000000000..2f27eee68 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp @@ -0,0 +1,32 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_q4_1 data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; + + const uint tid = gl_LocalInvocationID.x % 64; + const uint il = tid/32; + const uint ir = tid%32; + const uint ib = 32*i + ir; + if (ib >= p.nel / 32) { + return; + } + + const uint b_idx = 1024*i + 32*ir + 8*il; + + const float d = float(data_a[ib].d); + const float m = float(data_a[ib].m); + + const uint q_idx = 8*il; + + [[unroll]] for (uint l = 0; l < 8; ++l) { + data_b[b_idx + l + 0] = D_TYPE(d * (data_a[ib].qs[q_idx + l] & 0xF) + m); + data_b[b_idx + l + 16] = D_TYPE(d * (data_a[ib].qs[q_idx + l] >> 4) + m); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp new file mode 100644 index 000000000..987f113a3 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp @@ -0,0 +1,68 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { + const uint ib = gl_WorkGroupID.x * 256 + wgy; + if (ib >= p.M * p.K / QUANT_K) { + return; + } + + const uint tid = gl_LocalInvocationID.x; + const uint il = tid / 8; + const uint ir = tid % 8; + const uint is = 2 * il; + const uint n = 4; + + const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].d.x); + const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].d.y); + + const uint y_idx = ib * QUANT_K + 64 * il + n * ir; + const uint qs_idx = 32*il + n * ir; + + uint scidx0 = (is < 4) ? is : (is + 4); + uint scidx1 = (is < 4) ? is : (is - 4); + uint scidxmask1 = (is < 4) ? 0x30 : 0xC0; + uint scidxshift1 = (is < 4) ? 0 : 2; + uint mbidx0 = is + 4; + uint mbidx1 = (is < 4) ? is + 4 : is; + uint mbidxmask0 = (is < 4) ? 0xF : 0xF0; + uint mbidxshift0 = (is < 4) ? 0 : 4; + uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0; + uint mbidxshift1 = (is < 4) ? 0 : 2; + + uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); + uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + + const FLOAT_TYPE d1 = dall * sc; + const FLOAT_TYPE m1 = dmin * mbyte; + + scidx0 = (is < 4) ? is + 1 : (is + 5); + scidx1 = (is < 4) ? is + 1 : (is - 3); + scidxmask1 = (is < 4) ? 0x30 : 0xC0; + scidxshift1 = (is < 4) ? 0 : 2; + mbidx0 = is + 5; + mbidx1 = (is < 4) ? is + 5 : is + 1; + mbidxmask0 = (is < 4) ? 0xF : 0xF0; + mbidxshift0 = (is < 4) ? 0 : 4; + mbidxmask1 = (is < 4) ? 0x30 : 0xC0; + mbidxshift1 = (is < 4) ? 0 : 2; + + sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); + mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + + const FLOAT_TYPE d2 = dall * sc; + const FLOAT_TYPE m2 = dmin * mbyte; + + [[unroll]] for (uint l = 0; l < n; ++l) { + data_b[y_idx + l ] = D_TYPE(d1 * FLOAT_TYPE(data_a[ib].qs[qs_idx + l] & 0xF) - m1); + data_b[y_idx + l + 32] = D_TYPE(d2 * FLOAT_TYPE(data_a[ib].qs[qs_idx + l] >> 4) - m2); + } + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp new file mode 100644 index 000000000..b20b80529 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp @@ -0,0 +1,34 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_q5_0 data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; + + const uint tid = gl_LocalInvocationID.x % 64; + const uint il = tid/32; + const uint ir = tid%32; + const uint ib = 32*i + ir; + if (ib >= p.nel / 32) { + return; + } + + const uint b_idx = 1024*i + 32*ir + 8*il; + + const float d = float(data_a[ib].d); + const uint qh = uint(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]; + + const uint q_idx = 8*il; + + [[unroll]] for (uint l = 0; l < 8; ++l) { + const uint iqs = q_idx + l; + const uint vui = uint(data_a[ib].qs[iqs]); + data_b[b_idx + l + 0] = D_TYPE(d * (((vui & 0xF) | (((qh >> iqs) << 4) & 0x10)) - 16.0f)); + data_b[b_idx + l + 16] = D_TYPE(d * (((vui >> 4) | ((qh >> (iqs + 12)) & 0x10)) - 16.0f)); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp new file mode 100644 index 000000000..dc59fe3b7 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp @@ -0,0 +1,35 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_q5_1 data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; + + const uint tid = gl_LocalInvocationID.x % 64; + const uint il = tid/32; + const uint ir = tid%32; + const uint ib = 32*i + ir; + if (ib >= p.nel / 32) { + return; + } + + const uint b_idx = 1024*i + 32*ir + 8*il; + + const float d = float(data_a[ib].d); + const float m = float(data_a[ib].m); + const uint qh = data_a[ib].qh; + + const uint q_idx = 8*il; + + [[unroll]] for (uint l = 0; l < 8; ++l) { + const uint iqs = q_idx + l; + const uint vui = uint(data_a[ib].qs[iqs]); + data_b[b_idx + l + 0] = D_TYPE(d * (((vui & 0xF) | (((qh >> iqs) << 4) & 0x10))) + m); + data_b[b_idx + l + 16] = D_TYPE(d * (((vui >> 4) | ((qh >> (iqs + 12)) & 0x10))) + m); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp new file mode 100644 index 000000000..6db5403b6 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp @@ -0,0 +1,70 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { + const uint ib = gl_WorkGroupID.x * 256 + wgy; + if (ib >= p.M * p.K / QUANT_K) { + return; + } + + const uint tid = gl_LocalInvocationID.x; + const uint il = tid / 16; + const uint ir = tid % 16; + const uint is = 2 * il; + + const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].d.x); + const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].d.y); + + const uint y_idx = ib * QUANT_K + 64 * il + 2 * ir; + const uint qs_idx = 32*il + 2 * ir; + const uint qh_idx = 2 * ir; + + uint scidx0 = (is < 4) ? is : (is + 4); + uint scidx1 = (is < 4) ? is : (is - 4); + uint scidxmask1 = (is < 4) ? 0x30 : 0xC0; + uint scidxshift1 = (is < 4) ? 0 : 2; + uint mbidx0 = is + 4; + uint mbidx1 = (is < 4) ? is + 4 : is; + uint mbidxmask0 = (is < 4) ? 0xF : 0xF0; + uint mbidxshift0 = (is < 4) ? 0 : 4; + uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0; + uint mbidxshift1 = (is < 4) ? 0 : 2; + + uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); + uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + + const FLOAT_TYPE d1 = dall * sc; + const FLOAT_TYPE m1 = dmin * mbyte; + + scidx0 = (is < 4) ? is + 1 : (is + 5); + scidx1 = (is < 4) ? is + 1 : (is - 3); + scidxmask1 = (is < 4) ? 0x30 : 0xC0; + scidxshift1 = (is < 4) ? 0 : 2; + mbidx0 = is + 5; + mbidx1 = (is < 4) ? is + 5 : is + 1; + mbidxmask0 = (is < 4) ? 0xF : 0xF0; + mbidxshift0 = (is < 4) ? 0 : 4; + mbidxmask1 = (is < 4) ? 0x30 : 0xC0; + mbidxshift1 = (is < 4) ? 0 : 2; + + sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); + mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + + const FLOAT_TYPE d2 = dall * sc; + const FLOAT_TYPE m2 = dmin * mbyte; + + const uint8_t hm1 = uint8_t(1 << (2 * il )); + const uint8_t hm2 = uint8_t(1 << (2 * il + 1)); + data_b[y_idx ] = D_TYPE(d1 * FLOAT_TYPE((data_a[ib].qs[qs_idx ] & 0xF) + (((data_a[ib].qh[qh_idx ] & hm1) != 0) ? 16 : 0)) - m1); + data_b[y_idx + 1] = D_TYPE(d1 * FLOAT_TYPE((data_a[ib].qs[qs_idx + 1] & 0xF) + (((data_a[ib].qh[qh_idx + 1] & hm1) != 0) ? 16 : 0)) - m1); + data_b[y_idx + 32] = D_TYPE(d2 * FLOAT_TYPE((data_a[ib].qs[qs_idx ] >> 4) + (((data_a[ib].qh[qh_idx ] & hm2) != 0) ? 16 : 0)) - m2); + data_b[y_idx + 33] = D_TYPE(d2 * FLOAT_TYPE((data_a[ib].qs[qs_idx + 1] >> 4) + (((data_a[ib].qh[qh_idx + 1] & hm2) != 0) ? 16 : 0)) - m2); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp new file mode 100644 index 000000000..0b9131755 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp @@ -0,0 +1,33 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { + const uint i = gl_WorkGroupID.x * 256 + wgy; + if (i >= p.M * p.K / QUANT_K) { + return; + } + const uint tid = gl_LocalInvocationID.x; + const uint ip = tid / 32; + const uint il = tid - 32 * ip; + const uint is = 8 * ip + il / 16; + + const uint y_idx = i * QUANT_K + 128 * ip + il; + + const uint ql_idx = 64 * ip + il; + const uint8_t qh = data_a[i].qh[32 * ip + il]; + + const FLOAT_TYPE d = FLOAT_TYPE(data_a[i].d); + + data_b[y_idx + 0] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 0] * (int8_t((data_a[i].ql[ql_idx + 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32))); + data_b[y_idx + 32] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 2] * (int8_t((data_a[i].ql[ql_idx + 32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32))); + data_b[y_idx + 64] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 4] * (int8_t((data_a[i].ql[ql_idx + 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32))); + data_b[y_idx + 96] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 6] * (int8_t((data_a[i].ql[ql_idx + 32] >> 4) | (((qh >> 6) & 3) << 4)) - 32))); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp new file mode 100644 index 000000000..bd1344a88 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp @@ -0,0 +1,31 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_q8_0 data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; + + const uint tid = gl_LocalInvocationID.x % 64; + const uint il = tid/32; + const uint ir = tid%32; + const uint ib = 32*i + ir; + if (ib >= p.nel / 32) { + return; + } + + const uint b_idx = 1024*i + 32*ir + 16*il; + + const float d = float(data_a[ib].d); + + const uint q_idx = 16*il; + + [[unroll]] for (uint l = 0; l < 16; l += 2) { + data_b[b_idx + l ] = D_TYPE(d * data_a[ib].qs[q_idx + l ]); + data_b[b_idx + l + 1] = D_TYPE(d * data_a[ib].qs[q_idx + l + 1]); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp new file mode 100644 index 000000000..4e68742b5 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp @@ -0,0 +1,34 @@ +#version 450 + +#extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_control_flow_attributes : enable + +layout (push_constant) uniform parameter +{ + uint ncols; + uint rows_per_channel; + uint n_past; +} p; + +#include "types.comp" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint col = gl_GlobalInvocationID.y; + const uint row = gl_GlobalInvocationID.x; + + if (col >= p.ncols) { + return; + } + + const uint i = row*p.ncols + col; + if (col > p.n_past + row % p.rows_per_channel) { + data_d[i] = D_TYPE(uintBitsToFloat(0xFF800000)); + } else { + data_d[i] = D_TYPE(data_a[i]); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/div.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/div.comp new file mode 100644 index 000000000..9fb69c6c1 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/div.comp @@ -0,0 +1,27 @@ +#version 450 + +#include "types.comp" +#include "generic_binary_head.comp" + +const uint num_threads = 256; + +layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; + +void main() { + uint idx = get_idx(); + + // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation + const uint num_iter = 2; + + [[unroll]] for (uint i = 0; i < num_iter; ++i) { + if (idx >= p.ne) { + continue; + } + uint i00, i01, i02, i03; + get_indices(idx, i00, i01, i02, i03); + + data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) / FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)])); + + idx += num_threads; + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp new file mode 100644 index 000000000..c5be8131b --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -0,0 +1,289 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : require + +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require + +#extension GL_KHR_memory_scope_semantics : enable +#extension GL_KHR_cooperative_matrix : enable +#extension GL_NV_cooperative_matrix2 : enable +#extension GL_EXT_buffer_reference : enable +#extension GL_KHR_shader_subgroup_ballot : enable +#extension GL_KHR_shader_subgroup_vote : enable +#extension GL_EXT_null_initializer : enable + +#include "types.comp" +#include "dequant_funcs_cm2.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (constant_id = 1) const uint32_t Br = 32; +layout (constant_id = 2) const uint32_t Bc = 32; +layout (constant_id = 3) const uint32_t D = 32; +layout (constant_id = 4) const uint32_t Clamp = gl_CooperativeMatrixClampModeConstantNV; + +layout (push_constant) uniform parameter { + uint32_t N; + uint32_t KV; + + uint32_t ne1; + uint32_t ne2; + uint32_t ne3; + + uint32_t neq2; + uint32_t neq3; + uint32_t nek2; + uint32_t nek3; + uint32_t nev2; + uint32_t nev3; + uint32_t nem1; + + uint32_t nb02; + uint32_t nb03; + uint32_t nb12; + uint32_t nb13; + uint32_t nb22; + uint32_t nb23; + uint32_t nb31; + + float scale; + float max_bias; + float logit_softcap; + + uint32_t mask; + uint32_t n_head_log2; + float m0; + float m1; +} p; + +layout (binding = 0) readonly buffer Q {uint8_t data_q[];}; +layout (binding = 1) readonly buffer K {uint8_t data_k[];}; +layout (binding = 2) readonly buffer V {uint8_t data_v[];}; +layout (binding = 3) readonly buffer M {uint8_t data_m[];}; +layout (binding = 4) writeonly buffer O {D_TYPE data_o[];}; + +#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) + +ACC_TYPE maxReduce(const in ACC_TYPE x, const in ACC_TYPE y) { + return max(x, y); +} + +ACC_TYPE smearReduce(const in ACC_TYPE x, const in ACC_TYPE y) { + return x; +} + +// Replace matrix elements >= numRows or numCols with 'replace' +ACC_TYPE replacePadding(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem, const in ACC_TYPE replace, const in uint32_t numRows, const in uint32_t numCols) { + if (row >= numRows || col >= numCols) { + return replace; + } + return elem; +} + +ACC_TYPE Exp(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem) +{ + return exp(elem); +} + +ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem0, const in ACC_TYPE elem1) +{ + return max(elem0, elem1); +} + +#if defined(BLOCK_SIZE) +#define DECODEFUNC , DEQUANTFUNC +#else +#define DECODEFUNC +#endif + +void main() { +#if defined(DATA_A_IQ4_NL) + init_iq4nl_shmem(); +#endif + + const uint32_t N = p.N; + const uint32_t KV = p.KV; + + const uint32_t Tr = CEIL_DIV(N, Br); + const uint32_t Tc = CEIL_DIV(KV, Bc); + + const uint32_t i = gl_WorkGroupID.x; + + const uint32_t iq2 = gl_WorkGroupID.y; + const uint32_t iq3 = gl_WorkGroupID.z; + + // broadcast factors + const uint32_t rk2 = p.neq2/p.nek2; + const uint32_t rk3 = p.neq3/p.nek3; + + const uint32_t rv2 = p.neq2/p.nev2; + const uint32_t rv3 = p.neq3/p.nev3; + + // k indices + const uint32_t ik3 = iq3 / rk3; + const uint32_t ik2 = iq2 / rk2; + + // v indices + const uint32_t iv3 = iq3 / rv3; + const uint32_t iv2 = iq2 / rv2; + + tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutQ = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); + tensorLayoutNV<2, Clamp> tensorLayoutK = createTensorLayoutNV(2, Clamp); + tensorLayoutNV<2, Clamp> tensorLayoutV = createTensorLayoutNV(2, Clamp); + + tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0); + +#if defined(BLOCK_SIZE) + tensorLayoutK = setTensorLayoutBlockSizeNV(tensorLayoutK, 1, BLOCK_SIZE); + tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, BLOCK_SIZE); +#endif + + tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N, D); + tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D); + tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D); + + coopmat Q; + coopmat Qf16; + + uint32_t q_offset = iq2*p.nb02+iq3*p.nb03; + coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, D)); + + Qf16 = coopmat(Q); + Qf16 *= float16_t(p.scale); + + coopmat O = coopmat(0); + + coopmat L, M; + + L = coopmat(0); + M = coopmat(-1.0/0.0); + + ACC_TYPE slope = ACC_TYPE(1.0); + + // ALiBi + if (p.max_bias > 0.0f) { + const uint32_t h = iq2; + + const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1); + const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1); + + slope = pow(base, ACC_TYPE(exph)); + } + + [[dont_unroll]] + for (uint32_t j = 0; j < Tc; ++j) { + + coopmat S = coopmat(0); + + coopmat K_T; + + uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13; + coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, D), tensorViewTranspose DECODEFUNC); + S = coopMatMulAdd(Qf16, K_T, S); + + if (p.logit_softcap != 0.0f) { + [[unroll]] + for (int k = 0; k < S.length(); ++k) { + S[k] = ACC_TYPE(p.logit_softcap)*tanh(S[k]); + } + } + + if (p.mask != 0) { + tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); + tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV); + + coopmat mv; + + coopMatLoadTensorNV(mv, data_m, 0, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); + + S += slope*coopmat(mv); + } + + // Clear padding elements to -inf, so they don't contribute to rowmax + if (Clamp != 0 && + ((j + 1) * Bc > KV || + (i + 1) * Br > N)) { + + uint R = ((i + 1) * Br > N) ? (N % Br) : Br; + uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc; + + coopMatPerElementNV(S, S, replacePadding, ACC_TYPE(-1.0/0.0), R, C); + } + + coopmat rowmax, P, rowsum, eM; + + coopMatReduceNV(rowmax, S, gl_CooperativeMatrixReduceRowNV, maxReduce); + + coopmat Mold = M; + + // M = max(rowmax, Mold) + // P = e^(S - M) + // eM = e^(Mold - M) + coopMatPerElementNV(M, rowmax, Max, Mold); + coopMatPerElementNV(P, S - M, Exp); + coopMatPerElementNV(eM, Mold - M, Exp); + + // Clear padding elements to 0, so they don't contribute to rowsum + if (Clamp != 0 && + ((j + 1) * Bc > KV || + (i + 1) * Br > N)) { + + uint R = ((i + 1) * Br > N) ? (N % Br) : Br; + uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc; + + coopMatPerElementNV(P, P, replacePadding, ACC_TYPE(0.0), R, C); + } + + coopmat P_A = coopmat(P); + + // compute rowsum by multiplying by matrix of all ones. + coopmat One = coopmat(1.0); + + rowsum = coopmat(0.0); + rowsum = coopMatMulAdd(P_A, One, rowsum); + + coopmat V; + uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23; + coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, D) DECODEFUNC); + + L = eM*L + rowsum; + + // This is the "diagonal" matrix in the paper, but since we do componentwise + // multiply rather than matrix multiply it has the diagonal element smeared + // across the row + coopmat eMdiag; + + // resize eM by using smear/reduce + coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce); + + O = eMdiag * O; + + O = coopMatMulAdd(P_A, V, O); + } + + coopmat Ldiag; + + // resize L by using smear/reduce + coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce); + + [[unroll]] + for (int k = 0; k < Ldiag.length(); ++k) { + Ldiag[k] = ACC_TYPE(1.0) / Ldiag[k]; + } + + O = Ldiag*O; + + tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV); + tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, D); + + // permute dimensions + tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2); + uint32_t o_offset = iq3*p.ne2*p.ne1; + + coopmat O_D = coopmat(O); + coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, 1, 0, D), tensorViewPermute); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp new file mode 100644 index 000000000..4cc7a68ca --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp @@ -0,0 +1,25 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const float GELU_COEF_A = 0.044715f; + const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float xi = float(data_a[i]); + const float val = SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi); + data_d[i] = D_TYPE(0.5f*xi*(2.0f - 2.0f / (exp(2 * val) + 1))); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp new file mode 100644 index 000000000..e6e6fcfd2 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp @@ -0,0 +1,23 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const float GELU_QUICK_COEF = -1.702f; + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float x = float(data_a[i]); + data_d[i] = D_TYPE(x * (1.0f / (1.0f + exp(GELU_QUICK_COEF * x)))); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp new file mode 100644 index 000000000..062e2a4cd --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp @@ -0,0 +1,64 @@ +#extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_control_flow_attributes : require + +layout (push_constant) uniform parameter +{ + uint ne; + uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03; + uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13; + uint ne20; uint ne21; uint ne22; uint ne23; uint nb20; uint nb21; uint nb22; uint nb23; + uint misalign_offsets; + float param1; float param2; int param3; +} p; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; +layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; + +// true if src0/src1 are the same shape and the indices can be reused without additional modulus +layout(constant_id = 0) const bool norepeat = false; + +uint get_idx() { + return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; +} + +uint get_aoffset() { return p.misalign_offsets >> 16; } +uint get_boffset() { return (p.misalign_offsets >> 8) & 0xFF; } +uint get_doffset() { return p.misalign_offsets & 0xFF; } + +// mod and div are expensive and coordinates/dimensions are often power of 2 or equal to 1 +uint fastmod(uint a, uint b) { + if ((b & (b-1)) == 0) { + return a & (b-1); + } + return a % b; +} + +uint fastdiv(uint a, uint b) { + return (a < b) ? 0 : (a / b); +} + +void get_indices(uint idx, out uint i00, out uint i01, out uint i02, out uint i03) { + i03 = fastdiv(idx, (p.ne02*p.ne01*p.ne00)); + const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00; + i02 = fastdiv((idx - i03_offset), (p.ne01*p.ne00)); + const uint i02_offset = i02*p.ne01*p.ne00; + i01 = (idx - i03_offset - i02_offset) / p.ne00; + i00 = idx - i03_offset - i02_offset - i01*p.ne00; +} + +uint src0_idx(uint i00, uint i01, uint i02, uint i03) { + return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00; +} + +uint src1_idx(uint i00, uint i01, uint i02, uint i03) { + if (norepeat) { + return i03*p.nb13 + i02*p.nb12 + i01*p.nb11 + i00*p.nb10; + } else { + return fastmod(i03, p.ne13)*p.nb13 + fastmod(i02, p.ne12)*p.nb12 + fastmod(i01, p.ne11)*p.nb11 + fastmod(i00, p.ne10)*p.nb10; + } +} + +uint dst_idx(uint i00, uint i01, uint i02, uint i03) { + return i03*p.nb23 + i02*p.nb22 + i01*p.nb21 + i00*p.nb20; +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.comp new file mode 100644 index 000000000..66e46ae67 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.comp @@ -0,0 +1,9 @@ +#extension GL_EXT_shader_16bit_storage : require + +layout (push_constant) uniform parameter +{ + uint KX; + uint KY; + float param1; + float param2; +} p; diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp new file mode 100644 index 000000000..68d1bc9f1 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp @@ -0,0 +1,56 @@ +#extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_control_flow_attributes : require + +layout (push_constant) uniform parameter +{ + uint ne; + uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03; + uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13; + uint misalign_offsets; + float param1; float param2; + + uint ne0_012mp; uint ne0_012L; + uint ne0_01mp; uint ne0_01L; + uint ne0_0mp; uint ne0_0L; + uint ne1_012mp; uint ne1_012L; + uint ne1_01mp; uint ne1_01L; + uint ne1_0mp; uint ne1_0L; +} p; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +uint get_idx() { + return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; +} + +uint get_aoffset() { return p.misalign_offsets >> 16; } +uint get_doffset() { return p.misalign_offsets & 0xFFFF; } + +// see init_fastdiv_values in ggml-vulkan.cpp +uint fastdiv(uint n, uint mp, uint L) { + uint msbs, lsbs; + // msbs = mulhi(n, mp) + umulExtended(n, mp, msbs, lsbs); + return (msbs + n) >> L; +} + +uint src0_idx(uint idx) { + const uint i03 = fastdiv(idx, p.ne0_012mp, p.ne0_012L); + const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00; + const uint i02 = fastdiv(idx - i03_offset, p.ne0_01mp, p.ne0_01L); + const uint i02_offset = i02*p.ne01*p.ne00; + const uint i01 = fastdiv(idx - i03_offset - i02_offset, p.ne0_0mp, p.ne0_0L); + const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00; + return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00; +} + +uint dst_idx(uint idx) { + const uint i13 = fastdiv(idx, p.ne1_012mp, p.ne1_012L); + const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10; + const uint i12 = fastdiv(idx - i13_offset, p.ne1_01mp, p.ne1_01L); + const uint i12_offset = i12*p.ne11*p.ne10; + const uint i11 = fastdiv(idx - i13_offset - i12_offset, p.ne1_0mp, p.ne1_0L); + const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10; + return i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + i10*p.nb10; +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp new file mode 100644 index 000000000..e877ed779 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp @@ -0,0 +1,28 @@ +#version 450 + +#include "types.comp" +#include "generic_binary_head.comp" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint i00 = gl_GlobalInvocationID.x; + const uint i10 = gl_GlobalInvocationID.y; + const uint i11 = (gl_GlobalInvocationID.z)/p.ne12; + const uint i12 = (gl_GlobalInvocationID.z)%p.ne12; + + if (i00 >= p.ne00) { + return; + } + + const uint i01 = data_b[get_boffset() + i10*p.nb10 + i11*p.nb11 + i12*p.nb12]; + + const uint a_offset = get_aoffset() + i01*p.nb01 + i11*p.nb02 + i12*p.nb03; + const uint d_offset = get_doffset() + i10*p.nb21 + i11*p.nb22 + i12*p.nb23; + +#ifndef OPTIMIZATION_ERROR_WORKAROUND + data_d[d_offset + i00] = D_TYPE(data_a[a_offset + i00]); +#else + data_d[d_offset + i00] = data_a[a_offset + i00]; +#endif +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp new file mode 100644 index 000000000..1426fde65 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp @@ -0,0 +1,39 @@ +#version 450 + +#include "types.comp" +#include "generic_binary_head.comp" +#include "dequant_funcs.comp" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint i00 = (gl_GlobalInvocationID.x)*2; + const uint i10 = gl_GlobalInvocationID.y; + const uint i11 = (gl_GlobalInvocationID.z)/p.ne12; + const uint i12 = (gl_GlobalInvocationID.z)%p.ne12; + +#if defined(DATA_A_IQ4_NL) + init_iq4nl_shmem(); +#endif + + if (i00 >= p.ne00) { + return; + } + + const uint i01 = data_b[i10*p.nb10 + i11*p.nb11 + i12*p.nb12]; + + const uint a_offset = i01*p.nb01 + i11*p.nb02 + i12*p.nb03; + const uint d_offset = i10*p.nb21 + i11*p.nb22 + i12*p.nb23; + + const uint ib = a_offset + i00/QUANT_K; // block index + const uint iqs = (i00%QUANT_K)/QUANT_R; // quant index + const uint iybs = i00 - i00%QUANT_K; // dst block start index + const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2; + + vec2 v = dequantize(ib, iqs, 0); + const vec2 dm = get_dm(ib, 0); + v = v * dm.x + dm.y; + + data_d[d_offset + iybs + iqs ] = D_TYPE(v.x); + data_d[d_offset + iybs + iqs + y_offset] = D_TYPE(v.y); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp new file mode 100644 index 000000000..b6a0d5645 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp @@ -0,0 +1,66 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable +#define BLOCK_SIZE 512 + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +shared float tmp[BLOCK_SIZE]; + +void main() { + const uint group_size = p.KX; + const float eps = p.param1; + + const uint tid = gl_LocalInvocationID.x; + const uint start = gl_WorkGroupID.x * group_size + tid; + const uint end = (gl_WorkGroupID.x + 1) * group_size; + + tmp[tid] = 0.0f; + + // Calculate mean + [[unroll]] for (uint col = start; col < end; col += BLOCK_SIZE) { + tmp[tid] += float(data_a[col]); + } + + // tmp up partial tmps and write back result + barrier(); + [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + tmp[tid] += tmp[tid + s]; + } + barrier(); + } + + const float mean = tmp[0] / group_size; + barrier(); + tmp[tid] = 0.0f; + + // Calculate variance + [[unroll]] for (uint col = start; col < end; col += BLOCK_SIZE) { + const float xi = float(data_a[col]) - mean; + data_d[col] = D_TYPE(xi); + tmp[tid] += xi * xi; + } + + // sum up partial sums and write back result + barrier(); + [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + tmp[tid] += tmp[tid + s]; + } + barrier(); + } + + const float variance = tmp[0] / group_size; + const float scale = inversesqrt(variance + eps); + + [[unroll]] for (uint col = start; col < end; col += BLOCK_SIZE) { + data_d[col] *= D_TYPE(scale); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp new file mode 100644 index 000000000..122b1e93f --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp @@ -0,0 +1,87 @@ +#version 450 + +#extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_spirv_intrinsics: enable +#extension GL_EXT_control_flow_attributes : require + +#if RTE16 +spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits +#endif + +layout (push_constant) uniform parameter +{ + uint batch_offset; uint offset_delta; + uint IC; + uint IW; uint IH; + uint OW; uint OH; + uint KW; uint KH; + uint pelements; + uint CHW; + int s0; int s1; + int p0; int p1; + int d0; int d1; +} p; + +#include "types.comp" + +layout(constant_id = 0) const uint BLOCK_SIZE = 32; + +const uint NUM_ITER = 512 / BLOCK_SIZE; + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint gidx = gl_GlobalInvocationID.x; + + const uint oh = gl_GlobalInvocationID.y; + const uint batch = gl_GlobalInvocationID.z / p.IC; + const uint ic = gl_GlobalInvocationID.z % p.IC; + + A_TYPE values[NUM_ITER]; + uint offset_dst[NUM_ITER]; + [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { + values[idx] = A_TYPE(0); + } + + [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { + + const uint i = gidx * NUM_ITER + idx; + + const uint ksize = p.OW * (p.KH > 1 ? p.KW : 1); + const uint kx = i / ksize; + const uint kd = kx * ksize; + const uint ky = (i - kd) / p.OW; + const uint ix = i % p.OW; + + const uint iiw = ix * p.s0 + kx * p.d0 - p.p0; + const uint iih = oh * p.s1 + ky * p.d1 - p.p1; + + offset_dst[idx] = + ((batch * p.OH + oh) * p.OW + ix) * p.CHW + + (ic * (p.KW * p.KH) + ky * p.KW + kx); + + if (i >= p.pelements) { + continue; + } + + if (iih < p.IH && iiw < p.IW) { + const uint offset_src = ic * p.offset_delta + batch * p.batch_offset; + values[idx] = data_a[offset_src + iih * p.IW + iiw]; + } + } + + [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { + + const uint i = gidx * NUM_ITER + idx; + + if (i >= p.pelements) { + continue; + } + + data_d[offset_dst[idx]] = D_TYPE(values[idx]); + } + +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp new file mode 100644 index 000000000..d90a99aea --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp @@ -0,0 +1,22 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float val = float(data_a[i]); + data_d[i] = D_TYPE(max(val, 0.0f) + min(val, 0.0f) * p.param1); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp new file mode 100644 index 000000000..43de19df8 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp @@ -0,0 +1,27 @@ +#version 450 + +#include "types.comp" +#include "generic_binary_head.comp" + +const uint num_threads = 256; + +layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; + +void main() { + uint idx = get_idx(); + + // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation + const uint num_iter = 2; + + [[unroll]] for (uint i = 0; i < num_iter; ++i) { + if (idx >= p.ne) { + continue; + } + uint i00, i01, i02, i03; + get_indices(idx, i00, i01, i02, i03); + + data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) * FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)])); + + idx += num_threads; + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_split_k_reduce.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_split_k_reduce.comp new file mode 100644 index 000000000..4c64fd47a --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_split_k_reduce.comp @@ -0,0 +1,48 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {float data_a[];}; +layout (binding = 0) readonly buffer A4 {vec4 data_a4[];}; +layout (binding = 1) writeonly buffer D {float data_d[];}; +layout (binding = 1) writeonly buffer D4 {vec4 data_d4[];}; + +layout (push_constant) uniform parameter { + uint ne; + uint k_num; +} p; + +void main() { + // Each invocation handles four consecutive components + const uint idx = gl_GlobalInvocationID.x * 4; + + if (idx >= p.ne) { + return; + } + + // Check if all four components are in bounds and aligned, + // then use vector loads + if (idx + 3 < p.ne && (p.ne % 4) == 0) { + vec4 result = vec4(0.0f); + + [[unroll]] for (uint i = 0; i < p.k_num; i++) { + result += data_a4[(i * p.ne + idx) / 4]; + } + + data_d4[idx / 4] = result; + } else { + [[unroll]] for (uint j = 0; j < 4; ++j) { + if (idx + j < p.ne) { + float result = 0.0f; + + [[unroll]] for (uint i = 0; i < p.k_num; i++) { + result += data_a[i * p.ne + idx + j]; + } + + data_d[idx + j] = result; + } + } + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp new file mode 100644 index 000000000..24875cdcf --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp @@ -0,0 +1,152 @@ +#version 450 + +#ifdef FLOAT16 +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#endif +#extension GL_EXT_shader_explicit_arithmetic_types : require + +#include "mul_mat_vec_base.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +#if !defined(DATA_A_F32) && !defined(DATA_A_F16) +#define K_PER_ITER 8 +#else +#define K_PER_ITER 2 +#endif + + +uint a_offset, b_offset, d_offset, y_offset; + +void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i, bool lastiter) +{ + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + const uint col = i*BLOCK_SIZE + K_PER_ITER*tid; + const uint iqs = (col%QUANT_K)/QUANT_R; // quant index + const uint iybs = col - col%QUANT_K; // y block start index + +#if K_PER_ITER == 8 +#if QUANT_R == 2 + const B_TYPE_VEC4 bv02 = data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]; + const B_TYPE_VEC4 bv13 = data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs + y_offset) / 4]; + const vec4 bv0 = vec4(bv02.x, bv13.x, bv02.y, bv13.y); + const vec4 bv1 = vec4(bv02.z, bv13.z, bv02.w, bv13.w); +#else + const vec4 bv0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]); + const vec4 bv1 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4 + 1]); +#endif +#else + // Check if the second of the pair of elements is OOB, and don't fetch B or + // accumulate it. We still fetch a pair of elements for A, which is fine for + // quantized formats since they'll be within the same block. We should + // probably skip fetching the second element for F16/F32, but as of now we + // still do. + const bool OOB = lastiter && (iybs + iqs + y_offset >= p.ncols); + + FLOAT_TYPE b0 = 0, b1 = 0; + b0 = FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]); + if (!OOB) { + b1 = FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset]); + } +#endif + uint ibi = first_row*p.ncols; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib = (ibi + col)/QUANT_K; // block index + ibi += p.ncols; + +#if K_PER_ITER == 8 + vec4 v = dequantize4(ib, iqs, a_offset); + vec4 v2 = dequantize4(ib, iqs+(4/QUANT_R), a_offset); + + const vec2 dm = get_dm(ib, a_offset); + if (dm.y != 0) { // quant has min component + v = v * dm.x + dm.y; + v2 = v2 * dm.x + dm.y; + } + + // matrix multiplication + FLOAT_TYPE rowtmp = dot(bv0, v); + rowtmp += dot(bv1, v2); + + if (dm.y == 0) + rowtmp *= dm.x; + + temp[j][n] += rowtmp; +#else + const vec2 v = dequantize(ib, iqs, a_offset); + + // matrix multiplication + temp[j][n] = fma(FLOAT_TYPE(v.x), b0, temp[j][n]); + if (!OOB) { + temp[j][n] = fma(FLOAT_TYPE(v.y), b1, temp[j][n]); + } +#endif + } + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + const uint tid = gl_LocalInvocationID.x; + + get_offsets(a_offset, b_offset, d_offset); + a_offset /= QUANT_K; + + y_offset = QUANT_R == 1 ? 1 : QUANT_K/2; + + FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + uint num_iters = p.ncols / (K_PER_ITER * BLOCK_SIZE); + if (num_iters * K_PER_ITER * BLOCK_SIZE + K_PER_ITER*tid < p.ncols) { + num_iters++; + } + int unroll_count = 4; + uint unrolled_iters = num_iters & ~(unroll_count - 1); + + uint i = 0; + while (i < unrolled_iters) { + // Manually partially unroll the loop + [[unroll]] for (uint k = 0; k < unroll_count; ++k) { + iter(temp, first_row, num_rows, tid, i*K_PER_ITER, false); + i++; + } + } + unroll_count = 2; + unrolled_iters = num_iters & ~(unroll_count - 1); + while (i < unrolled_iters) { + // Manually partially unroll the loop + [[unroll]] for (uint k = 0; k < unroll_count; ++k) { + iter(temp, first_row, num_rows, tid, i*K_PER_ITER, false); + i++; + } + } + while (i < num_iters) { + iter(temp, first_row, num_rows, tid, i*K_PER_ITER, true); + i++; + } + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + +#if defined(DATA_A_IQ4_NL) + init_iq4nl_shmem(); +#endif + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp new file mode 100644 index 000000000..903753c7e --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp @@ -0,0 +1,118 @@ +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_shader_8bit_storage : require + +#ifdef MUL_MAT_ID +#define EXPERT_COUNT 8 +#endif + +#include "types.comp" + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; +layout (binding = 1) readonly buffer BV2 {B_TYPE_VEC2 data_b_v2[];}; +layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];}; + +layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; +#ifdef MUL_MAT_ID +layout (binding = 3) readonly buffer IDS {int data_ids[];}; +#endif + +#include "dequant_funcs.comp" + +layout (push_constant) uniform parameter +{ + uint ncols; + uint stride_a; + uint stride_b; + uint stride_d; + + uint batch_stride_a; + uint batch_stride_b; + uint batch_stride_d; + +#ifdef MUL_MAT_ID + uint nei0; + uint ne11; +#else + uint ne02; + uint ne12; + uint broadcast2; + uint broadcast3; +#endif +} p; + +void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) { +#ifdef MUL_MAT_ID + const uint expert_idx = gl_GlobalInvocationID.y; +#else + const uint batch_idx = gl_GlobalInvocationID.y; +#endif + +#ifndef MUL_MAT_ID + uint batch_idx_a = 0; + if (batch_idx != 0) { + const uint i13 = batch_idx / p.ne12; + const uint i12 = batch_idx % p.ne12; + + const uint i03 = i13 / p.broadcast3; + const uint i02 = i12 / p.broadcast2; + + batch_idx_a = i03 * p.ne02 + i02; + } +#else + const uint expert_id = data_ids[expert_idx]; +#endif + + a_offset = +#ifdef MUL_MAT_ID + expert_id * p.batch_stride_a; +#else + batch_idx_a * p.batch_stride_a; +#endif + b_offset = +#ifdef MUL_MAT_ID + (expert_idx % p.ne11) * p.stride_b; +#else + batch_idx * p.batch_stride_b; +#endif + d_offset = +#ifdef MUL_MAT_ID + expert_idx * p.stride_d; +#else + batch_idx * p.batch_stride_d; +#endif +} + +layout (constant_id = 0) const uint BLOCK_SIZE = 32; +layout (constant_id = 1) const uint NUM_ROWS = 1; +layout (constant_id = 2) const uint NUM_COLS = 1; + +shared FLOAT_TYPE tmpsh[NUM_COLS][NUM_ROWS][BLOCK_SIZE]; + +void reduce_result(const in FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid) { + // sum up partial sums and write back result + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + tmpsh[j][n][tid] = temp[j][n]; + } + } + barrier(); + [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) { + if (tid < s) { + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + tmpsh[j][n][tid] += tmpsh[j][n][tid + s]; + } + } + } + barrier(); + } + if (tid == 0) { + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(tmpsh[j][n][0]); + } + } + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp new file mode 100644 index 000000000..1cc4996d3 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp @@ -0,0 +1,71 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : require + +#define BLOCK_SIZE 32 +#define FLOAT_TYPE float + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; +layout (binding = 2) writeonly buffer D {D_TYPE dst[];}; + +layout (push_constant) uniform parameter +{ + uint ncols_x; + uint nrows_x; + uint row_stride_x; + uint channel_stride_x; + uint channel_x_divisor; + uint b_offset; + uint d_offset; +} p; + +shared FLOAT_TYPE tmp[BLOCK_SIZE]; + +void main() { + const uint tid = gl_LocalInvocationID.x; + const uint row_x = gl_GlobalInvocationID.y; + const uint channel = gl_GlobalInvocationID.z; + const uint channel_x = channel / p.channel_x_divisor; + + const uint nrows_y = p.ncols_x; + const uint nrows_dst = p.nrows_x; + const uint row_dst = row_x; + + const uint idst = channel*nrows_dst + row_dst; + + tmp[tid] = 0.0f; + + for (uint col_x0 = 0; col_x0 < p.ncols_x; col_x0 += BLOCK_SIZE) { + const uint col_x = col_x0 + tid; + + if (col_x >= p.ncols_x) { + break; + } + + const uint row_y = col_x; + + const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x; + const uint iy = channel*nrows_y + row_y; + + const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]); + + tmp[tid] = fma(xi, FLOAT_TYPE(data_b[iy]), tmp[tid]); + } + + // sum up partial sums and write back result + barrier(); + [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + tmp[tid] += tmp[tid + s]; + } + barrier(); + } + + if (tid == 0) { + dst[idst] = tmp[0]; + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp new file mode 100644 index 000000000..9b443807d --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp @@ -0,0 +1,73 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : require + +#define BLOCK_SIZE 32 +#define FLOAT_TYPE float + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; +layout (binding = 2) writeonly buffer D {D_TYPE dst[];}; + +layout (push_constant) uniform parameter +{ + uint ncols_x; + uint nrows_x; + uint nchannels_x; + uint nchannels_y; + uint b_offset; + uint d_offset; +} p; + +shared FLOAT_TYPE tmp[BLOCK_SIZE]; + +void main() { + const uint tid = gl_LocalInvocationID.x; + const uint row_x = gl_GlobalInvocationID.y; + const uint channel = gl_GlobalInvocationID.z; + const uint channel_x = channel / (p.nchannels_y / p.nchannels_x); + + const uint nrows_y = p.ncols_x; + const uint nrows_dst = p.nrows_x; + const uint row_dst = row_x; + + tmp[tid] = FLOAT_TYPE(0.0f); + + for (uint col_x0 = 0; col_x0 < p.ncols_x; col_x0 += BLOCK_SIZE) { + const uint col_x = col_x0 + tid; + + if (col_x >= p.ncols_x) { + break; + } + + // x is transposed and permuted + const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x; + const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]); + + const uint row_y = col_x; + + // y is not transposed but permuted + const uint iy = channel*nrows_y + row_y; + + tmp[tid] = fma(xi, FLOAT_TYPE(data_b[iy]), tmp[tid]); + } + + // dst is not transposed and not permuted + const uint idst = channel*nrows_dst + row_dst; + + // sum up partial sums and write back result + barrier(); + [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + tmp[tid] += tmp[tid + s]; + } + barrier(); + } + + if (tid == 0) { + dst[idst] = tmp[0]; + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp new file mode 100644 index 000000000..934213446 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp @@ -0,0 +1,115 @@ +#version 450 +#extension GL_EXT_shader_explicit_arithmetic_types : require + +#include "mul_mat_vec_base.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 16 threads are used to process each block + const uint it_size = gl_WorkGroupSize.x/16; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid%16; // 0...16 + const uint ix = tid/16; + + const uint step = 8; + + const uint v_im = itid/step; // 0 or 1. 0 computes 0..., 1 computes 128... + const uint v_in = itid - step*v_im; // 0...15 or 0...7 + + const uint l0 = 2*v_in; // 0...15 + const uint q_offset = 32*v_im + l0; + const uint s_offset = 8*v_im; + const uint y_offset = 128*v_im + l0; + + FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { + const uint y_idx = i * QUANT_K + y_offset; + + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + f16vec2 d = data_a[ib0 + i].d; + const FLOAT_TYPE dall = d.x; + const FLOAT_TYPE dmin = d.y; + + uint32_t s0_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 0]; + uint32_t s4_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 1]; + + uint32_t s0_lo4_u32 = s0_u32 & 0x0F0F0F0F; + uint32_t s0_hi4_u32 = (s0_u32 >> 4) & 0x0F0F0F0F; + uint32_t s4_lo4_u32 = s4_u32 & 0x0F0F0F0F; + uint32_t s4_hi4_u32 = (s4_u32 >> 4) & 0x0F0F0F0F; + + uvec4 s0_lo4 = uvec4(unpack8(s0_lo4_u32)); + uvec4 s4_lo4 = uvec4(unpack8(s4_lo4_u32)); + uvec4 s0_hi4 = uvec4(unpack8(s0_hi4_u32)); + uvec4 s4_hi4 = uvec4(unpack8(s4_hi4_u32)); + + uint16_t qs0_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2 + 0]; + uint16_t qs16_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]; + uvec2 qs0 = uvec2(unpack8(qs0_u16)); + uvec2 qs16 = uvec2(unpack8(qs16_u16)); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]; + B_TYPE_VEC2 b16 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]; + B_TYPE_VEC2 b32 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]; + B_TYPE_VEC2 b48 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]; + B_TYPE_VEC2 b64 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]; + B_TYPE_VEC2 b80 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]; + B_TYPE_VEC2 b96 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]; + B_TYPE_VEC2 b112 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]; + + FLOAT_TYPE sum1 = FLOAT_TYPE(0.0); + FLOAT_TYPE sum2 = FLOAT_TYPE(0.0); + [[unroll]] for (int l = 0; l < 2; ++l) { + sum1 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_lo4[0]) * FLOAT_TYPE((qs0[l] >> 0) & 3), + fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_lo4[1]) * FLOAT_TYPE((qs16[l] >> 0) & 3), + fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_lo4[2]) * FLOAT_TYPE((qs0[l] >> 2) & 3), + fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_lo4[3]) * FLOAT_TYPE((qs16[l] >> 2) & 3), + fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_lo4[0]) * FLOAT_TYPE((qs0[l] >> 4) & 3), + fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_lo4[1]) * FLOAT_TYPE((qs16[l] >> 4) & 3), + fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_lo4[2]) * FLOAT_TYPE((qs0[l] >> 6) & 3), + fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_lo4[3]) * FLOAT_TYPE((qs16[l] >> 6) & 3), sum1)))))))); + sum2 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_hi4[0]), + fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_hi4[1]), + fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_hi4[2]), + fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_hi4[3]), + fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_hi4[0]), + fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_hi4[1]), + fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_hi4[2]), + fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_hi4[3]), sum2)))))))); + } + temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n])); + } + } + } + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp new file mode 100644 index 000000000..86b0159d9 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp @@ -0,0 +1,103 @@ +#version 450 +#extension GL_EXT_shader_explicit_arithmetic_types : require + +#include "mul_mat_vec_base.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 16 threads are used to process each block + const uint it_size = gl_WorkGroupSize.x/16; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid%16; // 0...16 + const uint ix = tid/16; + + const uint step = 8; + + const uint v_im = itid/step; // 0 or 1. 0 computes 0..., 1 computes 128... + const uint v_in = itid - step*v_im; // 0...15 or 0...7 + + const uint8_t m = uint8_t(1 << (4 * v_im)); + + const uint l0 = 2*v_in; // 0...15 + const uint q_offset = 32*v_im + l0; + const uint y_offset = 128*v_im + l0; + + FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + const uint s_shift = 4 * v_im; + + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { + const uint y_idx = i * QUANT_K + y_offset; + + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); + + uint16_t s0_16 = data_a_packed16[ib0 + i].scales[0]; + uint16_t s2_16 = data_a_packed16[ib0 + i].scales[1]; + uint16_t s4_16 = data_a_packed16[ib0 + i].scales[2]; + uint16_t s6_16 = data_a_packed16[ib0 + i].scales[3]; + uint16_t s8_16 = data_a_packed16[ib0 + i].scales[4]; + uint16_t s10_16 = data_a_packed16[ib0 + i].scales[5]; + u8vec2 s0 = unpack8(s0_16); + u8vec2 s2 = unpack8(s2_16); + u8vec2 s4 = unpack8(s4_16); + u8vec2 s6 = unpack8(s6_16); + u8vec2 s8 = unpack8(s8_16); + u8vec2 s10 = unpack8(s10_16); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + + B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]; + B_TYPE_VEC2 b16 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]; + B_TYPE_VEC2 b32 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]; + B_TYPE_VEC2 b48 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]; + B_TYPE_VEC2 b64 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]; + B_TYPE_VEC2 b80 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]; + B_TYPE_VEC2 b96 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]; + B_TYPE_VEC2 b112 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]; + + FLOAT_TYPE sum = FLOAT_TYPE(0.0); + [[unroll]] for (int l = 0; l < 2; ++l) { + sum = fma(FLOAT_TYPE(b0[l]) * FLOAT_TYPE(int8_t(((s0[0] >> s_shift) & 0xF) | ((s8[0] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 0)) != 0) ? 0 : 4)), + fma(FLOAT_TYPE(b32[l]) * FLOAT_TYPE(int8_t(((s2[0] >> s_shift) & 0xF) | ((s10[0] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 1)) != 0) ? 0 : 4)), + fma(FLOAT_TYPE(b64[l]) * FLOAT_TYPE(int8_t(((s4[0] >> s_shift) & 0xF) | ((s8[0] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 2)) != 0) ? 0 : 4)), + fma(FLOAT_TYPE(b96[l]) * FLOAT_TYPE(int8_t(((s6[0] >> s_shift) & 0xF) | ((s10[0] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 3)) != 0) ? 0 : 4)), + fma(FLOAT_TYPE(b16[l]) * FLOAT_TYPE(int8_t(((s0[1] >> s_shift) & 0xF) | ((s8[1] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 0)) != 0) ? 0 : 4)), + fma(FLOAT_TYPE(b48[l]) * FLOAT_TYPE(int8_t(((s2[1] >> s_shift) & 0xF) | ((s10[1] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 1)) != 0) ? 0 : 4)), + fma(FLOAT_TYPE(b80[l]) * FLOAT_TYPE(int8_t(((s4[1] >> s_shift) & 0xF) | ((s8[1] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 2)) != 0) ? 0 : 4)), + fma(FLOAT_TYPE(b112[l]) * FLOAT_TYPE(int8_t(((s6[1] >> s_shift) & 0xF) | ((s10[1] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 3)) != 0) ? 0 : 4)), sum)))))))); + } + temp[j][n] = fma(d, sum, temp[j][n]); + } + } + } + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp new file mode 100644 index 000000000..cd1dd8e89 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp @@ -0,0 +1,133 @@ +#version 450 + +#extension GL_EXT_shader_explicit_arithmetic_types : require + +#include "mul_mat_vec_base.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 16 threads are used to process each block + const uint it_size = gl_WorkGroupSize.x/16; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid%16; // 0...16 + const uint ix = tid/16; + + const uint step = 4; + + const uint il = itid/step; // 0...3 + const uint ir = itid - step*il; // 0...7 or 0...3 + const uint n = 4; + + const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 + const uint v_in = il % 2; + + const uint l0 = n * (2 * ir + v_in); // 0...15 + const uint q_offset = 32*v_im + l0; + const uint y_offset = 64*v_im + l0; + + FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { + const uint y1_idx = i * QUANT_K + y_offset; + const uint y2_idx = y1_idx + 128; + + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + f16vec2 d = data_a[ib0 + i].d; + const FLOAT_TYPE dall = FLOAT_TYPE(d.x); + const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); + + uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; + uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; + uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4]; + uvec4 scale0 = uvec4(unpack8(scale0_u32)); + uvec4 scale4 = uvec4(unpack8(scale4_u32)); + uvec4 scale8 = uvec4(unpack8(scale8_u32)); + + const uint32_t sc0 = ( scale0.x & 0x3f); + const uint32_t sc1 = ( scale0.y & 0x3f); + const uint32_t sc2 = ( scale4.x & 0x3f); + const uint32_t sc3 = ( scale4.y & 0x3f); + const uint32_t sc4 = (( scale8.x & 0x0f) | ((scale0.x & 0xc0) >> 2)); + const uint32_t sc5 = (( scale8.y & 0x0f) | ((scale0.y & 0xc0) >> 2)); + const uint32_t sc6 = (((scale8.x >> 4) & 0x0f) | ((scale4.x & 0xc0) >> 2)); + const uint32_t sc7 = (((scale8.y >> 4) & 0x0f) | ((scale4.y & 0xc0) >> 2)); + + uint32_t qs0_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4]; + uint32_t qs64_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4 + 16]; + + uint32_t qs0_u32_lo4 = qs0_u32 & 0x0F0F0F0F; + uint32_t qs0_u32_hi4 = (qs0_u32 >> 4) & 0x0F0F0F0F; + uint32_t qs64_u32_lo4 = qs64_u32 & 0x0F0F0F0F; + uint32_t qs64_u32_hi4 = (qs64_u32 >> 4) & 0x0F0F0F0F; + + uvec4 qs0_lo4 = uvec4(unpack8(qs0_u32_lo4)); + uvec4 qs64_lo4 = uvec4(unpack8(qs64_u32_lo4)); + uvec4 qs0_hi4 = uvec4(unpack8(qs0_u32_hi4)); + uvec4 qs64_hi4 = uvec4(unpack8(qs64_u32_hi4)); + + const uint32_t q4_0 = qs0_lo4.x; + const uint32_t q4_1 = qs0_lo4.y; + const uint32_t q4_2 = qs0_lo4.z; + const uint32_t q4_3 = qs0_lo4.w; + const uint32_t q4_4 = qs0_hi4.x; + const uint32_t q4_5 = qs0_hi4.y; + const uint32_t q4_6 = qs0_hi4.z; + const uint32_t q4_7 = qs0_hi4.w; + const uint32_t q4_8 = qs64_lo4.x; + const uint32_t q4_9 = qs64_lo4.y; + const uint32_t q4_10 = qs64_lo4.z; + const uint32_t q4_11 = qs64_lo4.w; + const uint32_t q4_12 = qs64_hi4.x; + const uint32_t q4_13 = qs64_hi4.y; + const uint32_t q4_14 = qs64_hi4.z; + const uint32_t q4_15 = qs64_hi4.w; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + B_TYPE_VEC4 by10 = data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4]; + B_TYPE_VEC4 by132 = data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 + 8]; + B_TYPE_VEC4 by20 = data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4]; + B_TYPE_VEC4 by232 = data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 + 8]; + + const FLOAT_TYPE sx = fma(FLOAT_TYPE(by10.x), q4_0, fma(FLOAT_TYPE(by10.y), q4_1, fma(FLOAT_TYPE(by10.z), q4_2, FLOAT_TYPE(by10.w) * q4_3))); + const FLOAT_TYPE sy = fma(FLOAT_TYPE(by132.x), q4_4, fma(FLOAT_TYPE(by132.y), q4_5, fma(FLOAT_TYPE(by132.z), q4_6, FLOAT_TYPE(by132.w) * q4_7))); + const FLOAT_TYPE sz = fma(FLOAT_TYPE(by20.x), q4_8, fma(FLOAT_TYPE(by20.y), q4_9, fma(FLOAT_TYPE(by20.z), q4_10, FLOAT_TYPE(by20.w) * q4_11))); + const FLOAT_TYPE sw = fma(FLOAT_TYPE(by232.x), q4_12, fma(FLOAT_TYPE(by232.y), q4_13, fma(FLOAT_TYPE(by232.z), q4_14, FLOAT_TYPE(by232.w) * q4_15))); + const FLOAT_TYPE smin = + fma(FLOAT_TYPE(by10.x), sc2, fma(FLOAT_TYPE(by132.x), sc3, fma(FLOAT_TYPE(by20.x), sc6, fma(FLOAT_TYPE(by232.x), sc7, + fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7, + fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7, + fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6, FLOAT_TYPE(by232.w) * sc7))))))))))))))); + temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n])); + } + } + } + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp new file mode 100644 index 000000000..0a68891c3 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp @@ -0,0 +1,162 @@ +#version 450 + +#extension GL_EXT_shader_explicit_arithmetic_types : require + +#include "mul_mat_vec_base.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 16 threads are used to process each block + const uint it_size = gl_WorkGroupSize.x/16; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid%16; // 0...16 + const uint ix = tid/16; + + const uint il = itid/4; // 0...3 + const uint ir = itid - 4*il; // 0...7 or 0...3 + + const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 + const uint v_in = il % 2; + + const uint l0 = 4*ir + 2*v_in; // 0...15 + const uint q_offset = 32*v_im + l0; + const uint y_offset = 64*v_im + l0; + + FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { + const uint y1_idx = i * QUANT_K + y_offset; + const uint y2_idx = y1_idx + 128; + + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + f16vec2 d = data_a[ib0 + i].d; + const FLOAT_TYPE dall = FLOAT_TYPE(d.x); + const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); + + uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; + uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; + uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4]; + uvec4 scale0 = uvec4(unpack8(scale0_u32)); + uvec4 scale4 = uvec4(unpack8(scale4_u32)); + uvec4 scale8 = uvec4(unpack8(scale8_u32)); + + const uint32_t sc0 = ( scale0.x & 0x3f); + const uint32_t sc1 = ( scale0.y & 0x3f); + const uint32_t sc2 = ( scale4.x & 0x3f); + const uint32_t sc3 = ( scale4.y & 0x3f); + const uint32_t sc4 = (( scale8.x & 0x0f) | ((scale0.x & 0xc0) >> 2)); + const uint32_t sc5 = (( scale8.y & 0x0f) | ((scale0.y & 0xc0) >> 2)); + const uint32_t sc6 = (((scale8.x >> 4) & 0x0f) | ((scale4.x & 0xc0) >> 2)); + const uint32_t sc7 = (((scale8.y >> 4) & 0x0f) | ((scale4.y & 0xc0) >> 2)); + + uint32_t qs0_16_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16); + uint32_t qs64_80_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 32]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 40]) << 16); + + uint32_t qs0_16_u32_lo4 = qs0_16_u32 & 0x0F0F0F0F; + uint32_t qs0_16_u32_hi4 = (qs0_16_u32 >> 4) & 0x0F0F0F0F; + uint32_t qs64_80_u32_lo4 = qs64_80_u32 & 0x0F0F0F0F; + uint32_t qs64_80_u32_hi4 = (qs64_80_u32 >> 4) & 0x0F0F0F0F; + + uint32_t qh = pack32(u16vec2(data_a_packed16[ib0 + i].qh[l0 / 2], data_a_packed16[ib0 + i].qh[l0 / 2 + 8])); + + uint32_t qs0_16_lo4_offset16 = ((qh >> (2*v_im)) & 0x01010101) << 4; + uint32_t qs0_16_hi4_offset16 = ((qh >> (2*v_im)) & 0x02020202) << 3; + uint32_t qs64_80_lo4_offset16 = ((qh >> (2*v_im)) & 0x10101010) << 0; + uint32_t qs64_80_hi4_offset16 = ((qh >> (2*v_im)) & 0x20202020) >> 1; + + qs0_16_u32_lo4 += qs0_16_lo4_offset16; + qs0_16_u32_hi4 += qs0_16_hi4_offset16; + qs64_80_u32_lo4 += qs64_80_lo4_offset16; + qs64_80_u32_hi4 += qs64_80_hi4_offset16; + + uvec4 qs0_16_lo4 = uvec4(unpack8(qs0_16_u32_lo4)); + uvec4 qs64_80_lo4 = uvec4(unpack8(qs64_80_u32_lo4)); + uvec4 qs0_16_hi4 = uvec4(unpack8(qs0_16_u32_hi4)); + uvec4 qs64_80_hi4 = uvec4(unpack8(qs64_80_u32_hi4)); + + const uint32_t q4_0 = qs0_16_lo4.x; + const uint32_t q4_1 = qs0_16_lo4.y; + const uint32_t q4_2 = qs0_16_lo4.z; + const uint32_t q4_3 = qs0_16_lo4.w; + const uint32_t q4_4 = qs0_16_hi4.x; + const uint32_t q4_5 = qs0_16_hi4.y; + const uint32_t q4_6 = qs0_16_hi4.z; + const uint32_t q4_7 = qs0_16_hi4.w; + const uint32_t q4_8 = qs64_80_lo4.x; + const uint32_t q4_9 = qs64_80_lo4.y; + const uint32_t q4_10 = qs64_80_lo4.z; + const uint32_t q4_11 = qs64_80_lo4.w; + const uint32_t q4_12 = qs64_80_hi4.x; + const uint32_t q4_13 = qs64_80_hi4.y; + const uint32_t q4_14 = qs64_80_hi4.z; + const uint32_t q4_15 = qs64_80_hi4.w; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + B_TYPE_VEC2 by10 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2]; + B_TYPE_VEC2 by116 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 8]; + B_TYPE_VEC2 by132 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 16]; + B_TYPE_VEC2 by148 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 24]; + B_TYPE_VEC2 by20 = data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2]; + B_TYPE_VEC2 by216 = data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 8]; + B_TYPE_VEC2 by232 = data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 16]; + B_TYPE_VEC2 by248 = data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 24]; + + const FLOAT_TYPE sx = + fma(FLOAT_TYPE(by10.x), q4_0, + fma(FLOAT_TYPE(by10.y), q4_1, + fma(FLOAT_TYPE(by116.x), q4_2, + FLOAT_TYPE(by116.y) * q4_3))); + const FLOAT_TYPE sy = + fma(FLOAT_TYPE(by132.x), q4_4, + fma(FLOAT_TYPE(by132.y), q4_5, + fma(FLOAT_TYPE(by148.x), q4_6, + FLOAT_TYPE(by148.y) * q4_7))); + const FLOAT_TYPE sz = + fma(FLOAT_TYPE(by20.x), q4_8, + fma(FLOAT_TYPE(by20.y), q4_9, + fma(FLOAT_TYPE(by216.x), q4_10, + FLOAT_TYPE(by216.y) * q4_11))); + const FLOAT_TYPE sw = + fma(FLOAT_TYPE(by232.x), q4_12, + fma(FLOAT_TYPE(by232.y), q4_13, + fma(FLOAT_TYPE(by248.x), q4_14, + FLOAT_TYPE(by248.y) * q4_15))); + const FLOAT_TYPE smin = + fma(FLOAT_TYPE(by10.x) + FLOAT_TYPE(by10.y) + FLOAT_TYPE(by116.x) + FLOAT_TYPE(by116.y), sc2, + fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3, + fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6, + (FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7))); + temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n])); + } + } + } + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp new file mode 100644 index 000000000..70e13a56b --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp @@ -0,0 +1,112 @@ +#version 450 + +#extension GL_EXT_shader_explicit_arithmetic_types : require + +#include "mul_mat_vec_base.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 16 threads are used to process each block + const uint it_size = gl_WorkGroupSize.x/16; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid%16; // 0...16 + const uint ix = tid/16; + + const uint step = 8; + + const uint v_im = itid/step; // 0 or 1. 0 computes 0..., 1 computes 128... + const uint v_in = itid - step*v_im; // 0...15 or 0...7 + + const uint l0 = 4 * v_in; // 0, 4, 8, ..., 28 + const uint is = v_in / 4; + + const uint ql_offset = 64*v_im + l0; + const uint qh_offset = 32*v_im + l0; + const uint s_offset = 8*v_im + is; + const uint y_offset = 128*v_im + l0; + + FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { + const uint y_idx = i * QUANT_K + y_offset; + + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); + + FLOAT_TYPE scales[4]; + scales[0] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]); + scales[1] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]); + scales[2] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]); + scales[3] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]); + + uint32_t ql0_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 1]) << 16); + uint32_t ql32_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 16]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 17]) << 16); + + uint32_t ql0_u32_lo4 = ql0_u32 & 0x0F0F0F0F; + uint32_t ql0_u32_hi4 = (ql0_u32 >> 4) & 0x0F0F0F0F; + uint32_t ql32_u32_lo4 = ql32_u32 & 0x0F0F0F0F; + uint32_t ql32_u32_hi4 = (ql32_u32 >> 4) & 0x0F0F0F0F; + + uint32_t qh_u32 = uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2 + 1]) << 16); + uint32_t qh0_u32 = (qh_u32 & 0x03030303) << 4; + uint32_t qh2_u32 = (qh_u32 & 0x0C0C0C0C) << 2; + uint32_t qh4_u32 = (qh_u32 & 0x30303030) << 0; + uint32_t qh6_u32 = (qh_u32 & 0xC0C0C0C0) >> 2; + + uint32_t q0_u32 = ql0_u32_lo4 | qh0_u32; + uint32_t q1_u32 = ql32_u32_lo4 | qh2_u32; + uint32_t q2_u32 = ql0_u32_hi4 | qh4_u32; + uint32_t q3_u32 = ql32_u32_hi4 | qh6_u32; + + uvec4 q0 = uvec4(unpack8(q0_u32)); + uvec4 q1 = uvec4(unpack8(q1_u32)); + uvec4 q2 = uvec4(unpack8(q2_u32)); + uvec4 q3 = uvec4(unpack8(q3_u32)); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + B_TYPE_VEC4 by0 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4]; + B_TYPE_VEC4 by32 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 8]; + B_TYPE_VEC4 by64 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 16]; + B_TYPE_VEC4 by96 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 24]; + + FLOAT_TYPE sum = FLOAT_TYPE(0.0); + [[unroll]] for (int l = 0; l < 4; ++l) { + sum = fma(FLOAT_TYPE(by0[l]) * scales[0], FLOAT_TYPE(int8_t(q0[l]) - 32), + fma(FLOAT_TYPE(by32[l]) * scales[1], FLOAT_TYPE(int8_t(q1[l]) - 32), + fma(FLOAT_TYPE(by64[l]) * scales[2], FLOAT_TYPE(int8_t(q2[l]) - 32), + fma(FLOAT_TYPE(by96[l]) * scales[3], FLOAT_TYPE(int8_t(q3[l]) - 32), sum)))); + } + temp[j][n] += sum * d; + } + } + } + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp new file mode 100644 index 000000000..48122cbef --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -0,0 +1,631 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : require + +#ifdef FLOAT16 +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#endif + +#ifdef COOPMAT +#extension GL_KHR_cooperative_matrix : enable +#extension GL_KHR_memory_scope_semantics : enable +#extension GL_KHR_shader_subgroup_basic : enable +#endif + +#ifdef MUL_MAT_ID +#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require +#endif + +#include "types.comp" + +#ifndef LOAD_VEC_A +#define LOAD_VEC_A 1 +#endif +#ifndef LOAD_VEC_B +#define LOAD_VEC_B 1 +#endif + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; +layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; + +#ifdef MUL_MAT_ID +layout (binding = 3) readonly buffer IDS {int data_ids[];}; +#endif + +layout (push_constant) uniform parameter +{ + uint M; + uint N; + uint K; + uint stride_a; + uint stride_b; + uint stride_d; + + uint batch_stride_a; + uint batch_stride_b; + uint batch_stride_d; + +#ifdef MUL_MAT_ID + uint nei0; + uint nei1; + uint nbi1; + uint ne11; +#else + uint k_split; + uint ne02; + uint ne12; + uint broadcast2; + uint broadcast3; +#endif +} p; + +layout (constant_id = 0) const uint BLOCK_SIZE = 64; +layout (constant_id = 1) const uint BM = 64; +layout (constant_id = 2) const uint BN = 64; +layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant +layout (constant_id = 4) const uint WM = 32; +layout (constant_id = 5) const uint WN = 32; +layout (constant_id = 6) const uint WMITER = 2; +layout (constant_id = 7) const uint TM = 4; +layout (constant_id = 8) const uint TN = 2; +layout (constant_id = 9) const uint TK = 1; // Only needed for coopmat +layout (constant_id = 10) const uint WARP = 32; + +#ifdef COOPMAT +#define SHMEM_STRIDE (BK + 8) +#else +#define SHMEM_STRIDE (BK + 1) +#endif + +shared FLOAT_TYPE buf_a[BM * SHMEM_STRIDE]; +shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE]; + +#ifdef MUL_MAT_ID +shared u16vec2 row_ids[3072]; +#endif // MUL_MAT_ID + +#define NUM_WARPS (BLOCK_SIZE / WARP) + +#ifdef COOPMAT +shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS]; +#endif + +void main() { +#if defined(DATA_A_IQ4_NL) + init_iq4nl_shmem(); +#endif + +#ifdef MUL_MAT_ID + const uint expert_idx = gl_GlobalInvocationID.z; +#else + const uint batch_idx = gl_GlobalInvocationID.z; + + const uint i13 = batch_idx / p.ne12; + const uint i12 = batch_idx % p.ne12; + + const uint i03 = i13 / p.broadcast3; + const uint i02 = i12 / p.broadcast2; + + const uint batch_idx_a = i03 * p.ne02 + i02; +#endif + + const uint blocks_m = (p.M + BM - 1) / BM; + const uint ir = gl_WorkGroupID.x % blocks_m; + const uint ik = gl_WorkGroupID.x / blocks_m; + const uint ic = gl_WorkGroupID.y; + + const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER); + const uint WSUBM = WM / WMITER; + const uint WSUBN = WN / WNITER; + +#ifdef COOPMAT + const uint warp_i = gl_SubgroupID; + + const uint tiw = gl_SubgroupInvocationID; + + const uint cms_per_row = WM / TM; + const uint cms_per_col = WN / TN; + + const uint storestride = WARP / TM; + const uint store_r = tiw % TM; + const uint store_c = tiw / TM; +#else + const uint warp_i = gl_LocalInvocationID.x / WARP; + + const uint tiw = gl_LocalInvocationID.x % WARP; + + const uint tiwr = tiw % (WSUBM / TM); + const uint tiwc = tiw / (WSUBM / TM); +#endif + + const uint warp_r = warp_i % (BM / WM); + const uint warp_c = warp_i / (BM / WM); + + const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A); + const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A); + const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B); + const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B); + + const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A / BK; + const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B / BK; + +#ifdef MUL_MAT_ID + uint _ne1 = 0; + for (uint ii1 = 0; ii1 < p.nei1; ii1++) { + for (uint ii0 = 0; ii0 < p.nei0; ii0++) { + if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) { + row_ids[_ne1] = u16vec2(ii0, ii1); + _ne1++; + } + } + } + + barrier(); + + // Workgroup has no work + if (ic * BN >= _ne1) return; +#endif + +#ifdef MUL_MAT_ID + const uint start_k = 0; + const uint end_k = p.K; +#else + const uint start_k = ik * p.k_split; + const uint end_k = min(p.K, (ik + 1) * p.k_split); +#endif + + uint pos_a = ( +#ifdef MUL_MAT_ID + expert_idx * p.batch_stride_a + +#else + batch_idx_a * p.batch_stride_a + +#endif + ir * BM * p.stride_a + start_k) / LOAD_VEC_A; +#ifdef MUL_MAT_ID + uint pos_b = 0; +#else + uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B; +#endif + +#ifdef COOPMAT + coopmat cache_a; + coopmat cache_b; + coopmat sums[cms_per_row * cms_per_col]; + + [[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) { + sums[i] = coopmat(0.0f); + } +#else + ACC_TYPE sums[WMITER * TM * WNITER * TN]; + FLOAT_TYPE cache_a[WMITER * TM]; + FLOAT_TYPE cache_b[WNITER * TN]; + + [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) { + sums[i] = ACC_TYPE(0.0f); + } +#endif + + for (uint block = start_k; block < end_k; block += BK) { + [[unroll]] for (uint l = 0; l < BM; l += loadstride_a) { + +#if defined(DATA_A_F32) || defined(DATA_A_F16) +#if LOAD_VEC_A == 8 + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; + buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx][0].x); + buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx][0].y); + buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx][0].z); + buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx][0].w); + buf_a[buf_idx + 4] = FLOAT_TYPE(data_a[idx][1].x); + buf_a[buf_idx + 5] = FLOAT_TYPE(data_a[idx][1].y); + buf_a[buf_idx + 6] = FLOAT_TYPE(data_a[idx][1].z); + buf_a[buf_idx + 7] = FLOAT_TYPE(data_a[idx][1].w); +#elif LOAD_VEC_A == 4 + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; + buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx].x); + buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx].y); + buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx].z); + buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx].w); +#else + if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) { + buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]); + } else { + buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(0.0f); + } +#endif +#elif defined(DATA_A_Q4_0) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a; + + const uint ib = idx / 16; + const uint iqs = idx & 0xF; + + const float d = float(data_a[ib].d); + const uint vui = uint(data_a[ib].qs[iqs]); + const vec2 v = (vec2(vui & 0xF, vui >> 4) - 8.0f) * d; + + buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); +#elif defined(DATA_A_Q4_1) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a; + + const uint ib = idx / 16; + const uint iqs = idx & 0xF; + + const float d = float(data_a[ib].d); + const float m = float(data_a[ib].m); + const uint vui = uint(data_a[ib].qs[iqs]); + const vec2 v = vec2(vui & 0xF, vui >> 4) * d + m; + + buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); +#elif defined(DATA_A_Q5_0) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a; + + const uint ib = idx / 16; + const uint iqs = idx & 0xF; + + const float d = float(data_a[ib].d); + const uint uint_qh = uint(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]; + const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); + const uint vui = uint(data_a[ib].qs[iqs]); + const vec2 v = (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f) * d; + + buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); +#elif defined(DATA_A_Q5_1) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a; + + const uint ib = idx / 16; + const uint iqs = idx & 0xF; + + const float d = float(data_a[ib].d); + const float m = float(data_a[ib].m); + const uint uint_qh = data_a[ib].qh; + const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); + const uint vui = uint(data_a[ib].qs[iqs]); + const vec2 v = vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) * d + m; + + buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); +#elif defined(DATA_A_Q8_0) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; + + const uint ib = idx / 16; + const uint iqs = (idx & 0xF) * 2; + + const float d = float(data_a[ib].d); + const vec2 v = vec2(int(data_a[ib].qs[iqs]), int(data_a[ib].qs[iqs + 1])) * d; + + buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); +#elif defined(DATA_A_Q2_K) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; + + const uint ib = idx / 128; // 2 values per idx + const uint iqs = idx % 128; // 0..127 + + const uint qsi = (iqs / 64) * 32 + (iqs % 16) * 2; // 0,2,4..30 + const uint scalesi = iqs / 8; // 0..15 + const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6 + + const uvec2 qs = uvec2(data_a[ib].qs[qsi], data_a[ib].qs[qsi + 1]); + const uint scales = data_a[ib].scales[scalesi]; + const vec2 d = vec2(data_a[ib].d); + + const vec2 v = d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4); + + buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); +#elif defined(DATA_A_Q3_K) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; + + const uint ib = idx / 128; // 2 values per idx + const uint iqs = idx % 128; // 0..127 + + const uint n = iqs / 64; // 0,1 + const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..62 + const uint hmi = (iqs % 16) * 2; // 0,2,4..30 + const uint j = (iqs % 64) / 4; // 0..3 + const uint is = iqs / 8; // 0..15 + const uint halfsplit = ((iqs % 64) / 16); // 0,1,2,3 + const uint qsshift = halfsplit * 2; // 0,2,4,6 + const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128 + + const int8_t us = int8_t(is < 4 ? (data_a[ib].scales[is-0] & 0xF) | (((data_a[ib].scales[is+8] >> 0) & 3) << 4) : + is < 8 ? (data_a[ib].scales[is-0] & 0xF) | (((data_a[ib].scales[is+4] >> 2) & 3) << 4) : + is < 12 ? (data_a[ib].scales[is-8] >> 4) | (((data_a[ib].scales[is+0] >> 4) & 3) << 4) : + (data_a[ib].scales[is-8] >> 4) | (((data_a[ib].scales[is-4] >> 6) & 3) << 4)); + const float dl = float(data_a[ib].d) * float(us - 32); + + buf_a[buf_idx ] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi ] >> qsshift) & 3) - (((data_a[ib].hmask[hmi ] & m) != 0) ? 0 : 4))); + buf_a[buf_idx + 1] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4))); +#elif defined(DATA_A_Q4_K) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; + + const uint ib = idx / 128; // 2 values per idx + const uint iqs = idx % 128; // 0..127 + + const uint n = iqs / 32; // 0,1,2,3 + const uint b = (iqs % 32) / 16; // 0,1 + const uint is = 2 * n + b; // 0..7 + const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126 + + const vec2 loadd = vec2(data_a[ib].d); + + const uint scidx0 = (is < 4) ? is : (is + 4); + const uint scidx1 = (is < 4) ? is : (is - 4); + const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0; + const uint scidxshift1 = (is < 4) ? 0 : 2; + const uint mbidx0 = is + 4; + const uint mbidx1 = (is < 4) ? is + 4 : is; + const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0; + const uint mbidxshift0 = (is < 4) ? 0 : 4; + const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0; + const uint mbidxshift1 = (is < 4) ? 0 : 2; + + const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); + const uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + + const float d = loadd.x * sc; + const float m = -loadd.y * mbyte; + + buf_a[buf_idx ] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF), m)); + buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF), m)); +#elif defined(DATA_A_Q5_K) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; + + const uint ib = idx / 128; // 2 values per idx + const uint iqs = idx % 128; // 0..127 + + const uint n = iqs / 32; // 0,1,2,3 + const uint b = (iqs % 32) / 16; // 0,1 + const uint is = 2 * n + b; // 0..7 + const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126 + const uint qhi = (iqs % 16) * 2; // 0,2,4..30 + + const uint8_t hm = uint8_t(1 << (iqs / 16)); + + const vec2 loadd = vec2(data_a[ib].d); + + const uint scidx0 = (is < 4) ? is : (is + 4); + const uint scidx1 = (is < 4) ? is : (is - 4); + const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0; + const uint scidxshift1 = (is < 4) ? 0 : 2; + const uint mbidx0 = is + 4; + const uint mbidx1 = (is < 4) ? is + 4 : is; + const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0; + const uint mbidxshift0 = (is < 4) ? 0 : 4; + const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0; + const uint mbidxshift1 = (is < 4) ? 0 : 2; + + const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); + const uint8_t mbyte = uint8_t(((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + + const float d = loadd.x * sc; + const float m = -loadd.y * mbyte; + + buf_a[buf_idx ] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi ] & hm) != 0 ? 16 : 0), m)); + buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m)); +#elif defined(DATA_A_Q6_K) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; + + const uint ib = idx / 128; // 2 values per idx + const uint iqs = idx % 128; // 0..127 + + const uint n = iqs / 64; // 0,1 + const uint b = (iqs % 64) / 32; // 0,1 + const uint is_b = (iqs % 16) / 8; // 0,1 + const uint qhshift = ((iqs % 64) / 16) * 2; // 0,2,4,6 + const uint is = 8 * n + qhshift + is_b; // 0..15 + const uint qsi = n * 64 + (iqs % 32) * 2; // 0,2,4..126 + const uint qhi = n * 32 + (iqs % 16) * 2; // 0,2,4..62 + + const float dscale = float(data_a[ib].d) * float(data_a[ib].scales[is]); + + buf_a[buf_idx ] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi ] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi ] >> qhshift) & 3) << 4)) - 32)); + buf_a[buf_idx + 1] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32)); +#elif defined(DATA_A_IQ4_NL) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a; + + const uint ib = idx / 16; + const uint iqs = idx & 0xF; + + const float d = float(data_a[ib].d); + const uint vui = uint(data_a[ib].qs[iqs]); + const vec2 v = vec2(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[vui >> 4]) * d; + + buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); +#endif + } + [[unroll]] for (uint l = 0; l < BN; l += loadstride_b) { +#if LOAD_VEC_B == 8 +#ifdef MUL_MAT_ID + const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l]; + const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b; +#else + const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b; +#endif + const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B; + buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx][0].x); + buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx][0].y); + buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx][0].z); + buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx][0].w); + buf_b[buf_idx + 4] = FLOAT_TYPE(data_b[idx][1].x); + buf_b[buf_idx + 5] = FLOAT_TYPE(data_b[idx][1].y); + buf_b[buf_idx + 6] = FLOAT_TYPE(data_b[idx][1].z); + buf_b[buf_idx + 7] = FLOAT_TYPE(data_b[idx][1].w); +#elif LOAD_VEC_B == 4 +#ifdef MUL_MAT_ID + const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l]; + const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b; +#else + const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b; +#endif + const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B; + buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx].x); + buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx].y); + buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx].z); + buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx].w); +#elif !MUL_MAT_ID + if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) { + buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]); + } else { + buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f); + } +#else + const uint row_i = ic * BN + loadc_b + l; + if (row_i < _ne1) { + const u16vec2 row_idx = row_ids[row_i]; + buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]); + } else { + buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f); + } +#endif + } + + barrier(); + + pos_a += BK / LOAD_VEC_A; + pos_b += BK / LOAD_VEC_B; + +#ifdef COOPMAT + [[unroll]] for (uint i = 0; i < BK; i += TK) { + [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { + // Load from shared into cache + coopMatLoad(cache_a, buf_a, (warp_r * WM + cm_row * TM) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor); + + [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { + coopMatLoad(cache_b, buf_b, (warp_c * WN + cm_col * TN) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor); + + sums[cm_col * cms_per_row + cm_row] = coopMatMulAdd(cache_a, cache_b, sums[cm_col * cms_per_row + cm_row]); + } + } + } +#else + [[unroll]] for (uint i = 0; i < BK; i++) { + // Load from shared into cache + [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { + [[unroll]] for (uint j = 0; j < TM; j++) { + cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + i]; + } + } + [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { + [[unroll]] for (uint j = 0; j < TN; j++) { + cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * SHMEM_STRIDE + i]; + } + } + + [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { + [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { + [[unroll]] for (uint cc = 0; cc < TN; cc++) { + [[unroll]] for (uint cr = 0; cr < TM; cr++) { + const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr; + sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr]), ACC_TYPE(cache_b[wsic * TN + cc]), sums[sums_idx]); + } + } + } + } + } +#endif + + barrier(); + } + + const uint dr = ir * BM + warp_r * WM; + const uint dc = ic * BN + warp_c * WN; + +#ifndef MUL_MAT_ID + const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; +#endif + +#ifdef COOPMAT +#ifdef MUL_MAT_ID + [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { + [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { + coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); + + [[unroll]] for (uint col = 0; col < BN; col += storestride) { + const uint row_i = dc + cm_col * TN + col + store_c; + if (row_i >= _ne1) break; + + const u16vec2 row_idx = row_ids[row_i]; + + data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); + } + } + } +#else + const bool is_aligned = p.stride_d % 4 == 0; // Assumption: D_TYPE == float + + [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { + [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { + const bool is_in_bounds = dr + (cm_row + 1) * TM <= p.M && dc + (cm_col + 1) * TN <= p.N; + + if (is_aligned && is_in_bounds) { + // Full coopMat is within bounds and stride_d is aligned with 16B + coopmat cm_dtype = coopmat(sums[cm_col * cms_per_row + cm_row]); + coopMatStore(cm_dtype, data_d, offsets + (dc + cm_col * TN) * p.stride_d + dr + cm_row * TM, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor); + } else if (is_in_bounds) { + // Full coopMat is within bounds, but stride_d is not aligned + coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); + + [[unroll]] for (uint col = 0; col < TN; col += storestride) { + data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); + } + } else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) { + // Partial coopMat is within bounds + coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); + + [[unroll]] for (uint col = 0; col < TN; col += storestride) { + if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) { + data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); + } + } + } + } + } +#endif // MUL_MAT_ID +#else + [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { + [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { + + const uint dr_warp = dr + wsir * WSUBM + tiwr * TM; + const uint dc_warp = dc + wsic * WSUBN + tiwc * TN; + [[unroll]] for (uint cc = 0; cc < TN; cc++) { +#ifdef MUL_MAT_ID + const uint row_i = dc_warp + cc; + if (row_i >= _ne1) break; + + const u16vec2 row_idx = row_ids[row_i]; +#endif // MUL_MAT_ID + [[unroll]] for (uint cr = 0; cr < TM; cr++) { +#ifdef MUL_MAT_ID + data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); +#else + if (dr_warp + cr < p.M && dc_warp + cc < p.N) { + data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); + } +#endif // MUL_MAT_ID + } + } + } + } +#endif // COOPMAT +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp new file mode 100644 index 000000000..cbfa5dce1 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp @@ -0,0 +1,328 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : require + +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require + +#extension GL_KHR_memory_scope_semantics : enable +#extension GL_KHR_cooperative_matrix : enable +#extension GL_NV_cooperative_matrix2 : enable +#extension GL_EXT_buffer_reference : enable +#extension GL_KHR_shader_subgroup_ballot : enable +#extension GL_KHR_shader_subgroup_vote : enable + +#include "types.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (constant_id = 1) const uint BM = 64; +layout (constant_id = 2) const uint BN = 64; +layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant + +layout (push_constant) uniform parameter +{ + uint M; + uint N; + uint K; + uint stride_a; + uint stride_b; + uint stride_d; + + uint batch_stride_a; + uint batch_stride_b; + uint batch_stride_d; + +#ifdef MUL_MAT_ID + uint nei0; + uint nei1; + uint nbi1; + uint ne11; +#else + uint k_split; + uint ne02; + uint ne12; + uint broadcast2; + uint broadcast3; +#endif +} p; + + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; +layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; + +#if QUANT_K > 1 +#define DECODEFUNCA , dequantFuncA +#define MAT_A_TYPE float16_t + +#include "dequant_funcs_cm2.comp" + +#else +#define DECODEFUNCA +#define MAT_A_TYPE A_TYPE +#endif + +#define MAT_B_TYPE B_TYPE + +#ifdef MUL_MAT_ID +layout (binding = 3) readonly buffer IDS {int data_ids[];}; + +shared u16vec4 row_ids[3072]; + +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufB { + B_TYPE b[]; +}; + +uint _ne1; +shared uint _ne1_sh; + +B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const uint row_i = blockCoords[0]; + + if (row_i >= _ne1) { + return B_TYPE(0.0); + } + + const u16vec4 row_idx = row_ids[row_i]; + B_TYPE ret = data_b[row_idx.y * p.batch_stride_b + row_idx.x * p.stride_b + blockCoords[1]]; + + return ret; +} + +D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t ir, const in uint32_t ic) +{ + uint dr = ir * BM + r; + uint dc = ic * BN + c; + + if (dr < p.M && dc < _ne1) { + uint row_i = dc; + const u16vec4 row_idx = row_ids[row_i]; + data_d[row_idx.y * p.batch_stride_d + row_idx.z * p.stride_d + dr] = elem; + } + return elem; +} + +#endif + +void main() { +#if defined(DATA_A_IQ4_NL) + init_iq4nl_shmem(); +#endif + +#ifdef MUL_MAT_ID + const uint expert_idx = gl_GlobalInvocationID.z; +#else + const uint batch_idx = gl_GlobalInvocationID.z; + + const uint i13 = batch_idx / p.ne12; + const uint i12 = batch_idx % p.ne12; + + const uint i03 = i13 / p.broadcast3; + const uint i02 = i12 / p.broadcast2; + + const uint batch_idx_a = i03 * p.ne02 + i02; +#endif + + const uint blocks_m = (p.M + BM - 1) / BM; + const uint ir = gl_WorkGroupID.x % blocks_m; + const uint ik = gl_WorkGroupID.x / blocks_m; + const uint ic = gl_WorkGroupID.y; + +#ifdef MUL_MAT_ID + // Spread the search across all elements in the first subgroup + if (gl_SubgroupID == 0) { + _ne1 = 0; + uint num_elements = p.nei1 * p.nei0; + + for (uint i = gl_SubgroupInvocationID; subgroupAny(i < num_elements); i += gl_SubgroupSize) { + bool in_range = i < num_elements; + uint ii0 = i % p.nei0; + uint ii1 = i / p.nei0; + uint id = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0; + uvec4 ballot = subgroupBallot(in_range && id == expert_idx); + uint idx = subgroupBallotExclusiveBitCount(ballot); + if (in_range && id == expert_idx) { + row_ids[_ne1 + idx] = u16vec4(ii0 % p.ne11, ii1, ii0, 0); + } + _ne1 += subgroupBallotBitCount(ballot); + } + _ne1_sh = _ne1; + } + + barrier(); + + _ne1 = _ne1_sh; + + // Workgroup has no work + if (ic * BN >= _ne1) return; +#endif + +#ifdef MUL_MAT_ID + uint start_k = 0; + const uint end_k = p.K; +#else + uint start_k = ik * p.k_split; + const uint end_k = min(p.K, (ik + 1) * p.k_split); +#endif + + coopmat sum; + sum = coopmat(0.0); + +#ifdef MUL_MAT_ID + uint pos_a = (expert_idx * p.batch_stride_a) / QUANT_K; + uint pos_b = 0; +#else + uint pos_a = (batch_idx_a * p.batch_stride_a) / QUANT_K; + uint pos_b = batch_idx * p.batch_stride_b; +#endif + + uint stride_a = p.stride_a / QUANT_K; + uint stride_b = p.stride_b; + + // Hint to the compiler that values are aligned (want 16B alignment). + // Quants are always block-aligned, no alignment needed. +#if ALIGNED +#if QUANT_K == 1 + stride_a &= ~7; +#endif + stride_b &= ~7; +#endif + + // Create layouts for both clamped and unclamped accesses + tensorLayoutNV<2> tensorLayoutA = createTensorLayoutNV(2); + tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutAClamp = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); + tensorLayoutNV<2> tensorLayoutB = createTensorLayoutNV(2); + tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutBClamp = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); + tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); + +#if QUANT_K > 1 + tensorLayoutA = setTensorLayoutBlockSizeNV(tensorLayoutA, 1, QUANT_K); + tensorLayoutAClamp = setTensorLayoutBlockSizeNV(tensorLayoutAClamp, 1, QUANT_K); +#endif + + // Use end_k rather than p.K as the dimension because that's what + // we need to bound check against when using split_k + tensorLayoutA = setTensorLayoutDimensionNV(tensorLayoutA, p.M, end_k); + tensorLayoutB = setTensorLayoutDimensionNV(tensorLayoutB, p.N, end_k); + tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.N, p.M); + tensorLayoutAClamp = setTensorLayoutDimensionNV(tensorLayoutAClamp, p.M, end_k); + tensorLayoutBClamp = setTensorLayoutDimensionNV(tensorLayoutBClamp, p.N, end_k); + + tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0); + +#if !defined(MUL_MAT_ID) + // Detect a fast path where all loads are entirely in bounds and no clamping is required + if ((ir + 1) * BM <= p.M && (ic + 1) * BN <= p.N && (start_k % BK) == 0 && (end_k % BK) == 0 && +#if QUANT_K == 1 + (stride_a % 8) == 0 && +#endif + (stride_b % 8) == 0 && (start_k % 8) == 0) { + // Hint to the compiler that values are aligned (want 16B alignment) + start_k &= ~7; + stride_b &= ~7; +#if QUANT_K == 1 + stride_a &= ~7; +#endif + + tensorLayoutA = setTensorLayoutStrideNV(tensorLayoutA, stride_a, 1); + tensorLayoutB = setTensorLayoutStrideNV(tensorLayoutB, stride_b, 1); + + uint k_iters = (end_k - start_k + BK - 1) / BK; + + for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) { + + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopmat mat_a_ft = coopmat(mat_a); + + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose); + coopmat mat_b_ft = coopmat(mat_b); + + sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum); + } + } else +#endif // !defined(MUL_MAT_ID) + { + tensorLayoutA = setTensorLayoutStrideNV(tensorLayoutA, stride_a, 1); + + tensorLayoutAClamp = setTensorLayoutStrideNV(tensorLayoutAClamp, stride_a, 1); + + tensorLayoutB = setTensorLayoutStrideNV(tensorLayoutB, stride_b, 1); + + tensorLayoutBClamp = setTensorLayoutStrideNV(tensorLayoutBClamp, stride_b, 1); + + [[dont_unroll]] + for (uint block_k = start_k; block_k < end_k; block_k += BK) { + + coopmat mat_a; + coopmat mat_b; + coopmat mat_a_ft; + coopmat mat_b_ft; + + // Clamping is expensive, so detect different code paths for each combination + // of A and B needing clamping. + bool unclampedA = (ir + 1) * BM <= p.M && block_k + BK <= end_k && (block_k % 8) == 0; +#ifdef MUL_MAT_ID + bool unclampedB = true; +#else + bool unclampedB = (ic + 1) * BN <= p.N && block_k + BK <= end_k && (block_k % 8) == 0; +#endif + if (unclampedA && unclampedB) { + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, (block_k & ~7), BK) DECODEFUNCA); +#ifdef MUL_MAT_ID + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB); +#else + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, (block_k & ~7), BK), tensorViewTranspose); +#endif + mat_a_ft = coopmat(mat_a); + mat_b_ft = coopmat(mat_b); + sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum); + } else if (unclampedA && !unclampedB) { + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, (block_k & ~7), BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose); + + mat_a_ft = coopmat(mat_a); + mat_b_ft = coopmat(mat_b); + sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum); + } else if (!unclampedA && unclampedB) { + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); +#ifdef MUL_MAT_ID + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB); +#else + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, (block_k & ~7), BK), tensorViewTranspose); +#endif + mat_a_ft = coopmat(mat_a); + mat_b_ft = coopmat(mat_b); + sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum); + } else if (!unclampedA && !unclampedB) { + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose); + + mat_a_ft = coopmat(mat_a); + mat_b_ft = coopmat(mat_b); + sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum); + } + } + } + + // Convert from ACC_TYPE to D_TYPE + coopmat mat_d; + mat_d = coopmat(sum); + +#ifdef MUL_MAT_ID + // Call callback to store each element, remapping row through shared memory + coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic); +#else + tensorLayoutD = setTensorLayoutStrideNV(tensorLayoutD, p.stride_d, 1); + + uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; + coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose); +#endif +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp new file mode 100644 index 000000000..6627a50bd --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp @@ -0,0 +1,44 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable +#define BLOCK_SIZE 512 + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +shared vec2 sum[BLOCK_SIZE]; + +void main() { + const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; + const uint tid = gl_LocalInvocationID.x; + + sum[tid] = vec2(0.0f, 0.0f); + + [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { + const float xi = float(data_a[row*p.KX + col]); + sum[tid].x += xi; + sum[tid].y += xi * xi; + } + + // sum up partial sums and write back result + barrier(); + [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + sum[tid] += sum[tid + s]; + } + barrier(); + } + + const float mean = sum[0].x / p.KX; + const float var = sum[0].y / p.KX - mean * mean; + const float inv_std = inversesqrt(var + p.param1); + + [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { + data_d[row*p.KX + col] = D_TYPE((float(data_a[row*p.KX + col]) - mean) * inv_std); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp new file mode 100644 index 000000000..450b67fc5 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp @@ -0,0 +1,28 @@ +#version 450 + +#include "types.comp" +#include "generic_unary_head.comp" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (idx >= p.ne) { + return; + } + + const uint i3 = idx / (p.ne12*p.ne11*p.ne10); + const uint i3_offset = i3 * p.ne12*p.ne11*p.ne10; + const uint i2 = (idx - i3_offset) / (p.ne11*p.ne10); + const uint i2_offset = i2*p.ne11*p.ne10; + const uint i1 = (idx - i3_offset - i2_offset) / p.ne10; + const uint i0 = idx - i3_offset - i2_offset - i1*p.ne10; + + const uint src0_idx = i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0*p.nb00; + const uint dst_idx = i3*p.nb13 + i2*p.nb12 + i1*p.nb11 + i0*p.nb10; + + const bool is_src0 = i0 < p.ne00 && i1 < p.ne01 && i2 < p.ne02 && i3 < p.ne03; + + data_d[get_doffset() + dst_idx] = D_TYPE(is_src0 ? data_a[get_aoffset() + src0_idx] : 0.0f); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp new file mode 100644 index 000000000..b6124411a --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp @@ -0,0 +1,74 @@ +#version 450 + +#include "types.comp" + +#extension GL_EXT_shader_16bit_storage : require + +layout(push_constant) uniform parameter { + uint IW; uint IH; + uint OW; uint OH; + uint OC; + uint pelements; + uint op; + int k0; int k1; + int s0; int s1; + int p0; int p1; +} p; + +#define BLOCK_SIZE 512 +#define FLT_MAX 3.402823466e+38F +#define OP_POOL_MAX 0u +#define OP_POOL_AVG 1u + +layout (local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout(binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout(binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint idx = gl_GlobalInvocationID.x; + if (idx >= p.pelements) { + return; + } + + const uint O_HW = p.OW * p.OH; + + const uint nc = idx / O_HW; + const uint cur_oh = (idx % O_HW) / p.OW; + const uint cur_ow = (idx % O_HW) % p.OW; + + const int start_h = int(cur_oh) * p.s0 - p.p0; + const uint bh = max(start_h, 0); + const uint eh = min(start_h + p.k0, p.IH); + + const int start_w = int(cur_ow) * p.s1 - p.p1; + const uint bw = max(start_w, 0); + const uint ew = min(start_w + p.k1, p.IW); + + const float scale = 1.0 / float(p.k0 * p.k1); + float res; + + if (p.op == OP_POOL_AVG) { + res = 0.0; + } else if (p.op == OP_POOL_MAX) { + res = -FLT_MAX; + } else { + return; + } + + #pragma unroll + for (uint i = bh; i < eh; i++) { + #pragma unroll + for (uint j = bw; j < ew; j++) { + const float cur = D_TYPE(data_a[nc * p.IH * p.IW + i * p.IW + j]); + + if (p.op == OP_POOL_AVG) { + res += cur * scale; + } else if (p.op == OP_POOL_MAX) { + res = max(res, cur); + } + } + } + + data_d[nc * O_HW + cur_oh * p.OW + cur_ow] = res; +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp new file mode 100644 index 000000000..52a19b62a --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp @@ -0,0 +1,21 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + data_d[i] = max(float(data_a[i]), 0); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp new file mode 100644 index 000000000..1568b141d --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp @@ -0,0 +1,26 @@ +#version 450 + +#include "types.comp" +#include "generic_unary_head.comp" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +uint src0_idx_mod(uint idx) { + const uint i13 = idx / (p.ne12*p.ne11*p.ne10); + const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10; + const uint i12 = (idx - i13_offset) / (p.ne11*p.ne10); + const uint i12_offset = i12*p.ne11*p.ne10; + const uint i11 = (idx - i13_offset - i12_offset) / p.ne10; + const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10; + return (i13 % p.ne03)*p.nb03 + (i12 % p.ne02)*p.nb02 + (i11 % p.ne01)*p.nb01 + (i10 % p.ne00)*p.nb00; +} + +void main() { + const uint idx = get_idx(); + + if (idx >= p.ne) { + return; + } + + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(data_a[get_aoffset() + src0_idx_mod(idx)]); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp new file mode 100644 index 000000000..b554400ba --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp @@ -0,0 +1,42 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable +#define BLOCK_SIZE 512 + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +shared FLOAT_TYPE sum[BLOCK_SIZE]; + +void main() { + const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; + const uint tid = gl_LocalInvocationID.x; + + sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp + + [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { + const FLOAT_TYPE xi = FLOAT_TYPE(data_a[row*p.KX + col]); + sum[tid] += xi * xi; + } + + // sum up partial sums and write back result + barrier(); + [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + sum[tid] += sum[tid + s]; + } + barrier(); + } + + const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(p.KX); + const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1)); + + [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { + data_d[row*p.KX + col] = D_TYPE(scale * FLOAT_TYPE(data_a[row*p.KX + col])); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp new file mode 100644 index 000000000..574b51ca5 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp @@ -0,0 +1,49 @@ +#include "types.comp" + +#extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_spirv_intrinsics: enable + +#if RTE16 +spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits +#endif + +layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer Y {int data_pos[];}; +layout (binding = 2) readonly buffer Z {float data_ff[];}; +layout (binding = 3) writeonly buffer D {D_TYPE data_d[];}; + +layout (push_constant) uniform parameter { + uint ncols; + uint n_dims; + float freq_scale; + uint p_delta_rows; + float freq_base; + float ext_factor; + float attn_factor; + float corr_dims[2]; + float theta_scale; + uint has_ff; +} p; + +float rope_yarn_ramp(const float low, const float high, const uint i0) { + const float y = (i0 / 2 - low) / max(0.001f, high - low); + return 1.0f - min(1.0f, max(0.0f, y)); +} + +void rope_yarn(const float theta_extrap, const uint i0, out float cos_theta, out float sin_theta) { + float mscale = p.attn_factor; + // Get n-d rotational scaling corrected for extrapolation + float theta_interp = p.freq_scale * theta_extrap; + float theta = theta_interp; + if (p.ext_factor != 0.0f) { + float ramp_mix = rope_yarn_ramp(p.corr_dims[0], p.corr_dims[1], i0) * p.ext_factor; + theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; + + // Get n-d magnitude scaling corrected for interpolation + mscale *= 1.0f + 0.1f * log(1.0f / p.freq_scale); + } + cos_theta = cos(theta) * mscale; + sin_theta = sin(theta) * mscale; +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp new file mode 100644 index 000000000..83b46b69b --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp @@ -0,0 +1,37 @@ +#version 450 + +#include "rope_head.comp" + +void main() { + const uint col = gl_GlobalInvocationID.y * 2; + const uint row = gl_GlobalInvocationID.x; + + if (col >= p.ncols) { + return; + } + + if (col >= p.n_dims) { + const uint i = row*p.ncols + col; + + data_d[i + 0] = data_a[i + 0]; + data_d[i + 1] = data_a[i + 1]; + + return; + } + + const uint i = row*p.ncols + col/2; + const uint i2 = row/p.p_delta_rows; + + const float theta_base = data_pos[i2] * pow(p.theta_scale, col/2.0f); + + const float freq_factor = p.has_ff != 0 ? data_ff[col/2] : 1.0f; + + float cos_theta, sin_theta; + rope_yarn(theta_base / freq_factor, col, cos_theta, sin_theta); + + const float x0 = float(data_a[i + 0]); + const float x1 = float(data_a[i + p.n_dims/2]); + + data_d[i + 0] = D_TYPE(x0*cos_theta - x1*sin_theta); + data_d[i + p.n_dims/2] = D_TYPE(x0*sin_theta + x1*cos_theta); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp new file mode 100644 index 000000000..e416ad938 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp @@ -0,0 +1,37 @@ +#version 450 + +#include "rope_head.comp" + +void main() { + const uint col = gl_GlobalInvocationID.y * 2; + const uint row = gl_GlobalInvocationID.x; + + if (col >= p.ncols) { + return; + } + + if (col >= p.n_dims) { + const uint i = row*p.ncols + col; + + data_d[i + 0] = data_a[i + 0]; + data_d[i + 1] = data_a[i + 1]; + + return; + } + + const uint i = row*p.ncols + col; + const uint i2 = row/p.p_delta_rows; + + const float theta_base = data_pos[i2] * pow(p.theta_scale, col/2.0f); + + const float freq_factor = p.has_ff != 0 ? data_ff[col/2] : 1.0f; + + float cos_theta, sin_theta; + rope_yarn(theta_base / freq_factor, col, cos_theta, sin_theta); + + const float x0 = float(data_a[i + 0]); + const float x1 = float(data_a[i + 1]); + + data_d[i + 0] = D_TYPE(x0*cos_theta - x1*sin_theta); + data_d[i + 1] = D_TYPE(x0*sin_theta + x1*cos_theta); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp new file mode 100644 index 000000000..4663428de --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp @@ -0,0 +1,24 @@ +#version 450 + +#include "types.comp" +#include "generic_unary_head.comp" + +const uint num_threads = 128; + +layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; + +void main() { + uint idx = get_idx(); + + // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation + const uint num_iter = 4; + + [[unroll]] for (uint i = 0; i < num_iter; ++i) { + if (idx >= p.ne) { + continue; + } + + data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) * FLOAT_TYPE(p.param1)); + idx += num_threads; + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp new file mode 100644 index 000000000..4d36f88e0 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp @@ -0,0 +1,22 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float xi = float(data_a[i]); + data_d[i] = D_TYPE(xi / (1.0f + exp(-xi))); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp new file mode 100644 index 000000000..d7c15a169 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp @@ -0,0 +1,17 @@ +#version 450 + +#include "types.comp" +#include "generic_unary_head.comp" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint idx = get_idx(); + + if (idx >= p.ne) { + return; + } + + const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]); + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(sin(val)); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp new file mode 100644 index 000000000..a25808e16 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp @@ -0,0 +1,174 @@ +#version 450 + +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_EXT_control_flow_attributes : enable + +layout (push_constant) uniform parameter +{ + uint KX; + uint KY; + float scale; + float max_bias; + float m0; + float m1; + uint n_head_log2; + uint nrows_x; +} p; + +#include "types.comp" + +layout(constant_id = 0) const uint BLOCK_SIZE = 32; +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer Y {B_TYPE data_b[];}; +layout (binding = 2) buffer D {D_TYPE data_d[];}; + +shared FLOAT_TYPE vals[BLOCK_SIZE]; + +// num_iters is the number of BLOCK_SIZE loop iterations we need to iterate +// over all the columns. The main function tries to pass a constant here, +// as if it were a template function, to allow unrolling. +void soft_max(uint num_iters) { + const uint tid = gl_LocalInvocationID.x; + const uint rowx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; + const uint rowy = (p.KY > 0) ? (rowx % p.KY) : 0; + + if (rowx >= p.nrows_x) { + return; + } + + float slope = 1.0f; + + // ALiBi + if (p.max_bias > 0.0f) { + const uint h = rowx/p.KY; // head index + + const float base = h < p.n_head_log2 ? p.m0 : p.m1; + const uint exp = h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1; + + slope = pow(base, exp); + } + + // Find max + FLOAT_TYPE max_val = uintBitsToFloat(0xFF800000); + + // Cache values while we compute the max, so we don't need to read them + // again when we're ready to compute exp(x-max). + const uint DATA_CACHE_SIZE = 16; + FLOAT_TYPE data_cache[DATA_CACHE_SIZE]; + + [[unroll]] for (uint col0 = 0, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) { + const uint col = col0 + tid; + + FLOAT_TYPE a = FLOAT_TYPE(0); + if (col < p.KX) { + a = data_a[rowx * p.KX + col]; + } + + FLOAT_TYPE b = FLOAT_TYPE(0); + if (p.KY > 0 && col < p.KX) { + b = data_b[rowy * p.KX + col]; + } + + FLOAT_TYPE v = a * p.scale + slope * b; + + if (col < p.KX) { + max_val = max(max_val, v); + } + + if (idx < DATA_CACHE_SIZE) { + data_cache[idx] = v; + } + } + + // reduce across the workgroup + vals[tid] = max_val; + barrier(); + [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + vals[tid] = max(vals[tid], vals[tid + s]); + } + barrier(); + } + + max_val = vals[0]; + barrier(); + + FLOAT_TYPE sum = FLOAT_TYPE(0.0f); + + // Compute sum{exp(x - max)} + [[unroll]] for (uint col0 = 0, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) { + const uint col = col0 + tid; + + if (col >= p.KX) { + break; + } + + // compute exp(a*scale+b*slope), add it to sum, and cache the new value + // in data_cache if possible. + const uint i = rowx * p.KX + col; + FLOAT_TYPE val; + if (idx < DATA_CACHE_SIZE) { + val = exp(data_cache[idx] - max_val); + } else { + val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)) - max_val); + } + sum += val; + if (idx < DATA_CACHE_SIZE) { + data_cache[idx] = val; + } else { + data_d[i] = D_TYPE(val); + } + } + + // reduce across the workgroup + vals[tid] = sum; + barrier(); + [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + vals[tid] += vals[tid + s]; + } + barrier(); + } + sum = vals[0]; + + FLOAT_TYPE rcpdivisor = 1.0/sum; + + [[unroll]] for (uint col0 = 0, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) { + const uint col = col0 + tid; + + if (col >= p.KX) { + continue; + } + + if (idx < DATA_CACHE_SIZE) { + data_d[rowx*p.KX + col] = D_TYPE(data_cache[idx] * rcpdivisor); + } else { + data_d[rowx*p.KX + col] *= D_TYPE(rcpdivisor); + } + } +} + +void main() { + // instantiate the soft_max function for several different + // dimensions, to allow loop unrolling + uint num_blocks = (p.KX + BLOCK_SIZE - 1) / BLOCK_SIZE; + if (num_blocks > 32) { + soft_max(num_blocks); + } else if (num_blocks > 16) { + soft_max(32); + } else if (num_blocks > 8) { + soft_max(16); + } else if (num_blocks > 4) { + soft_max(8); + } else if (num_blocks == 4) { + soft_max(4); + } else if (num_blocks == 3) { + soft_max(3); + } else if (num_blocks == 2) { + soft_max(2); + } else if (num_blocks == 1) { + soft_max(1); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/square.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/square.comp new file mode 100644 index 000000000..ef43598ba --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/square.comp @@ -0,0 +1,17 @@ +#version 450 + +#include "types.comp" +#include "generic_unary_head.comp" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint idx = get_idx(); + + if (idx >= p.ne) { + return; + } + + const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]); + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val * val); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp new file mode 100644 index 000000000..961e5ffa1 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp @@ -0,0 +1,37 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +layout (constant_id = 0) const uint BLOCK_SIZE = 32; + +shared FLOAT_TYPE tmp[BLOCK_SIZE]; + +void main() { + const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; + const uint col = gl_LocalInvocationID.x; + + tmp[col] = FLOAT_TYPE(0.0f); + + for (uint i = col; i < p.KX; i += BLOCK_SIZE) { + tmp[col] += FLOAT_TYPE(data_a[row*p.KX + i]); + } + + barrier(); + [[unroll]] for (int s = int(BLOCK_SIZE) / 2; s > 0; s >>= 1) { + if (col < s) { + tmp[col] += tmp[col + s]; + } + barrier(); + } + + if (col == 0) { + data_d[row] = D_TYPE(tmp[0]); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp new file mode 100644 index 000000000..495f966bd --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp @@ -0,0 +1,20 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + data_d[i] = D_TYPE(1. - 2. / (exp(2.*data_a[i]) + 1.)); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat2_support.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat2_support.comp new file mode 100644 index 000000000..28eb24e11 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat2_support.comp @@ -0,0 +1,7 @@ +#version 460 + +#extension GL_NV_cooperative_matrix2 : require + +void main() +{ +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp new file mode 100644 index 000000000..79e065a93 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp @@ -0,0 +1,41 @@ +#version 450 + +#extension GL_EXT_shader_16bit_storage : require + +layout (push_constant) uniform parameter +{ + uint nb1; + uint dim; + uint max_period; +} p; + +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable +#define BLOCK_SIZE 256 + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_WorkGroupID.y; + const uint j = gl_GlobalInvocationID.x; + const uint d_offset = i * p.nb1; + + if (p.dim % 2 != 0 && j == ((p.dim + 1) / 2)) { + data_d[d_offset + p.dim] = 0.f; + } + + const uint half_dim = p.dim / 2; + if (j >= half_dim) { + return; + } + + const float timestep = float(data_a[i]); + const float freq = float(exp(-log(p.max_period) * j / half_dim)); + const float arg = timestep * freq; + data_d[d_offset + j] = D_TYPE(cos(arg)); + data_d[d_offset + j + half_dim] = D_TYPE(sin(arg)); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/types.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/types.comp new file mode 100644 index 000000000..eecc47f3a --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/types.comp @@ -0,0 +1,323 @@ + +#if !defined(GGML_TYPES_COMP) +#define GGML_TYPES_COMP + +#extension GL_EXT_shader_explicit_arithmetic_types : require + +#if defined(DATA_A_F32) +#define QUANT_K 1 +#define QUANT_R 1 + +#if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1 +#define A_TYPE float +#elif LOAD_VEC_A == 4 +#define A_TYPE vec4 +#elif LOAD_VEC_A == 8 +#define A_TYPE mat2x4 +#endif +#endif + +#if defined(DATA_A_F16) +#define QUANT_K 1 +#define QUANT_R 1 + +#if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1 +#define A_TYPE float16_t +#elif LOAD_VEC_A == 4 +#define A_TYPE f16vec4 +#elif LOAD_VEC_A == 8 +#define A_TYPE f16mat2x4 +#endif +#endif + +#define QUANT_K_Q4_0 32 +#define QUANT_R_Q4_0 2 + +struct block_q4_0 +{ + float16_t d; + uint8_t qs[16]; +}; +struct block_q4_0_packed16 +{ + float16_t d; + uint16_t qs[16/2]; +}; + +#if defined(DATA_A_Q4_0) +#define QUANT_K QUANT_K_Q4_0 +#define QUANT_R QUANT_R_Q4_0 +#define A_TYPE block_q4_0 +#define A_TYPE_PACKED16 block_q4_0_packed16 +#endif + +#define QUANT_K_Q4_1 32 +#define QUANT_R_Q4_1 2 + +struct block_q4_1 +{ + float16_t d; + float16_t m; + uint8_t qs[16]; +}; + +struct block_q4_1_packed16 +{ + float16_t d; + float16_t m; + uint16_t qs[16/2]; +}; + +#if defined(DATA_A_Q4_1) +#define QUANT_K QUANT_K_Q4_1 +#define QUANT_R QUANT_R_Q4_1 +#define A_TYPE block_q4_1 +#define A_TYPE_PACKED16 block_q4_1_packed16 +#endif + +#define QUANT_K_Q5_0 32 +#define QUANT_R_Q5_0 2 + +struct block_q5_0 +{ + float16_t d; + uint16_t qh[2]; + uint8_t qs[16]; +}; + +struct block_q5_0_packed16 +{ + float16_t d; + uint16_t qh[2]; + uint16_t qs[16/2]; +}; + +#if defined(DATA_A_Q5_0) +#define QUANT_K QUANT_K_Q5_0 +#define QUANT_R QUANT_R_Q5_0 +#define A_TYPE block_q5_0 +#define A_TYPE_PACKED16 block_q5_0_packed16 +#endif + +#define QUANT_K_Q5_1 32 +#define QUANT_R_Q5_1 2 + +struct block_q5_1 +{ + float16_t d; + float16_t m; + uint qh; + uint8_t qs[16]; +}; + +struct block_q5_1_packed16 +{ + float16_t d; + float16_t m; + uint qh; + uint16_t qs[16/2]; +}; + +#if defined(DATA_A_Q5_1) +#define QUANT_K QUANT_K_Q5_1 +#define QUANT_R QUANT_R_Q5_1 +#define A_TYPE block_q5_1 +#define A_TYPE_PACKED16 block_q5_1_packed16 +#endif + +#define QUANT_K_Q8_0 32 +#define QUANT_R_Q8_0 1 + +struct block_q8_0 +{ + float16_t d; + int8_t qs[32]; +}; +struct block_q8_0_packed16 +{ + float16_t d; + uint16_t qs[32/2]; +}; + +#if defined(DATA_A_Q8_0) +#define QUANT_K QUANT_K_Q8_0 +#define QUANT_R QUANT_R_Q8_0 +#define A_TYPE block_q8_0 +#define A_TYPE_PACKED16 block_q8_0_packed16 +#endif + +// K-quants +#define QUANT_K_Q2_K 256 + +struct block_q2_K +{ + uint8_t scales[QUANT_K_Q2_K/16]; + uint8_t qs[QUANT_K_Q2_K/4]; + f16vec2 d; +}; + +struct block_q2_K_packed16 +{ + uint16_t scales[QUANT_K_Q2_K/16/2]; + uint16_t qs[QUANT_K_Q2_K/4/2]; + f16vec2 d; +}; + +struct block_q2_K_packed32 +{ + uint32_t scales[QUANT_K_Q2_K/16/4]; + uint32_t qs[QUANT_K_Q2_K/4/4]; + f16vec2 d; +}; + +#if defined(DATA_A_Q2_K) +#define QUANT_K QUANT_K_Q2_K +#define A_TYPE block_q2_K +#define A_TYPE_PACKED16 block_q2_K_packed16 +#define A_TYPE_PACKED32 block_q2_K_packed32 +#endif + +#define QUANT_K_Q3_K 256 + +struct block_q3_K +{ + uint8_t hmask[QUANT_K_Q3_K/8]; + uint8_t qs[QUANT_K_Q3_K/4]; + uint8_t scales[12]; + float16_t d; +}; + +struct block_q3_K_packed16 +{ + uint16_t hmask[QUANT_K_Q3_K/8/2]; + uint16_t qs[QUANT_K_Q3_K/4/2]; + uint16_t scales[12/2]; + float16_t d; +}; + +#if defined(DATA_A_Q3_K) +#define QUANT_K QUANT_K_Q3_K +#define A_TYPE block_q3_K +#define A_TYPE_PACKED16 block_q3_K_packed16 +#endif + +#define QUANT_K_Q4_K 256 + +struct block_q4_K +{ + f16vec2 d; + uint8_t scales[3*QUANT_K_Q4_K/64]; + uint8_t qs[QUANT_K_Q4_K/2]; +}; + +struct block_q4_K_packed16 +{ + f16vec2 d; + uint16_t scales[3*QUANT_K_Q4_K/64/2]; + uint16_t qs[QUANT_K_Q4_K/2/2]; +}; + +struct block_q4_K_packed32 +{ + f16vec2 d; + uint32_t scales[3*QUANT_K_Q4_K/64/4]; + uint32_t qs[QUANT_K_Q4_K/2/4]; +}; + +#if defined(DATA_A_Q4_K) +#define QUANT_K QUANT_K_Q4_K +#define A_TYPE block_q4_K +#define A_TYPE_PACKED16 block_q4_K_packed16 +#define A_TYPE_PACKED32 block_q4_K_packed32 +#endif + +#define QUANT_K_Q5_K 256 + +struct block_q5_K +{ + f16vec2 d; + uint8_t scales[12]; + uint8_t qh[QUANT_K_Q5_K/8]; + uint8_t qs[QUANT_K_Q5_K/2]; +}; + +struct block_q5_K_packed16 +{ + f16vec2 d; + uint16_t scales[12/2]; + uint16_t qh[QUANT_K_Q5_K/8/2]; + uint16_t qs[QUANT_K_Q5_K/2/2]; +}; + +#if defined(DATA_A_Q5_K) +#define QUANT_K QUANT_K_Q5_K +#define A_TYPE block_q5_K +#define A_TYPE_PACKED16 block_q5_K_packed16 +#endif + +#define QUANT_K_Q6_K 256 + +struct block_q6_K +{ + uint8_t ql[QUANT_K_Q6_K/2]; + uint8_t qh[QUANT_K_Q6_K/4]; + int8_t scales[QUANT_K_Q6_K/16]; + float16_t d; +}; + +struct block_q6_K_packed16 +{ + uint16_t ql[QUANT_K_Q6_K/2/2]; + uint16_t qh[QUANT_K_Q6_K/4/2]; + int8_t scales[QUANT_K_Q6_K/16]; + float16_t d; +}; + +#if defined(DATA_A_Q6_K) +#define QUANT_K QUANT_K_Q6_K +#define A_TYPE block_q6_K +#define A_TYPE_PACKED16 block_q6_K_packed16 +#endif + +// IQuants + +#define QUANT_K_IQ4_NL 32 +#define QUANT_R_IQ4_NL 2 + +struct block_iq4_nl +{ + float16_t d; + uint8_t qs[QUANT_K_IQ4_NL/2]; +}; + +struct block_iq4_nl_packed16 +{ + float16_t d; + uint16_t qs[QUANT_K_IQ4_NL/2/2]; +}; + +#if defined(DATA_A_IQ4_NL) + +const int8_t kvalues_iq4nl_const[16] = { + int8_t(-127), int8_t(-104), int8_t(-83), int8_t(-65), int8_t(-49), int8_t(-35), int8_t(-22), int8_t(-10), + int8_t(1), int8_t(13), int8_t(25), int8_t(38), int8_t(53), int8_t(69), int8_t(89), int8_t(113) +}; + +shared FLOAT_TYPE kvalues_iq4nl[16]; + +void init_iq4nl_shmem() +{ + // copy the table into shared memory and sync + if (gl_LocalInvocationIndex.x < 16) { + kvalues_iq4nl[gl_LocalInvocationIndex.x] = FLOAT_TYPE(kvalues_iq4nl_const[gl_LocalInvocationIndex.x]); + } + barrier(); +} + +#define QUANT_K QUANT_K_IQ4_NL +#define QUANT_R QUANT_R_IQ4_NL +#define A_TYPE block_iq4_nl +#define A_TYPE_PACKED16 block_iq4_nl_packed16 +#endif + +#endif // !defined(GGML_TYPES_COMP) diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp new file mode 100644 index 000000000..6f607380d --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp @@ -0,0 +1,36 @@ +#version 450 + +layout (push_constant) uniform parameter +{ + uint ne; uint a_offset; uint d_offset; + uint nb00; uint nb01; uint nb02; uint nb03; + uint ne10; uint ne11; uint ne12; uint ne13; + float sf0; float sf1; float sf2; float sf3; +} p; + +#include "types.comp" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (idx >= p.ne) { + return; + } + + const uint i10 = idx % p.ne10; + const uint i11 = (idx / p.ne10) % p.ne11; + const uint i12 = (idx / (p.ne10 * p.ne11)) % p.ne12; + const uint i13 = (idx / (p.ne10 * p.ne11 * p.ne12)) % p.ne13; + + const uint i00 = uint(i10 / p.sf0); + const uint i01 = uint(i11 / p.sf1); + const uint i02 = uint(i12 / p.sf2); + const uint i03 = uint(i13 / p.sf3); + + data_d[p.d_offset + idx] = D_TYPE(data_a[p.a_offset + i03 * p.nb03 + i02 * p.nb02 + i01 * p.nb01 + i00 * p.nb00]); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp new file mode 100644 index 000000000..8111c0638 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -0,0 +1,594 @@ + + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef _WIN32 + #include + #include // For _mkdir on Windows + #include // For std::replace on w64devkit +#else + #include + #include + #include +#endif + +#include + +#define ASYNCIO_CONCURRENCY 64 + +std::mutex lock; +std::vector> shader_fnames; + +std::string GLSLC = "glslc"; +std::string input_dir = "vulkan-shaders"; +std::string output_dir = "/tmp"; +std::string target_hpp = "ggml-vulkan-shaders.hpp"; +std::string target_cpp = "ggml-vulkan-shaders.cpp"; +bool no_clean = false; + +const std::vector type_names = { + "f32", + "f16", + "q4_0", + "q4_1", + "q5_0", + "q5_1", + "q8_0", + "q2_k", + "q3_k", + "q4_k", + "q5_k", + "q6_k", + "iq4_nl" +}; + +namespace { +void execute_command(const std::string& command, std::string& stdout_str, std::string& stderr_str) { +#ifdef _WIN32 + HANDLE stdout_read, stdout_write; + HANDLE stderr_read, stderr_write; + SECURITY_ATTRIBUTES sa = { sizeof(SECURITY_ATTRIBUTES), NULL, TRUE }; + + if (!CreatePipe(&stdout_read, &stdout_write, &sa, 0) || + !SetHandleInformation(stdout_read, HANDLE_FLAG_INHERIT, 0)) { + throw std::runtime_error("Failed to create stdout pipe"); + } + + if (!CreatePipe(&stderr_read, &stderr_write, &sa, 0) || + !SetHandleInformation(stderr_read, HANDLE_FLAG_INHERIT, 0)) { + throw std::runtime_error("Failed to create stderr pipe"); + } + + PROCESS_INFORMATION pi; + STARTUPINFOA si = {}; + si.cb = sizeof(STARTUPINFOA); + si.dwFlags = STARTF_USESTDHANDLES; + si.hStdOutput = stdout_write; + si.hStdError = stderr_write; + + std::vector cmd(command.begin(), command.end()); + cmd.push_back('\0'); + + if (!CreateProcessA(NULL, cmd.data(), NULL, NULL, TRUE, 0, NULL, NULL, &si, &pi)) { + throw std::runtime_error("Failed to create process"); + } + + CloseHandle(stdout_write); + CloseHandle(stderr_write); + + std::array buffer; + DWORD bytes_read; + + while (ReadFile(stdout_read, buffer.data(), (DWORD)buffer.size(), &bytes_read, NULL) && bytes_read > 0) { + stdout_str.append(buffer.data(), bytes_read); + } + + while (ReadFile(stderr_read, buffer.data(), (DWORD)buffer.size(), &bytes_read, NULL) && bytes_read > 0) { + stderr_str.append(buffer.data(), bytes_read); + } + + CloseHandle(stdout_read); + CloseHandle(stderr_read); + WaitForSingleObject(pi.hProcess, INFINITE); + CloseHandle(pi.hProcess); + CloseHandle(pi.hThread); +#else +int stdout_pipe[2]; + int stderr_pipe[2]; + + if (pipe(stdout_pipe) != 0 || pipe(stderr_pipe) != 0) { + throw std::runtime_error("Failed to create pipes"); + } + + pid_t pid = fork(); + if (pid < 0) { + throw std::runtime_error("Failed to fork process"); + } + + if (pid == 0) { + close(stdout_pipe[0]); + close(stderr_pipe[0]); + dup2(stdout_pipe[1], STDOUT_FILENO); + dup2(stderr_pipe[1], STDERR_FILENO); + close(stdout_pipe[1]); + close(stderr_pipe[1]); + execl("/bin/sh", "sh", "-c", command.c_str(), (char*) nullptr); + _exit(EXIT_FAILURE); + } else { + close(stdout_pipe[1]); + close(stderr_pipe[1]); + + std::array buffer; + ssize_t bytes_read; + + while ((bytes_read = read(stdout_pipe[0], buffer.data(), buffer.size())) > 0) { + stdout_str.append(buffer.data(), bytes_read); + } + + while ((bytes_read = read(stderr_pipe[0], buffer.data(), buffer.size())) > 0) { + stderr_str.append(buffer.data(), bytes_read); + } + + close(stdout_pipe[0]); + close(stderr_pipe[0]); + waitpid(pid, nullptr, 0); + } +#endif +} + +bool directory_exists(const std::string& path) { + struct stat info; + if (stat(path.c_str(), &info) != 0) { + return false; // Path doesn't exist or can't be accessed + } + return (info.st_mode & S_IFDIR) != 0; // Check if it is a directory +} + +bool create_directory(const std::string& path) { +#ifdef _WIN32 + return _mkdir(path.c_str()) == 0 || errno == EEXIST; // EEXIST means the directory already exists +#else + return mkdir(path.c_str(), 0755) == 0 || errno == EEXIST; // 0755 is the directory permissions +#endif +} + +std::string to_uppercase(const std::string& input) { + std::string result = input; + for (char& c : result) { + c = std::toupper(c); + } + return result; +} + +bool string_ends_with(const std::string& str, const std::string& suffix) { + if (suffix.size() > str.size()) { + return false; + } + return std::equal(suffix.rbegin(), suffix.rend(), str.rbegin()); +} + +static const char path_separator = '/'; + +std::string join_paths(const std::string& path1, const std::string& path2) { + return path1 + path_separator + path2; +} + +std::string basename(const std::string &path) { + return path.substr(path.find_last_of("/\\") + 1); +} + +// variables to track number of compiles in progress +static uint32_t compile_count = 0; +static std::mutex compile_count_mutex; +static std::condition_variable compile_count_cond; + +void string_to_spv_func(const std::string& _name, const std::string& in_fname, const std::map& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) { + std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat ? "_coopmat" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32")); + std::string out_fname = join_paths(output_dir, name + ".spv"); + std::string in_path = join_paths(input_dir, in_fname); + + std::string target_env = (name.find("_cm2") != std::string::npos) ? "--target-env=vulkan1.3" : "--target-env=vulkan1.2"; + + // disable spirv-opt for coopmat shaders for https://github.com/ggerganov/llama.cpp/issues/10734 + std::string opt_level = coopmat ? "" : "-O"; + + #ifdef _WIN32 + std::vector cmd = {GLSLC, "-fshader-stage=compute", target_env, opt_level, "\"" + in_path + "\"", "-o", "\"" + out_fname + "\""}; + #else + std::vector cmd = {GLSLC, "-fshader-stage=compute", target_env, opt_level, in_path, "-o", out_fname}; + #endif + + #ifdef GGML_VULKAN_SHADER_DEBUG_INFO + cmd.push_back("-g"); + #endif + + for (const auto& define : defines) { + cmd.push_back("-D" + define.first + "=" + define.second); + } + + std::string command; + for (const auto& part : cmd) { + command += part + " "; + } + + std::string stdout_str, stderr_str; + try { + // std::cout << "Executing command: "; + // for (const auto& part : cmd) { + // std::cout << part << " "; + // } + // std::cout << std::endl; + + execute_command(command, stdout_str, stderr_str); + if (!stderr_str.empty()) { + std::cerr << "cannot compile " << name << "\n\n" << command << "\n\n" << stderr_str << std::endl; + return; + } + + std::lock_guard guard(lock); + shader_fnames.push_back(std::make_pair(name, out_fname)); + } catch (const std::exception& e) { + std::cerr << "Error executing command for " << name << ": " << e.what() << std::endl; + } + { + std::lock_guard guard(compile_count_mutex); + assert(compile_count > 0); + compile_count--; + } + compile_count_cond.notify_all(); +} + +std::map merge_maps(const std::map& a, const std::map& b) { + std::map result = a; + result.insert(b.begin(), b.end()); + return result; +} + +static std::vector> compiles; +void string_to_spv(const std::string& _name, const std::string& in_fname, const std::map& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) { + { + // wait until fewer than N compiles are in progress. + // 16 is an arbitrary limit, the goal is to avoid "failed to create pipe" errors. + uint32_t N = 16; + std::unique_lock guard(compile_count_mutex); + while (compile_count >= N) { + compile_count_cond.wait(guard); + } + compile_count++; + } + compiles.push_back(std::async(string_to_spv_func, _name, in_fname, defines, fp16, coopmat, coopmat2, f16acc)); +} + +void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool f16acc) { + std::string load_vec = coopmat2 ? "1" : fp16 ? "8" : "4"; + std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4"; + std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4"; + + std::map base_dict = {{"FLOAT_TYPE", (coopmat2 || fp16) ? "float16_t" : "float"}}; + std::string shader_name = "matmul"; + + if (matmul_id) { + base_dict["MUL_MAT_ID"] = "1"; + shader_name = "matmul_id"; + } + + if (fp16) { + base_dict["FLOAT16"] = "1"; + } + + base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float"; + + if (coopmat) { + base_dict["COOPMAT"] = "1"; + } + + base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float"; + + std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp"; + + // Shaders with f16 B_TYPE + string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + + string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); + + for (const auto& tname : type_names) { + std::string data_a_key = "DATA_A_" + to_uppercase(tname); + // For unaligned, load one at a time for f32/f16, or two at a time for quants + std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16") ? "1" : "2"; + // For aligned matmul loads + std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : "2"; + + string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + + if (tname != "f16" && tname != "f32") { + string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + } + } +} + +void process_shaders() { + std::cout << "ggml_vulkan: Generating and compiling shaders to SPIR-V" << std::endl; + std::map base_dict = {{"FLOAT_TYPE", "float"}}; + + // matmul + for (const auto& matmul_id : {false, true}) { + // No coopmats + // fp32 + matmul_shaders(false, matmul_id, false, false, false); + + // fp16, fp32acc and fp16acc + matmul_shaders(true, matmul_id, false, false, false); + matmul_shaders(true, matmul_id, false, false, true); + + // Coopmat, fp32acc and fp16acc + matmul_shaders(true, matmul_id, true, false, false); + matmul_shaders(true, matmul_id, true, false, true); + +#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + // Coopmat2, fp32acc and fp16acc + matmul_shaders(true, matmul_id, false, true, false); + matmul_shaders(true, matmul_id, false, true, true); +#endif + } + +#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + // flash attention + for (const auto& f16acc : {false, true}) { + std::string acctype = f16acc ? "float16_t" : "float"; + + for (const auto& tname : type_names) { + if (tname == "f32") { + continue; + } + + if (tname == "f16") { + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", + merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, true, f16acc); + } else { + std::string data_a_key = "DATA_A_" + to_uppercase(tname); + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", + merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc); + } + } + } +#endif + + for (const auto& tname : type_names) { + // mul mat vec + std::string data_a_key = "DATA_A_" + to_uppercase(tname); + std::string shader = (string_ends_with(tname, "_k")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp"; + + string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}})); + string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}})); + + string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}})); + + // Dequant shaders + if (tname != "f16") { + string_to_spv("dequant_" + tname, "dequant_" + tname + ".comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float16_t"}})); + } + + if (!string_ends_with(tname, "_k")) { + shader = (tname == "f32" || tname == "f16") ? "get_rows.comp" : "get_rows_quant.comp"; + + if (tname == "f16") { + string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}})); + } else { + string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}})); + } + string_to_spv("get_rows_" + tname + "_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}})); + } + } + + string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); + + // Norms + string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + + string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}); + string_to_spv("cpy_f16_f16", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); + string_to_spv("contig_cpy_f32_f32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("contig_cpy_f32_f16", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}); + string_to_spv("contig_cpy_f16_f16", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); + + string_to_spv("add_f32", "add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + string_to_spv("add_f16_f32_f16", "add.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}}); + + string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + + string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {}); + + string_to_spv("mul_f32", "mul.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + + string_to_spv("div_f32", "div.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + + string_to_spv("repeat_f32", "repeat.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + + string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + + string_to_spv("sqr_f32", "square.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + + string_to_spv("sin_f32", "sin.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + + string_to_spv("cos_f32", "cos.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + + string_to_spv("clamp_f32", "clamp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + + string_to_spv("pad_f32", "pad.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + + string_to_spv("concat_f32", "concat.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("concat_f16", "concat.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); + string_to_spv("concat_i32", "concat.comp", {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}}); + + string_to_spv("upscale_f32", "upscale.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); + + string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("gelu_quick_f32", "gelu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + + string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + + string_to_spv("soft_max_f32", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("soft_max_f32_f16", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}})); + + string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("rope_norm_f16_rte", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}); + + string_to_spv("rope_neox_f32", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("rope_neox_f16_rte", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}); + + string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}}); + + string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + + string_to_spv("im2col_f32", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("im2col_f32_f16", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}})); + string_to_spv("im2col_f32_f16_rte", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}})); + + string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + + string_to_spv("pool2d_f32", "pool2d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + + string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); + + for (auto &c : compiles) { + c.wait(); + } +} + +void write_output_files() { + FILE* hdr = fopen(target_hpp.c_str(), "w"); + FILE* src = fopen(target_cpp.c_str(), "w"); + + fprintf(hdr, "#include \n\n"); + fprintf(src, "#include \"%s\"\n\n", basename(target_hpp).c_str()); + + for (const auto& pair : shader_fnames) { + const std::string& name = pair.first; + #ifdef _WIN32 + std::string path = pair.second; + std::replace(path.begin(), path.end(), '/', '\\' ); + #else + const std::string& path = pair.second; + #endif + + FILE* spv = fopen(path.c_str(), "rb"); + if (!spv) { + std::cerr << "Error opening SPIR-V file: " << path << " (" << strerror(errno) << ")\n"; + continue; + } + + fseek(spv, 0, SEEK_END); + size_t size = ftell(spv); + fseek(spv, 0, SEEK_SET); + + std::vector data(size); + size_t read_size = fread(data.data(), 1, size, spv); + fclose(spv); + if (read_size != size) { + std::cerr << "Error reading SPIR-V file: " << path << " (" << strerror(errno) << ")\n"; + continue; + } + + fprintf(hdr, "extern unsigned char %s_data[%zu];\n", name.c_str(), size); + fprintf(hdr, "const uint64_t %s_len = %zu;\n\n", name.c_str(), size); + + fprintf(src, "unsigned char %s_data[%zu] = {\n", name.c_str(), size); + for (size_t i = 0; i < size; ++i) { + fprintf(src, "0x%02x,", data[i]); + if ((i + 1) % 12 == 0) fprintf(src, "\n"); + } + fprintf(src, "\n};\n\n"); + + if (!no_clean) { + std::remove(path.c_str()); + } + } + + fclose(hdr); + fclose(src); +} +} + +int main(int argc, char** argv) { + std::map args; + for (int i = 1; i < argc; ++i) { + std::string arg = argv[i]; + if (arg.rfind("--", 0) == 0) { + if (i + 1 < argc && argv[i + 1][0] != '-') { + args[arg] = argv[i + 1]; + ++i; + } else { + args[arg] = ""; + } + } + } + + if (args.find("--glslc") != args.end()) { + GLSLC = args["--glslc"]; // Path to glslc + } + if (args.find("--input-dir") != args.end()) { + input_dir = args["--input-dir"]; // Directory containing shader sources + } + if (args.find("--output-dir") != args.end()) { + output_dir = args["--output-dir"]; // Directory for containing SPIR-V output + } + if (args.find("--target-hpp") != args.end()) { + target_hpp = args["--target-hpp"]; // Path to generated header file + } + if (args.find("--target-cpp") != args.end()) { + target_cpp = args["--target-cpp"]; // Path to generated cpp file + } + if (args.find("--no-clean") != args.end()) { + no_clean = true; // Keep temporary SPIR-V files in output-dir after build + } + + if (!directory_exists(input_dir)) { + std::cerr << "\"" << input_dir << "\" must be a valid directory containing shader sources" << std::endl; + return EXIT_FAILURE; + } + + if (!directory_exists(output_dir)) { + if (!create_directory(output_dir)) { + std::cerr << "Error creating output directory: " << output_dir << "\n"; + return EXIT_FAILURE; + } + } + + process_shaders(); + + write_output_files(); + + return EXIT_SUCCESS; +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp new file mode 100644 index 000000000..35cc6c45f --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp @@ -0,0 +1,87 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : require + +#define BLOCK_SIZE 64 +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout(push_constant) uniform Parameters { + uint B; + uint T; + uint C; + uint H; +}; + +layout(binding = 0) readonly buffer KBuf { A_TYPE k[]; }; +layout(binding = 1) readonly buffer VBuf { A_TYPE v[]; }; +layout(binding = 2) readonly buffer RBuf { A_TYPE r[]; }; +layout(binding = 3) readonly buffer TimeFBuf { A_TYPE tf[]; }; +layout(binding = 4) readonly buffer TimeDBuf { A_TYPE td[]; }; +layout(binding = 5) readonly buffer StateBuf { A_TYPE state_in[]; }; +layout(binding = 6) buffer DstBuf { A_TYPE dst[]; }; + +shared A_TYPE _k[BLOCK_SIZE], _r[BLOCK_SIZE], _tf[BLOCK_SIZE], _td[BLOCK_SIZE]; + +void main() { + const uint head_size = BLOCK_SIZE; + const uint batch_id = gl_WorkGroupID.x / H; + const uint head_id = gl_WorkGroupID.x % H; + const uint tid = gl_LocalInvocationID.x; + + const uint state_size = C * head_size; + const uint n_seq_tokens = T / B; + + if (batch_id >= B || head_id >= H) { + return; + } + + A_TYPE state[BLOCK_SIZE]; + [[unroll]] for (uint i = 0; i < head_size; i++) { + state[i] = state_in[batch_id * state_size + head_id * head_size * head_size + + i * head_size + tid]; + } + + barrier(); + _tf[tid] = tf[head_id * head_size + tid]; + barrier(); + + const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid; + const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid; + + for (uint t = start_t; t < end_t; t += C) { + barrier(); + _k[tid] = k[t]; + _r[tid] = r[t]; + _td[tid] = td[t]; + barrier(); + + const A_TYPE v_val = v[t]; + A_TYPE y = 0.0; + + [[unroll]] for (uint j = 0; j < head_size; j += 4) { + vec4 k_vec = vec4(_k[j], _k[j+1], _k[j+2], _k[j+3]); + vec4 r_vec = vec4(_r[j], _r[j+1], _r[j+2], _r[j+3]); + vec4 tf_vec = vec4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]); + vec4 td_vec = vec4(_td[j], _td[j+1], _td[j+2], _td[j+3]); + vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]); + + vec4 kv = k_vec * v_val; + + vec4 temp = tf_vec * kv + s_vec; + y += dot(r_vec, temp); + + s_vec = s_vec * td_vec + kv; + state[j] = s_vec.x; + state[j+1] = s_vec.y; + state[j+2] = s_vec.z; + state[j+3] = s_vec.w; + } + + dst[t] = y; + } + + [[unroll]] for (uint i = 0; i < head_size; i++) { + dst[T * C + batch_id * state_size + head_id * head_size * head_size + + i * head_size + tid] = state[i]; + } +} From b14dd68feeaebafc71479f8df2ef033ccb0bac3d Mon Sep 17 00:00:00 2001 From: Vadim Grinco Date: Mon, 10 Mar 2025 12:15:43 +0100 Subject: [PATCH 026/172] Fixed the "detached head" issues Signed-off-by: Vadim Grinco --- Dockerfile | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/Dockerfile b/Dockerfile index 9e2928108..a31706243 100644 --- a/Dockerfile +++ b/Dockerfile @@ -37,13 +37,12 @@ COPY patches/ /tmp/patches/ RUN \ git clone https://github.com/pufferffish/ollama-vulkan.git "/tmp/ollama-vulkan-git" && \ cd "/tmp/ollama-vulkan-git" && \ - git checkout 2d443b3dd660a1fd2760d64538512df93648b4bb && git checkout -b ollama_vulkan_stable && \ + git checkout 2d443b3dd660a1fd2760d64538512df93648b4bb -b ollama_vulkan_stable && \ git config user.name "Builder" && git config user.email "builder@local" && \ git remote add ollama_vanilla https://github.com/ollama/ollama.git && \ - git fetch ollama_vanilla --tags && git checkout v0.5.14-rc0 && git checkout -b ollama_vanilla_stable && \ + git fetch ollama_vanilla --tags && git checkout v0.5.13 -b ollama_vanilla_stable && \ git checkout ollama_vulkan_stable && git merge ollama_vanilla_stable --allow-unrelated-histories --no-edit && \ - for p in /tmp/patches/00-fix-vulkan-building.patch; do patch -p1 < $p; done - + for p in /tmp/patches/*.patch; do patch -p1 < $p; done RUN \ cd "/tmp/ollama-vulkan-git" && \ make -f Makefile.sync clean sync From 31606b2febe9b5680cf4fb8e20f26a78dc39ee6e Mon Sep 17 00:00:00 2001 From: Vadim Grinco Date: Mon, 10 Mar 2025 12:49:47 +0100 Subject: [PATCH 027/172] Merged in the right direction Signed-off-by: Vadim Grinco --- Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index a31706243..cabb984fc 100644 --- a/Dockerfile +++ b/Dockerfile @@ -41,7 +41,7 @@ RUN \ git config user.name "Builder" && git config user.email "builder@local" && \ git remote add ollama_vanilla https://github.com/ollama/ollama.git && \ git fetch ollama_vanilla --tags && git checkout v0.5.13 -b ollama_vanilla_stable && \ - git checkout ollama_vulkan_stable && git merge ollama_vanilla_stable --allow-unrelated-histories --no-edit && \ + git checkout ollama_vanilla_stable && git merge ollama_vulkan_stable --allow-unrelated-histories --no-edit && \ for p in /tmp/patches/*.patch; do patch -p1 < $p; done RUN \ cd "/tmp/ollama-vulkan-git" && \ From 6b1f84e171ea5fbf88f0dfca8d00856d8e4e4555 Mon Sep 17 00:00:00 2001 From: Vadim Grinco Date: Tue, 11 Mar 2025 14:09:47 +0100 Subject: [PATCH 028/172] Merging the latest stable (#2) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Applied 00-fix-vulkan-building.patch * Implemented vulkan backend based on the work done by whyvl, Dts0, McBane87 and others Tested on AMD Ryzen 7 8845HS w/ Radeon 780M Graphics with ROCm disabled ``` [GIN-debug] POST /v1/chat/completions --> github.com/ollama/ollama/server.(*Server).ChatHandler-fm (6 handlers) [GIN-debug] POST /v1/completions --> github.com/ollama/ollama/server.(*Server).GenerateHandler-fm (6 handlers) [GIN-debug] POST /v1/embeddings --> github.com/ollama/ollama/server.(*Server).EmbedHandler-fm (6 handlers) [GIN-debug] GET /v1/models --> github.com/ollama/ollama/server.(*Server).ListHandler-fm (6 handlers) [GIN-debug] GET /v1/models/:model --> github.com/ollama/ollama/server.(*Server).ShowHandler-fm (6 handlers) time=2025-03-11T13:00:40.793Z level=INFO source=gpu.go:199 msg="vulkan: load libvulkan and libcap ok" time=2025-03-11T13:00:40.877Z level=INFO source=gpu.go:421 msg="error looking up vulkan GPU memory" error="device is a CPU" time=2025-03-11T13:00:40.878Z level=WARN source=amd_linux.go:443 msg="amdgpu detected, but no compatible rocm library found. Either install rocm v6, or follow manual install instructions at https://github.com/ollama/ollama/blob/main/docs/linux.md#manual-install" time=2025-03-11T13:00:40.878Z level=WARN source=amd_linux.go:348 msg="unable to verify rocm library: no suitable rocm found, falling back to CPU" time=2025-03-11T13:00:40.879Z level=INFO source=types.go:137 msg="inference compute" id=0 library=vulkan variant="" compute=1.3 driver=1.3 name="AMD Radeon Graphics (RADV GFX1103_R1)" total="15.6 GiB" available="15.6 GiB" ``` ``` # ollama run phi4:14b >>> /set verbose Set 'verbose' mode. >>> how's it going? Hello! I'm here to help you with any questions or tasks you have. How can I assist you today? 😊 total duration: 3.341959745s load duration: 18.165612ms prompt eval count: 15 token(s) prompt eval duration: 475ms prompt eval rate: 31.58 tokens/s eval count: 26 token(s) eval duration: 2.846s eval rate: 9.14 tokens/s >>> ``` --- Dockerfile | 217 +++++++++++++++++++++++++++++++++-------------------- 1 file changed, 135 insertions(+), 82 deletions(-) diff --git a/Dockerfile b/Dockerfile index cabb984fc..870adfd72 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,98 +1,151 @@ -FROM --platform=linux/amd64 library/ubuntu:noble as builder +# vim: filetype=dockerfile -ENV DEBIAN_FRONTEND="noninteractive" +ARG FLAVOR=${TARGETARCH} -ENV VULKAN_VER_BASE="1.3.296" -ENV VULKAN_VER="${VULKAN_VER_BASE}.0" -ENV UBUNTU_VERSION="noble" +ARG ROCMVERSION=6.3.3 +ARG JETPACK5VERSION=r35.4.1 +ARG JETPACK6VERSION=r36.4.0 +ARG CMAKEVERSION=3.31.2 +ARG VULKANVERSION=1.4.304.1 -ENV GOLANG_VERSION="1.22.8" -ENV GOARCH="amd64" -ENV CGO_ENABLED=1 +# CUDA v11 requires gcc v10. v10.3 has regressions, so the rockylinux 8.5 AppStream has the latest compatible version +FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base-amd64 +RUN yum install -y yum-utils \ + && yum-config-manager --add-repo https://dl.rockylinux.org/vault/rocky/8.5/AppStream/\$basearch/os/ \ + && rpm --import https://dl.rockylinux.org/pub/rocky/RPM-GPG-KEY-Rocky-8 \ + && dnf install -y yum-utils ccache gcc-toolset-10-gcc-10.2.1-8.2.el8 gcc-toolset-10-gcc-c++-10.2.1-8.2.el8 gcc-toolset-10-binutils-2.35-11.el8 \ + && yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo +ENV PATH=/opt/rh/gcc-toolset-10/root/usr/bin:$PATH +ARG VULKANVERSION +RUN wget https://sdk.lunarg.com/sdk/download/${VULKANVERSION}/linux/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz -O /tmp/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz \ + && tar xvf /tmp/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz \ + && dnf -y install ninja-build libcap-devel \ + && ln -s /usr/bin/python3 /usr/bin/python \ + && /${VULKANVERSION}/vulkansdk -j 8 vulkan-headers \ + && /${VULKANVERSION}/vulkansdk -j 8 shaderc +RUN cp -r /${VULKANVERSION}/x86_64/include/* /usr/local/include/ \ + && cp -r /${VULKANVERSION}/x86_64/lib/* /usr/local/lib +ENV PATH=/${VULKANVERSION}/x86_64/bin:$PATH + +FROM --platform=linux/arm64 almalinux:8 AS base-arm64 +# install epel-release for ccache +RUN yum install -y yum-utils epel-release \ + && dnf install -y clang ccache \ + && yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/sbsa/cuda-rhel8.repo +ENV CC=clang CXX=clang++ + +FROM base-${TARGETARCH} AS base +ARG CMAKEVERSION +RUN curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1 +COPY CMakeLists.txt CMakePresets.json . +COPY ml/backend/ggml/ggml ml/backend/ggml/ggml ENV LDFLAGS=-s -# Default mirror was very slow -RUN \ - sed -i 's/archive.ubuntu.com/gb.archive.ubuntu.com/g' /etc/apt/sources.list.d/ubuntu.sources +FROM base AS cpu +RUN dnf install -y gcc-toolset-11-gcc gcc-toolset-11-gcc-c++ +ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH +RUN --mount=type=cache,target=/root/.ccache \ + cmake --preset 'CPU' \ + && cmake --build --parallel --preset 'CPU' \ + && cmake --install build --component CPU --strip --parallel 8 -RUN \ - apt-get update && \ - apt-get install -y ca-certificates build-essential ccache cmake wget git curl rsync xz-utils libcap-dev +FROM base AS cuda-11 +ARG CUDA11VERSION=11.3 +RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-} +ENV PATH=/usr/local/cuda-11/bin:$PATH +RUN --mount=type=cache,target=/root/.ccache \ + cmake --preset 'CUDA 11' \ + && cmake --build --parallel --preset 'CUDA 11' \ + && cmake --install build --component CUDA --strip --parallel 8 -RUN \ - mkdir -p /usr/local 2>/dev/null || true && \ - curl -s -L https://dl.google.com/go/go${GOLANG_VERSION}.linux-${GOARCH}.tar.gz | tar -xz -C /usr/local && \ - ln -s /usr/local/go/bin/go /usr/local/bin/go && \ - ln -s /usr/local/go/bin/gofmt /usr/local/bin/gofmt +FROM base AS cuda-12 +ARG CUDA12VERSION=12.8 +RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-} +ENV PATH=/usr/local/cuda-12/bin:$PATH +RUN --mount=type=cache,target=/root/.ccache \ + cmake --preset 'CUDA 12' \ + && cmake --build --parallel --preset 'CUDA 12' \ + && cmake --install build --component CUDA --strip --parallel 8 + +FROM base AS rocm-6 +ENV PATH=/opt/rocm/hcc/bin:/opt/rocm/hip/bin:/opt/rocm/bin:/opt/rocm/hcc/bin:$PATH +RUN --mount=type=cache,target=/root/.ccache \ + cmake --preset 'ROCm 6' \ + && cmake --build --parallel --preset 'ROCm 6' \ + && cmake --install build --component HIP --strip --parallel 8 + +FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK5VERSION} AS jetpack-5 +ARG CMAKEVERSION +RUN apt-get update && apt-get install -y curl ccache \ + && curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1 +COPY CMakeLists.txt CMakePresets.json . +COPY ml/backend/ggml/ggml ml/backend/ggml/ggml +RUN --mount=type=cache,target=/root/.ccache \ + cmake --preset 'JetPack 5' \ + && cmake --build --parallel --preset 'JetPack 5' \ + && cmake --install build --component CUDA --strip --parallel 8 + +FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK6VERSION} AS jetpack-6 +ARG CMAKEVERSION +RUN apt-get update && apt-get install -y curl ccache \ + && curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1 +COPY CMakeLists.txt CMakePresets.json . +COPY ml/backend/ggml/ggml ml/backend/ggml/ggml +RUN --mount=type=cache,target=/root/.ccache \ + cmake --preset 'JetPack 6' \ + && cmake --build --parallel --preset 'JetPack 6' \ + && cmake --install build --component CUDA --strip --parallel 8 + +FROM base AS vulkan +RUN --mount=type=cache,target=/root/.ccache \ + cmake --preset 'Vulkan' \ + && cmake --build --parallel --preset 'Vulkan' \ + && cmake --install build --component Vulkan --strip --parallel 8 -RUN \ - wget -qO - https://packages.lunarg.com/lunarg-signing-key-pub.asc | gpg --dearmor -o /etc/apt/trusted.gpg.d/lunarg-signing-key-pub.gpg && \ - wget -qO /etc/apt/sources.list.d/lunarg-vulkan-${UBUNTU_VERSION}.list https://packages.lunarg.com/vulkan/${VULKAN_VER_BASE}/lunarg-vulkan-${VULKAN_VER_BASE}-${UBUNTU_VERSION}.list && \ - apt update && apt install -y vulkan-sdk +FROM base AS build +WORKDIR /go/src/github.com/ollama/ollama +COPY go.mod go.sum . +RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local +ENV PATH=/usr/local/go/bin:$PATH +RUN go mod download +COPY . . +ARG GOFLAGS="'-ldflags=-w -s'" +ENV CGO_ENABLED=1 +RUN --mount=type=cache,target=/root/.cache/go-build \ + go build -trimpath -buildmode=pie -o /bin/ollama . -# Last testet ollama-vulkan commit: -# 2d443b3dd660a1fd2760d64538512df93648b4bb -COPY patches/ /tmp/patches/ -RUN \ - git clone https://github.com/pufferffish/ollama-vulkan.git "/tmp/ollama-vulkan-git" && \ - cd "/tmp/ollama-vulkan-git" && \ - git checkout 2d443b3dd660a1fd2760d64538512df93648b4bb -b ollama_vulkan_stable && \ - git config user.name "Builder" && git config user.email "builder@local" && \ - git remote add ollama_vanilla https://github.com/ollama/ollama.git && \ - git fetch ollama_vanilla --tags && git checkout v0.5.13 -b ollama_vanilla_stable && \ - git checkout ollama_vanilla_stable && git merge ollama_vulkan_stable --allow-unrelated-histories --no-edit && \ - for p in /tmp/patches/*.patch; do patch -p1 < $p; done -RUN \ - cd "/tmp/ollama-vulkan-git" && \ - make -f Makefile.sync clean sync +FROM --platform=linux/amd64 scratch AS amd64 +COPY --from=cuda-11 dist/lib/ollama/cuda_v11 /lib/ollama/cuda_v11 +COPY --from=cuda-12 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_v12 +COPY --from=vulkan dist/lib/ollama/vulkan /lib/ollama/vulkan +FROM --platform=linux/arm64 scratch AS arm64 +COPY --from=cuda-11 dist/lib/ollama/cuda_v11 /lib/ollama/cuda_v11 +COPY --from=cuda-12 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_v12 +COPY --from=jetpack-5 dist/lib/ollama/cuda_v11 lib/ollama/cuda_jetpack5 +COPY --from=jetpack-6 dist/lib/ollama/cuda_v12 lib/ollama/cuda_jetpack6 -FROM builder AS cpu-build -RUN \ - cd "/tmp/ollama-vulkan-git" && \ - cmake --preset CPU && cmake --build --parallel --preset CPU && \ - cmake --install build --component CPU --strip +FROM scratch AS rocm +COPY --from=rocm-6 dist/lib/ollama/rocm /lib/ollama/rocm -FROM builder AS vulkan-build -RUN \ - cd "/tmp/ollama-vulkan-git" && \ - cmake --preset Vulkan && \ - cmake --build --parallel --preset Vulkan && \ - cmake --install build --component Vulkan --strip - -FROM builder AS binary-build -RUN \ - cd "/tmp/ollama-vulkan-git" && \ - . scripts/env.sh && \ - mkdir -p dist/bin && \ - go build -trimpath -buildmode=pie -o dist/bin/ollama . - - -FROM --platform=linux/amd64 library/ubuntu:noble -RUN \ - apt-get update && apt -y dist-upgrade && \ - apt-get install -y ca-certificates libcap2 libvulkan1 && \ - apt-get clean && rm -rf /var/lib/apt/lists/* - -# Install ROCm -RUN \ - apt update && \ - apt install -y wget python3-setuptools python3-wheel && \ - wget https://repo.radeon.com/amdgpu-install/6.3.3/ubuntu/noble/amdgpu-install_6.3.60303-1_all.deb -O /tmp/amdgpu-install_6.3.60303-1_all.deb && \ - apt install -y /tmp/amdgpu-install_6.3.60303-1_all.deb && \ - apt update && apt install -y rocm && \ - apt-get clean && rm -rf /var/lib/apt/lists/* - - -COPY --from=cpu-build /tmp/ollama-vulkan-git/dist/lib/ollama/ /lib/ollama/ -COPY --from=vulkan-build /tmp/ollama-vulkan-git/dist/lib/ollama/vulkan/ /lib/ollama/vulkan/ -COPY --from=binary-build /tmp/ollama-vulkan-git/dist/bin/ /bin/ - -RUN find /lib/ollama && find /bin/ollama +FROM ${FLAVOR} AS archive +ARG VULKANVERSION +COPY --from=cpu dist/lib/ollama /lib/ollama +COPY --from=build /bin/ollama /bin/ollama +FROM ubuntu:24.04 +RUN apt-get update \ + && apt-get install -y ca-certificates libcap2 libvulkan1 \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists/* +COPY --from=archive /bin /usr/bin +COPY --from=archive /lib/ollama /usr/lib/ollama +ENV PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin +ENV LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64 +ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility +ENV NVIDIA_VISIBLE_DEVICES=all +ENV OLLAMA_HOST=0.0.0.0:11434 EXPOSE 11434 -ENV OLLAMA_HOST 0.0.0.0 - ENTRYPOINT ["/bin/ollama"] -CMD ["serve"] +CMD ["serve"] \ No newline at end of file From 9cb4ad02e243ef56ed1cf2ce8f4fd2285433cdf7 Mon Sep 17 00:00:00 2001 From: Vadim Grinco Date: Tue, 11 Mar 2025 14:34:17 +0100 Subject: [PATCH 029/172] This is no longer needed Signed-off-by: Vadim Grinco --- patches/00-fix-vulkan-building.patch | 15297 ------------------------- 1 file changed, 15297 deletions(-) delete mode 100644 patches/00-fix-vulkan-building.patch diff --git a/patches/00-fix-vulkan-building.patch b/patches/00-fix-vulkan-building.patch deleted file mode 100644 index 52e498ee2..000000000 --- a/patches/00-fix-vulkan-building.patch +++ /dev/null @@ -1,15297 +0,0 @@ -From 7c5f98c4cbfaf472a0d05baa3cc61afdcaeee7de Mon Sep 17 00:00:00 2001 -From: dream -Date: Thu, 13 Feb 2025 18:58:59 +0800 -Subject: [PATCH 2/2] fix: fix vulkan building - -1. Add preset for vulkan. -2. Add backend ggml-vulkan. -3. Add some log info. ---- - CMakePresets.json | 13 +- - discover/gpu.go | 7 +- - .../ggml/ggml/src/ggml-vulkan/CMakeLists.txt | 92 + - .../ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp | 8745 +++++++++++++++++ - .../ggml-vulkan/vulkan-shaders/CMakeLists.txt | 9 + - .../src/ggml-vulkan/vulkan-shaders/acc.comp | 29 + - .../src/ggml-vulkan/vulkan-shaders/add.comp | 29 + - .../ggml-vulkan/vulkan-shaders/argsort.comp | 69 + - .../src/ggml-vulkan/vulkan-shaders/clamp.comp | 17 + - .../ggml-vulkan/vulkan-shaders/concat.comp | 41 + - .../vulkan-shaders/contig_copy.comp | 42 + - .../src/ggml-vulkan/vulkan-shaders/copy.comp | 20 + - .../src/ggml-vulkan/vulkan-shaders/cos.comp | 17 + - .../vulkan-shaders/dequant_f32.comp | 20 + - .../vulkan-shaders/dequant_funcs.comp | 118 + - .../vulkan-shaders/dequant_funcs_cm2.comp | 325 + - .../vulkan-shaders/dequant_head.comp | 13 + - .../vulkan-shaders/dequant_iq4_nl.comp | 32 + - .../vulkan-shaders/dequant_q2_k.comp | 34 + - .../vulkan-shaders/dequant_q3_k.comp | 42 + - .../vulkan-shaders/dequant_q4_0.comp | 30 + - .../vulkan-shaders/dequant_q4_1.comp | 32 + - .../vulkan-shaders/dequant_q4_k.comp | 68 + - .../vulkan-shaders/dequant_q5_0.comp | 34 + - .../vulkan-shaders/dequant_q5_1.comp | 35 + - .../vulkan-shaders/dequant_q5_k.comp | 70 + - .../vulkan-shaders/dequant_q6_k.comp | 33 + - .../vulkan-shaders/dequant_q8_0.comp | 31 + - .../vulkan-shaders/diag_mask_inf.comp | 34 + - .../src/ggml-vulkan/vulkan-shaders/div.comp | 27 + - .../vulkan-shaders/flash_attn_cm2.comp | 289 + - .../src/ggml-vulkan/vulkan-shaders/gelu.comp | 25 + - .../vulkan-shaders/gelu_quick.comp | 23 + - .../vulkan-shaders/generic_binary_head.comp | 64 + - .../vulkan-shaders/generic_head.comp | 9 + - .../vulkan-shaders/generic_unary_head.comp | 56 + - .../ggml-vulkan/vulkan-shaders/get_rows.comp | 28 + - .../vulkan-shaders/get_rows_quant.comp | 39 + - .../vulkan-shaders/group_norm.comp | 66 + - .../ggml-vulkan/vulkan-shaders/im2col.comp | 87 + - .../vulkan-shaders/leaky_relu.comp | 22 + - .../src/ggml-vulkan/vulkan-shaders/mul.comp | 27 + - .../mul_mat_split_k_reduce.comp | 48 + - .../vulkan-shaders/mul_mat_vec.comp | 152 + - .../vulkan-shaders/mul_mat_vec_base.comp | 118 + - .../vulkan-shaders/mul_mat_vec_nc.comp | 71 + - .../vulkan-shaders/mul_mat_vec_p021.comp | 73 + - .../vulkan-shaders/mul_mat_vec_q2_k.comp | 115 + - .../vulkan-shaders/mul_mat_vec_q3_k.comp | 103 + - .../vulkan-shaders/mul_mat_vec_q4_k.comp | 133 + - .../vulkan-shaders/mul_mat_vec_q5_k.comp | 162 + - .../vulkan-shaders/mul_mat_vec_q6_k.comp | 112 + - .../ggml-vulkan/vulkan-shaders/mul_mm.comp | 631 ++ - .../vulkan-shaders/mul_mm_cm2.comp | 328 + - .../src/ggml-vulkan/vulkan-shaders/norm.comp | 44 + - .../src/ggml-vulkan/vulkan-shaders/pad.comp | 28 + - .../ggml-vulkan/vulkan-shaders/pool2d.comp | 74 + - .../src/ggml-vulkan/vulkan-shaders/relu.comp | 21 + - .../ggml-vulkan/vulkan-shaders/repeat.comp | 26 + - .../ggml-vulkan/vulkan-shaders/rms_norm.comp | 42 + - .../ggml-vulkan/vulkan-shaders/rope_head.comp | 49 + - .../ggml-vulkan/vulkan-shaders/rope_neox.comp | 37 + - .../ggml-vulkan/vulkan-shaders/rope_norm.comp | 37 + - .../src/ggml-vulkan/vulkan-shaders/scale.comp | 24 + - .../src/ggml-vulkan/vulkan-shaders/silu.comp | 22 + - .../src/ggml-vulkan/vulkan-shaders/sin.comp | 17 + - .../ggml-vulkan/vulkan-shaders/soft_max.comp | 174 + - .../ggml-vulkan/vulkan-shaders/square.comp | 17 + - .../ggml-vulkan/vulkan-shaders/sum_rows.comp | 37 + - .../src/ggml-vulkan/vulkan-shaders/tanh.comp | 20 + - .../vulkan-shaders/test_coopmat2_support.comp | 7 + - .../vulkan-shaders/timestep_embedding.comp | 41 + - .../src/ggml-vulkan/vulkan-shaders/types.comp | 323 + - .../ggml-vulkan/vulkan-shaders/upscale.comp | 36 + - .../vulkan-shaders/vulkan-shaders-gen.cpp | 594 ++ - .../src/ggml-vulkan/vulkan-shaders/wkv6.comp | 87 + - 76 files changed, 14642 insertions(+), 4 deletions(-) - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/CMakeLists.txt - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/add.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/div.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_split_k_reduce.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/square.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat2_support.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/types.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp - create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp - -diff --git a/CMakePresets.json b/CMakePresets.json -index 3ecb0a8f..a77f15ba 100644 ---- a/CMakePresets.json -+++ b/CMakePresets.json -@@ -58,7 +58,11 @@ - "cacheVariables": { - "AMDGPU_TARGETS": "gfx803;gfx900;gfx940;gfx941;gfx942;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx906:xnack-;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-" - } -- } -+ }, -+ { -+ "name": "Vulkan", -+ "inherits": [ "Default" ] -+ } - ], - "buildPresets": [ - { -@@ -105,6 +109,11 @@ - "name": "ROCm 6", - "inherits": [ "ROCm" ], - "configurePreset": "ROCm 6" -- } -+ }, -+ { -+ "name": "Vulkan", -+ "targets": [ "ggml-vulkan" ], -+ "configurePreset": "Vulkan" -+ } - ] - } -diff --git a/discover/gpu.go b/discover/gpu.go -index ec96f5d4..8079be99 100644 ---- a/discover/gpu.go -+++ b/discover/gpu.go -@@ -197,7 +197,10 @@ func initVulkanHandles() *vulkanHandles { - libcapPaths := FindLibCapLibs() - - if len(vulkanPaths) > 0 && len(libcapPaths) > 0 { -+ slog.Info("vulkan: load libvulkan and libcap ok") - vHandles.deviceCount, vHandles.vulkan, vulkanLibPath, libcapLibPath = LoadVulkanMgmt(vulkanPaths, libcapPaths) -+ } else { -+ slog.Info("vulkan: failed to load libvulkan or libcap") - } - - return vHandles -@@ -426,7 +429,7 @@ func GetGPUInfo() GpuInfoList { - gpuInfo.ID = C.GoString(&memInfo.gpu_id[0]) - gpuInfo.Compute = fmt.Sprintf("%d.%d", memInfo.major, memInfo.minor) - gpuInfo.MinimumMemory = 0 -- gpuInfo.DependencyPath = depPaths -+ gpuInfo.DependencyPath = []string{LibOllamaPath} - gpuInfo.Name = C.GoString(&memInfo.gpu_name[0]) - gpuInfo.DriverMajor = int(memInfo.major) - gpuInfo.DriverMinor = int(memInfo.minor) -@@ -768,7 +771,7 @@ func LoadVulkanMgmt(vulkanLibPaths []string, capLibPaths []string) (int, *C.vk_h - - C.vk_init(vkLib, capLib, &resp) - if resp.err != nil { -- slog.Debug("Unable to load vulkan", "library", vkLibPath, capLibPath, "error", C.GoString(resp.err)) -+ slog.Error("Unable to load vulkan", "library", vkLibPath, capLibPath, "error", C.GoString(resp.err)) - C.free(unsafe.Pointer(resp.err)) - } else { - return int(resp.num_devices), &resp.ch, vkLibPath, capLibPath -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/CMakeLists.txt b/ml/backend/ggml/ggml/src/ggml-vulkan/CMakeLists.txt -new file mode 100644 -index 00000000..9501de73 ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/CMakeLists.txt -@@ -0,0 +1,92 @@ -+find_package(Vulkan COMPONENTS glslc REQUIRED) -+ -+if (Vulkan_FOUND) -+ message(STATUS "Vulkan found") -+ -+ ggml_add_backend_library(ggml-vulkan -+ ggml-vulkan.cpp -+ ../../include/ggml-vulkan.h -+ ) -+ -+ # Compile a test shader to determine whether GL_NV_cooperative_matrix2 is supported. -+ # If it's not, there will be an error to stderr. -+ # If it's supported, set a define to indicate that we should compile those shaders -+ execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat2_support.comp" -+ OUTPUT_VARIABLE glslc_output -+ ERROR_VARIABLE glslc_error) -+ -+ if (${glslc_error} MATCHES ".*extension not supported: GL_NV_cooperative_matrix2.*") -+ message(STATUS "GL_NV_cooperative_matrix2 not supported by glslc") -+ else() -+ message(STATUS "GL_NV_cooperative_matrix2 supported by glslc") -+ add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) -+ endif() -+ -+ target_link_libraries(ggml-vulkan PRIVATE Vulkan::Vulkan) -+ target_include_directories(ggml-vulkan PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) -+ -+ # Workaround to the "can't dereference invalidated vector iterator" bug in clang-cl debug build -+ # Posssibly relevant: https://stackoverflow.com/questions/74748276/visual-studio-no-displays-the-correct-length-of-stdvector -+ if (MSVC AND CMAKE_CXX_COMPILER_ID STREQUAL "Clang") -+ add_compile_definitions(_ITERATOR_DEBUG_LEVEL=0) -+ endif() -+ -+ if (GGML_VULKAN_CHECK_RESULTS) -+ add_compile_definitions(GGML_VULKAN_CHECK_RESULTS) -+ endif() -+ -+ if (GGML_VULKAN_DEBUG) -+ add_compile_definitions(GGML_VULKAN_DEBUG) -+ endif() -+ -+ if (GGML_VULKAN_MEMORY_DEBUG) -+ add_compile_definitions(GGML_VULKAN_MEMORY_DEBUG) -+ endif() -+ -+ if (GGML_VULKAN_SHADER_DEBUG_INFO) -+ add_compile_definitions(GGML_VULKAN_SHADER_DEBUG_INFO) -+ endif() -+ -+ if (GGML_VULKAN_PERF) -+ add_compile_definitions(GGML_VULKAN_PERF) -+ endif() -+ -+ if (GGML_VULKAN_VALIDATE) -+ add_compile_definitions(GGML_VULKAN_VALIDATE) -+ endif() -+ -+ if (GGML_VULKAN_RUN_TESTS) -+ add_compile_definitions(GGML_VULKAN_RUN_TESTS) -+ endif() -+ -+ add_subdirectory(vulkan-shaders) -+ -+ set (_ggml_vk_genshaders_cmd vulkan-shaders-gen) -+ set (_ggml_vk_header ${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.hpp) -+ set (_ggml_vk_source ${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.cpp) -+ set (_ggml_vk_input_dir ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders) -+ set (_ggml_vk_output_dir ${CMAKE_CURRENT_BINARY_DIR}/vulkan-shaders.spv) -+ -+ file(GLOB _ggml_vk_shader_deps "${_ggml_vk_input_dir}/*.comp") -+ -+ add_custom_command( -+ OUTPUT ${_ggml_vk_header} -+ ${_ggml_vk_source} -+ -+ COMMAND "$/${_ggml_vk_genshaders_cmd}" -+ --glslc ${Vulkan_GLSLC_EXECUTABLE} -+ --input-dir ${_ggml_vk_input_dir} -+ --output-dir ${_ggml_vk_output_dir} -+ --target-hpp ${_ggml_vk_header} -+ --target-cpp ${_ggml_vk_source} -+ --no-clean -+ -+ DEPENDS ${_ggml_vk_shader_deps} ${_ggml_vk_genshaders_cmd} -+ COMMENT "Generate vulkan shaders" -+ ) -+ -+ target_sources(ggml-vulkan PRIVATE ${_ggml_vk_source} ${_ggml_vk_header}) -+ -+else() -+ message(WARNING "Vulkan not found") -+endif() -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp -new file mode 100644 -index 00000000..d75cd6d6 ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp -@@ -0,0 +1,8745 @@ -+#include "ggml-vulkan.h" -+#include -+#if defined(GGML_VULKAN_RUN_TESTS) || defined(GGML_VULKAN_PERF) || defined(GGML_VULKAN_CHECK_RESULTS) -+#include -+#include "ggml-cpu.h" -+#endif -+ -+#include -+ -+#include -+#include -+#include -+#include -+#include -+#include -+#include -+#include -+#include -+#include -+#include -+#include -+#include -+#include -+#include -+#include -+ -+#include "ggml-impl.h" -+#include "ggml-backend-impl.h" -+ -+#include "ggml-vulkan-shaders.hpp" -+ -+#define VK_API_VERSION VK_API_VERSION_1_2 -+ -+#define CEIL_DIV(M, N) (((M) + (N)-1) / (N)) -+ -+#define VK_VENDOR_ID_AMD 0x1002 -+#define VK_VENDOR_ID_APPLE 0x106b -+#define VK_VENDOR_ID_INTEL 0x8086 -+#define VK_VENDOR_ID_NVIDIA 0x10de -+ -+#define VK_DEVICE_DESCRIPTOR_POOL_SIZE 32 -+ -+#define GGML_VK_MAX_NODES 8192 -+ -+#define MAX_VK_BUFFERS 256 -+ -+#define VK_CHECK(err, msg) \ -+ do { \ -+ vk::Result err_ = (err); \ -+ if (err_ != vk::Result::eSuccess) { \ -+ fprintf(stderr, "ggml_vulkan: %s error %s at %s:%d\n", \ -+ #err, to_string(err_).c_str(), __FILE__, __LINE__); \ -+ exit(1); \ -+ } \ -+ } while (0) -+ -+#ifdef GGML_VULKAN_DEBUG -+#define VK_LOG_DEBUG(msg) std::cerr << msg << std::endl -+#else -+#define VK_LOG_DEBUG(msg) ((void) 0) -+#endif // GGML_VULKAN_DEBUG -+ -+struct ggml_backend_vk_context; -+ -+struct vk_queue { -+ uint32_t queue_family_index; -+ vk::Queue queue; -+ vk::CommandPool pool; -+ uint32_t cmd_buffer_idx; -+ std::vector cmd_buffers; -+ -+ vk::PipelineStageFlags stage_flags; -+ -+ bool transfer_only; -+}; -+ -+struct vk_pipeline_struct { -+ std::string name; -+ vk::ShaderModule shader_module; -+ vk::DescriptorSetLayout dsl; -+ std::vector descriptor_pools; -+ std::vector descriptor_sets; -+ uint32_t descriptor_set_idx; -+ vk::PipelineLayout layout; -+ vk::Pipeline pipeline; -+ uint32_t push_constant_size; -+ uint32_t parameter_count; -+ std::array wg_denoms; -+ uint32_t align; -+}; -+ -+typedef std::shared_ptr vk_pipeline; -+typedef std::weak_ptr vk_pipeline_ref; -+ -+static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline); -+ -+struct vk_matmul_pipeline_struct { -+ vk_pipeline l, m, s; -+ vk_pipeline a_l, a_m, a_s; -+}; -+ -+typedef std::shared_ptr vk_matmul_pipeline; -+ -+struct vk_matmul_pipeline2 { -+ vk_matmul_pipeline2() { -+ f16acc = std::make_shared(); -+ f32acc = std::make_shared(); -+ } -+ vk_matmul_pipeline f32acc; -+ vk_matmul_pipeline f16acc; -+}; -+ -+struct vk_device_struct; -+typedef std::shared_ptr vk_device; -+typedef std::weak_ptr vk_device_ref; -+ -+struct vk_buffer_struct; -+typedef std::shared_ptr vk_buffer; -+typedef std::weak_ptr vk_buffer_ref; -+ -+struct ggml_backend_vk_buffer_type_context { -+ std::string name; -+ vk_device device; -+}; -+ -+static const char * ggml_backend_vk_buffer_type_name(ggml_backend_buffer_type_t buft); -+static ggml_backend_buffer_t ggml_backend_vk_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size); -+static size_t ggml_backend_vk_buffer_type_get_alignment(ggml_backend_buffer_type_t buft); -+static size_t ggml_backend_vk_buffer_type_get_max_size(ggml_backend_buffer_type_t buft); -+static size_t ggml_backend_vk_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor); -+static ggml_backend_buffer_type_i ggml_backend_vk_buffer_type_interface = { -+ /* .get_name = */ ggml_backend_vk_buffer_type_name, -+ /* .alloc_buffer = */ ggml_backend_vk_buffer_type_alloc_buffer, -+ /* .get_alignment = */ ggml_backend_vk_buffer_type_get_alignment, -+ /* .get_max_size = */ ggml_backend_vk_buffer_type_get_max_size, -+ /* .get_alloc_size = */ ggml_backend_vk_buffer_type_get_alloc_size, -+ /* .is_host = */ NULL, -+}; -+ -+#ifdef GGML_VULKAN_MEMORY_DEBUG -+class vk_memory_logger; -+#endif -+#ifdef GGML_VULKAN_PERF -+class vk_perf_logger; -+#endif -+static void ggml_vk_destroy_buffer(vk_buffer& buf); -+ -+static constexpr uint32_t mul_mat_vec_max_cols = 8; -+ -+struct vk_device_struct { -+ std::mutex mutex; -+ -+ vk::PhysicalDevice physical_device; -+ vk::PhysicalDeviceProperties properties; -+ std::string name; -+ uint64_t max_memory_allocation_size; -+ bool fp16; -+ bool pipeline_robustness; -+ vk::Device device; -+ uint32_t vendor_id; -+ vk_queue compute_queue; -+ vk_queue transfer_queue; -+ bool single_queue; -+ uint32_t subgroup_size; -+ uint32_t shader_core_count; -+ bool uma; -+ bool float_controls_rte_fp16; -+ -+ bool subgroup_size_control; -+ uint32_t subgroup_min_size; -+ uint32_t subgroup_max_size; -+ bool subgroup_require_full_support; -+ -+ bool coopmat_support; -+ bool coopmat_acc_f32_support; -+ bool coopmat_acc_f16_support; -+ uint32_t coopmat_m; -+ uint32_t coopmat_n; -+ uint32_t coopmat_k; -+ bool coopmat2; -+ -+ size_t idx; -+ -+ bool mul_mat_l; -+ bool mul_mat_m; -+ bool mul_mat_s; -+ bool mul_mat_id_l; -+ bool mul_mat_id_m; -+ bool mul_mat_id_s; -+ -+ vk_matmul_pipeline pipeline_matmul_f32; -+ vk_matmul_pipeline pipeline_matmul_f32_f16; -+ vk_matmul_pipeline2 pipeline_matmul_f16; -+ vk_matmul_pipeline2 pipeline_matmul_f16_f32; -+ vk_pipeline pipeline_matmul_split_k_reduce; -+ -+ vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_COUNT]; -+ vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat[GGML_TYPE_COUNT]; -+ -+ vk_matmul_pipeline pipeline_matmul_id_f32; -+ vk_matmul_pipeline2 pipeline_matmul_id_f16; -+ vk_matmul_pipeline2 pipeline_matmul_id_f16_f32; -+ -+ vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT]; -+ -+ vk_pipeline pipeline_dequant[GGML_TYPE_COUNT]; -+ vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols]; -+ vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols]; -+ vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_COUNT]; -+ -+ vk_pipeline pipeline_mul_mat_vec_p021_f16_f32; -+ vk_pipeline pipeline_mul_mat_vec_nc_f16_f32; -+ vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT]; -+ vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT]; -+ vk_pipeline pipeline_acc_f32; -+ vk_pipeline pipeline_add_f32, pipeline_add_f32_norepeat; -+ vk_pipeline pipeline_add_f16_f32_f16, pipeline_add_f16_f32_f16_norepeat; -+ vk_pipeline pipeline_mul_f32, pipeline_mul_f32_norepeat; -+ vk_pipeline pipeline_div_f32, pipeline_div_f32_norepeat; -+ vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32; -+ vk_pipeline pipeline_upscale_f32; -+ vk_pipeline pipeline_scale_f32; -+ vk_pipeline pipeline_sqr_f32; -+ vk_pipeline pipeline_sin_f32; -+ vk_pipeline pipeline_cos_f32; -+ vk_pipeline pipeline_clamp_f32; -+ vk_pipeline pipeline_pad_f32; -+ vk_pipeline pipeline_repeat_f32; -+ vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16; -+ vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16; -+ vk_pipeline pipeline_norm_f32; -+ vk_pipeline pipeline_group_norm_f32; -+ vk_pipeline pipeline_rms_norm_f32; -+ vk_pipeline pipeline_gelu_f32; -+ vk_pipeline pipeline_gelu_quick_f32; -+ vk_pipeline pipeline_silu_f32; -+ vk_pipeline pipeline_relu_f32; -+ vk_pipeline pipeline_leaky_relu_f32; -+ vk_pipeline pipeline_tanh_f32; -+ vk_pipeline pipeline_diag_mask_inf_f32; -+ vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16; -+ vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512; -+ vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16; -+ vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16; -+ vk_pipeline pipeline_argsort_f32; -+ vk_pipeline pipeline_sum_rows_f32; -+ vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16; -+ vk_pipeline pipeline_timestep_embedding_f32; -+ vk_pipeline pipeline_pool2d_f32; -+ vk_pipeline pipeline_rwkv_wkv6_f32; -+ -+ // [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned} -+ vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2]; -+ vk_pipeline pipeline_flash_attn_f32_f16_D80[GGML_TYPE_COUNT][2][2][2]; -+ vk_pipeline pipeline_flash_attn_f32_f16_D96[GGML_TYPE_COUNT][2][2][2]; -+ vk_pipeline pipeline_flash_attn_f32_f16_D112[GGML_TYPE_COUNT][2][2][2]; -+ vk_pipeline pipeline_flash_attn_f32_f16_D128[GGML_TYPE_COUNT][2][2][2]; -+ vk_pipeline pipeline_flash_attn_f32_f16_D256[GGML_TYPE_COUNT][2][2][2]; -+ -+ std::unordered_map pipelines; -+ std::unordered_map pipeline_descriptor_set_requirements; -+ -+ std::vector> pinned_memory; -+ -+ vk::Fence fence; -+ vk_buffer sync_staging; -+ -+ ggml_backend_buffer_type buffer_type; -+ -+#ifdef GGML_VULKAN_MEMORY_DEBUG -+ std::unique_ptr memory_logger; -+#endif -+#ifdef GGML_VULKAN_PERF -+ std::unique_ptr perf_logger; -+#endif -+ -+ ~vk_device_struct() { -+ VK_LOG_DEBUG("destroy device " << name); -+ -+ device.destroyFence(fence); -+ -+ ggml_vk_destroy_buffer(sync_staging); -+ -+ device.destroyCommandPool(compute_queue.pool); -+ if (!single_queue) { -+ device.destroyCommandPool(transfer_queue.pool); -+ } -+ -+ for (auto& pipeline : pipelines) { -+ if (pipeline.second.expired()) { -+ continue; -+ } -+ -+ vk_pipeline pl = pipeline.second.lock(); -+ ggml_vk_destroy_pipeline(device, pl); -+ } -+ pipelines.clear(); -+ -+ device.destroy(); -+ } -+}; -+ -+struct vk_buffer_struct { -+ vk::Buffer buffer = VK_NULL_HANDLE; -+ vk::DeviceMemory device_memory = VK_NULL_HANDLE; -+ vk::MemoryPropertyFlags memory_property_flags; -+ void * ptr; -+ size_t size = 0; -+ -+ vk_device device; -+ -+ ~vk_buffer_struct() { -+ if (size == 0) { -+ return; -+ } -+ VK_LOG_DEBUG("~vk_buffer_struct(" << buffer << ", " << size << ")"); -+ -+ device->device.freeMemory(device_memory); -+ device->device.destroyBuffer(buffer); -+ } -+}; -+ -+struct vk_subbuffer { -+ vk_buffer buffer; -+ uint64_t offset; -+ uint64_t size; -+ -+ operator vk::DescriptorBufferInfo() const { -+ return { buffer->buffer, offset, size }; -+ } -+}; -+ -+struct vk_semaphore { -+ vk::Semaphore s; -+ uint64_t value; -+}; -+ -+struct vk_submission { -+ vk::CommandBuffer buffer; -+ std::vector wait_semaphores; -+ std::vector signal_semaphores; -+}; -+ -+typedef std::vector vk_sequence; -+ -+struct vk_mat_mat_push_constants { -+ uint32_t M; uint32_t N; uint32_t K; -+ uint32_t stride_a; uint32_t stride_b; uint32_t stride_d; -+ uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d; -+ uint32_t k_split; -+ uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3; -+}; -+struct vk_mat_vec_push_constants { -+ uint32_t ncols; uint32_t stride_a; uint32_t stride_b; uint32_t stride_d; -+ uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d; -+ uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3; -+}; -+ -+struct vk_mat_mat_id_push_constants { -+ uint32_t M; uint32_t N; uint32_t K; -+ uint32_t stride_a; uint32_t stride_b; uint32_t stride_d; -+ uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d; -+ uint32_t nei0; uint32_t nei1; uint32_t nbi1; uint32_t ne11; -+}; -+struct vk_mat_vec_id_push_constants { -+ uint32_t ncols; uint32_t stride_a; uint32_t stride_b; uint32_t stride_d; -+ uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d; -+ uint32_t nei0; uint32_t ne11; -+}; -+ -+struct vk_flash_attn_push_constants { -+ uint32_t N; -+ uint32_t KV; -+ -+ uint32_t ne1; -+ uint32_t ne2; -+ uint32_t ne3; -+ -+ uint32_t neq2; -+ uint32_t neq3; -+ uint32_t nek2; -+ uint32_t nek3; -+ uint32_t nev2; -+ uint32_t nev3; -+ uint32_t nem1; -+ -+ uint32_t nb02; -+ uint32_t nb03; -+ uint32_t nb12; -+ uint32_t nb13; -+ uint32_t nb22; -+ uint32_t nb23; -+ uint32_t nb31; -+ -+ float scale; -+ float max_bias; -+ float logit_softcap; -+ -+ uint32_t mask; -+ uint32_t n_head_log2; -+ float m0; -+ float m1; -+}; -+ -+struct vk_op_push_constants { -+ uint32_t KX; -+ uint32_t KY; -+ float param1; -+ float param2; -+}; -+ -+struct vk_op_unary_push_constants { -+ uint32_t ne; -+ uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03; -+ uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13; -+ uint32_t misalign_offsets; -+ float param1; float param2; -+ uint32_t ne0_012mp; uint32_t ne0_012L; -+ uint32_t ne0_01mp; uint32_t ne0_01L; -+ uint32_t ne0_0mp; uint32_t ne0_0L; -+ uint32_t ne1_012mp; uint32_t ne1_012L; -+ uint32_t ne1_01mp; uint32_t ne1_01L; -+ uint32_t ne1_0mp; uint32_t ne1_0L; -+}; -+static_assert(sizeof(vk_op_unary_push_constants) <= 128, "sizeof(vk_op_unary_push_constants) must be <= 128"); -+ -+// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1. -+// Precompute mp (m' in the paper) and L such that division -+// can be computed using a multiply (high 32b of 64b result) -+// and a shift: -+// -+// n/d = (mulhi(n, mp) + n) >> L; -+static void init_fastdiv_values(uint32_t d, uint32_t &mp, uint32_t &L) -+{ -+ // compute L = ceil(log2(d)); -+ L = 0; -+ while (L < 32 && (uint32_t{1} << L) < d) { -+ L++; -+ } -+ -+ mp = (uint32_t)((uint64_t{1} << 32) * ((uint64_t{1} << L) - d) / d + 1); -+} -+ -+template void init_pushconst_fastdiv(T &p) { -+ GGML_UNUSED(p); -+ static_assert(!std::is_const::value, "unexpected type"); -+} -+ -+template <> void init_pushconst_fastdiv(vk_op_unary_push_constants &p) { -+ // Compute magic values to divide by these six numbers. -+ init_fastdiv_values(p.ne02*p.ne01*p.ne00, p.ne0_012mp, p.ne0_012L); -+ init_fastdiv_values(p.ne01*p.ne00, p.ne0_01mp, p.ne0_01L); -+ init_fastdiv_values(p.ne00, p.ne0_0mp, p.ne0_0L); -+ init_fastdiv_values(p.ne12*p.ne11*p.ne10, p.ne1_012mp, p.ne1_012L); -+ init_fastdiv_values(p.ne11*p.ne10, p.ne1_01mp, p.ne1_01L); -+ init_fastdiv_values(p.ne10, p.ne1_0mp, p.ne1_0L); -+} -+ -+struct vk_op_binary_push_constants { -+ uint32_t ne; -+ uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03; -+ uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13; -+ uint32_t ne20; uint32_t ne21; uint32_t ne22; uint32_t ne23; uint32_t nb20; uint32_t nb21; uint32_t nb22; uint32_t nb23; -+ uint32_t misalign_offsets; -+ float param1; float param2; int32_t param3; -+}; -+ -+struct vk_op_diag_mask_push_constants { -+ uint32_t ncols; -+ uint32_t rows_per_channel; -+ int32_t n_past; -+}; -+ -+struct vk_op_rope_push_constants { -+ uint32_t ncols; -+ uint32_t n_dims; -+ float freq_scale; -+ uint32_t p_delta_rows; -+ float freq_base; -+ float ext_factor; -+ float attn_factor; -+ float corr_dims[2]; -+ float theta_scale; -+ uint32_t has_ff; -+}; -+ -+struct vk_op_soft_max_push_constants { -+ uint32_t KX; -+ uint32_t KY; -+ float scale; -+ float max_bias; -+ float m0; -+ float m1; -+ uint32_t n_head_log2; -+ uint32_t nrows_x; -+}; -+ -+struct vk_op_argsort_push_constants { -+ uint32_t ncols; -+ uint32_t ncols_pad; -+ int32_t order; -+}; -+ -+struct vk_op_im2col_push_constants { -+ uint32_t batch_offset; uint32_t offset_delta; -+ uint32_t IC; -+ uint32_t IW; uint32_t IH; -+ uint32_t OW; uint32_t OH; -+ uint32_t KW; uint32_t KH; -+ uint32_t pelements; -+ uint32_t CHW; -+ int32_t s0; int32_t s1; -+ int32_t p0; int32_t p1; -+ int32_t d0; int32_t d1; -+}; -+ -+struct vk_op_timestep_embedding_push_constants { -+ uint32_t nb1; -+ uint32_t dim; -+ uint32_t max_period; -+}; -+ -+struct vk_op_pool2d_push_constants { -+ uint32_t IW; uint32_t IH; -+ uint32_t OW; uint32_t OH; -+ uint32_t OC; -+ uint32_t pelements; -+ uint32_t op; -+ int32_t k0; int32_t k1; -+ int32_t s0; int32_t s1; -+ int32_t p0; int32_t p1; -+}; -+ -+struct vk_op_rwkv_wkv6_push_constants { -+ uint32_t B; -+ uint32_t T; -+ uint32_t C; -+ uint32_t H; -+}; -+ -+// Allow pre-recording command buffers -+struct vk_staging_memcpy { -+ vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {} -+ -+ void * dst; -+ const void * src; -+ size_t n; -+}; -+ -+struct vk_op_upscale_push_constants { -+ uint32_t ne; uint32_t a_offset; uint32_t d_offset; -+ uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03; -+ uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; -+ float sf0; float sf1; float sf2; float sf3; -+}; -+ -+struct vk_context_struct { -+ vk_submission * s; -+ std::vector seqs; -+ -+ int exit_tensor_idx; -+ -+ std::vector in_memcpys; -+ std::vector out_memcpys; -+ -+ vk_queue * q; -+}; -+typedef std::shared_ptr vk_context; -+typedef std::weak_ptr vk_context_ref; -+ -+struct ggml_vk_garbage_collector { -+ std::vector tl_semaphores; -+ std::vector semaphores; -+ std::vector events; -+ std::vector temp_buffers; -+ std::vector contexts; -+}; -+ -+#if defined(GGML_VULKAN_MEMORY_DEBUG) || defined(GGML_VULKAN_DEBUG) -+#define VK_LOG_MEMORY(msg) std::cerr << "ggml_vulkan memory: " << msg << std::endl -+ -+static std::string format_size(size_t size) { -+ const size_t kib = 1024; -+ const size_t mib = kib * 1024; -+ const size_t gib = mib * 1024; -+ -+ std::ostringstream oss; -+ oss << std::fixed << std::setprecision(2); -+ -+ if (size >= gib) { -+ oss << static_cast(size) / gib << " GiB"; -+ } else if (size >= mib) { -+ oss << static_cast(size) / mib << " MiB"; -+ } else if (size >= kib) { -+ oss << static_cast(size) / kib << " KiB"; -+ } else { -+ oss << size << " B"; -+ } -+ -+ return oss.str(); -+} -+ -+static std::mutex log_mutex; -+ -+class vk_memory_logger { -+public: -+ vk_memory_logger(): total_device(0), total_host(0) {} -+ void log_allocation(vk_buffer_ref buf_ref, size_t size); -+ void log_deallocation(vk_buffer_ref buf_ref); -+ -+private: -+ std::map allocations; // Track allocations -+ size_t total_device; -+ size_t total_host; -+}; -+#else -+#define VK_LOG_MEMORY(msg) ((void) 0) -+#endif // GGML_VULKAN_MEMORY_DEBUG -+ -+#if defined(GGML_VULKAN_PERF) -+ -+class vk_perf_logger { -+public: -+ void print_timings() { -+ std::cerr << "----------------\nVulkan Timings:" << std::endl; -+ for (const auto& t : timings) { -+ uint64_t total = 0; -+ for (const auto& time : t.second) { -+ total += time; -+ } -+ std::cerr << t.first << ": " << t.second.size() << " x " << (total / t.second.size() / 1000.0) << " ms" << std::endl; -+ } -+ -+ timings.clear(); -+ } -+ -+ void log_timing(const ggml_tensor * node, uint64_t time) { -+ if (node->op == GGML_OP_UNARY) { -+ timings[ggml_unary_op_name(ggml_get_unary_op(node))].push_back(time); -+ return; -+ } -+ if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) { -+ const uint64_t m = node->src[0]->ne[1]; -+ const uint64_t n = node->src[1]->ne[1]; -+ const uint64_t k = node->src[1]->ne[0]; -+ std::string name = ggml_op_name(node->op); -+ if (n == 1) { -+ name += "_VEC m=" + std::to_string(m) + " k=" + std::to_string(k); -+ } else { -+ name += " m=" + std::to_string(m) + " n=" + std::to_string(n) + " k=" + std::to_string(k); -+ } -+ timings[name].push_back(time); -+ return; -+ } -+ timings[ggml_op_name(node->op)].push_back(time); -+ } -+private: -+ std::map> timings; -+}; -+#endif // GGML_VULKAN_PERF -+ -+struct ggml_backend_vk_context { -+ std::string name; -+ -+ vk_device device; -+ -+ size_t semaphore_idx, event_idx; -+ ggml_vk_garbage_collector gc; -+ size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k; -+ vk_buffer prealloc_x, prealloc_y, prealloc_split_k; -+ vk::Fence fence; -+ -+ vk_buffer buffer_pool[MAX_VK_BUFFERS]; -+ -+ vk_context_ref compute_ctx; -+ vk_context_ref transfer_ctx; -+ -+ std::vector tensor_ctxs; -+}; -+ -+static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT -+ -+static uint64_t vk_tensor_offset(const ggml_tensor * tensor) { -+ if (tensor->view_src) { -+ return (uint8_t *) tensor->view_src->data - (uint8_t *) vk_ptr_base; -+ } -+ return (uint8_t *) tensor->data - (uint8_t *) vk_ptr_base; -+} -+ -+struct ggml_backend_vk_buffer_context { -+ vk_device_ref device; -+ vk_buffer dev_buffer; -+ std::string name; -+ -+ ggml_backend_vk_buffer_context(vk_device_ref device, vk_buffer&& dev_buffer, std::string& name) : -+ device(device), -+ dev_buffer(dev_buffer), -+ name(name) { -+ } -+ -+ ~ggml_backend_vk_buffer_context() { -+ ggml_vk_destroy_buffer(dev_buffer); -+ } -+}; -+ -+#ifdef GGML_VULKAN_MEMORY_DEBUG -+void vk_memory_logger::log_allocation(vk_buffer_ref buf_ref, size_t size) { -+ std::lock_guard guard(log_mutex); -+ vk_buffer buf = buf_ref.lock(); -+ const bool device = bool(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eDeviceLocal); -+ const std::string type = device ? "device" : "host"; -+ allocations[buf->buffer] = size; -+ total_device += device ? size : 0; -+ total_host += device ? 0 : size; -+ VK_LOG_MEMORY(buf->device->name << ": +" << format_size(size) << " " << type << " at " << buf->buffer << ". Total device: " << format_size(total_device) << ", total host: " << format_size(total_host)); -+} -+ -+void vk_memory_logger::log_deallocation(vk_buffer_ref buf_ref) { -+ if (buf_ref.expired() || buf_ref.lock()->size == 0) { -+ return; -+ } -+ -+ std::lock_guard guard(log_mutex); -+ vk_buffer buf = buf_ref.lock(); -+ const bool device = bool(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eDeviceLocal); -+ std::string type = device ? "device" : "host"; -+ auto it = allocations.find(buf->buffer); -+ total_device -= device ? it->second : 0; -+ total_host -= device ? 0 : it->second; -+ if (it != allocations.end()) { -+ VK_LOG_MEMORY(buf->device->name << ": -" << format_size(it->second) << " " << type << " at " << buf->buffer << ". Total device: " << format_size(total_device) << ", total host: " << format_size(total_host)); -+ allocations.erase(it); -+ } else { -+ VK_LOG_MEMORY("ERROR " << buf->device->name << ": Attempted to deallocate unknown " << type << " memory at " << buf->buffer); -+ } -+} -+#endif // GGML_VULKAN_MEMORY_DEBUG -+ -+struct vk_instance_t { -+ vk::Instance instance; -+ -+ std::vector device_indices; -+ vk_device devices[GGML_VK_MAX_DEVICES]; -+}; -+ -+static bool vk_instance_initialized = false; -+static vk_instance_t vk_instance; -+ -+#ifdef GGML_VULKAN_CHECK_RESULTS -+static size_t vk_skip_checks; -+static size_t vk_output_tensor; -+ -+static void ggml_vk_print_tensor(const ggml_tensor * tensor, const char * name); -+static void ggml_vk_check_results_0(ggml_tensor * tensor); -+static void ggml_vk_check_results_1(ggml_tensor * tensor); -+#endif -+ -+typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); -+ -+static void ggml_backend_vk_free(ggml_backend_t backend); -+ -+// variables to track number of compiles in progress -+static uint32_t compile_count = 0; -+static std::mutex compile_count_mutex; -+static std::condition_variable compile_count_cond; -+ -+static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, const std::string name, size_t spv_size, const void* spv_data, const std::string entrypoint, -+ uint32_t parameter_count, uint32_t push_constant_size, std::array wg_denoms, std::vector specialization_constants, -+ uint32_t align, bool disable_robustness, bool require_full_subgroups, uint32_t required_subgroup_size) { -+ VK_LOG_DEBUG("ggml_vk_create_pipeline(" << device->name << ", " << name << ", " << entrypoint << ", " << parameter_count << ", " << push_constant_size << -+ ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << align << -+ ", " << disable_robustness << ", " << require_full_subgroups << ", " << required_subgroup_size << ")"); -+ GGML_ASSERT(parameter_count > 0); -+ GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT -+ -+ pipeline = std::make_shared(); -+ pipeline->name = name; -+ pipeline->parameter_count = parameter_count; -+ pipeline->push_constant_size = push_constant_size; -+ pipeline->wg_denoms = wg_denoms; -+ pipeline->align = align; -+ -+ vk::ShaderModuleCreateInfo shader_module_create_info({}, spv_size, reinterpret_cast(spv_data)); -+ pipeline->shader_module = device->device.createShaderModule(shader_module_create_info); -+ -+ std::vector dsl_binding; -+ std::vector dsl_binding_flags; -+ for (uint32_t i = 0; i < parameter_count; i++) { -+ dsl_binding.push_back({i, vk::DescriptorType::eStorageBuffer, 1, vk::ShaderStageFlagBits::eCompute}); -+ dsl_binding_flags.push_back({}); -+ } -+ -+ vk::DescriptorSetLayoutBindingFlagsCreateInfo dslbfci = { dsl_binding_flags }; -+ -+ vk::PushConstantRange pcr( -+ vk::ShaderStageFlagBits::eCompute, -+ 0, -+ pipeline->push_constant_size -+ ); -+ -+ vk::DescriptorSetLayoutCreateInfo descriptor_set_layout_create_info( -+ {}, -+ dsl_binding); -+ descriptor_set_layout_create_info.setPNext(&dslbfci); -+ pipeline->dsl = device->device.createDescriptorSetLayout(descriptor_set_layout_create_info); -+ -+ vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline->parameter_count * VK_DEVICE_DESCRIPTOR_POOL_SIZE); -+ vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, VK_DEVICE_DESCRIPTOR_POOL_SIZE, descriptor_pool_size); -+ pipeline->descriptor_pools.push_back(device->device.createDescriptorPool(descriptor_pool_create_info)); -+ -+ pipeline->descriptor_set_idx = 0; -+ -+ vk::PipelineLayoutCreateInfo pipeline_layout_create_info(vk::PipelineLayoutCreateFlags(), pipeline->dsl, pcr); -+ pipeline->layout = device->device.createPipelineLayout(pipeline_layout_create_info); -+ -+ std::vector specialization_entries(specialization_constants.size()); -+ -+ for (size_t i = 0; i < specialization_constants.size(); i++) { -+ specialization_entries[i].constantID = i; -+ specialization_entries[i].offset = i * sizeof(uint32_t); -+ specialization_entries[i].size = sizeof(uint32_t); -+ } -+ -+ vk::SpecializationInfo specialization_info( -+ specialization_entries.size(), -+ specialization_entries.data(), -+ specialization_constants.size() * sizeof(uint32_t), -+ specialization_constants.data() -+ ); -+ -+ vk::PipelineShaderStageCreateFlags pipeline_shader_stage_create_flags{}; -+ -+ if (device->subgroup_require_full_support && require_full_subgroups) { -+ pipeline_shader_stage_create_flags |= vk::PipelineShaderStageCreateFlagBits::eRequireFullSubgroupsEXT; -+ } -+ -+ vk::PipelineShaderStageCreateInfo pipeline_shader_create_info( -+ pipeline_shader_stage_create_flags, -+ vk::ShaderStageFlagBits::eCompute, -+ pipeline->shader_module, -+ entrypoint.c_str(), -+ &specialization_info); -+ -+ vk::PipelineShaderStageRequiredSubgroupSizeCreateInfoEXT pipeline_shader_stage_required_subgroup_size_create_info; -+ pipeline_shader_stage_required_subgroup_size_create_info.requiredSubgroupSize = required_subgroup_size; -+ if (device->subgroup_size_control && required_subgroup_size > 0) { -+ GGML_ASSERT(device->subgroup_min_size <= required_subgroup_size && required_subgroup_size <= device->subgroup_max_size); -+ pipeline_shader_create_info.setPNext(&pipeline_shader_stage_required_subgroup_size_create_info); -+ } -+ -+ vk::ComputePipelineCreateInfo compute_pipeline_create_info( -+ vk::PipelineCreateFlags{}, -+ pipeline_shader_create_info, -+ pipeline->layout); -+ -+ vk::PipelineRobustnessCreateInfoEXT rci; -+ -+ if (device->pipeline_robustness && disable_robustness) { -+ rci.storageBuffers = vk::PipelineRobustnessBufferBehaviorEXT::eDisabled; -+ rci.uniformBuffers = vk::PipelineRobustnessBufferBehaviorEXT::eDisabled; -+ compute_pipeline_create_info.setPNext(&rci); -+ } -+ -+ pipeline->pipeline = device->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value; -+ -+ { -+ std::lock_guard guard(device->mutex); -+ device->pipelines.insert({ pipeline->name, pipeline }); -+ } -+ -+ { -+ std::lock_guard guard(compile_count_mutex); -+ assert(compile_count > 0); -+ compile_count--; -+ -+ // "Progress bar" for shader compiles -+ static uint32_t total_compile_count = 0; -+ if ((total_compile_count++ % 10) == 0) { -+ std::cerr << "."; -+ } -+ } -+ compile_count_cond.notify_all(); -+} -+ -+static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline) { -+ VK_LOG_DEBUG("ggml_pipeline_destroy_pipeline(" << pipeline->name << ")"); -+ for (auto& pool : pipeline->descriptor_pools) { -+ device.destroyDescriptorPool(pool); -+ } -+ pipeline->descriptor_pools.clear(); -+ pipeline->descriptor_sets.clear(); -+ pipeline->descriptor_set_idx = 0; -+ -+ device.destroyDescriptorSetLayout(pipeline->dsl); -+ -+ device.destroyPipelineLayout(pipeline->layout); -+ -+ device.destroyShaderModule(pipeline->shader_module); -+ -+ device.destroyPipeline(pipeline->pipeline); -+} -+ -+static void ggml_pipeline_request_descriptor_sets(vk_device& device, vk_pipeline& pipeline, uint32_t n) { -+ VK_LOG_DEBUG("ggml_pipeline_request_descriptor_sets(" << pipeline->name << ", " << n << ")"); -+ device->pipeline_descriptor_set_requirements[pipeline->name] += n; -+} -+ -+static void ggml_pipeline_allocate_descriptor_sets(vk_device& device) { -+ std::lock_guard guard(device->mutex); -+ -+ for (auto& pair : device->pipeline_descriptor_set_requirements) { -+ vk_pipeline pipeline = device->pipelines.at(pair.first).lock(); -+ const uint64_t n = pair.second; -+ -+ VK_LOG_DEBUG("ggml_pipeline_allocate_descriptor_sets(" << pipeline->name << ", " << n << ")"); -+ -+ if (pipeline->descriptor_sets.size() >= pipeline->descriptor_set_idx + n) { -+ // Enough descriptors are available -+ continue; -+ } -+ -+ uint32_t to_alloc = pipeline->descriptor_set_idx + n - pipeline->descriptor_sets.size(); -+ uint32_t pool_remaining = VK_DEVICE_DESCRIPTOR_POOL_SIZE - pipeline->descriptor_sets.size() % VK_DEVICE_DESCRIPTOR_POOL_SIZE; -+ uint32_t pool_idx = pipeline->descriptor_sets.size() / VK_DEVICE_DESCRIPTOR_POOL_SIZE; -+ -+ while (to_alloc > 0) { -+ const uint32_t alloc_count = std::min(pool_remaining, to_alloc); -+ to_alloc -= alloc_count; -+ pool_remaining = VK_DEVICE_DESCRIPTOR_POOL_SIZE; -+ -+ if (pool_idx >= pipeline->descriptor_pools.size()) { -+ vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline->parameter_count * VK_DEVICE_DESCRIPTOR_POOL_SIZE); -+ vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, VK_DEVICE_DESCRIPTOR_POOL_SIZE, descriptor_pool_size); -+ pipeline->descriptor_pools.push_back(device->device.createDescriptorPool(descriptor_pool_create_info)); -+ } -+ -+ std::vector layouts(alloc_count); -+ for (uint32_t i = 0; i < alloc_count; i++) { -+ layouts[i] = pipeline->dsl; -+ } -+ vk::DescriptorSetAllocateInfo descriptor_set_alloc_info(pipeline->descriptor_pools[pool_idx], alloc_count, layouts.data()); -+ std::vector sets = device->device.allocateDescriptorSets(descriptor_set_alloc_info); -+ pipeline->descriptor_sets.insert(pipeline->descriptor_sets.end(), sets.begin(), sets.end()); -+ -+ pool_idx++; -+ } -+ } -+} -+ -+static void ggml_pipeline_cleanup(vk_pipeline& pipeline) { -+ VK_LOG_DEBUG("ggml_pipeline_cleanup(" << pipeline->name << ")"); -+ pipeline->descriptor_set_idx = 0; -+} -+ -+static vk::CommandBuffer ggml_vk_create_cmd_buffer(vk_device& device, vk_queue& q) { -+ VK_LOG_DEBUG("ggml_vk_create_cmd_buffer()"); -+ std::lock_guard guard(device->mutex); -+ -+ if (q.cmd_buffers.size() > q.cmd_buffer_idx) { -+ // Reuse command buffer -+ return q.cmd_buffers[q.cmd_buffer_idx++]; -+ } -+ -+ vk::CommandBufferAllocateInfo command_buffer_alloc_info( -+ q.pool, -+ vk::CommandBufferLevel::ePrimary, -+ 1); -+ const std::vector cmd_buffers = device->device.allocateCommandBuffers(command_buffer_alloc_info); -+ auto buf = cmd_buffers.front(); -+ -+ q.cmd_buffers.push_back(buf); -+ q.cmd_buffer_idx++; -+ -+ return buf; -+} -+ -+static vk_submission ggml_vk_create_submission(vk_device& device, vk_queue& q, std::vector wait_semaphores, std::vector signal_semaphores) { -+ VK_LOG_DEBUG("ggml_vk_create_submission()"); -+ vk_submission s; -+ s.buffer = ggml_vk_create_cmd_buffer(device, q); -+ s.wait_semaphores = std::move(wait_semaphores); -+ s.signal_semaphores = std::move(signal_semaphores); -+ return s; -+} -+ -+static void ggml_vk_submit(vk_context& ctx, vk::Fence fence) { -+ if (ctx->seqs.empty()) { -+ if (fence) { -+ ctx->q->queue.submit({}, fence); -+ } -+ return; -+ } -+ VK_LOG_DEBUG("ggml_vk_submit(" << ctx << ", " << fence << ")"); -+ -+ std::vector> tl_wait_vals; -+ std::vector> tl_signal_vals; -+ std::vector> tl_wait_semaphores; -+ std::vector> tl_signal_semaphores; -+ std::vector tl_submit_infos; -+ std::vector submit_infos; -+ int idx = -1; -+ std::vector> stage_flags; -+ -+ size_t reserve = 0; -+ -+ for (const auto& sequence : ctx->seqs) { -+ reserve += sequence.size(); -+ } -+ -+ // Pre-reserve vectors to prevent reallocation, which invalidates pointers -+ tl_wait_semaphores.reserve(reserve); -+ tl_wait_vals.reserve(reserve); -+ tl_signal_semaphores.reserve(reserve); -+ tl_signal_vals.reserve(reserve); -+ tl_submit_infos.reserve(reserve); -+ submit_infos.reserve(reserve); -+ stage_flags.reserve(reserve); -+ -+ for (const auto& sequence : ctx->seqs) { -+ for (const auto& submission : sequence) { -+ stage_flags.push_back({}); -+ idx++; -+ tl_wait_vals.push_back({}); -+ tl_wait_semaphores.push_back({}); -+ tl_signal_vals.push_back({}); -+ tl_signal_semaphores.push_back({}); -+ for (size_t i = 0; i < submission.wait_semaphores.size(); i++) { -+ stage_flags[idx].push_back(ctx->q->stage_flags); -+ tl_wait_vals[idx].push_back(submission.wait_semaphores[i].value); -+ tl_wait_semaphores[idx].push_back(submission.wait_semaphores[i].s); -+ } -+ for (size_t i = 0; i < submission.signal_semaphores.size(); i++) { -+ tl_signal_vals[idx].push_back(submission.signal_semaphores[i].value); -+ tl_signal_semaphores[idx].push_back(submission.signal_semaphores[i].s); -+ } -+ tl_submit_infos.push_back({ -+ (uint32_t) submission.wait_semaphores.size(), -+ tl_wait_vals[idx].data(), -+ (uint32_t) submission.signal_semaphores.size(), -+ tl_signal_vals[idx].data(), -+ }); -+ tl_submit_infos[idx].sType = vk::StructureType::eTimelineSemaphoreSubmitInfo; -+ tl_submit_infos[idx].pNext = nullptr; -+ vk::SubmitInfo si{ -+ (uint32_t) submission.wait_semaphores.size(), -+ tl_wait_semaphores[idx].data(), -+ stage_flags[idx].data(), -+ 1, -+ &submission.buffer, -+ (uint32_t) submission.signal_semaphores.size(), -+ tl_signal_semaphores[idx].data(), -+ }; -+ si.setPNext(&tl_submit_infos[idx]); -+ submit_infos.push_back(si); -+ } -+ } -+ -+ ctx->q->queue.submit(submit_infos, fence); -+ -+ ctx->seqs.clear(); -+} -+ -+static uint32_t ggml_vk_find_queue_family_index(std::vector& queue_family_props, const vk::QueueFlags& required, const vk::QueueFlags& avoid, int32_t compute_index, uint32_t min_num_queues) { -+ VK_LOG_DEBUG("ggml_vk_find_queue_family_index()"); -+ const uint32_t qfsize = queue_family_props.size(); -+ -+ // Try with avoid preferences first -+ for (uint32_t i = 0; i < qfsize; i++) { -+ if (queue_family_props[i].queueCount >= min_num_queues && (compute_index < 0 || i != (uint32_t) compute_index) && queue_family_props[i].queueFlags & required && !(queue_family_props[i].queueFlags & avoid)) { -+ return i; -+ } -+ } -+ -+ // Fall back to only required -+ for (size_t i = 0; i < qfsize; i++) { -+ if (queue_family_props[i].queueCount >= min_num_queues && (compute_index < 0 || i != (uint32_t) compute_index) && queue_family_props[i].queueFlags & required) { -+ return i; -+ } -+ } -+ -+ // Fall back to reusing compute queue -+ for (size_t i = 0; i < qfsize; i++) { -+ if (queue_family_props[i].queueCount >= min_num_queues && queue_family_props[i].queueFlags & required) { -+ return i; -+ } -+ } -+ -+ // Fall back to ignoring min_num_queries -+ for (size_t i = 0; i < qfsize; i++) { -+ if (queue_family_props[i].queueFlags & required) { -+ return i; -+ } -+ } -+ -+ // All commands that are allowed on a queue that supports transfer operations are also allowed on a queue that supports either graphics or compute operations. -+ // Thus, if the capabilities of a queue family include VK_QUEUE_GRAPHICS_BIT or VK_QUEUE_COMPUTE_BIT, then reporting the VK_QUEUE_TRANSFER_BIT capability separately for that queue family is optional. -+ if (compute_index >= 0) { -+ return compute_index; -+ } -+ -+ std::cerr << "ggml_vulkan: No suitable queue family index found." << std::endl; -+ -+ for(auto &q_family : queue_family_props) { -+ std::cerr << "Queue number: " + std::to_string(q_family.queueCount) << " flags: " + to_string(q_family.queueFlags) << std::endl; -+ } -+ abort(); -+} -+ -+static void ggml_vk_create_queue(vk_device& device, vk_queue& q, uint32_t queue_family_index, uint32_t queue_index, vk::PipelineStageFlags&& stage_flags, bool transfer_only) { -+ VK_LOG_DEBUG("ggml_vk_create_queue()"); -+ std::lock_guard guard(device->mutex); -+ -+ q.queue_family_index = queue_family_index; -+ q.transfer_only = transfer_only; -+ -+ vk::CommandPoolCreateInfo command_pool_create_info_compute(vk::CommandPoolCreateFlags(VK_COMMAND_POOL_CREATE_TRANSIENT_BIT), queue_family_index); -+ q.pool = device->device.createCommandPool(command_pool_create_info_compute); -+ -+ q.cmd_buffer_idx = 0; -+ -+ q.queue = device->device.getQueue(queue_family_index, queue_index); -+ -+ q.stage_flags = stage_flags; -+} -+ -+static vk_context ggml_vk_create_context(ggml_backend_vk_context * ctx, vk_queue& q) { -+ vk_context result = std::make_shared(); -+ VK_LOG_DEBUG("ggml_vk_create_context(" << result << ")"); -+ ctx->gc.contexts.emplace_back(result); -+ result->q = &q; -+ return result; -+} -+ -+static vk_context ggml_vk_create_temporary_context(vk_queue& q) { -+ vk_context result = std::make_shared(); -+ VK_LOG_DEBUG("ggml_vk_create_temporary_context(" << result << ")"); -+ result->q = &q; -+ return result; -+} -+ -+static vk_semaphore * ggml_vk_create_binary_semaphore(ggml_backend_vk_context * ctx) { -+ VK_LOG_DEBUG("ggml_vk_create_timeline_semaphore()"); -+ vk::SemaphoreTypeCreateInfo tci{ vk::SemaphoreType::eBinary, 0 }; -+ vk::SemaphoreCreateInfo ci{}; -+ ci.setPNext(&tci); -+ vk::Semaphore semaphore = ctx->device->device.createSemaphore(ci); -+ ctx->gc.semaphores.push_back({ semaphore, 0 }); -+ return &ctx->gc.semaphores[ctx->gc.semaphores.size() - 1]; -+} -+ -+static vk_semaphore * ggml_vk_create_timeline_semaphore(ggml_backend_vk_context * ctx) { -+ VK_LOG_DEBUG("ggml_vk_create_timeline_semaphore()"); -+ if (ctx->semaphore_idx >= ctx->gc.tl_semaphores.size()) { -+ vk::SemaphoreTypeCreateInfo tci{ vk::SemaphoreType::eTimeline, 0 }; -+ vk::SemaphoreCreateInfo ci{}; -+ ci.setPNext(&tci); -+ vk::Semaphore semaphore = ctx->device->device.createSemaphore(ci); -+ ctx->gc.tl_semaphores.push_back({ semaphore, 0 }); -+ } -+ return &ctx->gc.tl_semaphores[ctx->semaphore_idx++]; -+} -+ -+static vk::Event ggml_vk_create_event(ggml_backend_vk_context * ctx) { -+ if (ctx->event_idx >= ctx->gc.events.size()) { -+ ctx->gc.events.push_back(ctx->device->device.createEvent({})); -+ } -+ return ctx->gc.events[ctx->event_idx++]; -+} -+ -+static void ggml_vk_queue_cleanup(vk_device& device, vk_queue& q) { -+ VK_LOG_DEBUG("ggml_vk_queue_cleanup()"); -+ std::lock_guard guard(device->mutex); -+ -+ // Requires command buffers to be done -+ device->device.resetCommandPool(q.pool); -+ q.cmd_buffer_idx = 0; -+} -+ -+static uint32_t find_properties(const vk::PhysicalDeviceMemoryProperties* mem_props, vk::MemoryRequirements* mem_req, vk::MemoryPropertyFlags flags) { -+ for (uint32_t i = 0; i < mem_props->memoryTypeCount; ++i) { -+ vk::MemoryType memory_type = mem_props->memoryTypes[i]; -+ if ((mem_req->memoryTypeBits & ((uint64_t)1 << i)) && -+ (flags & memory_type.propertyFlags) == flags && -+ mem_props->memoryHeaps[memory_type.heapIndex].size >= mem_req->size) { -+ return static_cast(i); -+ } -+ } -+ return UINT32_MAX; -+} -+ -+static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, vk::MemoryPropertyFlags req_flags, vk::MemoryPropertyFlags fallback_flags = vk::MemoryPropertyFlags(0)) { -+ VK_LOG_DEBUG("ggml_vk_create_buffer(" << device->name << ", " << size << ", " << to_string(req_flags) << ", " << to_string(fallback_flags) << ")"); -+ if (size > device->max_memory_allocation_size) { -+ throw vk::OutOfDeviceMemoryError("Requested buffer size exceeds device memory allocation limit"); -+ } -+ -+ std::lock_guard guard(device->mutex); -+ -+ vk_buffer buf = std::make_shared(); -+ -+ if (size == 0) { -+ buf->size = 0; -+ return buf; -+ } -+ -+ vk::BufferCreateInfo buffer_create_info{ -+ vk::BufferCreateFlags(), -+ size, -+ vk::BufferUsageFlagBits::eStorageBuffer | vk::BufferUsageFlagBits::eTransferSrc | vk::BufferUsageFlagBits::eTransferDst, -+ vk::SharingMode::eExclusive, -+ 0, -+ nullptr, -+ }; -+ -+ buf->buffer = device->device.createBuffer(buffer_create_info); -+ -+ vk::MemoryRequirements mem_req = device->device.getBufferMemoryRequirements(buf->buffer); -+ -+ vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties(); -+ -+ uint32_t memory_type_index = UINT32_MAX; -+ -+ memory_type_index = find_properties(&mem_props, &mem_req, req_flags); -+ buf->memory_property_flags = req_flags; -+ -+ if (memory_type_index == UINT32_MAX && fallback_flags) { -+ memory_type_index = find_properties(&mem_props, &mem_req, fallback_flags); -+ buf->memory_property_flags = fallback_flags; -+ } -+ -+ if (memory_type_index == UINT32_MAX) { -+ device->device.destroyBuffer(buf->buffer); -+ throw vk::OutOfDeviceMemoryError("No suitable memory type found"); -+ } -+ -+ try { -+ buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index }); -+ } catch (const vk::SystemError& e) { -+ if (buf->memory_property_flags != fallback_flags) { -+ // Try again with fallback flags -+ memory_type_index = find_properties(&mem_props, &mem_req, fallback_flags); -+ buf->memory_property_flags = fallback_flags; -+ -+ try { -+ buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index }); -+ } -+ catch (const vk::SystemError& e) { -+ device->device.destroyBuffer(buf->buffer); -+ throw e; -+ } -+ } else { -+ // Out of Host/Device memory, clean up buffer -+ device->device.destroyBuffer(buf->buffer); -+ throw e; -+ } -+ } -+ buf->ptr = nullptr; -+ -+ if (buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) { -+ buf->ptr = device->device.mapMemory(buf->device_memory, 0, VK_WHOLE_SIZE); -+ } -+ -+ device->device.bindBufferMemory(buf->buffer, buf->device_memory, 0); -+ -+ buf->device = device; -+ buf->size = size; -+ -+#ifdef GGML_VULKAN_MEMORY_DEBUG -+ device->memory_logger->log_allocation(buf, size); -+#endif -+ -+ return buf; -+} -+ -+static vk_buffer ggml_vk_create_buffer_check(vk_device& device, size_t size, vk::MemoryPropertyFlags req_flags, vk::MemoryPropertyFlags fallback_flags = vk::MemoryPropertyFlags(0)) { -+ try { -+ return ggml_vk_create_buffer(device, size, req_flags, fallback_flags); -+ } catch (const vk::SystemError& e) { -+ std::cerr << "ggml_vulkan: Memory allocation of size " << size << " failed." << std::endl; -+ std::cerr << "ggml_vulkan: " << e.what() << std::endl; -+ throw e; -+ } -+} -+ -+static vk_buffer ggml_vk_create_buffer_device(vk_device& device, size_t size) { -+ vk_buffer buf; -+ try { -+ if (device->uma) { -+ // Fall back to host memory type -+ buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent); -+ } else { -+ // use rebar if available, otherwise fallback to device only visible memory -+ buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal | vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent, vk::MemoryPropertyFlagBits::eDeviceLocal); -+ } -+ } catch (const vk::SystemError& e) { -+ std::cerr << "ggml_vulkan: Device memory allocation of size " << size << " failed." << std::endl; -+ std::cerr << "ggml_vulkan: " << e.what() << std::endl; -+ throw e; -+ } -+ -+ return buf; -+} -+ -+static void ggml_vk_destroy_buffer(vk_buffer& buf) { -+ if (buf == nullptr) { -+ return; -+ } -+ -+#ifdef GGML_VULKAN_MEMORY_DEBUG -+ if (buf->device != nullptr) { -+ buf->device->memory_logger->log_deallocation(buf); -+ } -+#endif -+ -+ buf.reset(); -+} -+ -+static vk_subbuffer ggml_vk_subbuffer(vk_buffer& buf) { -+ return { buf, 0, VK_WHOLE_SIZE }; -+} -+ -+static void ggml_vk_sync_buffers(vk_context& ctx) { -+ VK_LOG_DEBUG("ggml_vk_sync_buffers()"); -+ -+ const bool transfer_queue = ctx->q->transfer_only; -+ -+ ctx->s->buffer.pipelineBarrier( -+ ctx->q->stage_flags, -+ ctx->q->stage_flags, -+ {}, -+ { { -+ { !transfer_queue ? (vk::AccessFlagBits::eShaderRead | vk::AccessFlagBits::eShaderWrite | vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) : (vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) }, -+ { !transfer_queue ? (vk::AccessFlagBits::eShaderRead | vk::AccessFlagBits::eShaderWrite | vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) : (vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) } -+ } }, -+ {}, -+ {} -+ ); -+} -+ -+static void ggml_vk_wait_events(vk_context& ctx, std::vector&& events) { -+ VK_LOG_DEBUG("ggml_vk_wait_events()"); -+ if (events.empty()) { -+ return; -+ } -+ -+ ctx->s->buffer.waitEvents( -+ events, -+ ctx->q->stage_flags, -+ ctx->q->stage_flags, -+ {}, -+ {}, -+ {} -+ ); -+} -+ -+// number of rows/cols for flash attention shader -+static constexpr uint32_t flash_attention_num_small_rows = 32; -+static std::array fa_rows_cols(uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) { -+ GGML_UNUSED(clamp); -+ -+ // small rows, large cols -+ if (small_rows) { -+ return {flash_attention_num_small_rows, 128}; -+ } -+ // small cols to reduce register count -+ if (ggml_is_quantized(type) || D == 256) { -+ return {64, 32}; -+ } -+ return {64, 64}; -+}; -+ -+static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector& warptile, bool mul_mat_id) { -+ // Needs to be kept up to date on shader changes -+ const uint32_t bank_conflict_offset = device->coopmat_support ? 8 : 1; -+ const uint32_t type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float); -+ const uint32_t warps = warptile[0] / warptile[10]; -+ -+ const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size; -+ const uint32_t mmid_row_ids = mul_mat_id ? 3072 * sizeof(uint32_t) : 0; -+ const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0; -+ -+ return (load_bufs + mmid_row_ids + coopmat_stage) <= device->properties.limits.maxComputeSharedMemorySize; -+} -+ -+static void ggml_vk_load_shaders(vk_device& device) { -+ VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")"); -+ -+ std::cerr << "ggml_vulkan: Compiling shaders"; -+ -+ // some shaders have a minimum subgroup size -+ const uint32_t subgroup_size_16 = std::max(device->subgroup_size, 16u); -+ const uint32_t subgroup_size_32 = std::max(device->subgroup_size, 32u); -+ -+ // mulmat -+ std::vector l_warptile, m_warptile, s_warptile, -+ l_warptile_mmq, m_warptile_mmq, s_warptile_mmq, -+ l_warptile_mmq_k, m_warptile_mmq_k, s_warptile_mmq_k, -+ l_warptile_mmqid, m_warptile_mmqid, s_warptile_mmqid; -+ std::array l_wg_denoms, m_wg_denoms, s_wg_denoms, -+ l_mmq_wg_denoms, m_mmq_wg_denoms, s_mmq_wg_denoms, -+ l_mmq_wg_denoms_k, m_mmq_wg_denoms_k, s_mmq_wg_denoms_k, -+ l_mmqid_wg_denoms, m_mmqid_wg_denoms, s_mmqid_wg_denoms; -+ -+ uint32_t l_align, m_align, s_align; -+ if (device->coopmat2) { -+ // spec constants and tile sizes for non-quant matmul/matmul_id -+ l_warptile = { 256, 128, 256, 64 }; -+ m_warptile = { 256, 128, 128, 64 }; -+ s_warptile = { 128, 64, 64, 64 }; -+ l_wg_denoms = {128, 256, 1 }; -+ m_wg_denoms = {128, 128, 1 }; -+ s_wg_denoms = { 64, 64, 1 }; -+ -+ // spec constants and tile sizes for quant matmul (non-Qi_K) -+ l_warptile_mmq = { 256, 128, 256, 64 }; -+ m_warptile_mmq = { 256, 128, 128, 64 }; -+ s_warptile_mmq = { 256, 128, 128, 64 }; -+ l_mmq_wg_denoms = { 128, 256, 1 }; -+ m_mmq_wg_denoms = { 128, 128, 1 }; -+ s_mmq_wg_denoms = { 128, 128, 1 }; -+ -+ // spec constants and tile sizes for quant matmul (Qi_K) -+ l_warptile_mmq_k = { 256, 128, 512, 16 }; -+ m_warptile_mmq_k = { 256, 128, 256, 16 }; -+ s_warptile_mmq_k = { 256, 32, 128, 64 }; -+ l_mmq_wg_denoms_k = { 128, 512, 1 }; -+ m_mmq_wg_denoms_k = { 128, 256, 1 }; -+ s_mmq_wg_denoms_k = { 32, 128, 1 }; -+ -+ // spec constants and tile sizes for quant matmul_id -+ l_warptile_mmqid = { 256, 128, 128, 16 }; -+ m_warptile_mmqid = { 256, 128, 64, 16 }; -+ s_warptile_mmqid = { 256, 64, 64, 16 }; -+ l_mmqid_wg_denoms = { 128, 128, 1 }; -+ m_mmqid_wg_denoms = { 128, 64, 1 }; -+ s_mmqid_wg_denoms = { 64, 64, 1 }; -+ -+ l_align = 128; -+ m_align = 64; -+ s_align = 32; -+ } else { -+ // Matrix cores require different warp group sizes -+ const uint32_t tm_l = device->coopmat_support ? device->coopmat_m : 4; -+ const uint32_t tm_m = device->coopmat_support ? device->coopmat_m : 4; -+ const uint32_t tm_s = device->coopmat_support ? device->coopmat_m : 2; -+ const uint32_t tn_l = device->coopmat_support ? device->coopmat_n : 4; -+ const uint32_t tn_m = device->coopmat_support ? device->coopmat_n : 2; -+ const uint32_t tn_s = device->coopmat_support ? device->coopmat_n : 2; -+ const uint32_t tk_l = device->coopmat_support ? device->coopmat_k : 1; -+ const uint32_t tk_m = device->coopmat_support ? device->coopmat_k : 1; -+ const uint32_t tk_s = device->coopmat_support ? device->coopmat_k : 1; -+ -+ l_warptile = { 128, 128, 128, 16, device->subgroup_size * 2, 64, 2, tm_l, tn_l, tk_l, device->subgroup_size }; -+ m_warptile = { 128, 64, 64, 16, device->subgroup_size, 32, 2, tm_m, tn_m, tk_m, device->subgroup_size }; -+ s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, device->subgroup_size }; -+ -+ l_warptile_mmq = { 128, 128, 128, 32, device->subgroup_size * 2, 64, 2, tm_l, tn_l, tk_l, device->subgroup_size }; -+ m_warptile_mmq = { 128, 64, 64, 32, device->subgroup_size, 32, 2, tm_m, tn_m, tk_m, device->subgroup_size }; -+ s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, device->subgroup_size }; -+ -+ l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 }; -+ m_mmq_wg_denoms = m_wg_denoms = { 64, 64, 1 }; -+ s_mmq_wg_denoms = s_wg_denoms = { 32, 32, 1 }; -+ l_align = 128; -+ m_align = 64; -+ s_align = 32; -+ -+ // Fallback to smaller sizes if there's not enough shared memory. Given the current shaders -+ // and tile sizes, this should handle 16KB, 32KB, and 48KB+. -+ // This logic doesn't explicitly account for the 12KB row_ids in the mul_mat_mat_id shaders. -+ // But the numbers happen to work out for 32KB shared memory size that when using the medium -+ // size there's enough room for everything, and we assert for this. -+ uint32_t shmem_needed = (l_warptile[1] + l_warptile[2]) * (l_warptile[3] + 1) * sizeof(float); -+ if (shmem_needed > device->properties.limits.maxComputeSharedMemorySize) { -+ l_warptile = m_warptile; -+ l_wg_denoms = m_wg_denoms; -+ shmem_needed = (l_warptile[1] + l_warptile[2]) * (l_warptile[3] + 1) * sizeof(float); -+ GGML_ASSERT(shmem_needed <= device->properties.limits.maxComputeSharedMemorySize); -+ } -+ if (device->properties.limits.maxComputeSharedMemorySize >= 32768) { -+ // assert mul_mat_mat_id shaders will fit. -+ GGML_ASSERT(shmem_needed + 3072*4 <= device->properties.limits.maxComputeSharedMemorySize); -+ } -+ -+ shmem_needed = (l_warptile_mmq[1] + l_warptile_mmq[2]) * (l_warptile_mmq[3] + 1) * sizeof(float); -+ if (shmem_needed > device->properties.limits.maxComputeSharedMemorySize) { -+ if (device->properties.limits.maxComputeSharedMemorySize == 32768) { -+ l_warptile_mmq = m_warptile_mmq; -+ l_mmq_wg_denoms = m_mmq_wg_denoms; -+ } else { -+ l_warptile_mmq = s_warptile_mmq; -+ l_mmq_wg_denoms = s_mmq_wg_denoms; -+ } -+ shmem_needed = (l_warptile_mmq[1] + l_warptile_mmq[2]) * (l_warptile_mmq[3] + 1) * sizeof(float); -+ GGML_ASSERT(shmem_needed <= device->properties.limits.maxComputeSharedMemorySize); -+ } -+ if (device->properties.limits.maxComputeSharedMemorySize >= 32768) { -+ // assert mul_mat_mat_id shaders will fit. -+ GGML_ASSERT(shmem_needed + 3072*4 <= device->properties.limits.maxComputeSharedMemorySize); -+ } -+ // Disable medium and large matrix multiplication if not enough shared memory is available -+ // Check mmq warptiles as the largest configuration -+ // Throw an error if not enough for any matrix multiplication is available -+ if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, false)) { -+ std::cerr << "ggml_vulkan: Error: Shared memory size too small for matrix multiplication." << std::endl; -+ throw std::runtime_error("Shared memory size too small for matrix multiplication."); -+ } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, false)) { -+ device->mul_mat_m = false; -+ device->mul_mat_l = false; -+ } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, false)) { -+ device->mul_mat_l = false; -+ } -+ -+ // Disable mul_mat_id if not enough shared memory is available -+ if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, true)) { -+ device->mul_mat_id_s = false; -+ device->mul_mat_id_m = false; -+ device->mul_mat_id_l = false; -+ } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, true)) { -+ device->mul_mat_id_m = false; -+ device->mul_mat_id_l = false; -+ } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, true)) { -+ device->mul_mat_id_l = false; -+ } -+ } -+ -+ device->pipeline_matmul_f32 = std::make_shared(); -+ device->pipeline_matmul_f32_f16 = std::make_shared(); -+ -+ device->pipeline_matmul_id_f32 = std::make_shared(); -+ -+ std::vector> compiles; -+ auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint, -+ uint32_t parameter_count, uint32_t push_constant_size, std::array wg_denoms, const std::vector& specialization_constants, -+ uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) { -+ { -+ // wait until fewer than N compiles are in progress -+ uint32_t N = std::max(1u, std::thread::hardware_concurrency()); -+ std::unique_lock guard(compile_count_mutex); -+ while (compile_count >= N) { -+ compile_count_cond.wait(guard); -+ } -+ compile_count++; -+ } -+ compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), name, spv_size, spv_data, entrypoint, -+ parameter_count, push_constant_size, wg_denoms, specialization_constants, align, disable_robustness, require_full_subgroups, required_subgroup_size)); -+ }; -+ -+#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) -+ if (device->coopmat2) { -+ -+ auto const &fa_wg_denoms = [&](uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array { -+ return {fa_rows_cols(D, clamp, type, small_rows)[0], 1, 1}; -+ }; -+ -+ auto const &fa_spec_constants = [&](uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector { -+ // For large number of rows, 128 invocations seems to work best. -+ // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we -+ // can't use 256 for D==80. -+ uint32_t wg_size = (small_rows && (D % 32) == 0) ? 256 : 128; -+ auto rows_cols = fa_rows_cols(D, clamp, type, small_rows); -+ return {wg_size, rows_cols[0], rows_cols[1], (D), clamp}; -+ }; -+ -+#define CREATE_FA2(TYPE, NAMELC, D) \ -+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,false), fa_spec_constants(D,1,TYPE,false), 1); \ -+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,false), fa_spec_constants(D,0,TYPE,false), fa_rows_cols(D,0,TYPE,false)[1]); \ -+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,false), fa_spec_constants(D,1,TYPE,false), 1); \ -+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,false), fa_spec_constants(D,0,TYPE,false), fa_rows_cols(D,0,TYPE,false)[1]); \ -+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,true), fa_spec_constants(D,1,TYPE,true), 1); \ -+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,true), fa_spec_constants(D,0,TYPE,true), fa_rows_cols(D,0,TYPE,true)[1]); \ -+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,true), fa_spec_constants(D,1,TYPE,true), 1); \ -+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,true), fa_spec_constants(D,0,TYPE,true), fa_rows_cols(D,0,TYPE,true)[1]); \ -+ -+#define CREATE_FA(TYPE, NAMELC) \ -+ CREATE_FA2(TYPE, NAMELC, 64) \ -+ CREATE_FA2(TYPE, NAMELC, 80) \ -+ CREATE_FA2(TYPE, NAMELC, 96) \ -+ CREATE_FA2(TYPE, NAMELC, 112) \ -+ CREATE_FA2(TYPE, NAMELC, 128) \ -+ CREATE_FA2(TYPE, NAMELC, 256) -+ -+ CREATE_FA(GGML_TYPE_F16, f16) -+ CREATE_FA(GGML_TYPE_Q4_0, q4_0) -+ CREATE_FA(GGML_TYPE_Q4_1, q4_1) -+ CREATE_FA(GGML_TYPE_Q5_0, q5_0) -+ CREATE_FA(GGML_TYPE_Q5_1, q5_1) -+ CREATE_FA(GGML_TYPE_Q8_0, q8_0) -+ // K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently -+ //CREATE_FA(GGML_TYPE_Q2_K, q2_k) -+ //CREATE_FA(GGML_TYPE_Q3_K, q3_k) -+ //CREATE_FA(GGML_TYPE_Q4_K, q4_k) -+ //CREATE_FA(GGML_TYPE_Q5_K, q5_k) -+ //CREATE_FA(GGML_TYPE_Q6_K, q6_k) -+ CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl) -+#undef CREATE_FA -+ -+ // Create 6 variants, {s,m,l}x{unaligned,aligned} -+#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ -+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ -+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ -+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ -+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \ -+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \ -+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \ -+ -+ // Create 2 variants, {f16,f32} accumulator -+#define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ -+ CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ -+ CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ -+ -+ CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3) -+ CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3) -+ -+ CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3) -+ CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3) -+ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) -+ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) -+ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) -+ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) -+ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) -+ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) -+ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) -+ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) -+ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) -+ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) -+ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) -+ -+ CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4) -+ CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4) -+ CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4) -+ -+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) -+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) -+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) -+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) -+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) -+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) -+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) -+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) -+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) -+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) -+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) -+#undef CREATE_MM -+#undef CREATE_MM2 -+ } else -+#endif // defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) -+ if (device->coopmat_support) { -+ // Create 6 variants, {s,m,l}x{unaligned,aligned} -+#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ -+ if (device->mul_mat ## ID ## _l) \ -+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \ -+ if (device->mul_mat ## ID ## _m) \ -+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \ -+ if (device->mul_mat ## ID ## _s) \ -+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \ -+ if (device->mul_mat ## ID ## _l) \ -+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \ -+ if (device->mul_mat ## ID ## _m) \ -+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \ -+ if (device->mul_mat ## ID ## _s) \ -+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \ -+ -+ // Create 2 variants, {f16,f32} accumulator -+#define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ -+ if (device->coopmat_acc_f16_support) { \ -+ CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ -+ } \ -+ if (device->coopmat_acc_f32_support) { \ -+ CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ -+ } \ -+ -+ CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); -+ CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); -+ CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); -+ CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); -+ -+ if (device->coopmat_acc_f16_support) { -+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); -+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); -+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); -+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); -+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); -+ -+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); -+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); -+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); -+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); -+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); -+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); -+ } else { -+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); -+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); -+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); -+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); -+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); -+ -+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); -+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); -+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); -+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); -+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); -+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); -+ } -+ -+ // If there's not enough shared memory for row_ids and the result tile, don't create these pipelines. -+ if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) { -+ CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); -+ CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); -+ CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); -+ -+ if (device->coopmat_acc_f16_support) { -+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); -+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); -+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); -+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); -+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); -+ -+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); -+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); -+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); -+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); -+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); -+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); -+ } else { -+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); -+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); -+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); -+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); -+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); -+ -+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); -+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); -+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); -+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); -+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); -+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); -+ } -+ } -+#undef CREATE_MM2 -+#undef CREATE_MM -+ } else if (device->fp16) { -+ // Create 6 variants, {s,m,l}x{unaligned,aligned} -+#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ -+ if (device->mul_mat ## ID ## _l) \ -+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ -+ if (device->mul_mat ## ID ## _m) \ -+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ -+ if (device->mul_mat ## ID ## _s) \ -+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ -+ if (device->mul_mat ## ID ## _l) \ -+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \ -+ if (device->mul_mat ## ID ## _m) \ -+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \ -+ if (device->mul_mat ## ID ## _s) \ -+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \ -+ -+ // Create 2 variants, {f16,f32} accumulator -+#define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ -+ CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ -+ CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ -+ -+ CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); -+ CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); -+ CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); -+ CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); -+ -+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); -+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); -+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); -+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); -+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); -+ -+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); -+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); -+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); -+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); -+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); -+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); -+ -+ // If there's not enough shared memory for row_ids and the result tile, don't create these pipelines. -+ if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) { -+ CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); -+ CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); -+ CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); -+ -+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); -+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); -+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); -+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); -+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); -+ -+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); -+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); -+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); -+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); -+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); -+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); -+ } -+#undef CREATE_MM2 -+#undef CREATE_MM -+ } else { -+ // Create 6 variants, {s,m,l}x{unaligned,aligned} -+#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ -+ if (device->mul_mat ## ID ## _l) \ -+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ -+ if (device->mul_mat ## ID ## _m) \ -+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ -+ if (device->mul_mat ## ID ## _s) \ -+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ -+ if (device->mul_mat ## ID ## _l) \ -+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \ -+ if (device->mul_mat ## ID ## _m) \ -+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \ -+ if (device->mul_mat ## ID ## _s) \ -+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \ -+ -+ CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); -+ CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); -+ CREATE_MM(pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); -+ CREATE_MM(pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); -+ -+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); -+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); -+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); -+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); -+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); -+ -+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); -+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); -+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); -+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); -+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); -+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); -+ -+ // If there's not enough shared memory for row_ids and the result tile, don't create these pipelines. -+ if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) { -+ CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); -+ CREATE_MM(pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); -+ CREATE_MM(pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); -+ -+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); -+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); -+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); -+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); -+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); -+ -+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); -+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); -+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); -+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); -+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); -+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); -+ } -+#undef CREATE_MM -+ } -+ -+ // mul mat vec -+ -+ // the number of rows computed per shader depends on GPU model and quant -+ uint32_t rm_stdq = 1; -+ uint32_t rm_kq = 2; -+ if (device->vendor_id == VK_VENDOR_ID_AMD) { -+ if (device->subgroup_min_size == 64 && device->subgroup_max_size == 64) { // GCN -+ rm_stdq = 2; -+ rm_kq = 4; -+ } -+ } else if (device->vendor_id == VK_VENDOR_ID_INTEL) -+ rm_stdq = 2; -+ -+ for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) { -+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32_"+std::to_string(i+1), mul_mat_vec_f32_f32_f32_len, mul_mat_vec_f32_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); -+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f32_f32_"+std::to_string(i+1), mul_mat_vec_f16_f32_f32_len, mul_mat_vec_f16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); -+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); -+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); -+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); -+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_1_f32_f32_len, mul_mat_vec_q5_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); -+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q8_0_f32_f32_len, mul_mat_vec_q8_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq, i+1}, 1, true); -+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q2_k_f32_f32_len, mul_mat_vec_q2_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); -+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q3_k_f32_f32_len, mul_mat_vec_q3_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); -+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); -+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); -+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); -+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq, i+1}, 1, true); -+ -+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32_"+std::to_string(i+1), mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); -+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32_"+std::to_string(i+1), mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); -+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); -+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); -+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); -+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_1_f16_f32_len, mul_mat_vec_q5_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); -+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q8_0_f16_f32_len, mul_mat_vec_q8_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq, i+1}, 1, true); -+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q2_k_f16_f32_len, mul_mat_vec_q2_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); -+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q3_k_f16_f32_len, mul_mat_vec_q3_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); -+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); -+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); -+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); -+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq, i+1}, 1, true); -+ } -+ -+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); -+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); -+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); -+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); -+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); -+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32", mul_mat_vec_id_q5_1_f32_len, mul_mat_vec_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); -+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32", mul_mat_vec_id_q8_0_f32_len, mul_mat_vec_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq}, 1, true); -+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", mul_mat_vec_id_q2_k_f32_len, mul_mat_vec_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); -+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", mul_mat_vec_id_q3_k_f32_len, mul_mat_vec_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); -+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); -+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); -+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); -+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq}, 1, true); -+ -+ // dequant shaders -+ ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); -+ ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_0], "dequant_q4_0", dequant_q4_0_len, dequant_q4_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); -+ ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_1], "dequant_q4_1", dequant_q4_1_len, dequant_q4_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); -+ ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_0], "dequant_q5_0", dequant_q5_0_len, dequant_q5_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); -+ ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_1], "dequant_q5_1", dequant_q5_1_len, dequant_q5_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); -+ ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q8_0], "dequant_q8_0", dequant_q8_0_len, dequant_q8_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); -+ ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q2_K], "dequant_q2_k", dequant_q2_k_len, dequant_q2_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1); -+ ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q3_K], "dequant_q3_k", dequant_q3_k_len, dequant_q3_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1); -+ ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_K], "dequant_q4_k", dequant_q4_k_len, dequant_q4_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); -+ ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_K], "dequant_q5_k", dequant_q5_k_len, dequant_q5_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1); -+ ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q6_K], "dequant_q6_k", dequant_q6_k_len, dequant_q6_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1); -+ ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_NL], "dequant_iq4_nl", dequant_iq4_nl_len, dequant_iq4_nl_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); -+ -+ // get_rows -+ ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); -+ ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F16 ], "get_rows_f16", get_rows_f16_len, get_rows_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); -+ ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_0], "get_rows_q4_0", get_rows_q4_0_len, get_rows_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); -+ ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_1], "get_rows_q4_1", get_rows_q4_1_len, get_rows_q4_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); -+ ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_0], "get_rows_q5_0", get_rows_q5_0_len, get_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); -+ ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_1], "get_rows_q5_1", get_rows_q5_1_len, get_rows_q5_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); -+ ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q8_0], "get_rows_q8_0", get_rows_q8_0_len, get_rows_q8_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); -+ ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl", get_rows_iq4_nl_len, get_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); -+ -+ ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); -+ ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); -+ ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_0], "get_rows_q4_0_f32", get_rows_q4_0_f32_len, get_rows_q4_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); -+ ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_1], "get_rows_q4_1_f32", get_rows_q4_1_f32_len, get_rows_q4_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); -+ ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_0], "get_rows_q5_0_f32", get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); -+ ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_1], "get_rows_q5_1_f32", get_rows_q5_1_f32_len, get_rows_q5_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); -+ ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q8_0], "get_rows_q8_0_f32", get_rows_q8_0_f32_len, get_rows_q8_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); -+ ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); -+ -+ ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1); -+ -+ ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32, "mul_mat_vec_p021_f16_f32", mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {}, 1); -+ ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 7 * sizeof(uint32_t), {1, 1, 1}, {}, 1); -+ -+ ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); -+ ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); -+ ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); -+ -+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); -+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); -+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); -+ -+ ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f32, "contig_cpy_f32_f32", contig_cpy_f32_f32_len, contig_cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); -+ ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); -+ ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); -+ -+ ggml_vk_create_pipeline(device, device->pipeline_add_f32, "add_f32", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1); -+ ggml_vk_create_pipeline(device, device->pipeline_add_f32_norepeat, "add_f32_norepeat", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1); -+ ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16, "add_f16_f32_f16", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1); -+ ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16_norepeat, "add_f16_f32_f16_norepeat", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1); -+ -+ ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); -+ -+ ggml_vk_create_pipeline(device, device->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1); -+ ggml_vk_create_pipeline(device, device->pipeline_mul_f32_norepeat, "mul_f32_norepeat", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1); -+ ggml_vk_create_pipeline(device, device->pipeline_div_f32, "div_f32", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1); -+ ggml_vk_create_pipeline(device, device->pipeline_div_f32_norepeat, "div_f32_norepeat", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1); -+ -+ ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); -+ ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); -+ ggml_vk_create_pipeline(device, device->pipeline_concat_i32, "concat_i32", concat_i32_len, concat_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); -+ -+ ggml_vk_create_pipeline(device, device->pipeline_upscale_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {}, 1); -+ -+ ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); -+ -+ ggml_vk_create_pipeline(device, device->pipeline_sqr_f32, "sqr_f32", sqr_f32_len, sqr_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); -+ ggml_vk_create_pipeline(device, device->pipeline_sin_f32, "sin_f32", sin_f32_len, sin_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); -+ ggml_vk_create_pipeline(device, device->pipeline_cos_f32, "cos_f32", cos_f32_len, cos_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); -+ -+ ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); -+ -+ ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); -+ -+ ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32", repeat_f32_len, repeat_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); -+ -+ ggml_vk_create_pipeline(device, device->pipeline_gelu_f32, "gelu_f32", gelu_f32_len, gelu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); -+ ggml_vk_create_pipeline(device, device->pipeline_gelu_quick_f32, "gelu_quick_f32", gelu_quick_f32_len, gelu_quick_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); -+ ggml_vk_create_pipeline(device, device->pipeline_silu_f32, "silu_f32", silu_f32_len, silu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); -+ ggml_vk_create_pipeline(device, device->pipeline_relu_f32, "relu_f32", relu_f32_len, relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); -+ ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); -+ ggml_vk_create_pipeline(device, device->pipeline_tanh_f32, "tanh_f32", tanh_f32_len, tanh_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); -+ -+ ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {512, 1, 1}, {}, 1); -+ -+ ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); -+ ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1); -+ ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); -+ ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1); -+ -+ ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); -+ ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); -+ -+ if (device->float_controls_rte_fp16) { -+ ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_rte_len, rope_norm_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); -+ ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_rte_len, rope_neox_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); -+ } else { -+ ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); -+ ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); -+ } -+ -+ ggml_vk_create_pipeline(device, device->pipeline_argsort_f32, "argsort_f32", argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1024, 1, 1}, {}, 1); -+ -+ ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); -+ -+ ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); -+ if (device->float_controls_rte_fp16) { -+ ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); -+ } else { -+ ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); -+ } -+ -+ ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1); -+ -+ ggml_vk_create_pipeline(device, device->pipeline_pool2d_f32, "pool2d_f32", pool2d_f32_len, pool2d_f32_data, "main", 2, sizeof(vk_op_pool2d_push_constants), {512, 1, 1}, {}, 1); -+ -+ ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); -+ -+ for (auto &c : compiles) { -+ c.wait(); -+ } -+ std::cerr << "Done!" << std::endl; -+} -+ -+static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props); -+ -+static vk_device ggml_vk_get_device(size_t idx) { -+ VK_LOG_DEBUG("ggml_vk_get_device(" << idx << ")"); -+ -+ if (vk_instance.devices[idx] == nullptr) { -+ VK_LOG_DEBUG("Initializing new vk_device"); -+ vk_device device = std::make_shared(); -+ vk_instance.devices[idx] = device; -+ -+#ifdef GGML_VULKAN_MEMORY_DEBUG -+ device->memory_logger = std::unique_ptr(new vk_memory_logger()); -+#endif -+#ifdef GGML_VULKAN_PERF -+ device->perf_logger = std::unique_ptr(new vk_perf_logger()); -+#endif -+ -+ size_t dev_num = vk_instance.device_indices[idx]; -+ -+ std::vector physical_devices = vk_instance.instance.enumeratePhysicalDevices(); -+ -+ if (dev_num >= physical_devices.size()) { -+ std::cerr << "ggml_vulkan: Device with index " << dev_num << " does not exist." << std::endl; -+ throw std::runtime_error("Device not found"); -+ } -+ -+ device->physical_device = physical_devices[dev_num]; -+ const std::vector ext_props = device->physical_device.enumerateDeviceExtensionProperties(); -+ -+ bool fp16_storage = false; -+ bool fp16_compute = false; -+ bool maintenance4_support = false; -+ bool sm_builtins = false; -+ bool amd_shader_core_properties2 = false; -+ bool pipeline_robustness = false; -+ bool coopmat2_support = false; -+ device->coopmat_support = false; -+ -+ // Check if maintenance4 is supported -+ for (const auto& properties : ext_props) { -+ if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) { -+ maintenance4_support = true; -+ } else if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) { -+ fp16_storage = true; -+ } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) { -+ fp16_compute = true; -+ } else if (strcmp("VK_NV_shader_sm_builtins", properties.extensionName) == 0) { -+ sm_builtins = true; -+ } else if (strcmp("VK_AMD_shader_core_properties2", properties.extensionName) == 0) { -+ amd_shader_core_properties2 = true; -+ } else if (strcmp("VK_EXT_pipeline_robustness", properties.extensionName) == 0) { -+ pipeline_robustness = true; -+ } else if (strcmp("VK_EXT_subgroup_size_control", properties.extensionName) == 0) { -+ device->subgroup_size_control = true; -+ } else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0 && -+ !getenv("GGML_VK_DISABLE_COOPMAT")) { -+ device->coopmat_support = true; -+ device->coopmat_m = 0; -+ device->coopmat_n = 0; -+ device->coopmat_k = 0; -+ } else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 && -+ !getenv("GGML_VK_DISABLE_COOPMAT2")) { -+ coopmat2_support = true; -+ } -+ } -+ -+ vk::PhysicalDeviceProperties2 props2; -+ vk::PhysicalDeviceMaintenance3Properties props3; -+ vk::PhysicalDeviceMaintenance4Properties props4; -+ vk::PhysicalDeviceSubgroupProperties subgroup_props; -+ vk::PhysicalDeviceDriverProperties driver_props; -+ vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props; -+ vk::PhysicalDeviceShaderCoreProperties2AMD amd_shader_core_properties2_props; -+ vk::PhysicalDeviceVulkan12Properties vk12_props; -+ vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props; -+ -+ props2.pNext = &props3; -+ props3.pNext = &subgroup_props; -+ subgroup_props.pNext = &driver_props; -+ driver_props.pNext = &vk12_props; -+ -+ VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_props; -+ -+ if (maintenance4_support) { -+ last_struct->pNext = (VkBaseOutStructure *)&props4; -+ last_struct = (VkBaseOutStructure *)&props4; -+ } -+ if (sm_builtins) { -+ last_struct->pNext = (VkBaseOutStructure *)&sm_props; -+ last_struct = (VkBaseOutStructure *)&sm_props; -+ } -+ if (amd_shader_core_properties2) { -+ last_struct->pNext = (VkBaseOutStructure *)&amd_shader_core_properties2_props; -+ last_struct = (VkBaseOutStructure *)&amd_shader_core_properties2_props; -+ } -+ if (device->subgroup_size_control) { -+ last_struct->pNext = (VkBaseOutStructure *)&subgroup_size_control_props; -+ last_struct = (VkBaseOutStructure *)&subgroup_size_control_props; -+ } -+ -+#if defined(VK_NV_cooperative_matrix2) -+ vk::PhysicalDeviceCooperativeMatrix2PropertiesNV coopmat2_props; -+ if (coopmat2_support) { -+ last_struct->pNext = (VkBaseOutStructure *)&coopmat2_props; -+ last_struct = (VkBaseOutStructure *)&coopmat2_props; -+ } -+#endif -+ -+ device->physical_device.getProperties2(&props2); -+ device->properties = props2.properties; -+ -+ const char* GGML_VK_FORCE_MAX_ALLOCATION_SIZE = getenv("GGML_VK_FORCE_MAX_ALLOCATION_SIZE"); -+ -+ if (GGML_VK_FORCE_MAX_ALLOCATION_SIZE != nullptr) { -+ device->max_memory_allocation_size = std::stoul(GGML_VK_FORCE_MAX_ALLOCATION_SIZE); -+ } else if (maintenance4_support) { -+ device->max_memory_allocation_size = std::min(props3.maxMemoryAllocationSize, props4.maxBufferSize); -+ } else { -+ device->max_memory_allocation_size = props3.maxMemoryAllocationSize; -+ } -+ -+ device->vendor_id = device->properties.vendorID; -+ device->subgroup_size = subgroup_props.subgroupSize; -+ device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu; -+ if (sm_builtins) { -+ device->shader_core_count = sm_props.shaderSMCount; -+ } else if (amd_shader_core_properties2) { -+ device->shader_core_count = amd_shader_core_properties2_props.activeComputeUnitCount; -+ } else { -+ device->shader_core_count = 0; -+ } -+ device->float_controls_rte_fp16 = vk12_props.shaderRoundingModeRTEFloat16; -+ -+ const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr; -+ -+ device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute; -+ -+ if (!ggml_vk_khr_cooperative_matrix_support(device->properties, driver_props)) { -+ device->coopmat_support = false; -+ } -+ -+ std::vector queue_family_props = device->physical_device.getQueueFamilyProperties(); -+ -+ // Try to find a non-graphics compute queue and transfer-focused queues -+ const uint32_t compute_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eCompute, vk::QueueFlagBits::eGraphics, -1, 1); -+ const uint32_t transfer_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eTransfer, vk::QueueFlagBits::eCompute | vk::QueueFlagBits::eGraphics, compute_queue_family_index, 1); -+ -+ const float priorities[] = { 1.0f, 1.0f }; -+ device->single_queue = compute_queue_family_index == transfer_queue_family_index && queue_family_props[compute_queue_family_index].queueCount == 1; -+ -+ std::vector device_queue_create_infos; -+ if (compute_queue_family_index != transfer_queue_family_index) { -+ device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities}); -+ device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), transfer_queue_family_index, 1, priorities + 1}); -+ } else if(!device->single_queue) { -+ device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 2, priorities}); -+ } else { -+ device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities}); -+ } -+ vk::DeviceCreateInfo device_create_info; -+ std::vector device_extensions; -+ vk::PhysicalDeviceFeatures device_features = device->physical_device.getFeatures(); -+ -+ VkPhysicalDeviceFeatures2 device_features2; -+ device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2; -+ device_features2.pNext = nullptr; -+ device_features2.features = (VkPhysicalDeviceFeatures)device_features; -+ -+ VkPhysicalDeviceVulkan11Features vk11_features; -+ vk11_features.pNext = nullptr; -+ vk11_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_1_FEATURES; -+ device_features2.pNext = &vk11_features; -+ -+ VkPhysicalDeviceVulkan12Features vk12_features; -+ vk12_features.pNext = nullptr; -+ vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES; -+ vk11_features.pNext = &vk12_features; -+ -+ last_struct = (VkBaseOutStructure *)&vk12_features; -+ -+ VkPhysicalDevicePipelineRobustnessFeaturesEXT pl_robustness_features; -+ pl_robustness_features.pNext = nullptr; -+ pl_robustness_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PIPELINE_ROBUSTNESS_FEATURES_EXT; -+ pl_robustness_features.pipelineRobustness = VK_FALSE; -+ -+ if (pipeline_robustness) { -+ last_struct->pNext = (VkBaseOutStructure *)&pl_robustness_features; -+ last_struct = (VkBaseOutStructure *)&pl_robustness_features; -+ device_extensions.push_back("VK_EXT_pipeline_robustness"); -+ } -+ -+ VkPhysicalDeviceSubgroupSizeControlFeaturesEXT subgroup_size_control_features; -+ subgroup_size_control_features.pNext = nullptr; -+ subgroup_size_control_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_SIZE_CONTROL_FEATURES_EXT; -+ subgroup_size_control_features.computeFullSubgroups = false; -+ subgroup_size_control_features.subgroupSizeControl = false; -+ -+ if (device->subgroup_size_control) { -+ last_struct->pNext = (VkBaseOutStructure *)&subgroup_size_control_features; -+ last_struct = (VkBaseOutStructure *)&subgroup_size_control_features; -+ } -+ -+ VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features; -+ coopmat_features.pNext = nullptr; -+ coopmat_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR; -+ coopmat_features.cooperativeMatrix = VK_FALSE; -+ -+ if (device->coopmat_support) { -+ last_struct->pNext = (VkBaseOutStructure *)&coopmat_features; -+ last_struct = (VkBaseOutStructure *)&coopmat_features; -+ } -+ -+#if defined(VK_NV_cooperative_matrix2) -+ VkPhysicalDeviceCooperativeMatrix2FeaturesNV coopmat2_features {}; -+ coopmat2_features.pNext = nullptr; -+ coopmat2_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_2_FEATURES_NV; -+ if (coopmat2_support) { -+ last_struct->pNext = (VkBaseOutStructure *)&coopmat2_features; -+ last_struct = (VkBaseOutStructure *)&coopmat2_features; -+ device_extensions.push_back("VK_NV_cooperative_matrix2"); -+ } -+#endif -+ -+ vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2); -+ -+ device->fp16 = device->fp16 && vk12_features.shaderFloat16; -+ -+ device->pipeline_robustness = pl_robustness_features.pipelineRobustness; -+ -+ if (device->subgroup_size_control) { -+ device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize; -+ device->subgroup_max_size = subgroup_size_control_props.maxSubgroupSize; -+ } -+ -+ device->subgroup_size_control = device->subgroup_size_control && -+ (subgroup_size_control_props.requiredSubgroupSizeStages & vk::ShaderStageFlagBits::eCompute) && -+ subgroup_size_control_features.subgroupSizeControl; -+ -+ if (device->subgroup_size_control) { -+ device->subgroup_require_full_support = subgroup_size_control_features.computeFullSubgroups; -+ device_extensions.push_back("VK_EXT_subgroup_size_control"); -+ } -+ -+ device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix; -+ -+ if (coopmat2_support) { -+#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) -+ if (coopmat2_features.cooperativeMatrixWorkgroupScope && -+ coopmat2_features.cooperativeMatrixFlexibleDimensions && -+ coopmat2_features.cooperativeMatrixReductions && -+ coopmat2_features.cooperativeMatrixConversions && -+ coopmat2_features.cooperativeMatrixPerElementOperations && -+ coopmat2_features.cooperativeMatrixTensorAddressing && -+ coopmat2_features.cooperativeMatrixBlockLoads && -+ vk12_features.bufferDeviceAddress) { -+ -+ std::vector flexible_dimensions; -+ uint32_t count = 0; -+ -+ PFN_vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV -+ _vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV = -+ (PFN_vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV) -+ vk_instance.instance.getProcAddr("vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV"); -+ -+ _vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV(device->physical_device, &count, nullptr); -+ -+ VkCooperativeMatrixFlexibleDimensionsPropertiesNV empty_prop {}; -+ empty_prop.sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_FLEXIBLE_DIMENSIONS_PROPERTIES_NV; -+ flexible_dimensions.resize(count, empty_prop); -+ -+ _vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV(device->physical_device, &count, flexible_dimensions.data()); -+ -+ bool found_fp16_128 = false, -+ found_fp16_256 = false, -+ found_fp32_128 = false, -+ found_fp32_256 = false; -+ // need to support fp16*fp16 with fp16/fp32 accumulator, for workgroupsize 128 -+ // with 32x16x16 and 256 with 32x32x16. -+ for (auto &prop : flexible_dimensions) { -+ if (prop.saturatingAccumulation == VK_FALSE && -+ prop.scope == VK_SCOPE_WORKGROUP_KHR && -+ prop.AType == VK_COMPONENT_TYPE_FLOAT16_KHR && -+ prop.BType == VK_COMPONENT_TYPE_FLOAT16_KHR) { -+ -+ if (prop.workgroupInvocations == 128 && -+ prop.MGranularity <= 32 && -+ prop.NGranularity <= 16 && -+ prop.KGranularity <= 16) { -+ if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR && -+ prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) { -+ found_fp16_128 = true; -+ } -+ if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR && -+ prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) { -+ found_fp32_128 = true; -+ } -+ } -+ if (prop.workgroupInvocations == 256 && -+ prop.MGranularity <= 32 && -+ prop.NGranularity <= 32 && -+ prop.KGranularity <= 16) { -+ if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR && -+ prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) { -+ found_fp16_256 = true; -+ } -+ if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR && -+ prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) { -+ found_fp32_256 = true; -+ } -+ } -+ } -+ } -+ if (found_fp16_128 && found_fp16_256 && -+ found_fp32_128 && found_fp32_256 && -+ coopmat2_props.cooperativeMatrixFlexibleDimensionsMaxDimension >= 512) { -+ device->coopmat2 = true; -+ } -+ } -+#endif -+ } -+ -+ if (!vk11_features.storageBuffer16BitAccess) { -+ std::cerr << "ggml_vulkan: device " << GGML_VK_NAME << idx << " does not support 16-bit storage." << std::endl; -+ throw std::runtime_error("Unsupported device"); -+ } -+ -+ device_extensions.push_back("VK_KHR_16bit_storage"); -+ -+#ifdef GGML_VULKAN_VALIDATE -+ device_extensions.push_back("VK_KHR_shader_non_semantic_info"); -+#endif -+ -+ if (device->fp16) { -+ device_extensions.push_back("VK_KHR_shader_float16_int8"); -+ } -+ -+ if (device->coopmat_support) { -+ // Query supported shapes -+ std::vector cm_props; -+ -+ PFN_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR pfn_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR = -+ (PFN_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR)vkGetInstanceProcAddr(vk_instance.instance, "vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR"); -+ -+ uint32_t cm_props_num; -+ -+ pfn_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR(device->physical_device, &cm_props_num, nullptr); -+ -+ cm_props.resize(cm_props_num); -+ -+ for (auto& prop : cm_props) { -+ prop.sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_PROPERTIES_KHR; -+ } -+ -+ pfn_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR(device->physical_device, &cm_props_num, cm_props.data()); -+ -+ VK_LOG_DEBUG("ggml_vulkan: Cooperative Matrix Shapes: " << cm_props.size()); -+ -+ for (auto& prop : cm_props) { -+ VK_LOG_DEBUG("ggml_vulkan: M: " << prop.MSize << " N: " << prop.NSize << " K: " << prop.KSize << " A: " << vk::to_string((vk::ComponentTypeKHR)prop.AType) << " B: " << vk::to_string((vk::ComponentTypeKHR)prop.BType) << " C: " << vk::to_string((vk::ComponentTypeKHR)prop.CType) << " Result: " << vk::to_string((vk::ComponentTypeKHR)prop.ResultType) << " saturatingAccumulation: " << prop.saturatingAccumulation << " scope: " << vk::to_string((vk::ScopeKHR)prop.scope)); -+ -+ if ((vk::ComponentTypeKHR)prop.AType == vk::ComponentTypeKHR::eFloat16 && -+ (vk::ComponentTypeKHR)prop.BType == vk::ComponentTypeKHR::eFloat16 && -+ (vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup -+ ) { -+ if ((vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat32 && -+ (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat32) { -+ // coopmat sizes not set yet -+ if (device->coopmat_m == 0) { -+ device->coopmat_acc_f32_support = true; -+ device->coopmat_m = prop.MSize; -+ device->coopmat_n = prop.NSize; -+ device->coopmat_k = prop.KSize; -+ } else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) { -+ // Only enable if shape is identical -+ device->coopmat_acc_f32_support = true; -+ } -+ } else if ((vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat16 && -+ (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat16) { -+ // coopmat sizes not set yet -+ if (device->coopmat_m == 0) { -+ device->coopmat_acc_f16_support = true; -+ device->coopmat_m = prop.MSize; -+ device->coopmat_n = prop.NSize; -+ device->coopmat_k = prop.KSize; -+ } else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) { -+ // Only enable if shape is identical -+ device->coopmat_acc_f16_support = true; -+ } -+ } -+ } -+ } -+ -+ if (device->coopmat_m == 0 || !device->coopmat_acc_f32_support) { -+ // No suitable matmul mode found -+ GGML_LOG_DEBUG("ggml_vulkan: WARNING: No suitable matrix core mode found. Disabling matrix cores.\n"); -+ device->coopmat_support = false; -+ } -+ } -+ -+ if (device->coopmat_support) { -+ device_extensions.push_back("VK_KHR_cooperative_matrix"); -+ } -+ -+ device->name = GGML_VK_NAME + std::to_string(idx); -+ -+ device_create_info = { -+ vk::DeviceCreateFlags(), -+ device_queue_create_infos, -+ {}, -+ device_extensions -+ }; -+ device_create_info.setPNext(&device_features2); -+ device->device = device->physical_device.createDevice(device_create_info); -+ -+ // Queues -+ ggml_vk_create_queue(device, device->compute_queue, compute_queue_family_index, 0, { vk::PipelineStageFlagBits::eComputeShader | vk::PipelineStageFlagBits::eTransfer }, false); -+ -+ // Shaders -+ // Disable matmul tile sizes early if performance low or not supported -+ switch (device->vendor_id) { -+#ifndef GGML_VULKAN_RUN_TESTS -+ case VK_VENDOR_ID_AMD: -+ case VK_VENDOR_ID_INTEL: -+ device->mul_mat_l = false; -+ device->mul_mat_m = true; -+ device->mul_mat_s = true; -+ device->mul_mat_id_l = false; -+ device->mul_mat_id_m = true; -+ device->mul_mat_id_s = true; -+ break; -+ case VK_VENDOR_ID_APPLE: -+ device->mul_mat_l = false; -+ device->mul_mat_m = true; -+ device->mul_mat_s = false; -+ device->mul_mat_id_l = false; -+ device->mul_mat_id_m = true; -+ device->mul_mat_id_s = false; -+ break; -+#endif -+ default: -+ device->mul_mat_l = true; -+ device->mul_mat_m = true; -+ device->mul_mat_s = true; -+ device->mul_mat_id_l = true; -+ device->mul_mat_id_m = true; -+ device->mul_mat_id_s = true; -+ break; -+ } -+ -+ ggml_vk_load_shaders(device); -+ -+ if (!device->single_queue) { -+ const uint32_t transfer_queue_index = compute_queue_family_index == transfer_queue_family_index ? 1 : 0; -+ ggml_vk_create_queue(device, device->transfer_queue, transfer_queue_family_index, transfer_queue_index, { vk::PipelineStageFlagBits::eTransfer }, true); -+ } else { -+ // TODO: Use pointer or reference to avoid copy -+ device->transfer_queue = device->compute_queue; -+ } -+ -+ device->buffer_type = { -+ /* .iface = */ ggml_backend_vk_buffer_type_interface, -+ /* .device = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), idx), -+ /* .context = */ new ggml_backend_vk_buffer_type_context{ device->name, device }, -+ }; -+ -+ device->fence = device->device.createFence({}); -+ -+ device->idx = idx; -+ -+ return device; -+ } -+ -+ return vk_instance.devices[idx]; -+} -+ -+static void ggml_vk_print_gpu_info(size_t idx) { -+ GGML_ASSERT(idx < vk_instance.device_indices.size()); -+ size_t dev_num = vk_instance.device_indices[idx]; -+ VK_LOG_DEBUG("ggml_vk_print_gpu_info(" << dev_num << ")"); -+ GGML_ASSERT(vk_instance_initialized); -+ -+ std::vector devices = vk_instance.instance.enumeratePhysicalDevices(); -+ -+ if (dev_num >= devices.size()) { -+ std::cerr << "ggml_vulkan: Device with index " << dev_num << " does not exist." << std::endl; -+ throw std::runtime_error("Device not found"); -+ } -+ -+ vk::PhysicalDevice physical_device = devices[dev_num]; -+ std::vector ext_props = physical_device.enumerateDeviceExtensionProperties(); -+ -+ vk::PhysicalDeviceProperties2 props2; -+ vk::PhysicalDeviceMaintenance3Properties props3; -+ vk::PhysicalDeviceSubgroupProperties subgroup_props; -+ vk::PhysicalDeviceDriverProperties driver_props; -+ props2.pNext = &props3; -+ props3.pNext = &subgroup_props; -+ subgroup_props.pNext = &driver_props; -+ physical_device.getProperties2(&props2); -+ -+ const size_t subgroup_size = subgroup_props.subgroupSize; -+ const bool uma = props2.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu; -+ -+ bool fp16_storage = false; -+ bool fp16_compute = false; -+ bool coopmat_support = false; -+ bool coopmat2_support = false; -+ -+ for (auto properties : ext_props) { -+ if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) { -+ fp16_storage = true; -+ } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) { -+ fp16_compute = true; -+ } else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0 && -+ !getenv("GGML_VK_DISABLE_COOPMAT")) { -+ coopmat_support = true; -+#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) -+ } else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 && -+ !getenv("GGML_VK_DISABLE_COOPMAT2")) { -+ coopmat2_support = true; -+#endif -+ } -+ } -+ -+ if (!ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props)) { -+ coopmat_support = false; -+ } -+ -+ const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16"); -+ bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr; -+ -+ bool fp16 = !force_disable_f16 && fp16_storage && fp16_compute; -+ -+ vk::PhysicalDeviceFeatures device_features = physical_device.getFeatures(); -+ -+ VkPhysicalDeviceFeatures2 device_features2; -+ device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2; -+ device_features2.pNext = nullptr; -+ device_features2.features = (VkPhysicalDeviceFeatures)device_features; -+ -+ VkPhysicalDeviceVulkan11Features vk11_features; -+ vk11_features.pNext = nullptr; -+ vk11_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_1_FEATURES; -+ device_features2.pNext = &vk11_features; -+ -+ VkPhysicalDeviceVulkan12Features vk12_features; -+ vk12_features.pNext = nullptr; -+ vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES; -+ vk11_features.pNext = &vk12_features; -+ -+ // Pointer to the last chain element -+ VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_features; -+ -+ VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features; -+ coopmat_features.pNext = nullptr; -+ coopmat_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR; -+ coopmat_features.cooperativeMatrix = VK_FALSE; -+ -+ if (coopmat_support) { -+ last_struct->pNext = (VkBaseOutStructure *)&coopmat_features; -+ last_struct = (VkBaseOutStructure *)&coopmat_features; -+ } -+ -+ vkGetPhysicalDeviceFeatures2(physical_device, &device_features2); -+ -+ fp16 = fp16 && vk12_features.shaderFloat16; -+ -+ coopmat_support = coopmat_support && coopmat_features.cooperativeMatrix; -+ -+ std::string matrix_cores = coopmat2_support ? "NV_coopmat2" : coopmat_support ? "KHR_coopmat" : "none"; -+ -+ std::string device_name = props2.properties.deviceName.data(); -+ GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | matrix cores: %s\n", -+ idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, subgroup_size, matrix_cores.c_str()); -+ -+ if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) { -+ GGML_LOG_DEBUG("ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want.\n"); -+ } -+} -+ -+static bool ggml_vk_instance_validation_ext_available(const std::vector& instance_extensions); -+static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector& instance_extensions); -+ -+void ggml_vk_instance_init() { -+ if (vk_instance_initialized) { -+ return; -+ } -+ VK_LOG_DEBUG("ggml_vk_instance_init()"); -+ -+ vk_instance_initialized = true; -+ -+ vk::ApplicationInfo app_info{ "ggml-vulkan", 1, nullptr, 0, VK_API_VERSION }; -+ -+ const std::vector instance_extensions = vk::enumerateInstanceExtensionProperties(); -+ const bool validation_ext = ggml_vk_instance_validation_ext_available(instance_extensions); -+#ifdef __APPLE__ -+ const bool portability_enumeration_ext = ggml_vk_instance_portability_enumeration_ext_available(instance_extensions); -+#endif -+ -+ std::vector layers; -+ -+ if (validation_ext) { -+ layers.push_back("VK_LAYER_KHRONOS_validation"); -+ } -+ std::vector extensions; -+ if (validation_ext) { -+ extensions.push_back("VK_EXT_validation_features"); -+ } -+#ifdef __APPLE__ -+ if (portability_enumeration_ext) { -+ extensions.push_back("VK_KHR_portability_enumeration"); -+ } -+#endif -+ vk::InstanceCreateInfo instance_create_info(vk::InstanceCreateFlags{}, &app_info, layers, extensions); -+#ifdef __APPLE__ -+ if (portability_enumeration_ext) { -+ instance_create_info.flags |= vk::InstanceCreateFlagBits::eEnumeratePortabilityKHR; -+ } -+#endif -+ -+ std::vector features_enable; -+ vk::ValidationFeaturesEXT validation_features; -+ -+ if (validation_ext) { -+ features_enable = { vk::ValidationFeatureEnableEXT::eBestPractices }; -+ validation_features = { -+ features_enable, -+ {}, -+ }; -+ validation_features.setPNext(nullptr); -+ instance_create_info.setPNext(&validation_features); -+ GGML_LOG_DEBUG("ggml_vulkan: Validation layers enabled\n"); -+ } -+ vk_instance.instance = vk::createInstance(instance_create_info); -+ -+ size_t num_available_devices = vk_instance.instance.enumeratePhysicalDevices().size(); -+ -+ // Emulate behavior of CUDA_VISIBLE_DEVICES for Vulkan -+ char * devices_env = getenv("GGML_VK_VISIBLE_DEVICES"); -+ if (devices_env != nullptr) { -+ std::string devices(devices_env); -+ std::replace(devices.begin(), devices.end(), ',', ' '); -+ -+ std::stringstream ss(devices); -+ size_t tmp; -+ while (ss >> tmp) { -+ if(tmp >= num_available_devices) { -+ std::cerr << "ggml_vulkan: Invalid device index " << tmp << " in GGML_VK_VISIBLE_DEVICES." << std::endl; -+ throw std::runtime_error("Invalid Vulkan device index"); -+ } -+ vk_instance.device_indices.push_back(tmp); -+ } -+ } else { -+ std::vector devices = vk_instance.instance.enumeratePhysicalDevices(); -+ -+ // Make sure at least one device exists -+ if (devices.empty()) { -+ std::cerr << "ggml_vulkan: Error: No devices found." << std::endl; -+ GGML_ABORT("fatal error"); -+ } -+ -+ // Default to using all dedicated GPUs -+ for (size_t i = 0; i < devices.size(); i++) { -+ vk::PhysicalDeviceProperties2 new_props; -+ vk::PhysicalDeviceDriverProperties new_driver; -+ vk::PhysicalDeviceIDProperties new_id; -+ new_props.pNext = &new_driver; -+ new_driver.pNext = &new_id; -+ devices[i].getProperties2(&new_props); -+ -+ if (new_props.properties.deviceType == vk::PhysicalDeviceType::eDiscreteGpu) { -+ // Check if there are two physical devices corresponding to the same GPU -+ auto old_device = std::find_if( -+ vk_instance.device_indices.begin(), -+ vk_instance.device_indices.end(), -+ [&devices, &new_id](const size_t k){ -+ vk::PhysicalDeviceProperties2 old_props; -+ vk::PhysicalDeviceIDProperties old_id; -+ old_props.pNext = &old_id; -+ devices[k].getProperties2(&old_props); -+ return std::equal(std::begin(old_id.deviceUUID), std::end(old_id.deviceUUID), std::begin(new_id.deviceUUID)); -+ } -+ ); -+ if (old_device == vk_instance.device_indices.end()) { -+ vk_instance.device_indices.push_back(i); -+ } else { -+ // There can be two physical devices corresponding to the same GPU if there are 2 different drivers -+ // This can cause error when splitting layers aross the devices, need to keep only 1 -+ VK_LOG_DEBUG("Device " << i << " and device " << *old_device << " have the same deviceUUID"); -+ -+ vk::PhysicalDeviceProperties2 old_props; -+ vk::PhysicalDeviceDriverProperties old_driver; -+ old_props.pNext = &old_driver; -+ devices[*old_device].getProperties2(&old_props); -+ -+ std::map driver_priorities {}; -+ int old_priority = std::numeric_limits::max(); -+ int new_priority = std::numeric_limits::max(); -+ -+ // Check https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VkDriverId.html for the list of driver id -+ // Smaller number -> higher priority -+ switch (old_props.properties.vendorID) { -+ case VK_VENDOR_ID_AMD: -+ driver_priorities[vk::DriverId::eMesaRadv] = 1; -+ driver_priorities[vk::DriverId::eAmdOpenSource] = 2; -+ driver_priorities[vk::DriverId::eAmdProprietary] = 3; -+ break; -+ case VK_VENDOR_ID_INTEL: -+ driver_priorities[vk::DriverId::eIntelOpenSourceMESA] = 1; -+ driver_priorities[vk::DriverId::eIntelProprietaryWindows] = 2; -+ break; -+ case VK_VENDOR_ID_NVIDIA: -+ driver_priorities[vk::DriverId::eNvidiaProprietary] = 1; -+#if defined(VK_API_VERSION_1_3) && VK_HEADER_VERSION >= 235 -+ driver_priorities[vk::DriverId::eMesaNvk] = 2; -+#endif -+ break; -+ } -+ -+ if (driver_priorities.count(old_driver.driverID)) { -+ old_priority = driver_priorities[old_driver.driverID]; -+ } -+ if (driver_priorities.count(new_driver.driverID)) { -+ new_priority = driver_priorities[new_driver.driverID]; -+ } -+ -+ if (new_priority < old_priority) { -+ auto r = std::remove(vk_instance.device_indices.begin(), vk_instance.device_indices.end(), *old_device); -+ vk_instance.device_indices.erase(r, vk_instance.device_indices.end()); -+ vk_instance.device_indices.push_back(i); -+ -+ VK_LOG_DEBUG("Prioritize device " << i << " driver " << new_driver.driverName << " over device " << *old_device << " driver " << old_driver.driverName); -+ } -+ else { -+ VK_LOG_DEBUG("Prioritize device " << *old_device << " driver " << old_driver.driverName << " over device " << i << " driver " << new_driver.driverName << std::endl); -+ } -+ } -+ } -+ } -+ -+ // If no dedicated GPUs found, fall back to GPU 0 -+ if (vk_instance.device_indices.empty()) { -+ vk_instance.device_indices.push_back(0); -+ } -+ } -+ GGML_LOG_DEBUG("ggml_vulkan: Found %zu Vulkan devices:\n", vk_instance.device_indices.size()); -+ -+ for (size_t i = 0; i < vk_instance.device_indices.size(); i++) { -+ ggml_vk_print_gpu_info(i); -+ } -+} -+ -+static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) { -+ VK_LOG_DEBUG("ggml_vk_init(" << ctx->name << ", " << idx << ")"); -+ ggml_vk_instance_init(); -+ GGML_ASSERT(idx < vk_instance.device_indices.size()); -+ -+ ctx->name = GGML_VK_NAME + std::to_string(idx); -+ -+ ctx->device = ggml_vk_get_device(idx); -+ -+ ctx->semaphore_idx = 0; -+ ctx->event_idx = 0; -+ -+ ctx->prealloc_size_x = 0; -+ ctx->prealloc_size_y = 0; -+ ctx->prealloc_size_split_k = 0; -+ -+ ctx->fence = ctx->device->device.createFence({}); -+ -+#ifdef GGML_VULKAN_CHECK_RESULTS -+ const char* skip_checks = getenv("GGML_VULKAN_SKIP_CHECKS"); -+ vk_skip_checks = (skip_checks == NULL ? 0 : atoi(skip_checks)); -+ const char* output_tensor = getenv("GGML_VULKAN_OUTPUT_TENSOR"); -+ vk_output_tensor = (output_tensor == NULL ? 0 : atoi(output_tensor)); -+#endif -+} -+ -+static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type type) { -+ VK_LOG_DEBUG("ggml_vk_get_to_fp16()"); -+ switch (type) { -+ case GGML_TYPE_F32: -+ case GGML_TYPE_Q4_0: -+ case GGML_TYPE_Q4_1: -+ case GGML_TYPE_Q5_0: -+ case GGML_TYPE_Q5_1: -+ case GGML_TYPE_Q8_0: -+ case GGML_TYPE_Q2_K: -+ case GGML_TYPE_Q3_K: -+ case GGML_TYPE_Q4_K: -+ case GGML_TYPE_Q5_K: -+ case GGML_TYPE_Q6_K: -+ case GGML_TYPE_IQ4_NL: -+ break; -+ default: -+ return nullptr; -+ } -+ -+ return ctx->device->pipeline_dequant[type]; -+} -+ -+static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type, ggml_prec prec) { -+ VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_pipeline(" << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")"); -+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) { -+ return ctx->device->pipeline_matmul_f32; -+ } -+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) { -+ return ctx->device->pipeline_matmul_f32_f16; -+ } -+ if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) { -+ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) { -+ return ctx->device->pipeline_matmul_f16_f32.f16acc; -+ } -+ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { -+ return ctx->device->pipeline_matmul_f16.f16acc; -+ } -+ } else { -+ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) { -+ return ctx->device->pipeline_matmul_f16_f32.f32acc; -+ } -+ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { -+ return ctx->device->pipeline_matmul_f16.f32acc; -+ } -+ } -+ -+ if (src1_type != GGML_TYPE_F32 && !ctx->device->coopmat2) { -+ return nullptr; -+ } -+ -+ switch (src0_type) { -+ case GGML_TYPE_Q4_0: -+ case GGML_TYPE_Q4_1: -+ case GGML_TYPE_Q5_0: -+ case GGML_TYPE_Q5_1: -+ case GGML_TYPE_Q8_0: -+ case GGML_TYPE_Q2_K: -+ case GGML_TYPE_Q3_K: -+ case GGML_TYPE_Q4_K: -+ case GGML_TYPE_Q5_K: -+ case GGML_TYPE_Q6_K: -+ case GGML_TYPE_IQ4_NL: -+ break; -+ default: -+ return nullptr; -+ } -+ -+ if (ctx->device->coopmat2) { -+ assert(src1_type == GGML_TYPE_F16); -+ return ctx->device->pipeline_dequant_mul_mat_mat_f16[src0_type].f16acc; -+ } -+ return ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc; -+} -+ -+static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type, uint32_t num_cols) { -+ VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec()"); -+ GGML_ASSERT(b_type == GGML_TYPE_F32 || b_type == GGML_TYPE_F16); -+ GGML_ASSERT(num_cols >= 1 && num_cols <= mul_mat_vec_max_cols); -+ -+ switch (a_type) { -+ case GGML_TYPE_F32: -+ case GGML_TYPE_F16: -+ case GGML_TYPE_Q4_0: -+ case GGML_TYPE_Q4_1: -+ case GGML_TYPE_Q5_0: -+ case GGML_TYPE_Q5_1: -+ case GGML_TYPE_Q8_0: -+ case GGML_TYPE_Q2_K: -+ case GGML_TYPE_Q3_K: -+ case GGML_TYPE_Q4_K: -+ case GGML_TYPE_Q5_K: -+ case GGML_TYPE_Q6_K: -+ case GGML_TYPE_IQ4_NL: -+ break; -+ default: -+ return nullptr; -+ } -+ -+ return b_type == GGML_TYPE_F32 ? ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[a_type][num_cols-1] : ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[a_type][num_cols-1]; -+} -+ -+static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type, ggml_prec prec) { -+ VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_id_pipeline()"); -+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) { -+ return ctx->device->pipeline_matmul_id_f32; -+ } -+ if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) { -+ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) { -+ return ctx->device->pipeline_matmul_id_f16_f32.f16acc; -+ } -+ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { -+ return ctx->device->pipeline_matmul_id_f16.f16acc; -+ } -+ } else { -+ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) { -+ return ctx->device->pipeline_matmul_id_f16_f32.f32acc; -+ } -+ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { -+ return ctx->device->pipeline_matmul_id_f16.f32acc; -+ } -+ } -+ -+ GGML_ASSERT(src1_type == GGML_TYPE_F32); -+ -+ switch (src0_type) { -+ case GGML_TYPE_Q4_0: -+ case GGML_TYPE_Q4_1: -+ case GGML_TYPE_Q5_0: -+ case GGML_TYPE_Q5_1: -+ case GGML_TYPE_Q8_0: -+ case GGML_TYPE_Q2_K: -+ case GGML_TYPE_Q3_K: -+ case GGML_TYPE_Q4_K: -+ case GGML_TYPE_Q5_K: -+ case GGML_TYPE_Q6_K: -+ case GGML_TYPE_IQ4_NL: -+ break; -+ default: -+ return nullptr; -+ } -+ -+ return ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f32acc; -+} -+ -+static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type) { -+ VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec()"); -+ GGML_ASSERT(b_type == GGML_TYPE_F32); -+ -+ switch (a_type) { -+ case GGML_TYPE_F32: -+ case GGML_TYPE_F16: -+ case GGML_TYPE_Q4_0: -+ case GGML_TYPE_Q4_1: -+ case GGML_TYPE_Q5_0: -+ case GGML_TYPE_Q5_1: -+ case GGML_TYPE_Q8_0: -+ case GGML_TYPE_Q2_K: -+ case GGML_TYPE_Q3_K: -+ case GGML_TYPE_Q4_K: -+ case GGML_TYPE_Q5_K: -+ case GGML_TYPE_Q6_K: -+ case GGML_TYPE_IQ4_NL: -+ break; -+ default: -+ return nullptr; -+ } -+ -+ return ctx->device->pipeline_dequant_mul_mat_vec_id_f32[a_type]; -+} -+ -+static vk_buffer ggml_vk_pool_malloc(ggml_backend_vk_context * ctx, size_t size) { -+ VK_LOG_DEBUG("ggml_vk_pool_malloc(" << size << ")"); -+ VK_LOG_MEMORY("ggml_vk_pool_malloc"); -+ -+ int best_i = -1; -+ size_t best_size = std::numeric_limits::max(); //smallest unused buffer that fits our needs -+ int worst_i = -1; -+ size_t worst_size = 0; //largest unused buffer seen so far -+ for (int i = 0; i < MAX_VK_BUFFERS; ++i) { -+ vk_buffer &b = ctx->buffer_pool[i]; -+ if (b != nullptr && b->size >= size && b->size < best_size) { -+ best_i = i; -+ best_size = b->size; -+ } -+ if (b != nullptr && b->size > worst_size) { -+ worst_i = i; -+ worst_size = b->size; -+ } -+ } -+ if(best_i != -1) { -+ //found the smallest buffer that fits our needs -+ vk_buffer b = ctx->buffer_pool[best_i]; -+ ctx->buffer_pool[best_i].reset(); -+ return b; -+ } -+ if(worst_i != -1) { -+ //no buffer that fits our needs, resize largest one to save memory -+ vk_buffer& b = ctx->buffer_pool[worst_i]; -+ ggml_vk_destroy_buffer(b); -+ } -+ -+ return ggml_vk_create_buffer_device(ctx->device, size); -+} -+ -+static void ggml_vk_pool_free(ggml_backend_vk_context * ctx, vk_buffer& buffer) { -+ VK_LOG_DEBUG("ggml_vk_pool_free(" << buffer->size << ")"); -+ for (int i = 0; i < MAX_VK_BUFFERS; ++i) { -+ vk_buffer& b = ctx->buffer_pool[i]; -+ if (b == nullptr) { -+ b = buffer; -+ return; -+ } -+ } -+ std::cerr << "ggml_vulkan: WARNING: vk buffer pool full, increase MAX_VK_BUFFERS" << std::endl; -+ ggml_vk_destroy_buffer(buffer); -+} -+ -+// Returns an available temporary buffer that may only be used temporarily, it will be reused -+static vk_buffer ggml_vk_create_buffer_temp(ggml_backend_vk_context * ctx, size_t size) { -+ // Try to find existing temp buffer with enough capacity -+ for (auto& buffer : ctx->gc.temp_buffers) { -+ if (buffer->size >= size) { -+ return buffer; -+ } -+ } -+ -+ VK_LOG_MEMORY("ggml_vk_create_buffer_temp(" << size << ")"); -+ -+ // Otherwise create new buffer -+ vk_buffer buf = ggml_vk_pool_malloc(ctx, size); -+ ctx->gc.temp_buffers.push_back(buf); -+ -+ return buf; -+} -+ -+static void * ggml_vk_host_malloc(vk_device& device, size_t size) { -+ VK_LOG_MEMORY("ggml_vk_host_malloc(" << size << ")"); -+ vk_buffer buf = ggml_vk_create_buffer(device, size, -+ vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached, -+ vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent); -+ -+ if(!(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible)) { -+ fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory\n", -+ size/1024.0/1024.0); -+ device->device.freeMemory(buf->device_memory); -+ device->device.destroyBuffer(buf->buffer); -+ return nullptr; -+ } -+ -+ device->pinned_memory.push_back(std::make_tuple(buf->ptr, size, buf)); -+ -+ return buf->ptr; -+} -+ -+static void ggml_vk_host_free(vk_device& device, void* ptr) { -+ if (ptr == nullptr) { -+ return; -+ } -+ VK_LOG_MEMORY("ggml_vk_host_free(" << ptr << ")"); -+ vk_buffer buf; -+ size_t index; -+ for (size_t i = 0; i < device->pinned_memory.size(); i++) { -+ const uint8_t* addr = (const uint8_t*) std::get<0>(device->pinned_memory[i]); -+ const uint8_t* endr = addr + std::get<1>(device->pinned_memory[i]); -+ if (ptr >= addr && ptr < endr) { -+ buf = std::get<2>(device->pinned_memory[i]); -+ index = i; -+ break; -+ } -+ } -+ if (buf == nullptr) { -+ fprintf(stderr, "WARNING: failed to free pinned memory: memory not in map\n"); -+ return; -+ } -+ -+ ggml_vk_destroy_buffer(buf); -+ -+ device->pinned_memory.erase(device->pinned_memory.begin() + index); -+} -+ -+static void ggml_vk_host_get(vk_device& device, const void * ptr, vk_buffer& buf, size_t& buf_offset) { -+ buf = nullptr; -+ buf_offset = 0; -+ for (size_t i = 0; i < device->pinned_memory.size(); i++) { -+ const uint8_t* addr = (const uint8_t*) std::get<0>(device->pinned_memory[i]); -+ const uint8_t* endr = addr + std::get<1>(device->pinned_memory[i]); -+ if (ptr >= addr && ptr < endr) { -+ buf = std::get<2>(device->pinned_memory[i]); -+ buf_offset = ((const uint8_t *)ptr) - addr; -+ break; -+ } -+ } -+} -+ -+static vk_submission ggml_vk_begin_submission(vk_device& device, vk_queue& q, bool one_time = true) { -+ vk_submission s; -+ s.buffer = ggml_vk_create_cmd_buffer(device, q); -+ if (one_time) { -+ s.buffer.begin({ vk::CommandBufferUsageFlagBits::eOneTimeSubmit }); -+ } else { -+ s.buffer.begin({ vk::CommandBufferUsageFlags{} }); -+ } -+ -+ return s; -+} -+ -+ -+ -+static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& subctx, vk_pipeline& pipeline, std::initializer_list const& descriptor_buffer_infos, size_t push_constant_size, const void* push_constants, std::array elements) { -+ const uint32_t wg0 = CEIL_DIV(elements[0], pipeline->wg_denoms[0]); -+ const uint32_t wg1 = CEIL_DIV(elements[1], pipeline->wg_denoms[1]); -+ const uint32_t wg2 = CEIL_DIV(elements[2], pipeline->wg_denoms[2]); -+ VK_LOG_DEBUG("ggml_vk_dispatch_pipeline(" << pipeline->name << ", {"; -+ for (auto& buffer : descriptor_buffer_infos) { -+ std::cerr << "(" << buffer.buffer << ", " << buffer.offset << ", " << buffer.range << "), "; -+ } -+ std::cerr << "}, (" << wg0 << "," << wg1 << "," << wg2 << "))"); -+ GGML_ASSERT(pipeline->descriptor_set_idx < pipeline->descriptor_sets.size()); -+ GGML_ASSERT(descriptor_buffer_infos.size() == pipeline->parameter_count); -+ -+ vk::DescriptorSet& descriptor_set = pipeline->descriptor_sets[pipeline->descriptor_set_idx++]; -+ vk::WriteDescriptorSet write_descriptor_set{ descriptor_set, 0, 0, pipeline->parameter_count, vk::DescriptorType::eStorageBuffer, nullptr, descriptor_buffer_infos.begin() }; -+ ctx->device->device.updateDescriptorSets({ write_descriptor_set }, {}); -+ -+ subctx->s->buffer.pushConstants(pipeline->layout, vk::ShaderStageFlagBits::eCompute, 0, push_constant_size, push_constants); -+ subctx->s->buffer.bindPipeline(vk::PipelineBindPoint::eCompute, pipeline->pipeline); -+ subctx->s->buffer.bindDescriptorSets(vk::PipelineBindPoint::eCompute, -+ pipeline->layout, -+ 0, -+ { descriptor_set }, -+ {}); -+ subctx->s->buffer.dispatch(wg0, wg1, wg2); -+} -+ -+static void ggml_vk_end_submission(vk_submission& s, std::vector wait_semaphores, std::vector signal_semaphores) { -+ s.buffer.end(); -+ -+ s.wait_semaphores = std::move(wait_semaphores); -+ s.signal_semaphores = std::move(signal_semaphores); -+} -+ -+static void ggml_vk_ctx_end(vk_context& ctx) { -+ VK_LOG_DEBUG("ggml_vk_ctx_end(" << ctx << ", " << ctx->seqs.size() << ")"); -+ if (ctx->s == nullptr) { -+ return; -+ } -+ -+ ctx->s->buffer.end(); -+ ctx->s = nullptr; -+} -+ -+static void ggml_vk_ctx_begin(vk_device& device, vk_context& subctx) { -+ VK_LOG_DEBUG("ggml_vk_ctx_begin(" << device->name << ")"); -+ if (subctx->s != nullptr) { -+ ggml_vk_ctx_end(subctx); -+ } -+ -+ subctx->seqs.push_back({ ggml_vk_begin_submission(device, *subctx->q) }); -+ subctx->s = subctx->seqs[subctx->seqs.size() - 1].data(); -+} -+ -+static size_t ggml_vk_align_size(size_t width, size_t align) { -+ VK_LOG_DEBUG("ggml_vk_align_size(" << width << ", " << align << ")"); -+ return CEIL_DIV(width, align) * align; -+} -+ -+static void deferred_memcpy(void * dst, const void * src, size_t size, std::vector* memcpys = nullptr) { -+ if (memcpys == nullptr) { -+ memcpy(dst, src, size); -+ } else { -+ memcpys->emplace_back(dst, src, size); -+ } -+} -+ -+static void ggml_vk_ensure_sync_staging_buffer(vk_device& device, size_t size) { -+ if (device->sync_staging == nullptr || device->sync_staging->size < size) { -+ VK_LOG_MEMORY("ggml_vk_ensure_sync_staging_buffer(" << size << ")"); -+ ggml_vk_destroy_buffer(device->sync_staging); -+ device->sync_staging = ggml_vk_create_buffer_check(device, size, -+ vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached, -+ vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent); -+ } -+} -+ -+static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_context& subctx, vk_buffer& dst, size_t offset, const ggml_tensor * tensor, bool sync_staging = false) { -+ VK_LOG_DEBUG("ggml_vk_buffer_write_nc_async(" << tensor << ")"); -+ GGML_ASSERT(!ggml_is_contiguous(tensor)); -+ // Buffer is already mapped -+ if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) { -+ std::cerr << "ggml_vulkan: buffer_write_nc_async dst buffer is host_visible. Use synchronous write." << std::endl; -+ GGML_ABORT("fatal error"); -+ } -+ // Check if src is pinned memory -+ vk_buffer buf = nullptr; -+ size_t buf_offset = 0; -+ ggml_vk_host_get(ctx->device, tensor->data, buf, buf_offset); -+ -+ const uint64_t ne0 = tensor->ne[0]; -+ const uint64_t ne1 = tensor->ne[1]; -+ const uint64_t ne2 = tensor->ne[2]; -+ const uint64_t ne3 = tensor->ne[3]; -+ const uint64_t nb0 = tensor->nb[0]; -+ const uint64_t nb1 = tensor->nb[1]; -+ const uint64_t nb2 = tensor->nb[2]; -+ const uint64_t nb3 = tensor->nb[3]; -+ const ggml_type type = tensor->type; -+ const uint64_t ts = ggml_type_size(type); -+ const uint64_t bs = ggml_blck_size(type); -+ -+ const uint64_t dstnb0 = ts; -+ const uint64_t dstnb1 = dstnb0*(ne0/bs); -+ const uint64_t dstnb2 = dstnb1*ne1; -+ const uint64_t dstnb3 = dstnb2*ne2; -+ -+ const uint64_t ne = ggml_nelements(tensor); -+ -+ if (buf != nullptr) { -+ // Memory is pinned, use as staging buffer -+ std::vector slices; -+ -+ for (uint64_t i3 = 0; i3 < ne3; i3++) { -+ for (uint64_t i2 = 0; i2 < ne2; i2++) { -+ // Find longest contiguous slice -+ if (ne1*nb1 == dstnb2) { -+ slices.push_back({ buf_offset + i3*nb3 + i2*nb2, offset + i3*dstnb3 + i2*dstnb2, dstnb2 }); -+ } else { -+ for (uint64_t i1 = 0; i1 < ne1; i1++) { -+ if (ne0*nb0/bs == dstnb1) { -+ slices.push_back({ buf_offset + i3*nb3 + i2*nb2 + i1*nb1, offset + i3*dstnb3 + i2*dstnb2 + i1*dstnb1, dstnb1 }); -+ } else { -+ const uint64_t s_off = buf_offset + i3*nb3 + i2*nb2 + i1*nb1; -+ const uint64_t d_off = offset + i3*dstnb3 + i2*dstnb2 + i1*dstnb1; -+ for (uint64_t i0 = 0; i0 < ne0; i0++) { -+ slices.push_back({ s_off + i1*nb0, d_off + i0*dstnb0, dstnb0 }); -+ } -+ } -+ } -+ } -+ } -+ } -+ -+ ggml_vk_sync_buffers(subctx); -+ subctx->s->buffer.copyBuffer(buf->buffer, dst->buffer, slices); -+ return; -+ } -+ -+ if (!sync_staging) { -+ GGML_ABORT("Asynchronous write to non-pinned memory not supported"); -+ } -+ -+ // Staging buffer required -+ vk_buffer& staging = ctx->device->sync_staging; -+ const uint64_t copy_size = ts*ne/bs; -+ ggml_vk_ensure_sync_staging_buffer(ctx->device, copy_size); -+ VkBufferCopy buf_copy{ 0, offset, copy_size }; -+ -+ ggml_vk_sync_buffers(subctx); -+ vkCmdCopyBuffer(subctx->s->buffer, (VkBuffer)staging->buffer, (VkBuffer)dst->buffer, 1, &buf_copy); -+ -+ for (uint64_t i3 = 0; i3 < ne3; i3++) { -+ for (uint64_t i2 = 0; i2 < ne2; i2++) { -+ // Find longest contiguous slice -+ if (ne1*nb1 == dstnb2) { -+ deferred_memcpy((uint8_t *)staging->ptr + i3*dstnb3 + i2*dstnb2, (const uint8_t *) tensor->data + buf_offset + i3*nb3 + i2*nb2, dstnb2, &subctx->in_memcpys); -+ } else { -+ for (uint64_t i1 = 0; i1 < ne1; i1++) { -+ if (ne0*nb0/bs == dstnb1) { -+ deferred_memcpy((uint8_t *)staging->ptr + i3*dstnb3 + i2*dstnb2 + i1*dstnb1, (const uint8_t *) tensor->data + buf_offset + i3*nb3 + i2*nb2 + i1*nb1, dstnb1, &subctx->in_memcpys); -+ } else { -+ const uint64_t s_off = buf_offset + i3*nb3 + i2*nb2 + i1*nb1; -+ const uint64_t d_off = i3*dstnb3 + i2*dstnb2 + i1*dstnb1; -+ for (uint64_t i0 = 0; i0 < ne0; i0++) { -+ deferred_memcpy((uint8_t *)staging->ptr + d_off + i0*dstnb0, (const uint8_t *) tensor->data + s_off + i0*nb0, dstnb0, &subctx->in_memcpys); -+ } -+ } -+ } -+ } -+ } -+ } -+} -+ -+static void ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height, bool sync_staging = false) { -+ VK_LOG_DEBUG("ggml_vk_buffer_write_2d_async(" << width << ", " << height << ")"); -+ // Buffer is already mapped -+ if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) { -+ std::cerr << "ggml_vulkan: buffer_write_async dst buffer is host_visible. Use synchronous write." << std::endl; -+ GGML_ABORT("fatal error"); -+ } -+ // Check if src is pinned memory -+ vk_buffer buf = nullptr; -+ size_t buf_offset = 0; -+ ggml_vk_host_get(dst->device, src, buf, buf_offset); -+ -+ if (buf != nullptr) { -+ // Memory is pinned, use as staging buffer -+ std::vector slices(1); -+ if (width == spitch) { -+ // Only do single write if stride is equal -+ slices[0].srcOffset = buf_offset; -+ slices[0].dstOffset = offset; -+ slices[0].size = width * height; -+ } else { -+ slices.resize(height); -+ for (size_t i = 0; i < height; i++) { -+ slices[i].srcOffset = buf_offset + i * spitch; -+ slices[i].dstOffset = offset + i * width; -+ slices[i].size = width; -+ } -+ } -+ -+ ggml_vk_sync_buffers(subctx); -+ subctx->s->buffer.copyBuffer(buf->buffer, dst->buffer, slices); -+ return; -+ } -+ VK_LOG_DEBUG("STAGING"); -+ -+ if (!sync_staging) { -+ GGML_ABORT("Asynchronous write to non-pinned memory not supported"); -+ } -+ -+ // Staging buffer required -+ const size_t copy_size = width*height; -+ ggml_vk_ensure_sync_staging_buffer(dst->device, copy_size); -+ -+ vk_buffer& staging_buffer = dst->device->sync_staging; -+ -+ VkBufferCopy buf_copy = { -+ 0, -+ offset, -+ copy_size}; -+ -+ ggml_vk_sync_buffers(subctx); -+ vkCmdCopyBuffer(subctx->s->buffer, (VkBuffer)staging_buffer->buffer, (VkBuffer)dst->buffer, 1, &buf_copy); -+ -+ if (width == spitch) { -+ deferred_memcpy((uint8_t *)staging_buffer->ptr, src, width * height, &subctx->in_memcpys); -+ } else { -+ for (size_t i = 0; i < height; i++) { -+ deferred_memcpy((uint8_t *)staging_buffer->ptr + i * width, (const uint8_t *) src + i * spitch, width, &subctx->in_memcpys); -+ } -+ } -+} -+ -+static void ggml_vk_buffer_write_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t size, bool sync_staging = false) { -+ VK_LOG_DEBUG("ggml_vk_buffer_write_async(" << size << ")"); -+ return ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, size, size, 1, sync_staging); -+} -+ -+static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height) { -+ VK_LOG_DEBUG("ggml_vk_buffer_write_2d(" << width << ", " << height << ")"); -+ // Buffer is already mapped -+ if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) { -+ GGML_ASSERT(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostCoherent); -+ -+ for (size_t i = 0; i < height; i++) { -+ memcpy((uint8_t *)dst->ptr + offset + i * width, (const uint8_t *) src + i * spitch, width); -+ } -+ } else { -+ vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue); -+ ggml_vk_ctx_begin(dst->device, subctx); -+ ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, spitch, width, height, true); -+ ggml_vk_ctx_end(subctx); -+ -+ for (auto& cpy : subctx->in_memcpys) { -+ memcpy(cpy.dst, cpy.src, cpy.n); -+ } -+ -+ ggml_vk_submit(subctx, dst->device->fence); -+ VK_CHECK(dst->device->device.waitForFences({ dst->device->fence }, true, UINT64_MAX), "vk_buffer_write_2d waitForFences"); -+ dst->device->device.resetFences({ dst->device->fence }); -+ } -+} -+ -+static void ggml_vk_buffer_write(vk_buffer& dst, size_t offset, const void * src, size_t size) { -+ VK_LOG_DEBUG("ggml_vk_buffer_write(" << size << ")"); -+ ggml_vk_buffer_write_2d(dst, offset, src, 0, size, 1); -+} -+ -+static void ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size_t offset, void * dst, size_t spitch, size_t dpitch, size_t width, size_t height, bool sync_staging = false) { -+ VK_LOG_DEBUG("ggml_vk_buffer_read_2d_async(offset=" << offset << ", width=" << width << ", height=" << height << ")"); -+ GGML_ASSERT(width > 0); -+ GGML_ASSERT(height > 0); -+ GGML_ASSERT(src != nullptr); -+ -+ // TODO: staging_offset is not used -+ -+ // Check if dst is pinned memory -+ vk_buffer buf = nullptr; -+ size_t buf_offset = 0; -+ ggml_vk_host_get(src->device, dst, buf, buf_offset); -+ -+ std::vector slices(1); -+ if (width == spitch && width == dpitch) { -+ // Only do single write if stride is equal -+ slices[0].srcOffset = offset; -+ slices[0].dstOffset = buf_offset; -+ slices[0].size = width * height; -+ } else { -+ slices.resize(height); -+ for (size_t i = 0; i < height; i++) { -+ slices[i].srcOffset = offset + i * spitch; -+ slices[i].dstOffset = buf_offset + i * dpitch; -+ slices[i].size = width; -+ } -+ } -+ -+ if (buf != nullptr) { -+ // Memory is pinned, use as staging buffer -+ ggml_vk_sync_buffers(subctx); -+ subctx->s->buffer.copyBuffer(src->buffer, buf->buffer, slices); -+ -+ return; -+ } -+ VK_LOG_DEBUG("STAGING"); -+ -+ if (!sync_staging) { -+ GGML_ABORT("Asynchronous read from non-pinned memory not supported"); -+ } -+ -+ // Fall back to staging buffer -+ const size_t copy_size = dpitch * height; -+ ggml_vk_ensure_sync_staging_buffer(src->device, copy_size); -+ -+ vk_buffer& staging_buffer = src->device->sync_staging; -+ -+ ggml_vk_sync_buffers(subctx); -+ subctx->s->buffer.copyBuffer(src->buffer, staging_buffer->buffer, slices); -+ -+ deferred_memcpy(dst, staging_buffer->ptr, copy_size, &subctx->out_memcpys); -+} -+ -+static void ggml_vk_buffer_read_async(vk_context subctx, vk_buffer& src, size_t offset, void * dst, size_t size, bool sync_staging = false) { -+ return ggml_vk_buffer_read_2d_async(subctx, src, offset, dst, size, size, size, 1, sync_staging); -+} -+ -+static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_t size) { -+ VK_LOG_DEBUG("ggml_vk_buffer_read(" << src->buffer << ", " << offset << ", " << size << ")"); -+ -+ // If the device is not an UMA device the memory is host-accessible through rebar. While writing -+ // through PCIe is sufficient fast reading back data from PCIe is slower than going through -+ // the HW device to host copy path. -+ if(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible && src->device->uma) { -+ GGML_ASSERT(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostCoherent); -+ -+ memcpy(dst, (uint8_t *) src->ptr + offset, size); -+ } else { -+ vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue); -+ ggml_vk_ctx_begin(src->device, subctx); -+ ggml_vk_buffer_read_async(subctx, src, offset, dst, size, true); -+ ggml_vk_ctx_end(subctx); -+ -+ ggml_vk_submit(subctx, src->device->fence); -+ VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX), "vk_buffer_read waitForFences"); -+ src->device->device.resetFences({ src->device->fence }); -+ -+ for (auto& cpy : subctx->out_memcpys) { -+ memcpy(cpy.dst, cpy.src, cpy.n); -+ } -+ } -+} -+ -+static void ggml_vk_buffer_copy_async(vk_context& ctx, vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) { -+ VK_LOG_DEBUG("ggml_vk_buffer_copy_async(" << size << ")"); -+ // Make sure both buffers are on same device -+ GGML_ASSERT(src->device == dst->device); -+ -+ VkBufferCopy bc{ src_offset, dst_offset, size }; -+ -+ vkCmdCopyBuffer(ctx->s->buffer, (VkBuffer)src->buffer, (VkBuffer)dst->buffer, 1, &bc); -+} -+ -+static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) { -+ if (src->device == dst->device) { -+ VK_LOG_DEBUG("ggml_vk_buffer_copy(SINGLE_DEVICE, " << size << ")"); -+ // Copy within the device -+ vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue); -+ ggml_vk_ctx_begin(src->device, subctx); -+ ggml_vk_buffer_copy_async(subctx, dst, dst_offset, src, src_offset, size); -+ ggml_vk_ctx_end(subctx); -+ ggml_vk_submit(subctx, src->device->fence); -+ VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX), "vk_buffer_copy waitForFences"); -+ src->device->device.resetFences({ src->device->fence }); -+ } else { -+ VK_LOG_DEBUG("ggml_vk_buffer_copy(MULTI_DEVICE, " << size << ")"); -+ // Copy device to device -+ ggml_vk_ensure_sync_staging_buffer(src->device, size); -+ ggml_vk_ensure_sync_staging_buffer(dst->device, size); -+ -+ // Copy to src staging buffer -+ ggml_vk_buffer_copy(src->device->sync_staging, 0, src, src_offset, size); -+ // memcpy to dst staging buffer -+ memcpy(dst->device->sync_staging->ptr, src->device->sync_staging->ptr, size); -+ // Copy to dst buffer -+ ggml_vk_buffer_copy(dst, dst_offset, dst->device->sync_staging, 0, size); -+ } -+} -+ -+static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, size_t size) { -+ VK_LOG_DEBUG("ggml_vk_buffer_memset(" << offset << ", " << c << ", " << size << ")"); -+ -+ vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue); -+ ggml_vk_ctx_begin(dst->device, subctx); -+ subctx->s->buffer.fillBuffer(dst->buffer, offset, size, c); -+ ggml_vk_ctx_end(subctx); -+ -+ ggml_vk_submit(subctx, dst->device->fence); -+ VK_CHECK(dst->device->device.waitForFences({ dst->device->fence }, true, UINT64_MAX), "vk_memset waitForFences"); -+ dst->device->device.resetFences({ dst->device->fence }); -+} -+ -+static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, int m, int n, int k, const vk_pipeline& pipeline) { -+ VK_LOG_DEBUG("ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ")"); -+ -+ uint32_t split_k = 1; -+ if (ctx->device->shader_core_count != 0 && m >= (int)pipeline->wg_denoms[0] && n >= (int)pipeline->wg_denoms[1]) { -+ // If k is 'large' and the SMs will fill less than halfway, use split_k. -+ uint32_t m_tiles = CEIL_DIV(m, pipeline->wg_denoms[0]); -+ uint32_t n_tiles = CEIL_DIV(n, pipeline->wg_denoms[1]); -+ if (k >= 2048 && m_tiles * n_tiles < ctx->device->shader_core_count / 2) { -+ split_k = ctx->device->shader_core_count / (m_tiles * n_tiles); -+ // Clamp to 2 or 4 -+ split_k = std::min(split_k, 4u); -+ if (split_k == 3) { -+ split_k = 2; -+ } -+ } -+ } -+ -+ return split_k; -+} -+ -+static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) { -+ VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ")"); -+ -+ if (ctx->device->coopmat2) { -+ if ((ctx->device->mul_mat_l && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_m && !ctx->device->mul_mat_s)) { -+ return aligned ? mmp->a_l : mmp->l; -+ } -+ if ((ctx->device->mul_mat_m && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_s) { -+ return aligned ? mmp->a_m : mmp->m; -+ } -+ return aligned ? mmp->a_s : mmp->s; -+ } -+ -+ if ((ctx->device->mul_mat_s && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_m && !ctx->device->mul_mat_l)) { -+ return aligned ? mmp->a_s : mmp->s; -+ } -+ if ((ctx->device->mul_mat_m && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_l) { -+ return aligned ? mmp->a_m : mmp->m; -+ } -+ return aligned ? mmp->a_l : mmp->l; -+} -+ -+static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n) { -+ VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ")"); -+ return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true)->align; -+} -+ -+static void ggml_vk_matmul( -+ ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline& pipeline, -+ vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& split_k_buffer, -+ uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d, -+ uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d, -+ uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3) { -+ VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ")"); -+ ggml_vk_sync_buffers(subctx); -+ if (split_k == 1) { -+ const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3 }; -+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, sizeof(vk_mat_mat_push_constants), &pc, { m, n, batch }); -+ return; -+ } -+ -+ GGML_ASSERT(batch_stride_d == m * n); -+ -+ const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, CEIL_DIV(k, split_k), ne02, ne12, broadcast2, broadcast3 }; -+ // Make sure enough workgroups get assigned for split k to work -+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, sizeof(vk_mat_mat_push_constants), &pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch }); -+ ggml_vk_sync_buffers(subctx); -+ const std::array pc2 = { (uint32_t)(m * n * batch), split_k }; -+ ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 }); -+} -+ -+static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) { -+ VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ")"); -+ -+ if (ctx->device->coopmat2) { -+ if ((ctx->device->mul_mat_id_l && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_id_m && !ctx->device->mul_mat_id_s)) { -+ return aligned ? mmp->a_l : mmp->l; -+ } -+ if ((ctx->device->mul_mat_id_m && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_id_s) { -+ return aligned ? mmp->a_m : mmp->m; -+ } -+ return aligned ? mmp->a_s : mmp->s; -+ } -+ -+ if ((ctx->device->mul_mat_id_s && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_id_m && !ctx->device->mul_mat_id_l)) { -+ return aligned ? mmp->a_s : mmp->s; -+ } -+ if ((ctx->device->mul_mat_id_m && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_id_l) { -+ return aligned ? mmp->a_m : mmp->m; -+ } -+ return aligned ? mmp->a_l : mmp->l; -+} -+ -+static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n) { -+ VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ")"); -+ return ggml_vk_guess_matmul_id_pipeline(ctx, mmp, m, n, true)->align; -+} -+ -+static void ggml_vk_matmul_id( -+ ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline& pipeline, -+ vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& ids, -+ uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d, -+ uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d, -+ uint32_t n_as, uint32_t nei0, uint32_t nei1, uint32_t nbi1, uint32_t ne11) { -+ VK_LOG_DEBUG("ggml_vk_matmul_id(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), ids: (" << ids.buffer->buffer << ", " << ids.offset << ", " << ids.size << "), " << -+ "m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", " << -+ "batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", " << -+ "n_as: " << n_as << ", nei0: " << nei0 << ", nei1: " << nei1 << ", nbi1: " << nbi1 << ", ne11: " << ne11 << ")"); -+ ggml_vk_sync_buffers(subctx); -+ const vk_mat_mat_id_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, -+ nei0, nei1, nbi1, ne11 }; -+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids }, sizeof(vk_mat_mat_id_push_constants), &pc, { m, nei1, n_as }); -+} -+ -+static bool ggml_vk_dim01_contiguous(const ggml_tensor * tensor) { -+ return -+ tensor->nb[0] == ggml_type_size(tensor->type) && -+ tensor->nb[1] == (tensor->nb[0]*tensor->ne[0])/ggml_blck_size(tensor->type) && -+ tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; -+} -+ -+static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src, const ggml_tensor * dst, ggml_type to) { -+ -+ // Choose "contiguous copy" shader if src/dst are contiguous -+ bool contig = ggml_is_contiguous(src) && (!dst || ggml_is_contiguous(dst)); -+ -+ if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_F32) { -+ if (contig) { -+ return ctx->device->pipeline_contig_cpy_f32_f32; -+ } else { -+ return ctx->device->pipeline_cpy_f32_f32; -+ } -+ } -+ if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_F16) { -+ if (contig) { -+ return ctx->device->pipeline_contig_cpy_f32_f16; -+ } else { -+ return ctx->device->pipeline_cpy_f32_f16; -+ } -+ } -+ if (src->type == GGML_TYPE_F16 && to == GGML_TYPE_F16) { -+ if (contig) { -+ return ctx->device->pipeline_contig_cpy_f16_f16; -+ } else { -+ return ctx->device->pipeline_cpy_f16_f16; -+ } -+ } -+ -+ std::cerr << "Missing CPY op for types: " << ggml_type_name(src->type) << " " << ggml_type_name(to) << std::endl; -+ GGML_ABORT("fatal error"); -+} -+ -+static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline pipeline, const ggml_tensor * tensor, vk_subbuffer&& in, vk_subbuffer&& out) { -+ VK_LOG_DEBUG("ggml_vk_cpy_to_contiguous((" << tensor << ", type=" << tensor->type << ", ne0=" << tensor->ne[0] << ", ne1=" << tensor->ne[1] << ", ne2=" << tensor->ne[2] << ", ne3=" << tensor->ne[3] << ", nb0=" << tensor->nb[0] << ", nb1=" << tensor->nb[1] << ", nb2=" << tensor->nb[2] << ", nb3=" << tensor->nb[3] << "), "; -+ std::cerr << "buffer in size=" << in.buffer->size << ", buffer out size=" << out.buffer->size << ")"); -+ const int tensor_type_size = ggml_type_size(tensor->type); -+ -+ const uint32_t ne = ggml_nelements(tensor); -+ std::array elements; -+ -+ if (ne > 262144) { -+ elements = { 512, 512, CEIL_DIV(ne, 262144) }; -+ } else if (ne > 512) { -+ elements = { 512, CEIL_DIV(ne, 512), 1 }; -+ } else { -+ elements = { ne, 1, 1 }; -+ } -+ -+ vk_op_unary_push_constants pc = { -+ (uint32_t)ne, -+ (uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3], (uint32_t)tensor->nb[0] / tensor_type_size, (uint32_t)tensor->nb[1] / tensor_type_size, (uint32_t)tensor->nb[2] / tensor_type_size, (uint32_t)tensor->nb[3] / tensor_type_size, -+ (uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3], 1 , (uint32_t)tensor->ne[0] , (uint32_t)(tensor->ne[0] * tensor->ne[1]) , (uint32_t)(tensor->ne[0] * tensor->ne[1] * tensor->ne[2]), -+ 0, -+ 0.0f, 0.0f, -+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -+ }; -+ init_pushconst_fastdiv(pc); -+ ggml_vk_sync_buffers(subctx); -+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(vk_op_unary_push_constants), &pc, elements); -+} -+ -+static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { -+ VK_LOG_DEBUG("ggml_vk_mul_mat_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; -+ std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; -+ std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; -+ std::cerr << "), " << (dryrun ? "dryrun" : "") << ")"); -+ GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); // NOLINT -+ GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT -+ -+ const uint64_t ne00 = src0->ne[0]; -+ const uint64_t ne01 = src0->ne[1]; -+ const uint64_t ne02 = src0->ne[2]; -+ const uint64_t ne03 = src0->ne[3]; -+ -+ const uint64_t ne10 = src1->ne[0]; -+ const uint64_t ne11 = src1->ne[1]; -+ const uint64_t ne12 = src1->ne[2]; -+ const uint64_t ne13 = src1->ne[3]; -+ -+ const uint64_t ne20 = dst->ne[0]; -+ const uint64_t ne21 = dst->ne[1]; -+ -+ const uint64_t r2 = ne12 / ne02; -+ const uint64_t r3 = ne13 / ne03; -+ -+ ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; -+ ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; -+ ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; -+ -+ vk_buffer d_Qx = nullptr; -+ size_t qx_buf_offset = 0; -+ vk_buffer d_Qy = nullptr; -+ size_t qy_buf_offset = 0; -+ -+ bool src0_uma = false; -+ bool src1_uma = false; -+ -+ if (ctx->device->uma) { -+ ggml_vk_host_get(ctx->device, src0->data, d_Qx, qx_buf_offset); -+ ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset); -+ src0_uma = d_Qx != nullptr; -+ src1_uma = d_Qy != nullptr; -+ } -+ -+ const bool x_non_contig = !ggml_vk_dim01_contiguous(src0); -+ // Reformat and convert to fp16 if src1 is non-contiguous, or for coopmat2 for better perf -+ const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) || -+ !ggml_vk_dim01_contiguous(src1); -+ -+ const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig; -+ -+ vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type, (ggml_prec)dst->op_params[0]); -+ -+ const bool qx_needs_dequant = mmp == nullptr || x_non_contig; -+ const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig; -+ -+ if (qx_needs_dequant) { -+ // Fall back to dequant + f16 mulmat -+ mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16, (ggml_prec)dst->op_params[0]); -+ } -+ -+ // Not implemented -+ GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT -+ -+ const int x_ne = ne01 * ne00; -+ const int y_ne = ne11 * ne10; -+ const int d_ne = ne11 * ne01; -+ -+ const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11)); -+ const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8; -+ -+ vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned); -+ -+ const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, pipeline); -+ -+ const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type); -+ const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); -+ const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne; -+ const uint64_t y_sz = y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne; -+ const uint64_t d_sz = sizeof(float) * d_ne; -+ -+ vk_pipeline to_fp16_vk_0 = nullptr; -+ vk_pipeline to_fp16_vk_1 = nullptr; -+ -+ if (x_non_contig) { -+ to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, GGML_TYPE_F16); -+ } else { -+ to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type); -+ } -+ if (y_non_contig) { -+ to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, GGML_TYPE_F16); -+ } else { -+ to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type); -+ } -+ GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT -+ GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT -+ -+ if (dryrun) { -+ const uint64_t x_sz_upd = x_sz * ne02 * ne03; -+ const uint64_t y_sz_upd = y_sz * ne12 * ne13; -+ const uint64_t split_k_size = split_k > 1 ? d_sz * ne12 * ne13 * split_k : 0; -+ if ( -+ (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) || -+ (qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size) || -+ (split_k > 1 && split_k_size > ctx->device->max_memory_allocation_size)) { -+ GGML_ABORT("Requested preallocation size is too large"); -+ } -+ if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) { -+ ctx->prealloc_size_x = x_sz_upd; -+ } -+ if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) { -+ ctx->prealloc_size_y = y_sz_upd; -+ } -+ if (split_k > 1 && ctx->prealloc_size_split_k < split_k_size) { -+ ctx->prealloc_size_split_k = split_k_size; -+ } -+ -+ // Request descriptor sets -+ ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); -+ if (qx_needs_dequant) { -+ ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_0, 1); -+ } -+ if (qy_needs_dequant) { -+ ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1); -+ } -+ if (split_k > 1) { -+ ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, 1); -+ } -+ return; -+ } -+ -+ vk_buffer d_D = dst_buf_ctx->dev_buffer; -+ const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; -+ GGML_ASSERT(d_D != nullptr); -+ GGML_ASSERT(d_D->size >= d_buf_offset + d_sz * ne02 * ne03); -+ vk_buffer d_X; -+ uint64_t x_buf_offset = 0; -+ vk_buffer d_Y; -+ uint64_t y_buf_offset = 0; -+ if (!src0_uma) { -+ d_Qx = src0_buf_ctx->dev_buffer; -+ qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs; -+ GGML_ASSERT(d_Qx != nullptr); -+ } -+ if (!src1_uma) { -+ d_Qy = src1_buf_ctx->dev_buffer; -+ qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs; -+ GGML_ASSERT(d_Qy != nullptr); -+ } -+ if (qx_needs_dequant) { -+ d_X = ctx->prealloc_x; -+ GGML_ASSERT(d_X->size >= x_sz * ne02 * ne03); -+ } else { -+ d_X = d_Qx; -+ x_buf_offset = qx_buf_offset; -+ GGML_ASSERT(qx_sz == x_sz); -+ } -+ if (qy_needs_dequant) { -+ d_Y = ctx->prealloc_y; -+ GGML_ASSERT(d_Y->size >= y_sz * ne02 * ne03); -+ } else { -+ d_Y = d_Qy; -+ y_buf_offset = qy_buf_offset; -+ GGML_ASSERT(qy_sz == y_sz); -+ } -+ -+ if (x_non_contig) { -+ ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE }); -+ } else if (qx_needs_dequant) { -+ const std::vector pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) }; -+ ggml_vk_sync_buffers(subctx); -+ ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, vk_subbuffer{ d_X, 0, x_sz * ne02 * ne03 } }, pc.size() * sizeof(uint32_t), pc.data(), { (uint32_t)(x_ne * ne02 * ne03), 1, 1}); -+ } -+ if (y_non_contig) { -+ ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }); -+ } -+ -+ uint32_t stride_batch_x = ne00*ne01; -+ uint32_t stride_batch_y = ne10*ne11; -+ -+ if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) { -+ stride_batch_x = src0->nb[0] / ggml_type_size(src0->type); -+ } -+ -+ if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) { -+ stride_batch_y = src1->nb[0] / ggml_type_size(src1->type); -+ } -+ -+ // compute -+ ggml_vk_matmul( -+ ctx, subctx, pipeline, -+ { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 }, -+ { d_D, d_buf_offset, d_sz * ne12 * ne13 }, { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k }, -+ ne01, ne11, ne10, -+ ne10, ne10, ne01, stride_batch_x, stride_batch_y, ne20*ne21, -+ split_k, ne12*ne13, ne02, ne12, r2, r3 -+ ); // NOLINT -+} -+ -+static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { -+ VK_LOG_DEBUG("ggml_vk_mul_mat_vec_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; -+ std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; -+ std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; -+ std::cerr << "), " << (dryrun ? "dryrun" : "") << "),)"); -+ GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); // NOLINT -+ GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT -+ -+ const uint64_t ne00 = src0->ne[0]; -+ const uint64_t ne01 = src0->ne[1]; -+ const uint64_t ne02 = src0->ne[2]; -+ const uint64_t ne03 = src0->ne[3]; -+ -+ const uint64_t ne10 = src1->ne[0]; -+ const uint64_t ne11 = src1->ne[1]; -+ const uint64_t ne12 = src1->ne[2]; -+ const uint64_t ne13 = src1->ne[3]; -+ -+ const uint64_t ne20 = dst->ne[0]; -+ const uint64_t ne21 = dst->ne[1]; -+ const uint64_t ne22 = dst->ne[2]; -+ const uint64_t ne23 = dst->ne[3]; -+ -+ const uint64_t r2 = ne12 / ne02; -+ const uint64_t r3 = ne13 / ne03; -+ -+ // batch_n indicates that we need to compute a few vector results, and this assumes -+ // ne12 and ne13 are 1. It overloads the batch_strides to hold the row strides. -+ GGML_ASSERT(ne11 == 1 || ne12 * ne13 == 1); -+ bool batch_n = ne11 > 1; -+ -+ ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; -+ ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; -+ ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; -+ -+ vk_buffer d_Qx = nullptr; -+ size_t qx_buf_offset = 0; -+ vk_buffer d_Qy = nullptr; -+ size_t qy_buf_offset = 0; -+ -+ bool src0_uma = false; -+ bool src1_uma = false; -+ -+ if (ctx->device->uma) { -+ ggml_vk_host_get(ctx->device, src0->data, d_Qx, qx_buf_offset); -+ ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset); -+ src0_uma = d_Qx != nullptr; -+ src1_uma = d_Qy != nullptr; -+ } -+ -+ const bool x_non_contig = !ggml_vk_dim01_contiguous(src0); -+ const bool y_non_contig = !ggml_vk_dim01_contiguous(src1); -+ -+ const bool f16_f32_kernel = src1->type == GGML_TYPE_F32; -+ -+ const bool qx_needs_dequant = x_non_contig; -+ const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig; -+ -+ // Not implemented -+ GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT -+ -+ const uint64_t x_ne = ne01 * ne00; -+ const uint64_t y_ne = ne11 * ne10; -+ const uint64_t d_ne = ne11 * ne01; -+ -+ const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment); -+ const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); -+ const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz; -+ const uint64_t y_sz = f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne; -+ const uint64_t d_sz = sizeof(float) * d_ne; -+ -+ vk_pipeline to_fp16_vk_0 = nullptr; -+ vk_pipeline to_fp16_vk_1 = nullptr; -+ if (x_non_contig) { -+ to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, src0->type); -+ } -+ if (y_non_contig) { -+ to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, src1->type); -+ } else { -+ to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type); -+ } -+ vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, src1->type, ne11); -+ GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT -+ GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT -+ GGML_ASSERT(dmmv != nullptr); -+ -+ if (dryrun) { -+ const uint64_t x_sz_upd = x_sz * ne02 * ne03; -+ const uint64_t y_sz_upd = y_sz * ne12 * ne13; -+ if ( -+ (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) || -+ (qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size)) { -+ GGML_ABORT("Requested preallocation size is too large"); -+ } -+ if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) { -+ ctx->prealloc_size_x = x_sz_upd; -+ } -+ if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) { -+ ctx->prealloc_size_y = y_sz_upd; -+ } -+ -+ // Request descriptor sets -+ if (qx_needs_dequant) { -+ ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_0, 1); -+ } -+ if (qy_needs_dequant) { -+ ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1); -+ } -+ ggml_pipeline_request_descriptor_sets(ctx->device, dmmv, 1); -+ return; -+ } -+ -+ vk_buffer d_D = dst_buf_ctx->dev_buffer; -+ const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; -+ GGML_ASSERT(d_D != nullptr); -+ vk_buffer d_X; -+ uint64_t x_buf_offset = 0; -+ vk_buffer d_Y; -+ uint64_t y_buf_offset = 0; -+ if(!src0_uma) { -+ d_Qx = src0_buf_ctx->dev_buffer; -+ qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs; -+ GGML_ASSERT(d_Qx != nullptr); -+ } -+ if(!src1_uma) { -+ d_Qy = src1_buf_ctx->dev_buffer; -+ qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs; -+ GGML_ASSERT(d_Qy != nullptr); -+ } -+ if (qx_needs_dequant) { -+ d_X = ctx->prealloc_x; -+ } else { -+ d_X = d_Qx; -+ x_buf_offset = qx_buf_offset; -+ GGML_ASSERT(qx_sz == x_sz); -+ } -+ if (qy_needs_dequant) { -+ d_Y = ctx->prealloc_y; -+ } else { -+ d_Y = d_Qy; -+ y_buf_offset = qy_buf_offset; -+ GGML_ASSERT(qy_sz == y_sz); -+ } -+ -+ if (x_non_contig) { -+ GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment)); -+ ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE }); -+ } -+ if (y_non_contig) { -+ GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne); -+ ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }); -+ } -+ -+ // For batch_n, the A matrix is the same for each batch, and B/D use the row stride as the batch stride -+ uint32_t stride_batch_x = batch_n ? 0 : ne00*ne01; -+ uint32_t stride_batch_y = batch_n ? ne10 : (ne10*ne11); -+ uint32_t stride_batch_d = batch_n ? ne20 : (ne20*ne21); -+ -+ if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) { -+ stride_batch_x = src0->nb[0] / ggml_type_size(src0->type); -+ } -+ -+ if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) { -+ stride_batch_y = src1->nb[0] / ggml_type_size(src1->type); -+ } -+ -+ const uint32_t max_groups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0]; -+ -+ uint32_t groups_x = ne01; -+ uint32_t groups_z = 1; -+ -+ if (ne01 > max_groups_x) { -+ groups_z = 64; -+ groups_x = CEIL_DIV(groups_x, groups_z); -+ } -+ -+ // compute -+ const vk_mat_vec_push_constants pc = { -+ (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01, -+ stride_batch_x, stride_batch_y, stride_batch_d, -+ (uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3, -+ }; -+ ggml_vk_sync_buffers(subctx); -+ ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, -+ { vk_subbuffer{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, vk_subbuffer{ d_Y, y_buf_offset, y_sz * ne12 * ne13 }, vk_subbuffer{ d_D, d_buf_offset, d_sz * ne22 * ne23} }, -+ sizeof(vk_mat_vec_push_constants), &pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z }); -+} -+ -+static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { -+ VK_LOG_DEBUG("ggml_vk_mul_mat_p021_f16_f32(" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; -+ std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; -+ std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; -+ std::cerr << "), " << (dryrun ? "dryrun" : "") << ")"); -+ GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1)); -+ GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // NOLINT -+ GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); // NOLINT -+ GGML_ASSERT(src0->type == GGML_TYPE_F16); -+ GGML_ASSERT(src1->type == GGML_TYPE_F32); -+ -+ const uint64_t ne00 = src0->ne[0]; -+ const uint64_t ne01 = src0->ne[1]; -+ const uint64_t ne02 = src0->ne[2]; -+ // const uint64_t ne03 = src0->ne[3]; -+ -+ const uint64_t ne10 = src1->ne[0]; -+ const uint64_t ne11 = src1->ne[1]; -+ const uint64_t ne12 = src1->ne[2]; -+ // const uint64_t ne13 = src1->ne[3]; -+ -+ GGML_ASSERT(ne11 == 1); -+ -+ ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; -+ ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; -+ ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; -+ -+ vk_buffer d_Qy = nullptr; -+ size_t qy_buf_offset = 0; -+ -+ bool src1_uma = false; -+ -+ if (ctx->device->uma) { -+ ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset); -+ src1_uma = d_Qy != nullptr; -+ } -+ -+ const uint64_t x_ne = ne00 * ne01 * ne02; -+ const uint64_t y_ne = ne10 * ne11 * ne12; -+ const uint64_t d_ne = ne01 * ne11 * ne12; -+ -+ const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment); -+ const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); -+ const uint64_t d_sz = sizeof(float) * d_ne; -+ -+ if (dryrun) { -+ // Request descriptor sets -+ ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_mul_mat_vec_p021_f16_f32, 1); -+ return; -+ } -+ -+ vk_buffer d_D = dst_buf_ctx->dev_buffer; -+ const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; -+ GGML_ASSERT(d_D != nullptr); -+ vk_buffer d_Qx = src0_buf_ctx->dev_buffer; -+ const uint64_t qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs; -+ GGML_ASSERT(d_Qx != nullptr); -+ if (!src1_uma) { -+ d_Qy = src1_buf_ctx->dev_buffer; -+ qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs; -+ GGML_ASSERT(d_Qx != nullptr); -+ } -+ -+ const uint64_t qy_buffer_offset = (qy_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment; -+ const uint64_t qy_shader_offset = qy_buf_offset - qy_buffer_offset; -+ -+ const uint64_t d_buffer_offset = (d_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment; -+ const uint64_t d_shader_offset = d_buf_offset - d_buffer_offset; -+ -+ // compute -+ const std::array pc = { (uint32_t)ne00, (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) }; -+ ggml_vk_sync_buffers(subctx); -+ ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 6 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 }); -+} -+ -+static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { -+ VK_LOG_DEBUG("ggml_vk_mul_mat_nc_f16_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; -+ std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; -+ std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; -+ std::cerr << "), " << (dryrun ? "dryrun" : "") << ")"); -+ GGML_ASSERT(!ggml_is_transposed(src0)); -+ GGML_ASSERT(!ggml_is_transposed(src1)); -+ GGML_ASSERT(!ggml_is_permuted(src0)); -+ GGML_ASSERT(src0->type == GGML_TYPE_F16); -+ GGML_ASSERT(src1->type == GGML_TYPE_F32); -+ -+ const uint64_t ne00 = src0->ne[0]; -+ const uint64_t ne01 = src0->ne[1]; -+ const uint64_t ne02 = src0->ne[2]; -+ // const uint64_t ne03 = src0->ne[3]; -+ -+ const uint64_t nb01 = src0->nb[1]; -+ const uint64_t nb02 = src0->nb[2]; -+ -+ // const uint64_t ne10 = src1->ne[0]; -+ const uint64_t ne11 = src1->ne[1]; -+ const uint64_t ne12 = src1->ne[2]; -+ // const uint64_t ne13 = src1->ne[3]; -+ -+ GGML_ASSERT(ne11 == 1); -+ -+ ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; -+ ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; -+ ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; -+ -+ vk_buffer d_Qy = nullptr; -+ size_t qy_buf_offset = 0; -+ -+ bool src1_uma = false; -+ -+ if (ctx->device->uma) { -+ ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset); -+ src1_uma = d_Qy != nullptr; -+ } -+ -+ const uint64_t d_ne = ne01 * ne11 * ne12; -+ -+ const uint32_t row_stride_x = nb01 / sizeof(ggml_fp16_t); -+ const uint32_t channel_stride_x = nb02 / sizeof(ggml_fp16_t); -+ -+ const uint64_t qx_sz = ggml_nbytes(src0); -+ const uint64_t qy_sz = ggml_nbytes(src1); -+ const uint64_t d_sz = sizeof(float) * d_ne; -+ -+ if (dryrun) { -+ // Request descriptor sets -+ ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, 1); -+ return; -+ } -+ -+ vk_buffer d_D = dst_buf_ctx->dev_buffer; -+ const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; -+ GGML_ASSERT(d_D != nullptr); -+ vk_buffer d_Qx = src0_buf_ctx->dev_buffer; -+ const uint64_t qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs; -+ GGML_ASSERT(d_Qx != nullptr); -+ if (!src1_uma) { -+ d_Qy = src1_buf_ctx->dev_buffer; -+ qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs; -+ GGML_ASSERT(d_Qx != nullptr); -+ } -+ -+ const uint64_t qy_buffer_offset = (qy_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment; -+ const uint64_t qy_shader_offset = qy_buf_offset - qy_buffer_offset; -+ -+ const uint64_t d_buffer_offset = (d_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment; -+ const uint64_t d_shader_offset = d_buf_offset - d_buffer_offset; -+ -+ // compute -+ const std::array pc = { (uint32_t)ne00, (uint32_t)ne01, row_stride_x, channel_stride_x, (uint32_t)(ne12 / ne02), (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) }; -+ ggml_vk_sync_buffers(subctx); -+ ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, -+ { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 7 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 }); -+} -+ -+static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { -+ VK_LOG_DEBUG("ggml_vk_mul_mat(" << src0 << ", " << src1 << ", " << dst << ")"); -+ if (src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && dst->ne[1] == 1 && -+ // detect 0213 permutation, and batch size of 1 -+ src0->nb[0] <= src0->nb[2] && -+ src0->nb[2] <= src0->nb[1] && -+ src0->nb[1] <= src0->nb[3] && -+ src1->nb[0] <= src1->nb[2] && -+ src1->nb[2] <= src1->nb[1] && -+ src1->nb[1] <= src1->nb[3] && -+ src0->ne[3] == 1 && -+ src1->ne[3] == 1) { -+ ggml_vk_mul_mat_vec_p021_f16_f32(ctx, subctx, src0, src1, dst, dryrun); -+ } else if (src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && dst->ne[1] == 1 && -+ !ggml_is_permuted(src0) && !ggml_is_permuted(src1)) { -+ ggml_vk_mul_mat_vec_nc_f16_f32(ctx, subctx, src0, src1, dst, dryrun); -+ // mul_mat_vec supports batching ne12*ne13 when ne11==1, or treating ne11 as the batch size (up to four) -+ // when ne12 and ne13 are one. -+ } else if ((dst->ne[1] == 1 || (dst->ne[1] <= mul_mat_vec_max_cols && src1->ne[2] * src1->ne[3] == 1)) && -+ (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) { -+ ggml_vk_mul_mat_vec_q_f16(ctx, subctx, src0, src1, dst, dryrun); -+ } else { -+ ggml_vk_mul_mat_q_f16(ctx, subctx, src0, src1, dst, dryrun); -+ } -+} -+ -+static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, bool dryrun = false) { -+ VK_LOG_DEBUG("ggml_vk_mul_mat_id_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; -+ std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; -+ std::cerr << "), (" << ids << ", name=" << ids->name << ", type=" << ids->type << ", ne0=" << ids->ne[0] << ", ne1=" << ids->ne[1] << ", ne2=" << ids->ne[2] << ", ne3=" << ids->ne[3] << ", nb0=" << ids->nb[0] << ", nb1=" << ids->nb[1] << ", nb2=" << ids->nb[2] << ", nb3=" << ids->nb[3]; -+ std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)"); -+ GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT -+ GGML_ASSERT(ids->type == GGML_TYPE_I32); -+ -+ const uint64_t ne00 = src0->ne[0]; -+ const uint64_t ne01 = src0->ne[1]; -+ const uint64_t ne02 = src0->ne[2]; -+ const uint64_t ne03 = src0->ne[3]; -+ -+ const uint64_t ne10 = src1->ne[0]; -+ const uint64_t ne11 = src1->ne[1]; -+ const uint64_t ne12 = src1->ne[2]; -+ const uint64_t ne13 = src1->ne[3]; -+ -+ const uint64_t nei0 = ids->ne[0]; -+ const uint64_t nei1 = ids->ne[1]; -+ GGML_ASSERT(nei0 * nei1 <= 3072); -+ -+ const uint32_t nbi1 = ids->nb[1]; -+ const uint32_t nbi2 = ids->nb[2]; -+ -+ const uint64_t ne20 = dst->ne[0]; -+ const uint64_t ne21 = dst->ne[1]; -+ const uint64_t ne22 = dst->ne[2]; -+ const uint64_t ne23 = dst->ne[3]; -+ -+ const uint64_t n_as = ne02; -+ -+ ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; -+ ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; -+ ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; -+ ggml_backend_vk_buffer_context * ids_buf_ctx = (ggml_backend_vk_buffer_context *)ids->buffer->context; -+ -+ vk_buffer d_Qx = nullptr; -+ size_t qx_buf_offset = 0; -+ vk_buffer d_Qy = nullptr; -+ size_t qy_buf_offset = 0; -+ vk_buffer d_ids = nullptr; -+ size_t ids_buf_offset = 0; -+ -+ bool src0_uma = false; -+ bool src1_uma = false; -+ bool ids_uma = false; -+ -+ if (ctx->device->uma) { -+ ggml_vk_host_get(ctx->device, src0->data, d_Qx, qx_buf_offset); -+ ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset); -+ ggml_vk_host_get(ctx->device, ids->data, d_ids, ids_buf_offset); -+ src0_uma = d_Qx != nullptr; -+ src1_uma = d_Qy != nullptr; -+ ids_uma = d_ids != nullptr; -+ } -+ -+ const bool x_non_contig = !ggml_vk_dim01_contiguous(src0); -+ const bool y_non_contig = !ggml_vk_dim01_contiguous(src1); -+ -+ const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig; -+ -+ vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type, (ggml_prec)dst->op_params[0]); -+ -+ const bool qx_needs_dequant = mmp == nullptr || x_non_contig; -+ const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig; -+ -+ if (qx_needs_dequant) { -+ GGML_ABORT("fatal error"); -+ } -+ -+ // Not implemented -+ GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT -+ -+ const uint64_t x_ne = ne01 * ne00; -+ const uint64_t y_ne = ne11 * ne10; -+ const uint64_t d_ne = ne21 * ne20; -+ -+ const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1)); -+ const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8; -+ -+ vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned); -+ -+ const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type); -+ const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); -+ const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne; -+ const uint64_t y_sz = y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne; -+ const uint64_t ids_sz = nbi2; -+ const uint64_t d_sz = sizeof(float) * d_ne; -+ -+ vk_pipeline to_fp16_vk_0 = nullptr; -+ vk_pipeline to_fp16_vk_1 = nullptr; -+ -+ if (x_non_contig) { -+ to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, GGML_TYPE_F16); -+ } else { -+ to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type); -+ } -+ if (y_non_contig) { -+ to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, GGML_TYPE_F16); -+ } else { -+ to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type); -+ } -+ GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT -+ GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT -+ -+ if (dryrun) { -+ const uint64_t x_sz_upd = x_sz * ne02 * ne03; -+ const uint64_t y_sz_upd = y_sz * ne12 * ne13; -+ if ( -+ (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) || -+ (qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size)) { -+ GGML_ABORT("Requested preallocation size is too large"); -+ } -+ if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) { -+ ctx->prealloc_size_x = x_sz_upd; -+ } -+ if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) { -+ ctx->prealloc_size_y = y_sz_upd; -+ } -+ -+ // Request descriptor sets -+ ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); -+ if (qx_needs_dequant) { -+ ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_0, 1); -+ } -+ if (qy_needs_dequant) { -+ ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1); -+ } -+ return; -+ } -+ -+ vk_buffer d_D = dst_buf_ctx->dev_buffer; -+ const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; -+ GGML_ASSERT(d_D != nullptr); -+ vk_buffer d_X; -+ uint64_t x_buf_offset = 0; -+ vk_buffer d_Y; -+ uint64_t y_buf_offset = 0; -+ if (!src0_uma) { -+ d_Qx = src0_buf_ctx->dev_buffer; -+ qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs; -+ GGML_ASSERT(d_Qx != nullptr); -+ } -+ if (!src1_uma) { -+ d_Qy = src1_buf_ctx->dev_buffer; -+ qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs; -+ GGML_ASSERT(d_Qy != nullptr); -+ } -+ if (!ids_uma) { -+ d_ids = ids_buf_ctx->dev_buffer; -+ ids_buf_offset = vk_tensor_offset(ids) + ids->view_offs; -+ GGML_ASSERT(d_ids != nullptr); -+ } -+ if (qx_needs_dequant) { -+ d_X = ctx->prealloc_x; -+ GGML_ASSERT(d_X->size >= x_sz * ne02 * ne03); -+ } else { -+ d_X = d_Qx; -+ x_buf_offset = qx_buf_offset; -+ GGML_ASSERT(qx_sz == x_sz); -+ } -+ if (qy_needs_dequant) { -+ d_Y = ctx->prealloc_y; -+ GGML_ASSERT(d_Y->size >= y_sz * ne02 * ne03); -+ } else { -+ d_Y = d_Qy; -+ y_buf_offset = qy_buf_offset; -+ GGML_ASSERT(qy_sz == y_sz); -+ } -+ -+ if (x_non_contig) { -+ ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE }); -+ } else if (qx_needs_dequant) { -+ const std::vector pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) }; -+ ggml_vk_sync_buffers(subctx); -+ ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0, -+ { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, vk_subbuffer{ d_X, 0, x_sz * ne02 * ne03 } }, pc.size() * sizeof(uint32_t), pc.data(), { (uint32_t)(x_ne * ne02 * ne03), 1, 1}); -+ } -+ if (y_non_contig) { -+ ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }); -+ } -+ -+ uint32_t stride_batch_x = ne00*ne01; -+ uint32_t stride_batch_y = ne10*ne11; -+ -+ if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) { -+ stride_batch_x = src0->nb[0] / ggml_type_size(src0->type); -+ } -+ -+ if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) { -+ stride_batch_y = src1->nb[0] / ggml_type_size(src1->type); -+ } -+ -+ // compute -+ ggml_vk_matmul_id( -+ ctx, subctx, pipeline, -+ { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 }, -+ { d_D, d_buf_offset, d_sz * ne22 * ne23 }, { d_ids, ids_buf_offset, ids_sz }, -+ ne01, ne21, ne10, ne10, ne10, ne01, -+ stride_batch_x, stride_batch_y, ne20*ne21, -+ n_as, nei0, nei1, nbi1 / ggml_type_size(ids->type), ne11 -+ ); // NOLINT -+} -+ -+static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, bool dryrun = false) { -+ VK_LOG_DEBUG("ggml_vk_mul_mat_vec_id_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; -+ std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; -+ std::cerr << "), (" << ids << ", name=" << ids->name << ", type=" << ids->type << ", ne0=" << ids->ne[0] << ", ne1=" << ids->ne[1] << ", ne2=" << ids->ne[2] << ", ne3=" << ids->ne[3] << ", nb0=" << ids->nb[0] << ", nb1=" << ids->nb[1] << ", nb2=" << ids->nb[2] << ", nb3=" << ids->nb[3]; -+ std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; -+ std::cerr << "), " << (dryrun ? "dryrun" : "") << ")"); -+ GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); // NOLINT -+ GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT -+ GGML_ASSERT(ids->type == GGML_TYPE_I32); -+ -+ const uint64_t ne00 = src0->ne[0]; -+ const uint64_t ne01 = src0->ne[1]; -+ const uint64_t ne02 = src0->ne[2]; -+ const uint64_t ne03 = src0->ne[3]; -+ -+ const uint64_t ne10 = src1->ne[0]; -+ const uint64_t ne11 = src1->ne[1]; -+ const uint64_t ne12 = src1->ne[2]; -+ const uint64_t ne13 = src1->ne[3]; -+ -+ const uint64_t nei0 = ids->ne[0]; -+ const uint64_t nei1 = ids->ne[1]; -+ -+ const uint64_t nbi2 = ids->nb[2]; -+ -+ GGML_ASSERT(nei1 == 1); -+ -+ const uint64_t ne20 = dst->ne[0]; -+ const uint64_t ne21 = dst->ne[1]; -+ const uint64_t ne22 = dst->ne[2]; -+ const uint64_t ne23 = dst->ne[3]; -+ -+ ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; -+ ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; -+ ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; -+ ggml_backend_vk_buffer_context * ids_buf_ctx = (ggml_backend_vk_buffer_context *)ids->buffer->context; -+ -+ vk_buffer d_Qx = nullptr; -+ size_t qx_buf_offset = 0; -+ vk_buffer d_Qy = nullptr; -+ size_t qy_buf_offset = 0; -+ vk_buffer d_ids = nullptr; -+ size_t ids_buf_offset = 0; -+ -+ bool src0_uma = false; -+ bool src1_uma = false; -+ bool ids_uma = false; -+ -+ if (ctx->device->uma) { -+ ggml_vk_host_get(ctx->device, src0->data, d_Qx, qx_buf_offset); -+ ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset); -+ ggml_vk_host_get(ctx->device, ids->data, d_ids, ids_buf_offset); -+ src0_uma = d_Qx != nullptr; -+ src1_uma = d_Qy != nullptr; -+ ids_uma = d_ids != nullptr; -+ } -+ -+ const bool x_non_contig = !ggml_vk_dim01_contiguous(src0); -+ const bool y_non_contig = !ggml_vk_dim01_contiguous(src1); -+ -+ const bool f16_f32_kernel = src1->type == GGML_TYPE_F32; -+ -+ const bool qx_needs_dequant = x_non_contig; -+ const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig; -+ -+ // Not implemented -+ GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT -+ -+ const uint64_t x_ne = ne01 * ne00; -+ const uint64_t y_ne = ne11 * ne10; -+ const uint64_t d_ne = ne21 * ne20; -+ -+ const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment); -+ const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); -+ const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz; -+ const uint64_t y_sz = f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne; -+ const uint64_t ids_sz = nbi2; -+ const uint64_t d_sz = sizeof(float) * d_ne; -+ -+ vk_pipeline to_fp16_vk_0 = nullptr; -+ vk_pipeline to_fp16_vk_1 = nullptr; -+ if (x_non_contig) { -+ to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, src0->type); -+ } -+ if (y_non_contig) { -+ to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, src1->type); -+ } else { -+ to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type); -+ } -+ vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec_id(ctx, src0->type, src1->type); -+ GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT -+ GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT -+ GGML_ASSERT(dmmv != nullptr); -+ -+ if (dryrun) { -+ const uint64_t x_sz_upd = x_sz * ne02 * ne03; -+ const uint64_t y_sz_upd = y_sz * ne12 * ne13; -+ if ( -+ (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) || -+ (qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size)) { -+ GGML_ABORT("Requested preallocation size is too large"); -+ } -+ if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) { -+ ctx->prealloc_size_x = x_sz_upd; -+ } -+ if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) { -+ ctx->prealloc_size_y = y_sz_upd; -+ } -+ -+ // Request descriptor sets -+ if (qx_needs_dequant) { -+ ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_0, 1); -+ } -+ if (qy_needs_dequant) { -+ ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1); -+ } -+ ggml_pipeline_request_descriptor_sets(ctx->device, dmmv, 1); -+ return; -+ } -+ -+ vk_buffer d_D = dst_buf_ctx->dev_buffer; -+ const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; -+ GGML_ASSERT(d_D != nullptr); -+ vk_buffer d_X; -+ uint64_t x_buf_offset = 0; -+ vk_buffer d_Y; -+ uint64_t y_buf_offset = 0; -+ if(!src0_uma) { -+ d_Qx = src0_buf_ctx->dev_buffer; -+ qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs; -+ GGML_ASSERT(d_Qx != nullptr); -+ } -+ if(!src1_uma) { -+ d_Qy = src1_buf_ctx->dev_buffer; -+ qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs; -+ GGML_ASSERT(d_Qy != nullptr); -+ } -+ if(!ids_uma) { -+ d_ids = ids_buf_ctx->dev_buffer; -+ ids_buf_offset = vk_tensor_offset(ids) + ids->view_offs; -+ GGML_ASSERT(d_ids != nullptr); -+ } -+ if (qx_needs_dequant) { -+ d_X = ctx->prealloc_x; -+ } else { -+ d_X = d_Qx; -+ x_buf_offset = qx_buf_offset; -+ GGML_ASSERT(qx_sz == x_sz); -+ } -+ if (qy_needs_dequant) { -+ d_Y = ctx->prealloc_y; -+ } else { -+ d_Y = d_Qy; -+ y_buf_offset = qy_buf_offset; -+ GGML_ASSERT(qy_sz == y_sz); -+ } -+ -+ if (x_non_contig) { -+ GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment)); -+ ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE }); -+ } -+ if (y_non_contig) { -+ GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne); -+ ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }); -+ } -+ -+ uint32_t stride_batch_y = ne10*ne11; -+ -+ if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) { -+ stride_batch_y = src1->nb[0] / ggml_type_size(src1->type); -+ } -+ -+ const uint32_t max_groups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0]; -+ -+ uint32_t groups_x = ne01; -+ uint32_t groups_z = 1; -+ -+ if (ne01 > max_groups_x) { -+ groups_z = 64; -+ groups_x = CEIL_DIV(groups_x, groups_z); -+ } -+ -+ // compute -+ const vk_mat_vec_id_push_constants pc = { -+ (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01, -+ (uint32_t)x_ne, stride_batch_y, (uint32_t)(ne20*ne21), -+ (uint32_t)nei0, (uint32_t)ne11, -+ }; -+ ggml_vk_sync_buffers(subctx); -+ ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, -+ { vk_subbuffer{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, -+ vk_subbuffer{ d_Y, y_buf_offset, y_sz * ne12 * ne13 }, vk_subbuffer{ d_D, d_buf_offset, d_sz * ne22 * ne23}, vk_subbuffer{ d_ids, ids_buf_offset, ids_sz } }, -+ sizeof(vk_mat_vec_id_push_constants), &pc, { groups_x, (uint32_t)nei0, groups_z }); -+} -+ -+static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) { -+ VK_LOG_DEBUG("ggml_vk_mul_mat_id(" << src0 << ", " << src1 << ", " << src2 << ", " << dst << ")"); -+ if (src2->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) { -+ ggml_vk_mul_mat_vec_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun); -+ } else { -+ ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun); -+ } -+} -+ -+static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * q, const ggml_tensor * k, const ggml_tensor * v, const ggml_tensor * mask, ggml_tensor * dst, bool dryrun = false) { -+ VK_LOG_DEBUG("ggml_vk_flash_attn((" << q << ", name=" << q->name << ", type=" << q->type << ", ne0=" << q->ne[0] << ", ne1=" << q->ne[1] << ", ne2=" << q->ne[2] << ", ne3=" << q->ne[3] << ", nb0=" << q->nb[0] << ", nb1=" << q->nb[1] << ", nb2=" << q->nb[2] << ", nb3=" << q->nb[3]; -+ std::cerr << "), (" << k << ", name=" << k->name << ", type=" << k->type << ", ne0=" << k->ne[0] << ", ne1=" << k->ne[1] << ", ne2=" << k->ne[2] << ", ne3=" << k->ne[3] << ", nb0=" << k->nb[0] << ", nb1=" << k->nb[1] << ", nb2=" << k->nb[2] << ", nb3=" << k->nb[3]; -+ std::cerr << "), (" << v << ", name=" << v->name << ", type=" << v->type << ", ne0=" << v->ne[0] << ", ne1=" << v->ne[1] << ", ne2=" << v->ne[2] << ", ne3=" << v->ne[3] << ", nb0=" << v->nb[0] << ", nb1=" << v->nb[1] << ", nb2=" << v->nb[2] << ", nb3=" << v->nb[3]; -+ std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; -+ std::cerr << "), " << (dryrun ? "dryrun" : "") << ")"); -+ -+ GGML_TENSOR_LOCALS(int64_t, neq, q, ne) -+ GGML_TENSOR_LOCALS(size_t, nbq, q, nb) -+ GGML_TENSOR_LOCALS(int64_t, nek, k, ne) -+ GGML_TENSOR_LOCALS(size_t, nbk, k, nb) -+ GGML_TENSOR_LOCALS(int64_t, nev, v, ne) -+ GGML_TENSOR_LOCALS(size_t, nbv, v, nb) -+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) -+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb) -+ -+ const uint32_t nem1 = mask ? mask->ne[1] : 0; -+ const uint32_t nbm1 = mask ? mask->nb[1] : 0; -+ -+ const uint32_t D = neq0; -+ const uint32_t N = neq1; -+ const uint32_t KV = nek1; -+ -+ GGML_ASSERT(ne0 == D); -+ GGML_ASSERT(ne2 == N); -+ -+ // input tensor rows must be contiguous -+ GGML_ASSERT(nbq0 == ggml_type_size(q->type)); -+ GGML_ASSERT(nbk0 == ggml_type_size(k->type)); -+ GGML_ASSERT(nbv0 == ggml_type_size(v->type)); -+ -+ GGML_ASSERT(neq0 == D); -+ GGML_ASSERT(nek0 == D); -+ GGML_ASSERT(nev0 == D); -+ -+ GGML_ASSERT(neq1 == N); -+ GGML_ASSERT(nev0 == D); -+ -+ GGML_ASSERT(nev1 == nek1); -+ -+ // dst cannot be transposed or permuted -+ GGML_ASSERT(nb0 == sizeof(float)); -+ GGML_ASSERT(nb0 <= nb1); -+ GGML_ASSERT(nb1 <= nb2); -+ GGML_ASSERT(nb2 <= nb3); -+ -+ assert(dst->type == GGML_TYPE_F32); -+ assert(q->type == GGML_TYPE_F32); -+ assert(k->type == v->type); -+ -+ vk_pipeline *pipelines; -+ // XXX TODO other backends may be changing accumulator precision to default to f32 soon -+ bool f32acc = dst->op_params[3] == GGML_PREC_F32; -+ bool small_rows = N <= flash_attention_num_small_rows; -+ switch (D) { -+ case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break; -+ case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break; -+ case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96[k->type][f32acc][small_rows][0]; break; -+ case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112[k->type][f32acc][small_rows][0]; break; -+ case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128[k->type][f32acc][small_rows][0]; break; -+ case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256[k->type][f32acc][small_rows][0]; break; -+ default: -+ assert(!"unsupported D value"); -+ return; -+ } -+ assert(pipelines); -+ -+ bool aligned = (KV % pipelines[1]->align) == 0; -+ vk_pipeline pipeline = pipelines[aligned]; -+ assert(pipeline); -+ -+ if (dryrun) { -+ // Request descriptor sets -+ ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); -+ return; -+ } -+ -+ float scale = 1.0f; -+ float max_bias = 0.0f; -+ float logit_softcap = 0.0f; -+ -+ memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float)); -+ memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float)); -+ memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float)); -+ -+ if (logit_softcap != 0) { -+ scale /= logit_softcap; -+ } -+ -+ const uint32_t n_head_kv = neq2; -+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv)); -+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); -+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); -+ -+ ggml_vk_sync_buffers(subctx); -+ -+ vk_buffer d_Q = nullptr, d_K = nullptr, d_V = nullptr, d_D = nullptr, d_M = nullptr; -+ size_t q_buf_offset = 0, k_buf_offset = 0, v_buf_offset = 0, d_buf_offset = 0, m_buf_offset = 0; -+ -+ bool Q_uma = false, K_uma = false, V_uma = false, D_uma = false, M_uma = false; -+ -+ if (ctx->device->uma) { -+ ggml_vk_host_get(ctx->device, q->data, d_Q, q_buf_offset); -+ ggml_vk_host_get(ctx->device, k->data, d_K, q_buf_offset); -+ ggml_vk_host_get(ctx->device, v->data, d_V, q_buf_offset); -+ ggml_vk_host_get(ctx->device, dst->data, d_D, q_buf_offset); -+ Q_uma = d_Q != nullptr; -+ K_uma = d_K != nullptr; -+ V_uma = d_V != nullptr; -+ D_uma = d_D != nullptr; -+ if (mask) { -+ ggml_vk_host_get(ctx->device, mask->data, d_M, q_buf_offset); -+ M_uma = d_M != nullptr; -+ } -+ } -+ -+ -+ ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; -+ ggml_backend_vk_buffer_context * q_buf_ctx = (ggml_backend_vk_buffer_context *)q->buffer->context; -+ ggml_backend_vk_buffer_context * k_buf_ctx = (ggml_backend_vk_buffer_context *)k->buffer->context; -+ ggml_backend_vk_buffer_context * v_buf_ctx = (ggml_backend_vk_buffer_context *)v->buffer->context; -+ -+ if (!Q_uma) { -+ d_Q = q_buf_ctx->dev_buffer; -+ q_buf_offset = vk_tensor_offset(q) + q->view_offs; -+ } -+ if (!K_uma) { -+ d_K = k_buf_ctx->dev_buffer; -+ k_buf_offset = vk_tensor_offset(k) + k->view_offs; -+ } -+ if (!V_uma) { -+ d_V = v_buf_ctx->dev_buffer; -+ v_buf_offset = vk_tensor_offset(v) + v->view_offs; -+ } -+ if (!D_uma) { -+ d_D = d_buf_ctx->dev_buffer; -+ d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; -+ } -+ -+ if (!M_uma) { -+ d_M = d_Q; -+ m_buf_offset = q_buf_offset; -+ if (mask) { -+ ggml_backend_vk_buffer_context * m_buf_ctx = (ggml_backend_vk_buffer_context*)mask->buffer->context; -+ d_M = m_buf_ctx->dev_buffer; -+ m_buf_offset = vk_tensor_offset(mask) + mask->view_offs; -+ } -+ } -+ -+ const vk_flash_attn_push_constants pc = { N, KV, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, (uint32_t)neq2, (uint32_t)neq3, (uint32_t)nek2, (uint32_t)nek3, (uint32_t)nev2, (uint32_t)nev3, nem1, (uint32_t)nbq2, (uint32_t)nbq3, (uint32_t)nbk2, (uint32_t)nbk3, (uint32_t)nbv2, (uint32_t)nbv3, nbm1, scale, max_bias, logit_softcap, mask != nullptr, n_head_log2, m0, m1 }; -+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, -+ { -+ vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE}, -+ vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE}, -+ vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE}, -+ vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE}, -+ vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE}, -+ }, -+ sizeof(vk_flash_attn_push_constants), &pc, { (uint32_t)neq1, (uint32_t)neq2, (uint32_t)neq3 }); -+} -+ -+static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) { -+ switch (op) { -+ case GGML_OP_GET_ROWS: -+ GGML_ASSERT(src1->type == GGML_TYPE_I32); -+ if (dst->type == GGML_TYPE_F16) { -+ return ctx->device->pipeline_get_rows[src0->type]; -+ } -+ if (dst->type == GGML_TYPE_F32) { -+ return ctx->device->pipeline_get_rows_f32[src0->type]; -+ } -+ return nullptr; -+ case GGML_OP_ACC: -+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { -+ return ctx->device->pipeline_acc_f32; -+ } -+ return nullptr; -+ case GGML_OP_ADD: -+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { -+ return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_f32_norepeat : ctx->device->pipeline_add_f32; -+ } -+ if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) { -+ return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_f16_f32_f16_norepeat : ctx->device->pipeline_add_f16_f32_f16; -+ } -+ return nullptr; -+ case GGML_OP_MUL: -+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { -+ return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_mul_f32_norepeat : ctx->device->pipeline_mul_f32; -+ } -+ return nullptr; -+ case GGML_OP_DIV: -+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { -+ return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_div_f32_norepeat : ctx->device->pipeline_div_f32; -+ } -+ return nullptr; -+ case GGML_OP_CONCAT: -+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { -+ return ctx->device->pipeline_concat_f32; -+ } -+ if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { -+ return ctx->device->pipeline_concat_f16; -+ } -+ if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) { -+ return ctx->device->pipeline_concat_i32; -+ } -+ return nullptr; -+ case GGML_OP_UPSCALE: -+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { -+ return ctx->device->pipeline_upscale_f32; -+ } -+ return nullptr; -+ case GGML_OP_SCALE: -+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { -+ return ctx->device->pipeline_scale_f32; -+ } -+ return nullptr; -+ case GGML_OP_SQR: -+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { -+ return ctx->device->pipeline_sqr_f32; -+ } -+ return nullptr; -+ case GGML_OP_SIN: -+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { -+ return ctx->device->pipeline_sin_f32; -+ } -+ return nullptr; -+ case GGML_OP_COS: -+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { -+ return ctx->device->pipeline_cos_f32; -+ } -+ return nullptr; -+ case GGML_OP_CLAMP: -+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { -+ return ctx->device->pipeline_clamp_f32; -+ } -+ return nullptr; -+ case GGML_OP_PAD: -+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { -+ return ctx->device->pipeline_pad_f32; -+ } -+ return nullptr; -+ case GGML_OP_REPEAT: -+ if (ggml_type_size(src0->type) == sizeof(float) && ggml_type_size(dst->type) == sizeof(float)) { -+ return ctx->device->pipeline_repeat_f32; -+ } -+ return nullptr; -+ case GGML_OP_CPY: -+ case GGML_OP_CONT: -+ case GGML_OP_DUP: -+ return ggml_vk_get_cpy_pipeline(ctx, src0, dst, dst->type); -+ case GGML_OP_NORM: -+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { -+ return ctx->device->pipeline_norm_f32; -+ } -+ return nullptr; -+ case GGML_OP_GROUP_NORM: -+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { -+ return ctx->device->pipeline_group_norm_f32; -+ } -+ return nullptr; -+ case GGML_OP_RMS_NORM: -+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { -+ return ctx->device->pipeline_rms_norm_f32; -+ } -+ return nullptr; -+ case GGML_OP_UNARY: -+ switch (ggml_get_unary_op(dst)) { -+ case GGML_UNARY_OP_SILU: -+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { -+ return ctx->device->pipeline_silu_f32; -+ } -+ break; -+ case GGML_UNARY_OP_GELU: -+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { -+ return ctx->device->pipeline_gelu_f32; -+ } -+ break; -+ case GGML_UNARY_OP_GELU_QUICK: -+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { -+ return ctx->device->pipeline_gelu_quick_f32; -+ } -+ break; -+ case GGML_UNARY_OP_RELU: -+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { -+ return ctx->device->pipeline_relu_f32; -+ } -+ break; -+ case GGML_UNARY_OP_TANH: -+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { -+ return ctx->device->pipeline_tanh_f32; -+ } -+ break; -+ default: -+ break; -+ } -+ return nullptr; -+ case GGML_OP_DIAG_MASK_INF: -+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { -+ return ctx->device->pipeline_diag_mask_inf_f32; -+ } -+ return nullptr; -+ case GGML_OP_SOFT_MAX: -+ GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); -+ -+ if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) { -+ return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_wg512 : ctx->device->pipeline_soft_max_f32; -+ } -+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { -+ return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_f16_wg512 : ctx->device->pipeline_soft_max_f32_f16; -+ } -+ return nullptr; -+ case GGML_OP_ROPE: -+ { -+ const int mode = ((const int32_t *) dst->op_params)[2]; -+ const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; -+ -+ if (is_neox) { -+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { -+ return ctx->device->pipeline_rope_neox_f32; -+ } -+ if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { -+ return ctx->device->pipeline_rope_neox_f16; -+ } -+ } else { -+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { -+ return ctx->device->pipeline_rope_norm_f32; -+ } -+ if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { -+ return ctx->device->pipeline_rope_norm_f16; -+ } -+ } -+ return nullptr; -+ } -+ case GGML_OP_ARGSORT: -+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) { -+ return ctx->device->pipeline_argsort_f32; -+ } -+ return nullptr; -+ case GGML_OP_SUM_ROWS: -+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { -+ return ctx->device->pipeline_sum_rows_f32; -+ } -+ return nullptr; -+ case GGML_OP_IM2COL: -+ if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { -+ return ctx->device->pipeline_im2col_f32; -+ } -+ if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) { -+ return ctx->device->pipeline_im2col_f32_f16; -+ } -+ return nullptr; -+ case GGML_OP_TIMESTEP_EMBEDDING: -+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { -+ return ctx->device->pipeline_timestep_embedding_f32; -+ } -+ return nullptr; -+ case GGML_OP_POOL_2D: -+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { -+ return ctx->device->pipeline_pool2d_f32; -+ } -+ return nullptr; -+ case GGML_OP_RWKV_WKV6: -+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { -+ return ctx->device->pipeline_rwkv_wkv6_f32; -+ } -+ return nullptr; -+ case GGML_OP_LEAKY_RELU: -+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { -+ return ctx->device->pipeline_leaky_relu_f32; -+ } -+ return nullptr; -+ default: -+ return nullptr; -+ } -+ -+ GGML_UNUSED(src2); -+} -+ -+static bool ggml_vk_op_supports_incontiguous(ggml_op op) { -+ switch (op) { -+ case GGML_OP_CPY: -+ case GGML_OP_GET_ROWS: -+ case GGML_OP_ADD: -+ case GGML_OP_MUL: -+ case GGML_OP_DIV: -+ case GGML_OP_CONCAT: -+ case GGML_OP_UPSCALE: -+ case GGML_OP_SQR: -+ case GGML_OP_SIN: -+ case GGML_OP_COS: -+ case GGML_OP_CLAMP: -+ case GGML_OP_PAD: -+ case GGML_OP_REPEAT: -+ return true; -+ default: -+ return false; -+ } -+} -+ -+static uint32_t get_misalign_bytes(ggml_backend_vk_context * ctx, const ggml_tensor * t) -+{ -+ return ((vk_tensor_offset(t) + t->view_offs) & (ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1));; -+} -+ -+template void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, T &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { -+ GGML_UNUSED(p); -+ GGML_UNUSED(src0); -+ GGML_UNUSED(src1); -+ GGML_UNUSED(src2); -+ GGML_UNUSED(dst); -+ static_assert(!std::is_const::value, "unexpected type"); -+ GGML_ASSERT(!src0 || get_misalign_bytes(ctx, src0) == 0); -+ GGML_ASSERT(!src1 || get_misalign_bytes(ctx, src1) == 0); -+ GGML_ASSERT(!src2 || get_misalign_bytes(ctx, src2) == 0); -+ GGML_ASSERT(!dst || get_misalign_bytes(ctx, dst) == 0); -+} -+ -+template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_unary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { -+ const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); -+ const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); -+ -+ p.misalign_offsets = (a_offset << 16) | d_offset; -+ -+ GGML_UNUSED(src1); -+ GGML_UNUSED(src2); -+} -+ -+template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_binary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { -+ const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); -+ const uint32_t b_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type); -+ const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); -+ -+ GGML_ASSERT(dst->op != GGML_OP_GET_ROWS || (a_offset == 0 && b_offset == 0 && d_offset == 0)); -+ -+ p.misalign_offsets = (a_offset << 16) | (b_offset << 8) | d_offset; -+ -+ GGML_UNUSED(src2); -+} -+ -+template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_upscale_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { -+ const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); -+ const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); -+ -+ p.a_offset = a_offset; -+ p.d_offset = d_offset; -+ -+ GGML_UNUSED(src1); -+ GGML_UNUSED(src2); -+} -+ -+template -+static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op, PC&& pc, bool dryrun = false) { -+ VK_LOG_DEBUG("ggml_vk_op_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; -+ if (src1 != nullptr) { -+ std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; -+ } -+ if (src2 != nullptr) { -+ std::cerr << "), (" << src2 << ", name=" << src2->name << ", type=" << src2->type << ", ne0=" << src2->ne[0] << ", ne1=" << src2->ne[1] << ", ne2=" << src2->ne[2] << ", ne3=" << src2->ne[3] << ", nb0=" << src2->nb[0] << ", nb1=" << src2->nb[1] << ", nb2=" << src2->nb[2] << ", nb3=" << src2->nb[3]; -+ } -+ std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; -+ std::cerr << "), " << ggml_op_name(op) << ", " << (dryrun ? "dryrun" : "") << ")"); -+ GGML_ASSERT(op == GGML_OP_GET_ROWS || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type)))); // NOLINT -+ GGML_ASSERT(ggml_vk_op_supports_incontiguous(op) || ggml_vk_dim01_contiguous(src0)); // NOLINT -+ GGML_ASSERT(dst->buffer != nullptr); -+ const uint64_t ne00 = src0->ne[0]; -+ const uint64_t ne01 = src0->ne[1]; -+ const uint64_t ne02 = src0->ne[2]; -+ const uint64_t ne03 = src0->ne[3]; -+ const uint64_t ne0 = ne00 * ne01; -+ -+ const bool use_src1 = src1 != nullptr; -+ const uint64_t ne10 = use_src1 ? src1->ne[0] : 0; -+ const uint64_t ne11 = use_src1 ? src1->ne[1] : 0; -+ const uint64_t ne12 = use_src1 ? src1->ne[2] : 0; -+ const uint64_t ne13 = use_src1 ? src1->ne[3] : 0; -+ const uint64_t ne1 = ne10 * ne11; -+ // const uint64_t nb10 = use_src1 ? src1->nb[0] : 0; -+ -+ const bool use_src2 = src2 != nullptr; -+ const uint64_t ne20 = use_src2 ? src2->ne[0] : 0; -+ const uint64_t ne21 = use_src2 ? src2->ne[1] : 0; -+ const uint64_t ne22 = use_src2 ? src2->ne[2] : 0; -+ const uint64_t ne23 = use_src2 ? src2->ne[3] : 0; -+ const uint64_t ne2 = ne20 * ne21; -+ -+ const uint64_t ned0 = dst->ne[0]; -+ const uint64_t ned1 = dst->ne[1]; -+ const uint64_t ned2 = dst->ne[2]; -+ const uint64_t ned3 = dst->ne[3]; -+ const uint64_t ned = ned0 * ned1; -+ -+ init_pushconst_fastdiv(pc); -+ -+ vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, dst, op); -+ -+ if (pipeline == nullptr) { -+ std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(op) << " for " << ggml_type_name(src0->type); -+ if (src1 != nullptr) { -+ std::cerr << " and " << ggml_type_name(src1->type); -+ } -+ std::cerr << " to " << ggml_type_name(dst->type) << std::endl; -+ GGML_ABORT("fatal error"); -+ } -+ -+ if (dryrun) { -+ ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); -+ return; -+ } -+ -+ const bool op_supports_incontiguous = ggml_vk_op_supports_incontiguous(op); -+ -+ ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; -+ ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; -+ ggml_backend_vk_buffer_context * src1_buf_ctx = use_src1 ? (ggml_backend_vk_buffer_context *)src1->buffer->context : nullptr; -+ ggml_backend_vk_buffer_context * src2_buf_ctx = use_src2 ? (ggml_backend_vk_buffer_context *)src2->buffer->context : nullptr; -+ -+ vk_buffer d_X = nullptr; -+ size_t x_buf_offset = 0; -+ vk_buffer d_Y = nullptr; -+ size_t y_buf_offset = 0; -+ vk_buffer d_Z = nullptr; -+ size_t z_buf_offset = 0; -+ -+ bool src0_uma = false; -+ bool src1_uma = false; -+ bool src2_uma = false; -+ -+ if (ctx->device->uma) { -+ ggml_vk_host_get(ctx->device, src0->data, d_X, x_buf_offset); -+ src0_uma = d_X != nullptr; -+ if (use_src1) { -+ ggml_vk_host_get(ctx->device, src1->data, d_Y, y_buf_offset); -+ src1_uma = d_Y != nullptr; -+ } -+ if (use_src2) { -+ ggml_vk_host_get(ctx->device, src2->data, d_Z, z_buf_offset); -+ src2_uma = d_Z != nullptr; -+ } -+ } -+ -+ uint64_t x_sz = ggml_type_size(src0->type)/ggml_blck_size(src0->type) * ne0; -+ uint64_t y_sz = use_src1 ? ggml_type_size(src1->type) * ne1 : 0; -+ uint64_t z_sz = use_src2 ? ggml_type_size(src2->type) * ne2 : 0; -+ uint64_t d_sz = ggml_type_size(dst->type) * ned; -+ -+ vk_buffer d_D = dst_buf_ctx->dev_buffer; -+ -+ // Workaround for tiny tensor inputs on ROPE -+ if (op == GGML_OP_ROPE && use_src1 && y_sz > d_D->size) { -+ y_sz = VK_WHOLE_SIZE; -+ } -+ -+ GGML_ASSERT(d_D != nullptr); -+ uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; -+ if(!src0_uma) { -+ d_X = src0_buf_ctx->dev_buffer; -+ x_buf_offset = vk_tensor_offset(src0) + src0->view_offs; -+ GGML_ASSERT(d_X != nullptr); -+ } -+ if (use_src1 && !src1_uma) { -+ d_Y = src1_buf_ctx->dev_buffer; -+ y_buf_offset = vk_tensor_offset(src1) + src1->view_offs; -+ GGML_ASSERT(d_Y != nullptr); -+ } -+ if (use_src2 && !src2_uma) { -+ d_Z = src2_buf_ctx->dev_buffer; -+ z_buf_offset = vk_tensor_offset(src2) + src2->view_offs; -+ GGML_ASSERT(d_Z != nullptr); -+ } -+ // Compute misalignment offset for descriptors and store it in in push constants, then align the descriptor offsets. -+ init_pushconst_tensor_offsets(ctx, pc, src0, src1, src2, dst); -+ x_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); -+ y_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); -+ z_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); -+ d_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); -+ -+ if (op_supports_incontiguous) { -+ x_sz = ggml_nbytes(src0); -+ y_sz = use_src1 ? ggml_nbytes(src1) : 0; -+ z_sz = use_src2 ? ggml_nbytes(src2) : 0; -+ d_sz = ggml_nbytes(dst); -+ -+ if (x_buf_offset + x_sz >= d_X->size) { -+ x_sz = VK_WHOLE_SIZE; -+ } -+ if (use_src1 && y_buf_offset + y_sz >= d_Y->size) { -+ y_sz = VK_WHOLE_SIZE; -+ } -+ if (use_src2 && z_buf_offset + z_sz >= d_Z->size) { -+ z_sz = VK_WHOLE_SIZE; -+ } -+ if (d_buf_offset + d_sz >= d_D->size) { -+ d_sz = VK_WHOLE_SIZE; -+ } -+ } -+ -+ std::array elements; -+ -+ // Single call if dimension 2 is contiguous -+ GGML_ASSERT(op_supports_incontiguous || (ggml_is_contiguous(src0) && (src1 == nullptr || ggml_is_contiguous(src1)))); -+ -+ switch (op) { -+ case GGML_OP_NORM: -+ case GGML_OP_RMS_NORM: -+ case GGML_OP_SOFT_MAX: -+ case GGML_OP_SUM_ROWS: -+ { -+ const uint32_t nr = ggml_nrows(src0); -+ if (nr > 262144) { -+ elements = { 512, 512, CEIL_DIV(nr, 262144) }; -+ } else if (nr > 512) { -+ elements = { 512, CEIL_DIV(nr, 512), 1 }; -+ } else { -+ elements = { nr, 1, 1 }; -+ } -+ } break; -+ case GGML_OP_GROUP_NORM: -+ { -+ const uint32_t num_groups = dst->op_params[0]; -+ elements = { num_groups * (uint32_t)src0->ne[3], 1, 1 }; -+ } break; -+ case GGML_OP_DIAG_MASK_INF: -+ case GGML_OP_ROPE: -+ elements = { (uint32_t)ggml_nrows(src0), (uint32_t)ne00, 1 }; -+ break; -+ case GGML_OP_GET_ROWS: -+ elements = { (uint32_t)ne00, (uint32_t)ne10, (uint32_t)(ne11 * ne12) }; -+ break; -+ case GGML_OP_ARGSORT: -+ elements = { (uint32_t)ne00, (uint32_t)ggml_nrows(src0), 1 }; -+ break; -+ case GGML_OP_IM2COL: -+ { -+ const bool is_2D = dst->op_params[6] == 1; -+ -+ const uint32_t IC = src1->ne[is_2D ? 2 : 1]; -+ -+ const uint32_t KH = is_2D ? src0->ne[1] : 1; -+ const uint32_t KW = src0->ne[0]; -+ -+ const uint32_t OH = is_2D ? dst->ne[2] : 1; -+ const uint32_t OW = dst->ne[1]; -+ -+ const uint32_t batch = src1->ne[is_2D ? 3 : 2]; -+ -+ elements = { OW * KW * KH, OH, batch * IC }; -+ } break; -+ case GGML_OP_TIMESTEP_EMBEDDING: -+ { -+ const uint32_t dim = dst->op_params[0]; -+ uint32_t half_ceil = (dim + 1) / 2; -+ elements = { half_ceil, (uint32_t)src0->ne[0], 1 }; -+ } break; -+ case GGML_OP_POOL_2D: -+ { -+ const uint32_t N = dst->ne[3]; -+ const uint32_t OC = dst->ne[2]; -+ const uint32_t OH = dst->ne[1]; -+ const uint32_t OW = dst->ne[0]; -+ elements = { N * OC * OH * OW, 1, 1}; -+ } break; -+ case GGML_OP_ADD: -+ case GGML_OP_DIV: -+ case GGML_OP_MUL: -+ case GGML_OP_SCALE: -+ case GGML_OP_SQR: -+ case GGML_OP_SIN: -+ case GGML_OP_COS: -+ case GGML_OP_CLAMP: -+ case GGML_OP_PAD: -+ case GGML_OP_REPEAT: -+ case GGML_OP_CPY: -+ case GGML_OP_CONCAT: -+ case GGML_OP_UPSCALE: -+ case GGML_OP_UNARY: -+ { -+ const uint32_t ne = ggml_nelements(dst); -+ if (ne > 262144) { -+ elements = { 512, 512, CEIL_DIV(ne, 262144) }; -+ } else if (ne > 512) { -+ elements = { 512, CEIL_DIV(ne, 512), 1 }; -+ } else { -+ elements = { ne, 1, 1 }; -+ } -+ } break; -+ default: -+ elements = { (uint32_t)ggml_nelements(src0), 1, 1 }; -+ break; -+ } -+ -+ if (!op_supports_incontiguous) { -+ if (x_sz != VK_WHOLE_SIZE) { -+ x_sz *= ne02 * ne03; -+ } -+ if (use_src1 && y_sz != VK_WHOLE_SIZE) { -+ y_sz *= ne12 * ne13; -+ } -+ if (use_src2 && z_sz != VK_WHOLE_SIZE) { -+ z_sz *= ne22 * ne23; -+ } -+ if (d_sz != VK_WHOLE_SIZE) { -+ d_sz *= ned2 * ned3; -+ } -+ } -+ -+ if (op == GGML_OP_SOFT_MAX) { -+ // Empty src1 is possible in soft_max, but the shader needs a buffer -+ vk_subbuffer subbuf_y; -+ if (use_src1) { -+ subbuf_y = { d_Y, y_buf_offset, y_sz }; -+ } else { -+ subbuf_y = { d_X, 0, x_sz }; -+ } -+ -+ ggml_vk_sync_buffers(subctx); -+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); -+ } else if (op == GGML_OP_ROPE) { -+ // Empty src2 is possible in rope, but the shader needs a buffer -+ vk_subbuffer subbuf_z; -+ if (use_src2) { -+ subbuf_z = { d_Z, z_buf_offset, z_sz }; -+ } else { -+ subbuf_z = { d_X, 0, x_sz }; -+ } -+ -+ ggml_vk_sync_buffers(subctx); -+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); -+ } else if (op == GGML_OP_IM2COL) { -+ // im2col uses only src1 and dst buffers -+ ggml_vk_sync_buffers(subctx); -+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); -+ } else if (use_src2) { -+ ggml_vk_sync_buffers(subctx); -+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); -+ } else if (use_src1) { -+ ggml_vk_sync_buffers(subctx); -+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); -+ } else { -+ ggml_vk_sync_buffers(subctx); -+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); -+ } -+} -+ -+static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { -+ const uint32_t src0_type_size = ggml_type_size(src0->type); -+ const uint32_t src1_type_size = ggml_type_size(src1->type); -+ const uint32_t dst_type_size = ggml_type_size(dst->type); -+ -+ ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GET_ROWS, { -+ (uint32_t)ggml_nelements(src0), -+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, -+ (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, -+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, -+ 0, -+ 0.0f, 0.0f, 0, -+ }, dryrun); -+} -+ -+static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { -+ const uint32_t src0_type_size = ggml_type_size(src0->type); -+ const uint32_t src1_type_size = ggml_type_size(src1->type); -+ const uint32_t dst_type_size = ggml_type_size(dst->type); -+ -+ int nb1 = dst->op_params[0] / 4; // 4 bytes of float32 -+ int nb2 = dst->op_params[1] / 4; // 4 bytes of float32 -+ // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused -+ int offset = dst->op_params[3] / 4; // offset in bytes -+ -+ ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_ACC, { -+ (uint32_t)ggml_nelements(src0), -+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)src0->nb[3] / src0_type_size, -+ (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, -+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t) dst->nb[3] / dst_type_size, -+ 0, -+ 0.0f, 0.0f, offset, -+ }, dryrun); -+} -+ -+static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { -+ const uint32_t src0_type_size = ggml_type_size(src0->type); -+ const uint32_t src1_type_size = ggml_type_size(src1->type); -+ const uint32_t dst_type_size = ggml_type_size(dst->type); -+ -+ ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_ADD, { -+ (uint32_t)ggml_nelements(src0), -+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, -+ (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, -+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, -+ 0, -+ 0.0f, 0.0f, 0, -+ }, dryrun); -+} -+ -+static void ggml_vk_mul(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { -+ const uint32_t src0_type_size = ggml_type_size(src0->type); -+ const uint32_t src1_type_size = ggml_type_size(src1->type); -+ const uint32_t dst_type_size = ggml_type_size(dst->type); -+ -+ ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_MUL, { -+ (uint32_t)ggml_nelements(src0), -+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, -+ (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, -+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, -+ 0, -+ 0.0f, 0.0f, 0, -+ }, dryrun); -+} -+ -+static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { -+ const uint32_t src0_type_size = ggml_type_size(src0->type); -+ const uint32_t src1_type_size = ggml_type_size(src1->type); -+ const uint32_t dst_type_size = ggml_type_size(dst->type); -+ -+ ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_DIV, { -+ (uint32_t)ggml_nelements(src0), -+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, -+ (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, -+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, -+ 0, -+ 0.0f, 0.0f, 0, -+ }, dryrun); -+} -+ -+static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, bool dryrun = false) { -+ const ggml_tensor * k = dst->src[0]; -+ const ggml_tensor * v = dst->src[1]; -+ const ggml_tensor * r = dst->src[2]; -+ const ggml_tensor * tf = dst->src[3]; -+ const ggml_tensor * td = dst->src[4]; -+ const ggml_tensor * state = dst->src[5]; -+ -+ GGML_ASSERT(!ggml_is_quantized(k->type)); -+ GGML_ASSERT(!ggml_is_quantized(v->type)); -+ GGML_ASSERT(!ggml_is_quantized(r->type)); -+ GGML_ASSERT(!ggml_is_quantized(tf->type)); -+ GGML_ASSERT(!ggml_is_quantized(td->type)); -+ GGML_ASSERT(!ggml_is_quantized(state->type)); -+ GGML_ASSERT(dst->buffer != nullptr); -+ -+ vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, k, v, r, dst, GGML_OP_RWKV_WKV6); -+ GGML_ASSERT(pipeline != nullptr); -+ -+ if (dryrun) { -+ ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); -+ return; -+ } -+ -+ ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; -+ ggml_backend_vk_buffer_context * k_buf_ctx = (ggml_backend_vk_buffer_context *)k->buffer->context; -+ ggml_backend_vk_buffer_context * v_buf_ctx = (ggml_backend_vk_buffer_context *)v->buffer->context; -+ ggml_backend_vk_buffer_context * r_buf_ctx = (ggml_backend_vk_buffer_context *)r->buffer->context; -+ ggml_backend_vk_buffer_context * tf_buf_ctx = (ggml_backend_vk_buffer_context *)tf->buffer->context; -+ ggml_backend_vk_buffer_context * td_buf_ctx = (ggml_backend_vk_buffer_context *)td->buffer->context; -+ ggml_backend_vk_buffer_context * state_buf_ctx = (ggml_backend_vk_buffer_context *)state->buffer->context; -+ -+ ggml_vk_sync_buffers(subctx); -+ -+ vk_buffer d_D = nullptr, d_K = nullptr, d_V = nullptr, d_R = nullptr, d_TF = nullptr, d_TD = nullptr, d_State = nullptr; -+ size_t k_offset = 0, v_offset = 0, r_offset = 0, tf_offset = 0, td_offset = 0, state_offset = 0, dst_offset = 0; -+ bool K_uma = false, V_uma = false, R_uma = false, TF_uma = false, TD_uma = false, STATE_uma = false, DST_uma = false; -+ -+ if (ctx->device->uma) { -+ ggml_vk_host_get(ctx->device, k->data, d_K, k_offset); -+ ggml_vk_host_get(ctx->device, v->data, d_V, v_offset); -+ ggml_vk_host_get(ctx->device, r->data, d_R, r_offset); -+ ggml_vk_host_get(ctx->device, tf->data, d_TF, tf_offset); -+ ggml_vk_host_get(ctx->device, td->data, d_TD, td_offset); -+ ggml_vk_host_get(ctx->device, state->data, d_State, state_offset); -+ ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset); -+ -+ K_uma = d_K != nullptr; -+ V_uma = d_V != nullptr; -+ R_uma = d_R != nullptr; -+ TF_uma = d_TF != nullptr; -+ TD_uma = d_TD != nullptr; -+ STATE_uma = d_State != nullptr; -+ DST_uma = d_D != nullptr; -+ } -+ -+ if (!K_uma) { -+ d_K = k_buf_ctx->dev_buffer; -+ k_offset = vk_tensor_offset(k) + k->view_offs; -+ } -+ if (!V_uma) { -+ d_V = v_buf_ctx->dev_buffer; -+ v_offset = vk_tensor_offset(v) + v->view_offs; -+ } -+ if (!R_uma) { -+ d_R = r_buf_ctx->dev_buffer; -+ r_offset = vk_tensor_offset(r) + r->view_offs; -+ } -+ if (!TF_uma) { -+ d_TF = tf_buf_ctx->dev_buffer; -+ tf_offset = vk_tensor_offset(tf) + tf->view_offs; -+ } -+ if (!TD_uma) { -+ d_TD = td_buf_ctx->dev_buffer; -+ td_offset = vk_tensor_offset(td) + td->view_offs; -+ } -+ if (!STATE_uma) { -+ d_State = state_buf_ctx->dev_buffer; -+ state_offset = vk_tensor_offset(state) + state->view_offs; -+ } -+ if (!DST_uma) { -+ d_D = dst_buf_ctx->dev_buffer; -+ dst_offset = vk_tensor_offset(dst) + dst->view_offs; -+ } -+ -+ const uint64_t k_size = ggml_nbytes(k); -+ const uint64_t v_size = ggml_nbytes(v); -+ const uint64_t r_size = ggml_nbytes(r); -+ const uint64_t tf_size = ggml_nbytes(tf); -+ const uint64_t td_size = ggml_nbytes(td); -+ const uint64_t state_size = ggml_nbytes(state); -+ const uint64_t dst_size = ggml_nbytes(dst); -+ -+ std::array elements = { -+ (uint32_t)(pc.B * pc.H), -+ 1, -+ 1 -+ }; -+ -+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { -+ vk_subbuffer{ d_K, k_offset, k_size }, -+ vk_subbuffer{ d_V, v_offset, v_size }, -+ vk_subbuffer{ d_R, r_offset, r_size }, -+ vk_subbuffer{ d_TF, tf_offset, tf_size }, -+ vk_subbuffer{ d_TD, td_offset, td_size }, -+ vk_subbuffer{ d_State, state_offset, state_size }, -+ vk_subbuffer{ d_D, dst_offset, dst_size } -+ }, sizeof(vk_op_rwkv_wkv6_push_constants), &pc, elements); -+} -+ -+static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) { -+ const size_t seq_length = dst->src[0]->ne[3]; -+ const size_t n_embed = dst->ne[0]; -+ const size_t n_heads = dst->src[0]->ne[2]; -+ const size_t n_seqs = dst->src[5]->ne[1]; -+ -+ ggml_vk_op_f32_rwkv6( -+ ctx, subctx, dst, -+ { -+ (uint32_t)n_seqs, -+ (uint32_t)seq_length, -+ (uint32_t)n_embed, -+ (uint32_t)n_heads, -+ }, -+ dryrun -+ ); -+} -+ -+static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { -+ int * op_params = (int *)dst->op_params; -+ -+ const uint32_t src0_type_size = ggml_type_size(src0->type); -+ const uint32_t src1_type_size = ggml_type_size(src1->type); -+ const uint32_t dst_type_size = ggml_type_size(dst->type); -+ -+ ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONCAT, { -+ (uint32_t)ggml_nelements(dst), -+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, -+ (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, -+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, -+ 0, -+ 0.0f, 0.0f, op_params[0], -+ }, dryrun); -+} -+ -+static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { -+ const uint32_t src0_type_size = ggml_type_size(src0->type); -+ -+ const float sf0 = (float)dst->ne[0] / src0->ne[0]; -+ const float sf1 = (float)dst->ne[1] / src0->ne[1]; -+ const float sf2 = (float)dst->ne[2] / src0->ne[2]; -+ const float sf3 = (float)dst->ne[3] / src0->ne[3]; -+ -+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UPSCALE, { -+ (uint32_t)ggml_nelements(dst), 0, 0, -+ (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, -+ (uint32_t)dst->ne[0], (uint32_t)dst->ne[1], (uint32_t)dst->ne[2],(uint32_t)dst->ne[3], -+ sf0, sf1, sf2, sf3, -+ }, dryrun); -+} -+ -+static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { -+ float * op_params = (float *)dst->op_params; -+ const uint32_t src0_type_size = ggml_type_size(src0->type); -+ const uint32_t dst_type_size = ggml_type_size(dst->type); -+ -+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SCALE, { -+ (uint32_t)ggml_nelements(src0), -+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, -+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, -+ 0, -+ op_params[0], 0.0f, -+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -+ }, dryrun); -+} -+ -+static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { -+ const uint32_t src0_type_size = ggml_type_size(src0->type); -+ const uint32_t dst_type_size = ggml_type_size(dst->type); -+ -+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQR, { -+ (uint32_t)ggml_nelements(src0), -+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, -+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, -+ 0, -+ 0.0f, 0.0f, -+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -+ }, dryrun); -+} -+ -+static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { -+ const uint32_t src0_type_size = ggml_type_size(src0->type); -+ const uint32_t dst_type_size = ggml_type_size(dst->type); -+ -+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SIN, { -+ (uint32_t)ggml_nelements(src0), -+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, -+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, -+ 0, -+ 0.0f, 0.0f, -+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -+ }, dryrun); -+} -+ -+static void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { -+ const uint32_t src0_type_size = ggml_type_size(src0->type); -+ const uint32_t dst_type_size = ggml_type_size(dst->type); -+ -+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_COS, { -+ (uint32_t)ggml_nelements(src0), -+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, -+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, -+ 0, -+ 0.0f, 0.0f, -+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -+ }, dryrun); -+} -+ -+static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { -+ float * op_params = (float *)dst->op_params; -+ const uint32_t src0_type_size = ggml_type_size(src0->type); -+ const uint32_t dst_type_size = ggml_type_size(dst->type); -+ -+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CLAMP, { -+ (uint32_t)ggml_nelements(src0), -+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, -+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, -+ 0, -+ op_params[0], op_params[1], -+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -+ }, dryrun); -+} -+ -+static void ggml_vk_pad(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { -+ const uint32_t src0_type_size = ggml_type_size(src0->type); -+ const uint32_t dst_type_size = ggml_type_size(dst->type); -+ -+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_PAD, { -+ (uint32_t)ggml_nelements(dst), -+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, -+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, -+ 0, -+ 0.0f, 0.0f, -+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -+ }, dryrun); -+} -+ -+static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { -+ const uint32_t src0_type_size = ggml_type_size(src0->type); -+ const uint32_t dst_type_size = ggml_type_size(dst->type); -+ -+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT, { -+ (uint32_t)ggml_nelements(dst), -+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, -+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, -+ 0, -+ 0.0f, 0.0f, -+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -+ }, dryrun); -+} -+ -+static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { -+ const uint32_t src0_type_size = ggml_type_size(src0->type); -+ const uint32_t dst_type_size = ggml_type_size(dst->type); -+ -+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, { -+ (uint32_t)ggml_nelements(src0), -+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, -+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, -+ 0, -+ 0.0f, 0.0f, -+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -+ }, dryrun); -+} -+ -+static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { -+ float * op_params = (float *)dst->op_params; -+ -+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun); -+} -+ -+static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { -+ const int * int_op_params = (const int *)dst->op_params; -+ const float * float_op_params = (const float *)dst->op_params; -+ -+ const uint32_t num_groups = int_op_params[0]; -+ const float eps = float_op_params[1]; -+ const uint32_t group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups); -+ -+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun); -+} -+ -+static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { -+ float * op_params = (float *)dst->op_params; -+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun); -+} -+ -+static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { -+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun); -+} -+ -+static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { -+ int32_t * op_params = (int32_t *)dst->op_params; -+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }, dryrun); -+} -+ -+static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { -+ float * op_params = (float *)dst->op_params; -+ -+ float scale = op_params[0]; -+ float max_bias = op_params[1]; -+ -+ const uint32_t ncols = (uint32_t)src0->ne[0]; -+ const uint32_t nrows_x = (uint32_t)ggml_nrows(src0); -+ const uint32_t nrows_y = (uint32_t)src0->ne[1]; -+ -+ const uint32_t n_head_kv = nrows_x/nrows_y; -+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv)); -+ -+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); -+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); -+ -+ ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, { -+ ncols, -+ src1 != nullptr ? nrows_y : (uint32_t)0, -+ scale, max_bias, -+ m0, m1, -+ n_head_log2, -+ nrows_x, -+ }, dryrun); -+} -+ -+static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) { -+ const int n_dims = ((int32_t *) dst->op_params)[1]; -+ // const int mode = ((int32_t *) dst->op_params)[2]; -+ // const int n_ctx = ((int32_t *) dst->op_params)[3]; -+ const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; -+ const float freq_base = ((float *) dst->op_params)[5]; -+ const float freq_scale = ((float *) dst->op_params)[6]; -+ const float ext_factor = ((float *) dst->op_params)[7]; -+ const float attn_factor = ((float *) dst->op_params)[8]; -+ const float beta_fast = ((float *) dst->op_params)[9]; -+ const float beta_slow = ((float *) dst->op_params)[10]; -+ -+ float corr_dims[2]; -+ ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); -+ -+ const float theta_scale = powf(freq_base, -2.0f/n_dims); -+ -+ ggml_vk_op_f32(ctx, subctx, src0, src1, src2, dst, GGML_OP_ROPE, { -+ (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1], -+ freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale, -+ src2 != nullptr, -+ }, dryrun); -+} -+ -+static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { -+ int32_t * op_params = (int32_t *)dst->op_params; -+ -+ uint32_t ncols = src0->ne[0]; -+ -+ uint32_t ncols_pad = 1; -+ while (ncols_pad < ncols) { -+ ncols_pad *= 2; -+ } -+ -+ GGML_ASSERT(ncols_pad <= 1024); -+ -+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ARGSORT, { -+ ncols, -+ ncols_pad, -+ op_params[0], -+ }, dryrun); -+} -+ -+static void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { -+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, { (uint32_t)src0->ne[0], 0, 0.0f, 0.0f }, dryrun); -+} -+ -+static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { -+ const int32_t s0 = dst->op_params[0]; -+ const int32_t s1 = dst->op_params[1]; -+ const int32_t p0 = dst->op_params[2]; -+ const int32_t p1 = dst->op_params[3]; -+ const int32_t d0 = dst->op_params[4]; -+ const int32_t d1 = dst->op_params[5]; -+ -+ const bool is_2D = dst->op_params[6] == 1; -+ -+ const uint32_t IC = src1->ne[is_2D ? 2 : 1]; -+ const uint32_t IH = is_2D ? src1->ne[1] : 1; -+ const uint32_t IW = src1->ne[0]; -+ -+ const uint32_t KH = is_2D ? src0->ne[1] : 1; -+ const uint32_t KW = src0->ne[0]; -+ -+ const uint32_t OH = is_2D ? dst->ne[2] : 1; -+ const uint32_t OW = dst->ne[1]; -+ -+ const uint32_t offset_delta = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32 -+ const uint32_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32 -+ -+ const uint32_t pelements = OW * KW * KH; -+ -+ ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_IM2COL, { -+ batch_offset, offset_delta, -+ IC, IW, IH, OW, OH, KW, KH, -+ pelements, -+ IC * KH * KW, -+ s0, s1, p0, p1, d0, d1, -+ }, dryrun); -+} -+ -+static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { -+ const uint32_t dim = dst->op_params[0]; -+ const uint32_t max_period = dst->op_params[1]; -+ const uint32_t nb1 = dst->nb[1] / ggml_type_size(dst->type); -+ -+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_TIMESTEP_EMBEDDING, { -+ nb1, dim, max_period, -+ }, dryrun); -+} -+ -+static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { -+ uint32_t op = static_cast(dst->op_params[0]); -+ const int32_t k1 = dst->op_params[1]; -+ const int32_t k0 = dst->op_params[2]; -+ const int32_t s1 = dst->op_params[3]; -+ const int32_t s0 = dst->op_params[4]; -+ const int32_t p1 = dst->op_params[5]; -+ const int32_t p0 = dst->op_params[6]; -+ -+ const uint32_t IH = src0->ne[1]; -+ const uint32_t IW = src0->ne[0]; -+ -+ const uint32_t N = dst->ne[3]; -+ -+ const uint32_t OC = dst->ne[2]; -+ const uint32_t OH = dst->ne[1]; -+ const uint32_t OW = dst->ne[0]; -+ -+ const uint32_t parallel_elements = N * OC * OH * OW; -+ -+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_POOL_2D, { -+ IW, IH, OW, OH, OC, -+ parallel_elements, -+ op, -+ k0, k1, s0, s1, p0, p1, -+ }, dryrun); -+} -+ -+static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { -+ const float * op_params = (const float *)dst->op_params; -+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f }, dryrun); -+} -+ -+#ifdef GGML_VULKAN_RUN_TESTS -+static void ggml_vk_print_matrix_area(const void * data, ggml_type type, int ne0, int ne1, int i0, int i1, int i2) { -+ if (type != GGML_TYPE_F32 && type != GGML_TYPE_F16) { -+ return; -+ } -+ i0 = std::max(i0, 5); -+ i1 = std::max(i1, 5); -+ i2 = std::max(i2, 0); -+ fprintf(stderr, " "); -+ for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) { -+ fprintf(stderr, "%7d ", idx1); -+ } -+ fprintf(stderr, "\n"); -+ for (int idx0 = i0 - 5; idx0 < i0 + 5; idx0++) { -+ fprintf(stderr, "%7d: ", idx0); -+ for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) { -+ if (idx0 >= 0 && idx0 < ne0 && idx1 >= 0 && idx1 < ne1) { -+ float val; -+ if (type == GGML_TYPE_F32) { -+ val = *((const float *) data + i2*ne1*ne0 + idx1*ne0 + idx0); -+ } else if (type == GGML_TYPE_F16) { -+ val = ggml_fp16_to_fp32(*((const ggml_fp16_t *) data + i2*ne1*ne0 + idx1*ne0 + idx0)); -+ } else { -+ GGML_ABORT("fatal error"); -+ } -+ fprintf(stderr, "% 7.2f ", val); -+ } else { -+ fprintf(stderr, " "); -+ } -+ } -+ fprintf(stderr, "\n"); -+ } -+} -+ -+template -+static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, int split_k, int shader_size) { -+ VK_LOG_DEBUG("ggml_vk_test_matmul(" << m << ", " << n << ", " << k << ", " << batch << ", " << num_it << ", " << split_k << ", " << shader_size << ")"); -+ const size_t x_ne = m * k * batch; -+ const size_t y_ne = k * n * batch; -+ const size_t d_ne = m * n * batch; -+ -+ vk_pipeline p; -+ std::string shname; -+ if (shader_size == 0) { -+ if (std::is_same() && std::is_same()) { -+ p = ctx->device->pipeline_matmul_f32->a_s; -+ shname = "F32_ALIGNED_S"; -+ } else if (std::is_same() && std::is_same()) { -+ p = ctx->device->pipeline_matmul_f32_f16->a_s; -+ shname = "F32_F16_ALIGNED_S"; -+ } else if (std::is_same() && std::is_same()) { -+ p = ctx->device->pipeline_matmul_f16_f32.f32acc->a_s; -+ shname = "F16_F32_ALIGNED_S"; -+ } else if (std::is_same() && std::is_same()) { -+ p = ctx->device->pipeline_matmul_f16.f32acc->a_s; -+ shname = "F16_ALIGNED_S"; -+ } else { -+ GGML_ABORT("fatal error"); -+ } -+ } else if (shader_size == 1) { -+ if (std::is_same() && std::is_same()) { -+ p = ctx->device->pipeline_matmul_f32->a_m; -+ shname = "F32_ALIGNED_M"; -+ } else if (std::is_same() && std::is_same()) { -+ p = ctx->device->pipeline_matmul_f32_f16->a_m; -+ shname = "F32_F16_ALIGNED_M"; -+ } else if (std::is_same() && std::is_same()) { -+ p = ctx->device->pipeline_matmul_f16_f32.f32acc->a_m; -+ shname = "F16_F32_ALIGNED_M"; -+ } else if (std::is_same() && std::is_same()) { -+ p = ctx->device->pipeline_matmul_f16.f32acc->a_m; -+ shname = "F16_ALIGNED_M"; -+ } else { -+ GGML_ABORT("fatal error"); -+ } -+ } else if (shader_size == 2) { -+ if (std::is_same() && std::is_same()) { -+ p = ctx->device->pipeline_matmul_f32->a_l; -+ shname = "F32_ALIGNED_L"; -+ } else if (std::is_same() && std::is_same()) { -+ p = ctx->device->pipeline_matmul_f32_f16->a_l; -+ shname = "F32_F16_ALIGNED_L"; -+ } else if (std::is_same() && std::is_same()) { -+ p = ctx->device->pipeline_matmul_f16_f32.f32acc->a_l; -+ shname = "F16_F32_ALIGNED_L"; -+ } else if (std::is_same() && std::is_same()) { -+ p = ctx->device->pipeline_matmul_f16.f32acc->a_l; -+ shname = "F16_ALIGNED_L"; -+ } else { -+ GGML_ABORT("fatal error"); -+ } -+ } else { -+ GGML_ASSERT(0); -+ } -+ -+ const size_t kpad = ggml_vk_align_size(k, p->align); -+ -+ if (k != kpad) { -+ if (shader_size == 0) { -+ if (std::is_same() && std::is_same()) { -+ p = ctx->device->pipeline_matmul_f32->s; -+ shname = "F32_S"; -+ } else if (std::is_same() && std::is_same()) { -+ p = ctx->device->pipeline_matmul_f32_f16->s; -+ shname = "F32_F16_S"; -+ } else if (std::is_same() && std::is_same()) { -+ p = ctx->device->pipeline_matmul_f16_f32.f32acc->s; -+ shname = "F16_F32_S"; -+ } else if (std::is_same() && std::is_same()) { -+ p = ctx->device->pipeline_matmul_f16.f32acc->s; -+ shname = "F16_S"; -+ } -+ } else if (shader_size == 1) { -+ if (std::is_same() && std::is_same()) { -+ p = ctx->device->pipeline_matmul_f32->m; -+ shname = "F32_M"; -+ } else if (std::is_same() && std::is_same()) { -+ p = ctx->device->pipeline_matmul_f32_f16->m; -+ shname = "F32_F16_M"; -+ } else if (std::is_same() && std::is_same()) { -+ p = ctx->device->pipeline_matmul_f16_f32.f32acc->m; -+ shname = "F16_F32_M"; -+ } else if (std::is_same() && std::is_same()) { -+ p = ctx->device->pipeline_matmul_f16.f32acc->m; -+ shname = "F16_M"; -+ } -+ } else if (shader_size == 2) { -+ if (std::is_same() && std::is_same()) { -+ p = ctx->device->pipeline_matmul_f32->l; -+ shname = "F32_L"; -+ } else if (std::is_same() && std::is_same()) { -+ p = ctx->device->pipeline_matmul_f32_f16->l; -+ shname = "F32_F16_L"; -+ } else if (std::is_same() && std::is_same()) { -+ p = ctx->device->pipeline_matmul_f16_f32.f32acc->l; -+ shname = "F16_F32_L"; -+ } else if (std::is_same() && std::is_same()) { -+ p = ctx->device->pipeline_matmul_f16.f32acc->l; -+ shname = "F16_L"; -+ } -+ } -+ } -+ -+ ggml_pipeline_request_descriptor_sets(ctx->device, p, num_it); -+ if (split_k > 1) { -+ ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, num_it); -+ -+ if (ctx->prealloc_split_k == nullptr || ctx->prealloc_split_k->size < sizeof(float) * d_ne * split_k) { -+ // Resize buffer -+ if (ctx->prealloc_split_k != nullptr) { -+ ggml_vk_destroy_buffer(ctx->prealloc_split_k); -+ } -+ ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, vk::MemoryPropertyFlagBits::eDeviceLocal); -+ } -+ } -+ -+ ggml_pipeline_allocate_descriptor_sets(ctx->device); -+ -+ vk_buffer d_X = ggml_vk_create_buffer_check(ctx->device, sizeof(X_TYPE) * x_ne, vk::MemoryPropertyFlagBits::eDeviceLocal); -+ vk_buffer d_Y = ggml_vk_create_buffer_check(ctx->device, sizeof(Y_TYPE) * y_ne, vk::MemoryPropertyFlagBits::eDeviceLocal); -+ vk_buffer d_D = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne, vk::MemoryPropertyFlagBits::eDeviceLocal); -+ -+ X_TYPE* x = (X_TYPE *) malloc(sizeof(X_TYPE) * x_ne); -+ Y_TYPE* y = (Y_TYPE *) malloc(sizeof(Y_TYPE) * y_ne); -+ float* d = (float *) malloc(sizeof(float) * d_ne); -+ -+ for (size_t i = 0; i < x_ne; i++) { -+ if (std::is_same()) { -+ x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f; -+ // x[i] = 1.0f; -+ // x[i] = i + 1; -+ // x[i] = (i % k == i / k) ? 1.0f : 0.0f; -+ } else if (std::is_same()) { -+ x[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f); -+ // x[i] = ggml_fp32_to_fp16(1.0f); -+ // x[i] = ggml_fp32_to_fp16(i + 1); -+ // x[i] = ggml_fp32_to_fp16((i % k == i / k) ? 1.0f : 0.0f); -+ } else { -+ GGML_ABORT("fatal error"); -+ } -+ } -+ for (size_t i = 0; i < y_ne; i++) { -+ if (std::is_same()) { -+ y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f; -+ // y[i] = (i % k == i / k) ? 1.0f : 0.0f; -+ // y[i] = i + 1; -+ } else if (std::is_same()) { -+ y[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f); -+ // y[i] = ggml_fp32_to_fp16((i % k == i / k) ? 1.0f : 0.0f); -+ // y[i] = ggml_fp32_to_fp16(i + 1); -+ } else { -+ GGML_ABORT("fatal error"); -+ } -+ } -+ -+ ggml_vk_buffer_write(d_X, 0, x, sizeof(X_TYPE) * k * m * batch); -+ ggml_vk_buffer_write(d_Y, 0, y, sizeof(Y_TYPE) * k * n * batch); -+ -+ vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue); -+ ggml_vk_ctx_begin(ctx->device, subctx); -+ for (size_t i = 0; i < num_it; i++) { -+ ggml_vk_matmul( -+ ctx, subctx, p, ggml_vk_subbuffer(d_X), ggml_vk_subbuffer(d_Y), ggml_vk_subbuffer(d_D), ggml_vk_subbuffer(ctx->prealloc_split_k), -+ m, n, k, -+ k, k, m, k*m, k*n, m*n, -+ split_k, batch, batch, batch, 1, 1 -+ ); -+ } -+ ggml_vk_ctx_end(subctx); -+ -+ auto begin = std::chrono::high_resolution_clock::now(); -+ ggml_vk_submit(subctx, ctx->fence); -+ VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_matmul waitForFences"); -+ ctx->device->device.resetFences({ ctx->fence }); -+ -+ auto end = std::chrono::high_resolution_clock::now(); -+ double time = std::chrono::duration_cast(end-begin).count() / 1000.0; -+ -+ // copy dst to host -+ ggml_vk_buffer_read(d_D, 0, d, sizeof(float) * d_ne); -+ -+ float * d_chk = (float *) malloc(sizeof(float) * d_ne); -+ -+ ggml_init_params iparams = { -+ /*.mem_size =*/ 1024*1024*1024, -+ /*.mem_buffer =*/ NULL, -+ /*.no_alloc =*/ true, -+ }; -+ -+ ggml_context * ggml_ctx = ggml_init(iparams); -+ -+ ggml_type src0_type; -+ ggml_type src1_type; -+ -+ if (std::is_same()) { -+ src0_type = GGML_TYPE_F32; -+ } else if (std::is_same()) { -+ src0_type = GGML_TYPE_F16; -+ } else { -+ GGML_ABORT("fatal error"); -+ } -+ if (std::is_same()) { -+ src1_type = GGML_TYPE_F32; -+ } else if (std::is_same()) { -+ src1_type = GGML_TYPE_F16; -+ } else { -+ GGML_ABORT("fatal error"); -+ } -+ -+ ggml_tensor * src0_ggml = ggml_new_tensor_3d(ggml_ctx, src0_type, k, m, batch); -+ ggml_tensor * src1_ggml = ggml_new_tensor_3d(ggml_ctx, src1_type, k, n, batch); -+ ggml_tensor * tensor_ggml = ggml_mul_mat(ggml_ctx, src0_ggml, src1_ggml); -+ -+ src0_ggml->data = x; -+ src1_ggml->data = y; -+ tensor_ggml->data = d_chk; -+ -+ ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx); -+ ggml_build_forward_expand(cgraph, tensor_ggml); -+ -+ ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 1); -+ -+ ggml_free(ggml_ctx); -+ -+ double avg_err = 0.0; -+ int first_err_n = -1; -+ int first_err_m = -1; -+ int first_err_b = -1; -+ -+ for (size_t i = 0; i < m*n*batch; i++) { -+ double err = std::fabs(d[i] - d_chk[i]); -+ avg_err += err; -+ -+ if ((err > 0.05f || std::isnan(err)) && first_err_n == -1) { -+ first_err_b = i / (m * n); -+ first_err_n = (i % (m * n)) / m; -+ first_err_m = (i % (m * n)) % m; -+ } -+ } -+ -+ avg_err /= m * n; -+ -+ double tflops = 2.0*m*n*k*batch*num_it / (time / 1000.0) / (1000.0*1000.0*1000.0*1000.0); -+ -+ std::cerr << "TEST " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl; -+ -+ if (avg_err > 0.1 || std::isnan(avg_err)) { -+ std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl; -+ std::cerr << "Actual result: " << std::endl << std::endl; -+ ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); -+ std::cerr << "Expected result: " << std::endl << std::endl; -+ ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); -+ -+ if (split_k > 1) { -+ float * split_k_buf = (float *) malloc(sizeof(float) * d_ne * split_k); -+ ggml_vk_buffer_read(ctx->prealloc_split_k, 0, split_k_buf, sizeof(float) * d_ne * split_k); -+ -+ std::cerr << "d_buf0: " << std::endl << std::endl; -+ ggml_vk_print_matrix_area(split_k_buf, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); -+ -+ std::cerr << "d_buf1: " << std::endl << std::endl; -+ ggml_vk_print_matrix_area(split_k_buf + d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); -+ -+ std::cerr << "d_buf2: " << std::endl << std::endl; -+ ggml_vk_print_matrix_area(split_k_buf + 2 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); -+ -+ std::cerr << "d_buf3: " << std::endl << std::endl; -+ ggml_vk_print_matrix_area(split_k_buf + 3 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); -+ -+ free(split_k_buf); -+ } -+ } -+ -+ free(d_chk); -+ -+ ggml_vk_queue_cleanup(ctx->device, ctx->device->transfer_queue); -+ ggml_vk_queue_cleanup(ctx->device, ctx->device->compute_queue); -+ -+ ggml_vk_destroy_buffer(d_X); -+ ggml_vk_destroy_buffer(d_Y); -+ ggml_vk_destroy_buffer(d_D); -+ -+ ggml_pipeline_cleanup(p); -+ ggml_pipeline_cleanup(ctx->device->pipeline_matmul_split_k_reduce); -+ -+ free(x); -+ free(y); -+ free(d); -+} -+ -+static void ggml_vk_print_tensor_area(const ggml_tensor * tensor, int i0, int i1, int i2, int i3) { -+ if (tensor->type != GGML_TYPE_F32 && tensor->type != GGML_TYPE_F16) { -+ return; -+ } -+ i0 = std::max(i0, 5); -+ i1 = std::max(i1, 5); -+ i2 = std::max(i2, 0); -+ i3 = std::max(i3, 0); -+ fprintf(stderr, " "); -+ for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) { -+ fprintf(stderr, "%7d ", idx1); -+ } -+ fprintf(stderr, "\n"); -+ for (int idx0 = i0 - 5; idx0 < i0 + 5; idx0++) { -+ fprintf(stderr, "%7d: ", idx0); -+ for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) { -+ if (idx0 >= 0 && idx0 < tensor->ne[0] && idx1 >= 0 && idx1 < tensor->ne[1] && i2 >= 0 && i2 < tensor->ne[2] && i3 >= 0 && i3 < tensor->ne[3]) { -+ float val; -+ if (tensor->type == GGML_TYPE_F32) { -+ val = *(float *) ((char *) tensor->data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]); -+ } else if (tensor->type == GGML_TYPE_F16) { -+ val = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) tensor->data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0])); -+ } else { -+ GGML_ABORT("fatal error"); -+ } -+ fprintf(stderr, "% 7.2f ", val); -+ } else { -+ fprintf(stderr, " "); -+ } -+ } -+ fprintf(stderr, "\n"); -+ } -+} -+ -+static void ggml_vk_quantize_data(const float * from, void * to, size_t ne, ggml_type quant) { -+ ggml_quantize_chunk(quant, from, to, 0, 1, ne, nullptr); -+} -+ -+static void ggml_vk_dequantize_data(const void * from, float * to, size_t ne, ggml_type quant) { -+ if (quant == GGML_TYPE_F32) { -+ memcpy(to, from, sizeof(float) * ne); -+ return; -+ } -+ -+ const auto * tt = ggml_get_type_traits(quant); -+ -+ ggml_to_float_t dequant_fn = tt->to_float; -+ -+ dequant_fn(from, to, ne); -+} -+ -+static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_type quant) { -+ VK_LOG_DEBUG("ggml_vk_test_dequant(" << ne << ")"); -+ const size_t x_sz = sizeof(float) * ne; -+ const size_t x_sz_f16 = sizeof(ggml_fp16_t) * ne; -+ const size_t qx_sz = ne * ggml_type_size(quant)/ggml_blck_size(quant); -+ float * x = (float *) malloc(x_sz); -+ void * qx = malloc(qx_sz); -+ vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal); -+ vk_buffer x_buf = ggml_vk_create_buffer_check(ctx->device, x_sz_f16, vk::MemoryPropertyFlagBits::eDeviceLocal); -+ float * x_ref = (float *) malloc(x_sz); -+ ggml_fp16_t * x_chk = (ggml_fp16_t *) malloc(x_sz_f16); -+ -+ for (size_t i = 0; i < ne; i++) { -+ x[i] = rand() / (float)RAND_MAX; -+ } -+ -+ vk_pipeline p = ggml_vk_get_to_fp16(ctx, quant); -+ -+ ggml_vk_quantize_data(x, qx, ne, quant); -+ ggml_vk_dequantize_data(qx, x_ref, ne, quant); -+ -+ ggml_pipeline_request_descriptor_sets(ctx->device, p, 1); -+ -+ ggml_pipeline_allocate_descriptor_sets(ctx->device); -+ -+ ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz); -+ -+ vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue); -+ ggml_vk_ctx_begin(ctx->device, subctx); -+ const std::vector pc = { 1, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne }; -+ ggml_vk_dispatch_pipeline(ctx, subctx, p, { vk_subbuffer{ qx_buf, 0, qx_sz }, vk_subbuffer{ x_buf, 0, x_sz_f16 } }, pc.size() * sizeof(int), pc.data(), { (uint32_t)ne, 1, 1}); -+ ggml_vk_ctx_end(subctx); -+ -+ auto begin = std::chrono::high_resolution_clock::now(); -+ -+ ggml_vk_submit(subctx, ctx->fence); -+ VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_dequant waitForFences"); -+ ctx->device->device.resetFences({ ctx->fence }); -+ -+ auto end = std::chrono::high_resolution_clock::now(); -+ -+ double ms_dequant = std::chrono::duration_cast(end-begin).count() / 1000.0; -+ ggml_vk_buffer_read(x_buf, 0, x_chk, x_sz_f16); -+ -+ int first_err = -1; -+ -+ double avg_err = 0.0; -+ for (size_t i = 0; i < ne; i++) { -+ double error = std::fabs(x_ref[i] - ggml_fp16_to_fp32(x_chk[i])); -+ avg_err += error; -+ -+ if (first_err < 0 && error > 0.05) { -+ first_err = i; -+ } -+ } -+ -+ avg_err /= ne; -+ -+ std::cerr << "TEST DEQUANT " << ggml_type_name(quant) << " time=" << ms_dequant << "ms avg_err=" << avg_err << std::endl; -+ -+ if (avg_err > 0.1) { -+ std::cerr << "first_error = " << first_err << std::endl; -+ std::cerr << "Actual result: " << std::endl << std::endl; -+ for (int i = std::max(0, first_err - 5); i < std::min((int)ne, first_err + 5); i++) { -+ std::cerr << ggml_fp16_to_fp32(x_chk[i]) << ", "; -+ } -+ std::cerr << std::endl << "Expected result: " << std::endl << std::endl; -+ for (int i = std::max(0, first_err - 5); i < std::min((int)ne, first_err + 5); i++) { -+ std::cerr << x_ref[i] << ", "; -+ } -+ std::cerr << std::endl; -+ } -+ -+ ggml_vk_destroy_buffer(x_buf); -+ ggml_vk_destroy_buffer(qx_buf); -+ -+ free(x); -+ free(qx); -+ free(x_ref); -+ free(x_chk); -+} -+ -+static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, size_t split_k, size_t shader_size, ggml_type quant) { -+ VK_LOG_DEBUG("ggml_vk_test_dequant_matmul(" << m << ", " << n << ", " << k << ", " << batch << ", " << num_it << ", " << split_k << ", " << ggml_type_name(quant) << ")"); -+ const size_t x_ne = m * k * batch; -+ const size_t y_ne = k * n * batch; -+ const size_t d_ne = m * n * batch; -+ -+ vk_pipeline p; -+ std::string shname; -+ if (shader_size == 0) { -+ p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_s : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_s; -+ shname = std::string(ggml_type_name(quant)) + "_ALIGNED_S"; -+ } else if (shader_size == 1) { -+ p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_m : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_m; -+ shname = std::string(ggml_type_name(quant)) + "_ALIGNED_M"; -+ } else if (shader_size == 2) { -+ p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_l : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_l; -+ shname = std::string(ggml_type_name(quant)) + "_ALIGNED_L"; -+ } else { -+ GGML_ASSERT(0); -+ } -+ -+ const size_t kpad = ggml_vk_align_size(k, p->align); -+ -+ if (k != kpad) { -+ if (shader_size == 0) { -+ p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->s : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->s; -+ shname = std::string(ggml_type_name(quant)) + "_S"; -+ } else if (shader_size == 1) { -+ p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->m : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->m; -+ shname = std::string(ggml_type_name(quant)) + "_M"; -+ } else if (shader_size == 2) { -+ p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->l : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->l; -+ shname = std::string(ggml_type_name(quant)) + "_L"; -+ } else { -+ GGML_ASSERT(0); -+ } -+ } -+ -+ const size_t x_sz = sizeof(float) * x_ne; -+ const size_t y_sz = sizeof(float) * y_ne; -+ const size_t qx_sz = x_ne * ggml_type_size(quant)/ggml_blck_size(quant); -+ const size_t d_sz = sizeof(float) * d_ne; -+ float * x = (float *) malloc(x_sz); -+ float * y = (float *) malloc(y_sz); -+ void * qx = malloc(qx_sz); -+ vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal); -+ vk_buffer y_buf = ggml_vk_create_buffer_check(ctx->device, y_sz, vk::MemoryPropertyFlagBits::eDeviceLocal); -+ vk_buffer d_buf = ggml_vk_create_buffer_check(ctx->device, d_sz, vk::MemoryPropertyFlagBits::eDeviceLocal); -+ float * d = (float *) malloc(d_sz); -+ float * d_chk = (float *) malloc(d_sz); -+ -+ for (size_t i = 0; i < x_ne; i++) { -+ x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f; -+ } -+ -+ ggml_vk_quantize_data(x, qx, x_ne, quant); -+ -+ for (size_t i = 0; i < y_ne; i++) { -+ // y[i] = rand() / (float)RAND_MAX; -+ y[i] = (i % k == i / k) ? 1.0f : 0.0f; -+ } -+ -+ ggml_pipeline_request_descriptor_sets(ctx->device, p, num_it); -+ if (split_k > 1) { -+ ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, num_it); -+ -+ if (ctx->prealloc_split_k == nullptr || ctx->prealloc_split_k->size < sizeof(float) * d_ne * split_k) { -+ // Resize buffer -+ if (ctx->prealloc_split_k != nullptr) { -+ ggml_vk_destroy_buffer(ctx->prealloc_split_k); -+ } -+ ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, vk::MemoryPropertyFlagBits::eDeviceLocal); -+ } -+ } -+ -+ ggml_pipeline_allocate_descriptor_sets(ctx->device); -+ -+ ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz); -+ ggml_vk_buffer_write(y_buf, 0, y, y_sz); -+ -+ vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue); -+ ggml_vk_ctx_begin(ctx->device, subctx); -+ for (size_t i = 0; i < num_it; i++) { -+ ggml_vk_matmul( -+ ctx, subctx, p, ggml_vk_subbuffer(qx_buf), ggml_vk_subbuffer(y_buf), ggml_vk_subbuffer(d_buf), ggml_vk_subbuffer(ctx->prealloc_split_k), -+ m, n, k, -+ k, k, m, k*m, k*n, m*n, -+ split_k, batch, batch, batch, 1, 1 -+ ); -+ } -+ ggml_vk_ctx_end(subctx); -+ -+ auto begin = std::chrono::high_resolution_clock::now(); -+ -+ ggml_vk_submit(subctx, ctx->fence); -+ VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_dequant waitForFences"); -+ ctx->device->device.resetFences({ ctx->fence }); -+ -+ auto end = std::chrono::high_resolution_clock::now(); -+ -+ double time_ms = std::chrono::duration_cast(end-begin).count() / 1000.0; -+ ggml_vk_buffer_read(d_buf, 0, d, d_sz); -+ -+ ggml_init_params iparams = { -+ /*.mem_size =*/ 1024*1024*1024, -+ /*.mem_buffer =*/ NULL, -+ /*.no_alloc =*/ true, -+ }; -+ -+ ggml_context * ggml_ctx = ggml_init(iparams); -+ -+ ggml_tensor * src0_ggml = ggml_new_tensor_3d(ggml_ctx, quant, k, m, batch); -+ ggml_tensor * src1_ggml = ggml_new_tensor_3d(ggml_ctx, GGML_TYPE_F32, k, n, batch); -+ ggml_tensor * tensor_ggml = ggml_mul_mat(ggml_ctx, src0_ggml, src1_ggml); -+ -+ src0_ggml->data = qx; -+ src1_ggml->data = y; -+ tensor_ggml->data = d_chk; -+ -+ ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx); -+ ggml_build_forward_expand(cgraph, tensor_ggml); -+ -+ ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 1); -+ -+ ggml_free(ggml_ctx); -+ -+ double avg_err = 0.0; -+ int first_err_n = -1; -+ int first_err_m = -1; -+ int first_err_b = -1; -+ -+ for (size_t i = 0; i < m*n*batch; i++) { -+ double err = std::fabs(d[i] - d_chk[i]); -+ avg_err += err; -+ -+ if ((err > 0.05f || std::isnan(err)) && first_err_n == -1) { -+ first_err_b = i / (m * n); -+ first_err_n = (i % (m * n)) / m; -+ first_err_m = (i % (m * n)) % m; -+ } -+ } -+ -+ avg_err /= m * n; -+ -+ double tflops = 2.0*m*n*k*batch*num_it / (time_ms / 1000.0) / (1000.0*1000.0*1000.0*1000.0); -+ -+ std::cerr << "TEST MMQ " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time_ms / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl; -+ -+ if (avg_err > 0.01 || std::isnan(avg_err)) { -+ std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl; -+ std::cerr << "Actual result: " << std::endl << std::endl; -+ ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); -+ std::cerr << std::endl; -+ std::cerr << "Expected result: " << std::endl << std::endl; -+ ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); -+ -+ if (split_k > 1) { -+ float * split_k_buf = (float *) malloc(sizeof(float) * d_ne * split_k); -+ ggml_vk_buffer_read(ctx->prealloc_split_k, 0, split_k_buf, sizeof(float) * d_ne * split_k); -+ -+ std::cerr << "d_buf0: " << std::endl << std::endl; -+ ggml_vk_print_matrix_area(split_k_buf, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); -+ -+ std::cerr << "d_buf1: " << std::endl << std::endl; -+ ggml_vk_print_matrix_area(split_k_buf + d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); -+ -+ std::cerr << "d_buf2: " << std::endl << std::endl; -+ ggml_vk_print_matrix_area(split_k_buf + 2 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); -+ -+ std::cerr << "d_buf3: " << std::endl << std::endl; -+ ggml_vk_print_matrix_area(split_k_buf + 3 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); -+ -+ free(split_k_buf); -+ } -+ } -+ -+ ggml_vk_destroy_buffer(qx_buf); -+ ggml_vk_destroy_buffer(y_buf); -+ ggml_vk_destroy_buffer(d_buf); -+ -+ free(x); -+ free(qx); -+ free(y); -+ free(d); -+ free(d_chk); -+} -+#endif -+ -+static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) { -+#if defined(GGML_VULKAN_RUN_TESTS) -+ const std::vector vals { -+ 512, 512, 128, -+ 128, 512, 512, -+ 4096, 512, 4096, -+ 11008, 512, 4096, -+ 4096, 512, 11008, -+ 32000, 512, 4096, -+ 8, 8, 8, -+ 100, 46, 576, -+ 623, 111, 128, -+ 100, 46, 558, -+ 512, 1, 256, -+ 128, 110, 622, -+ 511, 511, 127, -+ 511, 511, 7, -+ 511, 511, 17, -+ 49, 49, 128, -+ 128, 49, 49, -+ 4096, 49, 4096, -+ }; -+ const size_t num_it = 100; -+ -+ for (size_t i = 0; i < vals.size(); i += 3) { -+ ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0); -+ ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1); -+ ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2); -+ std::cerr << '\n'; -+ ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 0); -+ ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 1); -+ ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 2); -+ std::cerr << '\n'; -+ ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0); -+ ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1); -+ ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2); -+ std::cerr << '\n' << std::endl; -+ -+ if (vals[i + 2] % 32 == 0) { -+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0, GGML_TYPE_Q4_0); -+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1, GGML_TYPE_Q4_0); -+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2, GGML_TYPE_Q4_0); -+ std::cerr << '\n'; -+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 0, GGML_TYPE_Q4_0); -+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 1, GGML_TYPE_Q4_0); -+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 2, GGML_TYPE_Q4_0); -+ std::cerr << '\n'; -+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0, GGML_TYPE_Q4_0); -+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1, GGML_TYPE_Q4_0); -+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2, GGML_TYPE_Q4_0); -+ std::cerr << '\n' << std::endl; -+ } -+ -+ if (vals[i + 2] % 256 == 0) { -+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0, GGML_TYPE_Q4_K); -+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1, GGML_TYPE_Q4_K); -+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2, GGML_TYPE_Q4_K); -+ std::cerr << '\n'; -+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 0, GGML_TYPE_Q4_K); -+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 1, GGML_TYPE_Q4_K); -+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 2, GGML_TYPE_Q4_K); -+ std::cerr << '\n'; -+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0, GGML_TYPE_Q4_K); -+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1, GGML_TYPE_Q4_K); -+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2, GGML_TYPE_Q4_K); -+ std::cerr << '\n' << std::endl; -+ } -+ } -+ -+ GGML_ABORT("fatal error"); -+#endif -+ -+ if (ctx->prealloc_x == nullptr || (ctx->prealloc_size_x > 0 && ctx->prealloc_x->size < ctx->prealloc_size_x)) { -+ VK_LOG_MEMORY("ggml_vk_preallocate_buffers(x_size: " << ctx->prealloc_size_x << ")"); -+ // Resize buffer -+ if (ctx->prealloc_x != nullptr) { -+ ggml_vk_destroy_buffer(ctx->prealloc_x); -+ } -+ ctx->prealloc_x = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_x); -+ } -+ if (ctx->prealloc_y == nullptr || (ctx->prealloc_size_y > 0 && ctx->prealloc_y->size < ctx->prealloc_size_y)) { -+ VK_LOG_MEMORY("ggml_vk_preallocate_buffers(y_size: " << ctx->prealloc_size_y << ")"); -+ // Resize buffer -+ if (ctx->prealloc_y != nullptr) { -+ ggml_vk_destroy_buffer(ctx->prealloc_y); -+ } -+ ctx->prealloc_y = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_y); -+ } -+ if (ctx->prealloc_split_k == nullptr || (ctx->prealloc_size_split_k > 0 && ctx->prealloc_split_k->size < ctx->prealloc_size_split_k)) { -+ VK_LOG_MEMORY("ggml_vk_preallocate_buffers(split_k_size: " << ctx->prealloc_size_split_k << ")"); -+ // Resize buffer -+ if (ctx->prealloc_split_k != nullptr) { -+ ggml_vk_destroy_buffer(ctx->prealloc_split_k); -+ } -+ ctx->prealloc_split_k = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_split_k); -+ } -+} -+ -+static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_tensor* tensor, int tensor_idx, bool use_fence); -+ -+// Returns true if node has enqueued work into the queue, false otherwise -+// If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution. -+static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * node, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool submit){ -+ if (ggml_is_empty(node) || !node->buffer) { -+ return false; -+ } -+ -+ VK_LOG_DEBUG("ggml_vk_build_graph(" << node << ", " << ggml_op_name(node->op) << ")"); -+ ctx->semaphore_idx = 0; -+ -+ const ggml_tensor * src0 = node->src[0]; -+ const ggml_tensor * src1 = node->src[1]; -+ const ggml_tensor * src2 = node->src[2]; -+ const ggml_tensor * src3 = node->src[3]; -+ -+ switch (node->op) { -+ // Return on empty ops to avoid generating a compute_ctx and setting exit_tensor -+ case GGML_OP_RESHAPE: -+ case GGML_OP_VIEW: -+ case GGML_OP_PERMUTE: -+ case GGML_OP_TRANSPOSE: -+ case GGML_OP_NONE: -+ return false; -+ case GGML_OP_UNARY: -+ switch (ggml_get_unary_op(node)) { -+ case GGML_UNARY_OP_SILU: -+ case GGML_UNARY_OP_GELU: -+ case GGML_UNARY_OP_GELU_QUICK: -+ case GGML_UNARY_OP_RELU: -+ case GGML_UNARY_OP_TANH: -+ break; -+ default: -+ return false; -+ } -+ break; -+ case GGML_OP_REPEAT: -+ case GGML_OP_GET_ROWS: -+ case GGML_OP_ADD: -+ case GGML_OP_ACC: -+ case GGML_OP_MUL: -+ case GGML_OP_DIV: -+ case GGML_OP_CONCAT: -+ case GGML_OP_UPSCALE: -+ case GGML_OP_SCALE: -+ case GGML_OP_SQR: -+ case GGML_OP_SIN: -+ case GGML_OP_COS: -+ case GGML_OP_CLAMP: -+ case GGML_OP_PAD: -+ case GGML_OP_CPY: -+ case GGML_OP_CONT: -+ case GGML_OP_DUP: -+ case GGML_OP_NORM: -+ case GGML_OP_GROUP_NORM: -+ case GGML_OP_RMS_NORM: -+ case GGML_OP_DIAG_MASK_INF: -+ case GGML_OP_SOFT_MAX: -+ case GGML_OP_ROPE: -+ case GGML_OP_MUL_MAT: -+ case GGML_OP_MUL_MAT_ID: -+ case GGML_OP_ARGSORT: -+ case GGML_OP_SUM_ROWS: -+ case GGML_OP_IM2COL: -+ case GGML_OP_TIMESTEP_EMBEDDING: -+ case GGML_OP_POOL_2D: -+ case GGML_OP_RWKV_WKV6: -+ case GGML_OP_LEAKY_RELU: -+ case GGML_OP_FLASH_ATTN_EXT: -+ break; -+ default: -+ std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl; -+ GGML_ABORT("fatal error"); -+ return false; -+ } -+ -+ vk_context compute_ctx; -+ -+ if (!dryrun) { -+ if (ctx->compute_ctx.expired()) { -+ compute_ctx = ggml_vk_create_context(ctx, ctx->device->compute_queue); -+ ctx->compute_ctx = compute_ctx; -+ ggml_vk_ctx_begin(ctx->device, compute_ctx); -+ } else { -+ compute_ctx = ctx->compute_ctx.lock(); -+ } -+ } else { -+ switch (node->op) { -+ case GGML_OP_REPEAT: -+ case GGML_OP_ACC: -+ case GGML_OP_GET_ROWS: -+ case GGML_OP_ADD: -+ case GGML_OP_MUL: -+ case GGML_OP_DIV: -+ case GGML_OP_CONCAT: -+ case GGML_OP_UPSCALE: -+ case GGML_OP_SCALE: -+ case GGML_OP_SQR: -+ case GGML_OP_SIN: -+ case GGML_OP_COS: -+ case GGML_OP_CLAMP: -+ case GGML_OP_PAD: -+ case GGML_OP_CPY: -+ case GGML_OP_CONT: -+ case GGML_OP_DUP: -+ case GGML_OP_NORM: -+ case GGML_OP_GROUP_NORM: -+ case GGML_OP_RMS_NORM: -+ case GGML_OP_UNARY: -+ case GGML_OP_DIAG_MASK_INF: -+ case GGML_OP_SOFT_MAX: -+ case GGML_OP_ROPE: -+ case GGML_OP_ARGSORT: -+ case GGML_OP_SUM_ROWS: -+ case GGML_OP_IM2COL: -+ case GGML_OP_TIMESTEP_EMBEDDING: -+ case GGML_OP_POOL_2D: -+ case GGML_OP_LEAKY_RELU: -+ { -+ // These operations all go through ggml_vk_op_f32, so short-circuit and -+ // do the only thing needed for the dryrun. -+ vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, node, node->op); -+ ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); -+ return false; -+ } -+ default: -+ break; -+ } -+ } -+ -+ switch (node->op) { -+ case GGML_OP_REPEAT: -+ ggml_vk_repeat(ctx, compute_ctx, src0, node, dryrun); -+ -+ break; -+ case GGML_OP_ACC: -+ ggml_vk_acc(ctx, compute_ctx, src0, src1, node, dryrun); -+ -+ break; -+ case GGML_OP_GET_ROWS: -+ ggml_vk_get_rows(ctx, compute_ctx, src0, src1, node, dryrun); -+ -+ break; -+ case GGML_OP_ADD: -+ ggml_vk_add(ctx, compute_ctx, src0, src1, node, dryrun); -+ -+ break; -+ case GGML_OP_MUL: -+ ggml_vk_mul(ctx, compute_ctx, src0, src1, node, dryrun); -+ -+ break; -+ case GGML_OP_DIV: -+ ggml_vk_div(ctx, compute_ctx, src0, src1, node, dryrun); -+ -+ break; -+ case GGML_OP_CONCAT: -+ ggml_vk_concat(ctx, compute_ctx, src0, src1, node, dryrun); -+ -+ break; -+ case GGML_OP_UPSCALE: -+ ggml_vk_upscale(ctx, compute_ctx, src0, node, dryrun); -+ -+ break; -+ case GGML_OP_SCALE: -+ ggml_vk_scale(ctx, compute_ctx, src0, node, dryrun); -+ -+ break; -+ case GGML_OP_SQR: -+ ggml_vk_sqr(ctx, compute_ctx, src0, node, dryrun); -+ -+ break; -+ case GGML_OP_SIN: -+ ggml_vk_sin(ctx, compute_ctx, src0, node, dryrun); -+ -+ break; -+ case GGML_OP_COS: -+ ggml_vk_cos(ctx, compute_ctx, src0, node, dryrun); -+ -+ break; -+ case GGML_OP_CLAMP: -+ ggml_vk_clamp(ctx, compute_ctx, src0, node, dryrun); -+ -+ break; -+ case GGML_OP_PAD: -+ ggml_vk_pad(ctx, compute_ctx, src0, node, dryrun); -+ -+ break; -+ case GGML_OP_CPY: -+ case GGML_OP_CONT: -+ case GGML_OP_DUP: -+ ggml_vk_cpy(ctx, compute_ctx, src0, node, dryrun); -+ -+ break; -+ case GGML_OP_NORM: -+ ggml_vk_norm(ctx, compute_ctx, src0, node, dryrun); -+ -+ break; -+ case GGML_OP_GROUP_NORM: -+ ggml_vk_group_norm(ctx, compute_ctx, src0, node, dryrun); -+ -+ break; -+ case GGML_OP_RMS_NORM: -+ ggml_vk_rms_norm(ctx, compute_ctx, src0, node, dryrun); -+ -+ break; -+ case GGML_OP_UNARY: -+ switch (ggml_get_unary_op(node)) { -+ case GGML_UNARY_OP_SILU: -+ case GGML_UNARY_OP_GELU: -+ case GGML_UNARY_OP_GELU_QUICK: -+ case GGML_UNARY_OP_RELU: -+ case GGML_UNARY_OP_TANH: -+ ggml_vk_unary(ctx, compute_ctx, src0, node, dryrun); -+ break; -+ default: -+ return false; -+ } -+ break; -+ case GGML_OP_DIAG_MASK_INF: -+ ggml_vk_diag_mask_inf(ctx, compute_ctx, src0, node, dryrun); -+ -+ break; -+ case GGML_OP_SOFT_MAX: -+ ggml_vk_soft_max(ctx, compute_ctx, src0, src1, node, dryrun); -+ -+ break; -+ case GGML_OP_ROPE: -+ ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, dryrun); -+ -+ break; -+ case GGML_OP_ARGSORT: -+ ggml_vk_argsort(ctx, compute_ctx, src0, node, dryrun); -+ -+ break; -+ case GGML_OP_SUM_ROWS: -+ ggml_vk_sum_rows(ctx, compute_ctx, src0, node, dryrun); -+ -+ break; -+ case GGML_OP_IM2COL: -+ ggml_vk_im2col(ctx, compute_ctx, src0, src1, node, dryrun); -+ -+ break; -+ case GGML_OP_TIMESTEP_EMBEDDING: -+ ggml_vk_timestep_embedding(ctx, compute_ctx, src0, node, dryrun); -+ -+ break; -+ case GGML_OP_POOL_2D: -+ ggml_vk_pool_2d(ctx, compute_ctx, src0, node, dryrun); -+ -+ break; -+ case GGML_OP_LEAKY_RELU: -+ ggml_vk_leaky_relu(ctx, compute_ctx, src0, node, dryrun); -+ -+ break; -+ case GGML_OP_MUL_MAT: -+ ggml_vk_mul_mat(ctx, compute_ctx, src0, src1, node, dryrun); -+ -+ break; -+ case GGML_OP_MUL_MAT_ID: -+ ggml_vk_mul_mat_id(ctx, compute_ctx, src0, src1, src2, node, dryrun); -+ -+ break; -+ -+ case GGML_OP_FLASH_ATTN_EXT: -+ ggml_vk_flash_attn(ctx, compute_ctx, src0, src1, src2, src3, node, dryrun); -+ -+ break; -+ -+ case GGML_OP_RWKV_WKV6: -+ ggml_vk_rwkv_wkv6(ctx, compute_ctx, node, dryrun); -+ -+ break; -+ default: -+ return false; -+ } -+ -+ if (dryrun) { -+ return false; -+ } -+ -+ ctx->tensor_ctxs[node_idx] = compute_ctx; -+ -+#if defined(GGML_VULKAN_CHECK_RESULTS) || defined(GGML_VULKAN_PERF) -+ // Force context reset on each node so that each tensor ends up in its own context -+ // and can be run and compared to its CPU equivalent separately -+ last_node = true; -+#endif -+ -+ if (submit || last_node) { -+ ggml_vk_ctx_end(compute_ctx); -+ -+ // TODO probably it'd be better to pass a exit_node flag to ggml_vk_compute_forward -+ if (last_node) { -+ compute_ctx->exit_tensor_idx = node_idx_begin; -+ } -+ else { -+ compute_ctx->exit_tensor_idx = -1; -+ } -+ -+ ctx->compute_ctx.reset(); -+ -+ bool ok = ggml_vk_compute_forward(ctx, node_begin, node_idx_begin, false); -+ if (!ok) { -+ if (node->op == GGML_OP_UNARY) { -+ std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast(node->op_params[0])) << ")" << std::endl; -+ } -+ else { -+ std::cerr << __func__ << ": error: op not supported " << node->name << " (" << ggml_op_name(node->op) << ")" << std::endl; -+ } -+ } -+ -+ } -+ return true; -+} -+ -+static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * tensor, int tensor_idx, bool use_fence = true){ -+ ggml_backend_buffer * buf = nullptr; -+ -+ switch (tensor->op) { -+ case GGML_OP_ADD: -+ case GGML_OP_ACC: -+ case GGML_OP_GET_ROWS: -+ case GGML_OP_MUL: -+ case GGML_OP_DIV: -+ case GGML_OP_CONCAT: -+ case GGML_OP_UPSCALE: -+ case GGML_OP_SCALE: -+ case GGML_OP_SQR: -+ case GGML_OP_SIN: -+ case GGML_OP_COS: -+ case GGML_OP_CLAMP: -+ case GGML_OP_PAD: -+ case GGML_OP_CPY: -+ case GGML_OP_CONT: -+ case GGML_OP_DUP: -+ case GGML_OP_NORM: -+ case GGML_OP_GROUP_NORM: -+ case GGML_OP_RMS_NORM: -+ case GGML_OP_DIAG_MASK_INF: -+ case GGML_OP_SOFT_MAX: -+ case GGML_OP_ROPE: -+ case GGML_OP_RESHAPE: -+ case GGML_OP_VIEW: -+ case GGML_OP_PERMUTE: -+ case GGML_OP_TRANSPOSE: -+ case GGML_OP_NONE: -+ case GGML_OP_ARGSORT: -+ case GGML_OP_SUM_ROWS: -+ case GGML_OP_IM2COL: -+ case GGML_OP_TIMESTEP_EMBEDDING: -+ case GGML_OP_POOL_2D: -+ case GGML_OP_RWKV_WKV6: -+ case GGML_OP_LEAKY_RELU: -+ case GGML_OP_REPEAT: -+ buf = tensor->buffer; -+ -+ break; -+ case GGML_OP_UNARY: -+ switch (ggml_get_unary_op(tensor)) { -+ case GGML_UNARY_OP_SILU: -+ case GGML_UNARY_OP_GELU: -+ case GGML_UNARY_OP_GELU_QUICK: -+ case GGML_UNARY_OP_RELU: -+ case GGML_UNARY_OP_TANH: -+ buf = tensor->buffer; -+ break; -+ default: -+ return false; -+ } -+ break; -+ case GGML_OP_MUL_MAT: -+ case GGML_OP_MUL_MAT_ID: -+ case GGML_OP_FLASH_ATTN_EXT: -+ buf = tensor->buffer; -+ -+ break; -+ default: -+ return false; -+ } -+ -+ if (buf == nullptr) { -+ return false; -+ } -+ -+ VK_LOG_DEBUG("ggml_vk_compute_forward(" << tensor << ", name=" << tensor->name << ", op=" << ggml_op_name(tensor->op) << ", type=" << tensor->type << ", ne0=" << tensor->ne[0] << ", ne1=" << tensor->ne[1] << ", ne2=" << tensor->ne[2] << ", ne3=" << tensor->ne[3] << ", nb0=" << tensor->nb[0] << ", nb1=" << tensor->nb[1] << ", nb2=" << tensor->nb[2] << ", nb3=" << tensor->nb[3] << ", view_src=" << tensor->view_src << ", view_offs=" << tensor->view_offs << ")"); -+ -+ vk_context subctx = ctx->tensor_ctxs[tensor_idx].lock(); -+ -+ // always wait for the GPU work to be done for the last submit -+ if (tensor_idx == subctx->exit_tensor_idx) { -+ use_fence = true; -+ } -+ -+ // Only run if ctx hasn't been submitted yet -+ if (!subctx->seqs.empty()) { -+#ifdef GGML_VULKAN_CHECK_RESULTS -+ ggml_vk_check_results_0(tensor); -+ use_fence = true; -+#endif -+ -+ // Do staging buffer copies -+ for (auto& cpy : subctx->in_memcpys) { -+ memcpy(cpy.dst, cpy.src, cpy.n); -+ } -+ -+ ggml_vk_submit(subctx, use_fence ? ctx->fence : vk::Fence{}); -+ -+ if (use_fence) { -+ VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_compute_forward waitForFences"); -+ -+ ctx->device->device.resetFences({ ctx->fence }); -+ } -+#ifdef GGML_VULKAN_CHECK_RESULTS -+ ggml_vk_check_results_1(tensor); -+#endif -+ } -+ -+ if (tensor_idx == subctx->exit_tensor_idx) { -+ // Do staging buffer copies -+ for (auto& cpy : subctx->out_memcpys) { -+ memcpy(cpy.dst, cpy.src, cpy.n); -+ } -+ subctx->in_memcpys.clear(); -+ subctx->out_memcpys.clear(); -+ } -+ -+ return true; -+} -+ -+// Clean up after graph processing is done -+static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) { -+ VK_LOG_DEBUG("ggml_vk_graph_cleanup()"); -+ for (auto& buffer : ctx->gc.temp_buffers) { -+ ggml_vk_pool_free(ctx, buffer); -+ } -+ ctx->gc.temp_buffers.clear(); -+ -+ for (auto& dsr : ctx->device->pipeline_descriptor_set_requirements) { -+ vk_pipeline_ref plr = ctx->device->pipelines[dsr.first]; -+ -+ if (plr.expired()) { -+ continue; -+ } -+ -+ vk_pipeline pl = plr.lock(); -+ ggml_pipeline_cleanup(pl); -+ } -+ -+ ggml_vk_queue_cleanup(ctx->device, ctx->device->compute_queue); -+ ggml_vk_queue_cleanup(ctx->device, ctx->device->transfer_queue); -+ -+ for (size_t i = 0; i < ctx->gc.semaphores.size(); i++) { -+ ctx->device->device.destroySemaphore({ ctx->gc.semaphores[i].s }); -+ } -+ ctx->gc.semaphores.clear(); -+ -+ for (size_t i = 0; i < ctx->gc.tl_semaphores.size(); i++) { -+ ctx->device->device.destroySemaphore({ ctx->gc.tl_semaphores[i].s }); -+ } -+ ctx->gc.tl_semaphores.clear(); -+ ctx->semaphore_idx = 0; -+ -+ ctx->event_idx = 0; -+ -+ for (auto& event : ctx->gc.events) { -+ ctx->device->device.resetEvent(event); -+ } -+ -+ ctx->tensor_ctxs.clear(); -+ ctx->gc.contexts.clear(); -+ ctx->device->pipeline_descriptor_set_requirements.clear(); -+} -+ -+// Clean up on backend free -+static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) { -+ VK_LOG_DEBUG("ggml_vk_cleanup(" << ctx->name << ")"); -+ ggml_vk_graph_cleanup(ctx); -+ -+ ggml_vk_destroy_buffer(ctx->prealloc_x); -+ ggml_vk_destroy_buffer(ctx->prealloc_y); -+ ggml_vk_destroy_buffer(ctx->prealloc_split_k); -+ -+ for (auto& buffer : ctx->buffer_pool) { -+ ggml_vk_destroy_buffer(buffer); -+ } -+ -+ ctx->prealloc_size_x = 0; -+ ctx->prealloc_size_y = 0; -+ ctx->prealloc_size_split_k = 0; -+ -+ for (auto& event : ctx->gc.events) { -+ ctx->device->device.destroyEvent(event); -+ } -+ ctx->gc.events.clear(); -+ -+ ctx->device->device.destroyFence(ctx->fence); -+} -+ -+static int ggml_vk_get_device_count() { -+ ggml_vk_instance_init(); -+ -+ return vk_instance.device_indices.size(); -+} -+ -+static void ggml_vk_get_device_description(int device, char * description, size_t description_size) { -+ ggml_vk_instance_init(); -+ -+ std::vector devices = vk_instance.instance.enumeratePhysicalDevices(); -+ -+ vk::PhysicalDeviceProperties props; -+ devices[device].getProperties(&props); -+ -+ snprintf(description, description_size, "%s", props.deviceName.data()); -+} -+ -+// backend interface -+ -+#define UNUSED GGML_UNUSED -+ -+// device backend -+ -+static bool ggml_backend_buffer_is_vk(ggml_backend_buffer_t buffer) { -+ return buffer->buft->iface.get_name == ggml_backend_vk_buffer_type_name; -+} -+ -+static void ggml_backend_vk_buffer_free_buffer(ggml_backend_buffer_t buffer) { -+ VK_LOG_MEMORY("ggml_backend_vk_buffer_free_buffer()"); -+ ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context; -+ ggml_vk_destroy_buffer(ctx->dev_buffer); -+ delete ctx; -+} -+ -+static void * ggml_backend_vk_buffer_get_base(ggml_backend_buffer_t buffer) { -+ return vk_ptr_base; -+ -+ UNUSED(buffer); -+} -+ -+static void ggml_backend_vk_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { -+ VK_LOG_DEBUG("ggml_backend_vk_buffer_init_tensor(" << buffer << " (" << buffer->context << "), " << tensor << ")"); -+ if (tensor->view_src != nullptr) { -+ GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft); -+ } -+} -+ -+static void ggml_backend_vk_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { -+ VK_LOG_DEBUG("ggml_backend_vk_buffer_set_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")"); -+ ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context; -+ vk_buffer buf = buf_ctx->dev_buffer; -+ -+ ggml_vk_buffer_write(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size); -+} -+ -+static void ggml_backend_vk_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { -+ VK_LOG_DEBUG("ggml_backend_vk_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")"); -+ ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context; -+ -+ vk_buffer buf = buf_ctx->dev_buffer; -+ -+ ggml_vk_buffer_read(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size); -+} -+ -+static bool ggml_backend_vk_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) { -+ if (ggml_backend_buffer_is_vk(src->buffer)) { -+ ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context; -+ ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; -+ -+ vk_buffer src_buf = src_buf_ctx->dev_buffer; -+ vk_buffer dst_buf = dst_buf_ctx->dev_buffer; -+ -+ ggml_vk_buffer_copy(dst_buf, vk_tensor_offset(dst) + dst->view_offs, src_buf, vk_tensor_offset(src) + src->view_offs, ggml_nbytes(src)); -+ -+ return true; -+ } -+ return false; -+ -+ UNUSED(buffer); -+} -+ -+static void ggml_backend_vk_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { -+ ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context; -+ -+ ggml_vk_buffer_memset(ctx->dev_buffer, 0, value, buffer->size); -+} -+ -+static ggml_backend_buffer_i ggml_backend_vk_buffer_interface = { -+ /* .free_buffer = */ ggml_backend_vk_buffer_free_buffer, -+ /* .get_base = */ ggml_backend_vk_buffer_get_base, -+ /* .init_tensor = */ ggml_backend_vk_buffer_init_tensor, -+ /* .memset_tensor = */ NULL, -+ /* .set_tensor = */ ggml_backend_vk_buffer_set_tensor, -+ /* .get_tensor = */ ggml_backend_vk_buffer_get_tensor, -+ /* .cpy_tensor = */ ggml_backend_vk_buffer_cpy_tensor, -+ /* .clear = */ ggml_backend_vk_buffer_clear, -+ /* .reset = */ NULL, -+}; -+ -+// vk buffer type -+static const char * ggml_backend_vk_buffer_type_name(ggml_backend_buffer_type_t buft) { -+ ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *)buft->context; -+ -+ return ctx->name.c_str(); -+} -+ -+static ggml_backend_buffer_t ggml_backend_vk_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { -+ VK_LOG_MEMORY("ggml_backend_vk_buffer_type_alloc_buffer(" << size << ")"); -+ ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context; -+ -+ vk_buffer dev_buffer = nullptr; -+ try { -+ dev_buffer = ggml_vk_create_buffer_device(ctx->device, size); -+ } catch (const vk::SystemError& e) { -+ return nullptr; -+ } -+ -+ ggml_backend_vk_buffer_context * bufctx = new ggml_backend_vk_buffer_context(ctx->device, std::move(dev_buffer), ctx->name); -+ -+ return ggml_backend_buffer_init(buft, ggml_backend_vk_buffer_interface, bufctx, size); -+} -+ -+static size_t ggml_backend_vk_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { -+ ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context; -+ return ctx->device->properties.limits.minStorageBufferOffsetAlignment; -+} -+ -+static size_t ggml_backend_vk_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) { -+ ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context; -+ return ctx->device->max_memory_allocation_size; -+} -+ -+static size_t ggml_backend_vk_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { -+ return ggml_nbytes(tensor); -+ -+ UNUSED(buft); -+} -+ -+ggml_backend_buffer_type_t ggml_backend_vk_buffer_type(size_t dev_num) { -+ ggml_vk_instance_init(); -+ -+ VK_LOG_DEBUG("ggml_backend_vk_buffer_type(" << dev_num << ")"); -+ -+ vk_device dev = ggml_vk_get_device(dev_num); -+ -+ return &dev->buffer_type; -+} -+ -+// host buffer type -+ -+static const char * ggml_backend_vk_host_buffer_type_name(ggml_backend_buffer_type_t buft) { -+ return GGML_VK_NAME "_Host"; -+ -+ UNUSED(buft); -+} -+ -+static const char * ggml_backend_vk_host_buffer_name(ggml_backend_buffer_t buffer) { -+ return GGML_VK_NAME "_Host"; -+ -+ UNUSED(buffer); -+} -+ -+static void ggml_backend_vk_host_buffer_free_buffer(ggml_backend_buffer_t buffer) { -+ VK_LOG_MEMORY("ggml_backend_vk_host_buffer_free_buffer()"); -+ ggml_vk_host_free(vk_instance.devices[0], buffer->context); -+} -+ -+static ggml_backend_buffer_t ggml_backend_vk_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { -+ VK_LOG_MEMORY("ggml_backend_vk_host_buffer_type_alloc_buffer(" << size << ")"); -+ -+ size += 32; // Behave like the CPU buffer type -+ void * ptr = nullptr; -+ try { -+ ptr = ggml_vk_host_malloc(vk_instance.devices[0], size); -+ } catch (vk::SystemError& e) { -+ std::cerr << "ggml_vulkan: Failed to allocate pinned memory." << std::endl; -+ std::cerr << "ggml_vulkan: " << e.what() << std::endl; -+ // fallback to cpu buffer -+ return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size); -+ } -+ -+ ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size); -+ buffer->buft = buft; -+ buffer->iface.free_buffer = ggml_backend_vk_host_buffer_free_buffer; -+ -+ return buffer; -+ -+ UNUSED(buft); -+} -+ -+static size_t ggml_backend_vk_host_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { -+ return vk_instance.devices[0]->properties.limits.minMemoryMapAlignment; -+ -+ UNUSED(buft); -+} -+ -+// Should be changed to return device-specific host buffer type -+// but that probably requires changes in llama.cpp -+ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type() { -+ static struct ggml_backend_buffer_type ggml_backend_vk_buffer_type_host = { -+ /* .iface = */ { -+ /* .get_name = */ ggml_backend_vk_host_buffer_type_name, -+ /* .alloc_buffer = */ ggml_backend_vk_host_buffer_type_alloc_buffer, -+ /* .get_alignment = */ ggml_backend_vk_host_buffer_type_get_alignment, -+ /* .get_max_size = */ NULL, // defaults to SIZE_MAX -+ /* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size, -+ /* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host, -+ }, -+ /* .device = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), 0), -+ /* .context = */ nullptr, -+ }; -+ -+ // Make sure device 0 is initialized -+ ggml_vk_instance_init(); -+ ggml_vk_get_device(0); -+ -+ return &ggml_backend_vk_buffer_type_host; -+} -+ -+ -+// backend -+ -+static const char * ggml_backend_vk_name(ggml_backend_t backend) { -+ ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; -+ -+ return ctx->name.c_str(); -+} -+ -+static void ggml_backend_vk_free(ggml_backend_t backend) { -+ ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; -+ VK_LOG_DEBUG("ggml_backend_vk_free(" << ctx->name << ")"); -+ -+ ggml_vk_cleanup(ctx); -+ -+ delete ctx; -+ delete backend; -+} -+ -+static ggml_backend_buffer_type_t ggml_backend_vk_get_default_buffer_type(ggml_backend_t backend) { -+ ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; -+ -+ return &ctx->device->buffer_type; -+} -+ -+static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { -+ VK_LOG_DEBUG("ggml_backend_vk_set_tensor_async(" << size << ")"); -+ ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; -+ GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type"); -+ -+ ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context; -+ -+ vk_context transfer_ctx; -+ -+ if (ctx->transfer_ctx.expired()) { -+ // Initialize new transfer context -+ transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue); -+ ctx->transfer_ctx = transfer_ctx; -+ ggml_vk_ctx_begin(ctx->device, transfer_ctx); -+ } else { -+ transfer_ctx = ctx->transfer_ctx.lock(); -+ } -+ -+ vk_buffer buf = buf_ctx->dev_buffer; -+ -+ ggml_vk_buffer_write_async(transfer_ctx, buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size); -+} -+ -+static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { -+ VK_LOG_DEBUG("ggml_backend_vk_get_tensor_async(" << size << ")"); -+ ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; -+ GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type"); -+ -+ ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context; -+ -+ vk_context transfer_ctx; -+ -+ if (ctx->transfer_ctx.expired()) { -+ // Initialize new transfer context -+ transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue); -+ ctx->transfer_ctx = transfer_ctx; -+ ggml_vk_ctx_begin(ctx->device, transfer_ctx); -+ } else { -+ transfer_ctx = ctx->transfer_ctx.lock(); -+ } -+ -+ vk_buffer buf = buf_ctx->dev_buffer; -+ -+ ggml_vk_buffer_read_async(transfer_ctx, buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size); -+} -+ -+static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend, const ggml_tensor * src, ggml_tensor * dst) { -+ VK_LOG_DEBUG("ggml_backend_vk_cpy_tensor_async()"); -+ ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; -+ if ((dst->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || dst->buffer->buft == ggml_backend_vk_host_buffer_type()) && ggml_backend_buffer_is_vk(src->buffer)) { -+ ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context; -+ ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; -+ -+ vk_context transfer_ctx; -+ -+ if (ctx->transfer_ctx.expired()) { -+ // Initialize new transfer context -+ transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue); -+ ctx->transfer_ctx = transfer_ctx; -+ ggml_vk_ctx_begin(ctx->device, transfer_ctx); -+ } else { -+ transfer_ctx = ctx->transfer_ctx.lock(); -+ } -+ -+ vk_buffer src_buf = src_buf_ctx->dev_buffer; -+ vk_buffer dst_buf = dst_buf_ctx->dev_buffer; -+ -+ ggml_vk_buffer_copy_async(transfer_ctx, dst_buf, vk_tensor_offset(dst) + dst->view_offs, src_buf, vk_tensor_offset(src) + src->view_offs, ggml_nbytes(src)); -+ return true; -+ } -+ -+ return false; -+} -+ -+static void ggml_backend_vk_synchronize(ggml_backend_t backend) { -+ VK_LOG_DEBUG("ggml_backend_vk_synchronize()"); -+ ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; -+ if(ctx->transfer_ctx.expired()) { -+ return; -+ } -+ -+ vk_context transfer_ctx = ctx->transfer_ctx.lock(); -+ -+ ggml_vk_ctx_end(transfer_ctx); -+ -+ for (auto& cpy : transfer_ctx->in_memcpys) { -+ memcpy(cpy.dst, cpy.src, cpy.n); -+ } -+ -+ ggml_vk_submit(transfer_ctx, ctx->fence); -+ VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_backend_vk_synchronize waitForFences"); -+ ctx->device->device.resetFences({ ctx->fence }); -+ -+ for (auto& cpy : transfer_ctx->out_memcpys) { -+ memcpy(cpy.dst, cpy.src, cpy.n); -+ } -+ -+ ctx->transfer_ctx.reset(); -+} -+ -+static bool ggml_vk_is_empty(ggml_tensor * node) { -+ return ggml_is_empty(node) || node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE; -+} -+ -+static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { -+ VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)"); -+ ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; -+ -+ for (int i = 0; i < cgraph->n_nodes; i++) { -+ ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false); -+ } -+ ggml_vk_preallocate_buffers(ctx); -+ ggml_pipeline_allocate_descriptor_sets(ctx->device); -+ -+ int last_node = cgraph->n_nodes - 1; -+ -+ // If the last op in the cgraph isn't backend GPU, the command buffer doesn't get closed properly -+ while (last_node > 0 && ggml_vk_is_empty(cgraph->nodes[last_node])) { -+ last_node -= 1; -+ } -+ -+ // Reserve tensor context space for all nodes -+ ctx->tensor_ctxs.resize(cgraph->n_nodes); -+ -+ bool first_node_in_batch = true; // true if next node will be first node in a batch -+ int submit_node_idx = 0; // index to first node in a batch -+ -+ // Submit work every nodes_per_submit nodes to overlap CPU cmdbuffer generation with GPU execution. -+ // Start with a smaller count to get work submitted right away, and increase it after each submit. -+ int nodes_per_submit = 20; -+ int submitted_nodes = 0; -+ int submit_count = 0; -+ for (int i = 0; i < cgraph->n_nodes; i++) { -+ if (first_node_in_batch) { -+ submit_node_idx = i; -+ } -+ -+ bool submit = (submitted_nodes >= nodes_per_submit) || (i == last_node); -+ -+ bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, submit); -+ -+ if (enqueued) { -+ ++submitted_nodes; -+ -+#ifndef GGML_VULKAN_CHECK_RESULTS -+ if (first_node_in_batch) { -+ first_node_in_batch = false; -+ } -+#endif -+ } -+ -+ if (submit) { -+ first_node_in_batch = true; -+ submitted_nodes = 0; -+ switch (submit_count) { -+ case 0: -+ nodes_per_submit = 50; -+ break; -+ default: -+ nodes_per_submit = 100; -+ break; -+ } -+ submit_count++; -+ } -+ } -+ -+#ifdef GGML_VULKAN_PERF -+ ctx->device->perf_logger->print_timings(); -+#endif -+ -+ ggml_vk_graph_cleanup(ctx); -+ -+ return GGML_STATUS_SUCCESS; -+ -+ UNUSED(backend); -+} -+ -+// TODO: enable async and synchronize -+static ggml_backend_i ggml_backend_vk_interface = { -+ /* .get_name = */ ggml_backend_vk_name, -+ /* .free = */ ggml_backend_vk_free, -+ /* .set_tensor_async = */ NULL, // ggml_backend_vk_set_tensor_async, -+ /* .get_tensor_async = */ NULL, // ggml_backend_vk_get_tensor_async, -+ /* .cpy_tensor_async = */ NULL, // ggml_backend_vk_cpy_tensor_async, -+ /* .synchronize = */ NULL, // ggml_backend_vk_synchronize, -+ /* .graph_plan_create = */ NULL, -+ /* .graph_plan_free = */ NULL, -+ /* .graph_plan_update = */ NULL, -+ /* .graph_plan_compute = */ NULL, -+ /* .graph_compute = */ ggml_backend_vk_graph_compute, -+ /* .event_record = */ NULL, -+ /* .event_wait = */ NULL, -+}; -+ -+static ggml_guid_t ggml_backend_vk_guid() { -+ static ggml_guid guid = { 0xb8, 0xf7, 0x4f, 0x86, 0x40, 0x3c, 0xe1, 0x02, 0x91, 0xc8, 0xdd, 0xe9, 0x02, 0x3f, 0xc0, 0x2b }; -+ return &guid; -+} -+ -+ggml_backend_t ggml_backend_vk_init(size_t dev_num) { -+ VK_LOG_DEBUG("ggml_backend_vk_init(" << dev_num << ")"); -+ -+ ggml_backend_vk_context * ctx = new ggml_backend_vk_context; -+ ggml_vk_init(ctx, dev_num); -+ -+ ggml_backend_t vk_backend = new ggml_backend { -+ /* .guid = */ ggml_backend_vk_guid(), -+ /* .interface = */ ggml_backend_vk_interface, -+ /* .device = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), dev_num), -+ /* .context = */ ctx, -+ }; -+ -+ return vk_backend; -+} -+ -+bool ggml_backend_is_vk(ggml_backend_t backend) { -+ return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_vk_guid()); -+} -+ -+int ggml_backend_vk_get_device_count() { -+ return ggml_vk_get_device_count(); -+} -+ -+void ggml_backend_vk_get_device_description(int device, char * description, size_t description_size) { -+ GGML_ASSERT(device < (int) vk_instance.device_indices.size()); -+ int dev_idx = vk_instance.device_indices[device]; -+ ggml_vk_get_device_description(dev_idx, description, description_size); -+} -+ -+void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total) { -+ GGML_ASSERT(device < (int) vk_instance.device_indices.size()); -+ -+ vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device]]; -+ -+ vk::PhysicalDeviceMemoryProperties memprops = vkdev.getMemoryProperties(); -+ -+ for (const vk::MemoryHeap& heap : memprops.memoryHeaps) { -+ if (heap.flags & vk::MemoryHeapFlagBits::eDeviceLocal) { -+ *total = heap.size; -+ *free = heap.size; -+ break; -+ } -+ } -+} -+ -+////////////////////////// -+ -+struct ggml_backend_vk_device_context { -+ size_t device; -+ std::string name; -+ std::string description; -+}; -+ -+static const char * ggml_backend_vk_device_get_name(ggml_backend_dev_t dev) { -+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; -+ return ctx->name.c_str(); -+} -+ -+static const char * ggml_backend_vk_device_get_description(ggml_backend_dev_t dev) { -+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; -+ return ctx->description.c_str(); -+} -+ -+static void ggml_backend_vk_device_get_memory(ggml_backend_dev_t device, size_t * free, size_t * total) { -+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)device->context; -+ ggml_backend_vk_get_device_memory(ctx->device, free, total); -+} -+ -+static ggml_backend_buffer_type_t ggml_backend_vk_device_get_buffer_type(ggml_backend_dev_t dev) { -+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; -+ return ggml_backend_vk_buffer_type(ctx->device); -+} -+ -+static ggml_backend_buffer_type_t ggml_backend_vk_device_get_host_buffer_type(ggml_backend_dev_t dev) { -+ UNUSED(dev); -+ return ggml_backend_vk_host_buffer_type(); -+} -+ -+static enum ggml_backend_dev_type ggml_backend_vk_device_get_type(ggml_backend_dev_t dev) { -+ UNUSED(dev); -+ return GGML_BACKEND_DEVICE_TYPE_GPU; -+} -+ -+static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) { -+ props->name = ggml_backend_vk_device_get_name(dev); -+ props->description = ggml_backend_vk_device_get_description(dev); -+ props->type = ggml_backend_vk_device_get_type(dev); -+ ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total); -+ props->caps = { -+ /* .async = */ false, -+ /* .host_buffer = */ true, -+ /* .buffer_from_host_ptr = */ false, -+ /* .events = */ false, -+ }; -+} -+ -+static ggml_backend_t ggml_backend_vk_device_init(ggml_backend_dev_t dev, const char * params) { -+ UNUSED(params); -+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; -+ return ggml_backend_vk_init(ctx->device); -+} -+ -+static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { -+ switch (op->op) { -+ case GGML_OP_UNARY: -+ switch (ggml_get_unary_op(op)) { -+ case GGML_UNARY_OP_GELU: -+ case GGML_UNARY_OP_GELU_QUICK: -+ case GGML_UNARY_OP_SILU: -+ case GGML_UNARY_OP_RELU: -+ case GGML_UNARY_OP_TANH: -+ return ggml_is_contiguous(op->src[0]); -+ default: -+ return false; -+ } -+ break; -+ case GGML_OP_MUL_MAT: -+ case GGML_OP_MUL_MAT_ID: -+ { -+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; -+ const vk_device& device = ggml_vk_get_device(ctx->device); -+ if (op->op == GGML_OP_MUL_MAT_ID && !device->mul_mat_id_s && !device->mul_mat_id_m && !device->mul_mat_id_l) { -+ // If there's not enough shared memory for row_ids and the result tile, fallback to CPU -+ return false; -+ } -+ switch (op->src[0]->type) { -+ case GGML_TYPE_F32: -+ case GGML_TYPE_F16: -+ case GGML_TYPE_Q4_0: -+ case GGML_TYPE_Q4_1: -+ case GGML_TYPE_Q5_0: -+ case GGML_TYPE_Q5_1: -+ case GGML_TYPE_Q8_0: -+ case GGML_TYPE_Q2_K: -+ case GGML_TYPE_Q3_K: -+ case GGML_TYPE_Q4_K: -+ case GGML_TYPE_Q5_K: -+ case GGML_TYPE_Q6_K: -+ case GGML_TYPE_IQ4_NL: -+ break; -+ default: -+ return false; -+ } -+ struct ggml_tensor * a; -+ struct ggml_tensor * b; -+ if (op->op == GGML_OP_MUL_MAT) { -+ a = op->src[0]; -+ b = op->src[1]; -+ } else { -+ a = op->src[2]; -+ b = op->src[1]; -+ } -+ if (a->ne[3] != b->ne[3]) { -+ return false; -+ } -+ if (!(ggml_vk_dim01_contiguous(op->src[0]) || op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) || -+ !(ggml_vk_dim01_contiguous(op->src[1]) || op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16)) { -+ return false; -+ } -+ -+ return true; -+ } break; -+ case GGML_OP_FLASH_ATTN_EXT: -+ { -+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; -+ if (!ggml_vk_get_device(ctx->device)->coopmat2) { -+ return false; -+ } -+ switch (op->src[0]->ne[0]) { -+ case 64: -+ case 80: -+ case 96: -+ case 112: -+ case 128: -+ case 256: -+ break; -+ default: -+ return false; -+ } -+ if (op->src[0]->type != GGML_TYPE_F32) { -+ return false; -+ } -+ if (op->type != GGML_TYPE_F32) { -+ return false; -+ } -+ if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) { -+ return false; -+ } -+ // It's straightforward to support different K/V dequant, but would -+ // significantly increase the number of pipelines -+ if (op->src[1]->type != op->src[2]->type) { -+ return false; -+ } -+ switch (op->src[1]->type) { -+ case GGML_TYPE_F16: -+ case GGML_TYPE_Q4_0: -+ case GGML_TYPE_Q4_1: -+ case GGML_TYPE_Q5_0: -+ case GGML_TYPE_Q5_1: -+ case GGML_TYPE_Q8_0: -+ // K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently -+ //case GGML_TYPE_Q2_K: -+ //case GGML_TYPE_Q3_K: -+ //case GGML_TYPE_Q4_K: -+ //case GGML_TYPE_Q5_K: -+ //case GGML_TYPE_Q6_K: -+ case GGML_TYPE_IQ4_NL: -+ break; -+ default: -+ return false; -+ } -+ return true; -+ } -+ case GGML_OP_GET_ROWS: -+ { -+ switch (op->src[0]->type) { -+ case GGML_TYPE_F32: -+ case GGML_TYPE_F16: -+ case GGML_TYPE_Q4_0: -+ case GGML_TYPE_Q4_1: -+ case GGML_TYPE_Q5_0: -+ case GGML_TYPE_Q5_1: -+ case GGML_TYPE_Q8_0: -+ case GGML_TYPE_IQ4_NL: -+ return true; -+ default: -+ return false; -+ } -+ } break; -+ case GGML_OP_CONT: -+ case GGML_OP_CPY: -+ case GGML_OP_DUP: -+ { -+ ggml_type src0_type = op->src[0]->type; -+ ggml_type src1_type = op->src[1] != nullptr ? op->src[1]->type : src0_type; -+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) { -+ return true; -+ } -+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) { -+ return true; -+ } -+ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { -+ return true; -+ } -+ return false; -+ } break; -+ case GGML_OP_REPEAT: -+ return ggml_type_size(op->type) == sizeof(float) && ggml_type_size(op->src[0]->type) == sizeof(float); -+ case GGML_OP_ROPE: -+ { -+ const int mode = ((const int32_t *) op->op_params)[2]; -+ if (mode & GGML_ROPE_TYPE_MROPE) { -+ return false; -+ } -+ if (mode & GGML_ROPE_TYPE_VISION) { -+ return false; -+ } -+ return ggml_is_contiguous(op->src[0]); -+ } -+ case GGML_OP_NONE: -+ case GGML_OP_RESHAPE: -+ case GGML_OP_VIEW: -+ case GGML_OP_PERMUTE: -+ case GGML_OP_TRANSPOSE: -+ case GGML_OP_NORM: -+ case GGML_OP_GROUP_NORM: -+ case GGML_OP_RMS_NORM: -+ case GGML_OP_ADD: -+ case GGML_OP_ACC: -+ case GGML_OP_MUL: -+ case GGML_OP_DIV: -+ case GGML_OP_CONCAT: -+ case GGML_OP_UPSCALE: -+ case GGML_OP_SCALE: -+ case GGML_OP_SQR: -+ case GGML_OP_SIN: -+ case GGML_OP_COS: -+ case GGML_OP_CLAMP: -+ case GGML_OP_PAD: -+ case GGML_OP_DIAG_MASK_INF: -+ case GGML_OP_SOFT_MAX: -+ case GGML_OP_ARGSORT: -+ case GGML_OP_SUM_ROWS: -+ case GGML_OP_IM2COL: -+ case GGML_OP_TIMESTEP_EMBEDDING: -+ case GGML_OP_POOL_2D: -+ case GGML_OP_RWKV_WKV6: -+ case GGML_OP_LEAKY_RELU: -+ return true; -+ default: -+ return false; -+ } -+ -+ UNUSED(dev); -+} -+ -+static bool ggml_backend_vk_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { -+ if (buft->iface.get_name != ggml_backend_vk_buffer_type_name) { -+ return false; -+ } -+ -+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; -+ ggml_backend_vk_buffer_type_context * buft_ctx = (ggml_backend_vk_buffer_type_context *)buft->context; -+ -+ return buft_ctx->device->idx == ctx->device; -+} -+ -+static bool ggml_backend_vk_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) { -+ const int min_batch_size = 32; -+ -+ return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) || -+ (op->ne[2] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID); -+ -+ UNUSED(dev); -+} -+ -+static const struct ggml_backend_device_i ggml_backend_vk_device_i = { -+ /* .get_name = */ ggml_backend_vk_device_get_name, -+ /* .get_description = */ ggml_backend_vk_device_get_description, -+ /* .get_memory = */ ggml_backend_vk_device_get_memory, -+ /* .get_type = */ ggml_backend_vk_device_get_type, -+ /* .get_props = */ ggml_backend_vk_device_get_props, -+ /* .init_backend = */ ggml_backend_vk_device_init, -+ /* .get_buffer_type = */ ggml_backend_vk_device_get_buffer_type, -+ /* .get_host_buffer_type = */ ggml_backend_vk_device_get_host_buffer_type, -+ /* .buffer_from_host_ptr = */ NULL, -+ /* .supports_op = */ ggml_backend_vk_device_supports_op, -+ /* .supports_buft = */ ggml_backend_vk_device_supports_buft, -+ /* .offload_op = */ ggml_backend_vk_device_offload_op, -+ /* .event_new = */ NULL, -+ /* .event_free = */ NULL, -+ /* .event_synchronize = */ NULL, -+}; -+ -+static const char * ggml_backend_vk_reg_get_name(ggml_backend_reg_t reg) { -+ UNUSED(reg); -+ return GGML_VK_NAME; -+} -+ -+static size_t ggml_backend_vk_reg_get_device_count(ggml_backend_reg_t reg) { -+ UNUSED(reg); -+ return ggml_backend_vk_get_device_count(); -+} -+ -+static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, size_t device) { -+ static std::vector devices; -+ -+ static bool initialized = false; -+ -+ { -+ static std::mutex mutex; -+ std::lock_guard lock(mutex); -+ if (!initialized) { -+ for (int i = 0; i < ggml_backend_vk_get_device_count(); i++) { -+ ggml_backend_vk_device_context * ctx = new ggml_backend_vk_device_context; -+ char desc[256]; -+ ggml_backend_vk_get_device_description(i, desc, sizeof(desc)); -+ ctx->device = i; -+ ctx->name = GGML_VK_NAME + std::to_string(i); -+ ctx->description = desc; -+ devices.push_back(new ggml_backend_device { -+ /* .iface = */ ggml_backend_vk_device_i, -+ /* .reg = */ reg, -+ /* .context = */ ctx, -+ }); -+ } -+ initialized = true; -+ } -+ } -+ -+ GGML_ASSERT(device < devices.size()); -+ return devices[device]; -+} -+ -+static const struct ggml_backend_reg_i ggml_backend_vk_reg_i = { -+ /* .get_name = */ ggml_backend_vk_reg_get_name, -+ /* .get_device_count = */ ggml_backend_vk_reg_get_device_count, -+ /* .get_device = */ ggml_backend_vk_reg_get_device, -+ /* .get_proc_address = */ NULL, -+}; -+ -+ggml_backend_reg_t ggml_backend_vk_reg() { -+ static ggml_backend_reg reg = { -+ /* .api_version = */ GGML_BACKEND_API_VERSION, -+ /* .iface = */ ggml_backend_vk_reg_i, -+ /* .context = */ nullptr, -+ }; -+ -+ return ® -+} -+ -+// Extension availability -+static bool ggml_vk_instance_validation_ext_available(const std::vector& instance_extensions) { -+#ifdef GGML_VULKAN_VALIDATE -+ bool portability_enumeration_ext = false; -+ // Check for portability enumeration extension for MoltenVK support -+ for (const auto& properties : instance_extensions) { -+ if (strcmp("VK_KHR_portability_enumeration", properties.extensionName) == 0) { -+ return true; -+ } -+ } -+ if (!portability_enumeration_ext) { -+ std::cerr << "ggml_vulkan: WARNING: Instance extension VK_KHR_portability_enumeration not found." << std::endl; -+ } -+#endif -+ return false; -+ -+ UNUSED(instance_extensions); -+} -+static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector& instance_extensions) { -+#ifdef __APPLE__ -+ bool portability_enumeration_ext = false; -+ // Check for portability enumeration extension for MoltenVK support -+ for (const auto& properties : instance_extensions) { -+ if (strcmp("VK_KHR_portability_enumeration", properties.extensionName) == 0) { -+ return true; -+ } -+ } -+ if (!portability_enumeration_ext) { -+ std::cerr << "ggml_vulkan: WARNING: Instance extension VK_KHR_portability_enumeration not found." << std::endl; -+ } -+#endif -+ return false; -+ -+ UNUSED(instance_extensions); -+} -+ -+static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props) { -+ switch (props.vendorID) { -+ case VK_VENDOR_ID_INTEL: -+ // Intel drivers don't support coopmat properly yet -+ return false; -+ case VK_VENDOR_ID_AMD: -+ if (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource) { -+ // Workaround for AMD proprietary driver reporting support on all GPUs -+ const std::string name = props.deviceName; -+ return name.rfind("AMD Radeon RX 7", 0) == 0 || name.rfind("AMD Radeon(TM) RX 7", 0) == 0 || // RDNA 3 consumer GPUs -+ name.rfind("AMD Radeon PRO W7", 0) == 0 || name.rfind("AMD Radeon(TM) PRO W7", 0) == 0 || // RDNA 3 workstation GPUs -+ name.rfind("AMD Radeon 7", 0) == 0 || name.rfind("AMD Radeon(TM) 7", 0) == 0; // RDNA 3 APUs -+ } -+ return true; -+ default: -+ return true; -+ } -+} -+ -+// checks -+ -+#ifdef GGML_VULKAN_CHECK_RESULTS -+static void ggml_vk_print_graph_origin(const ggml_tensor * tensor, std::vector& done, int level = 0) { -+ if (std::find(done.begin(), done.end(), tensor) != done.end() || level > 10) { -+ return; -+ } -+ for (int j = 0; j < level; j++) { -+ std::cerr << " "; -+ } -+ std::cerr << ggml_op_name(tensor->op) << " gpu=" << (tensor->extra != nullptr) << std::endl; -+ -+ done.push_back(tensor); -+ -+ for (int i = 0; i < GGML_MAX_SRC; i++) { -+ if (tensor->src[i] != nullptr) { -+ ggml_vk_print_graph_origin(tensor->src[i], done, level + 1); -+ } -+ } -+} -+ -+static void ggml_vk_print_tensor_area(const ggml_tensor * tensor, const void * data, int i0, int i1, int i2, int i3) { -+ if (tensor->type != GGML_TYPE_F32 && tensor->type != GGML_TYPE_F16 && tensor->type != GGML_TYPE_I32) { -+ return; -+ } -+ i0 = std::max(i0, 5); -+ i1 = std::max(i1, 5); -+ i2 = std::max(i2, 0); -+ i3 = std::max(i3, 0); -+ fprintf(stderr, " "); -+ for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) { -+ fprintf(stderr, "%7d ", idx1); -+ } -+ fprintf(stderr, "\n"); -+ for (int idx0 = i0 - 5; idx0 < i0 + 5; idx0++) { -+ fprintf(stderr, "%7d: ", idx0); -+ for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) { -+ if (idx0 >= 0 && idx0 < tensor->ne[0] && idx1 >= 0 && idx1 < tensor->ne[1] && i2 >= 0 && i2 < tensor->ne[2] && i3 >= 0 && i3 < tensor->ne[3]) { -+ float val; -+ if (tensor->type == GGML_TYPE_F32) { -+ val = *(const float *) ((const char *) data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]); -+ } else if (tensor->type == GGML_TYPE_F16) { -+ val = ggml_fp16_to_fp32(*(const ggml_fp16_t *) ((const char *) data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0])); -+ } else if (tensor->type == GGML_TYPE_I32) { -+ val = *(const int32_t *) ((const char *) data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]); -+ } else { -+ GGML_ABORT("fatal error"); -+ } -+ fprintf(stderr, "% 7.2f ", val); -+ } else { -+ fprintf(stderr, " "); -+ } -+ } -+ fprintf(stderr, "\n"); -+ } -+} -+ -+static void ggml_vk_print_tensor(const ggml_tensor * tensor, const char * name) { -+ void * tensor_data = tensor->data; -+ -+ const bool is_gpu = tensor->buffer != nullptr && ggml_backend_buffer_is_vk(tensor->buffer); -+ -+ if (is_gpu) { -+ const size_t tensor_size = ggml_nbytes(tensor); -+ tensor_data = malloc(tensor_size); -+ -+ ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context; -+ -+ vk_buffer buffer_gpu = buf_ctx->dev_buffer; -+ ggml_vk_buffer_read(buffer_gpu, vk_tensor_offset(tensor) + tensor->view_offs, tensor_data, tensor_size); -+ } -+ -+ std::cerr << "TENSOR CHECK " << name << " (" << tensor->name << "): " << ggml_op_name(tensor->op) << std::endl; -+ std::cerr << "tensor=" << tensor << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << std::endl; -+ if (tensor->src[0] != nullptr) { -+ std::cerr << "tensor->src[0]=" << tensor->src[0] << " name=" << tensor->src[0]->name << " op=" << ggml_op_name(tensor->src[0]->op) << " type=" << ggml_type_name(tensor->src[0]->type) << " ne0=" << tensor->src[0]->ne[0] << " nb0=" << tensor->src[0]->nb[0] << " ne1=" << tensor->src[0]->ne[1] << " nb1=" << tensor->src[0]->nb[1] << " ne2=" << tensor->src[0]->ne[2] << " nb2=" << tensor->src[0]->nb[2] << " ne3=" << tensor->src[0]->ne[3] << " nb3=" << tensor->src[0]->nb[3] << std::endl; -+ } -+ if (tensor->src[1] != nullptr) { -+ std::cerr << "tensor->src[1]=" << tensor->src[1] << " name=" << tensor->src[1]->name << " op=" << ggml_op_name(tensor->src[1]->op) << " type=" << ggml_type_name(tensor->src[1]->type) << " ne0=" << tensor->src[1]->ne[0] << " nb0=" << tensor->src[1]->nb[0] << " ne1=" << tensor->src[1]->ne[1] << " nb1=" << tensor->src[1]->nb[1] << " ne2=" << tensor->src[1]->ne[2] << " nb2=" << tensor->src[1]->nb[2] << " ne3=" << tensor->src[1]->ne[3] << " nb3=" << tensor->src[1]->nb[3] << std::endl; -+ } -+ std::cerr << std::endl << "Result:" << std::endl; -+ ggml_vk_print_tensor_area(tensor, tensor_data, 5, 5, 0, 0); -+ std::cerr << std::endl; -+ std::vector done; -+ ggml_vk_print_graph_origin(tensor, done); -+ -+ if (is_gpu) { -+ free(tensor_data); -+ } -+} -+ -+void * comp_result; -+size_t comp_size; -+size_t comp_nb[GGML_MAX_DIMS]; -+size_t check_counter = 0; -+static void ggml_vk_check_results_0(ggml_tensor * tensor) { -+ if (tensor->op == GGML_OP_TRANSPOSE) { -+ return; -+ } -+ -+ check_counter++; -+ if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) { -+ return; -+ } -+ -+ VK_LOG_DEBUG("ggml_vk_check_results_0(" << tensor->name << ")"); -+ -+ ggml_tensor * src0 = tensor->src[0]; -+ ggml_tensor * src1 = tensor->src[1]; -+ ggml_tensor * src2 = tensor->src[2]; -+ ggml_tensor * src3 = tensor->src[3]; -+ -+ struct ggml_init_params iparams = { -+ /*.mem_size =*/ 2ul*1024ul*1024ul*1024ul, -+ /*.mem_buffer =*/ NULL, -+ /*.no_alloc =*/ false, -+ }; -+ -+ struct ggml_context * ggml_ctx = ggml_init(iparams); -+ -+ struct ggml_tensor * src0_clone = nullptr; -+ struct ggml_tensor * src1_clone = nullptr; -+ struct ggml_tensor * src2_clone = nullptr; -+ struct ggml_tensor * src3_clone = nullptr; -+ struct ggml_tensor * tensor_clone = nullptr; -+ -+ size_t src0_size; -+ size_t src1_size; -+ size_t src2_size; -+ size_t src3_size; -+ -+ void * src0_buffer = nullptr; -+ void * src1_buffer = nullptr; -+ void * src2_buffer = nullptr; -+ void * src3_buffer = nullptr; -+ -+ if (src0 != nullptr) { -+ src0_clone = ggml_dup_tensor(ggml_ctx, src0); -+ -+ src0_size = ggml_nbytes(src0); -+ -+ src0_buffer = malloc(src0_size); -+ src0_clone->data = src0_buffer; -+ if (ggml_backend_buffer_is_host(src0->buffer)) { -+ memcpy(src0_clone->data, src0->data, src0_size); -+ memcpy(src0_clone->nb, src0->nb, sizeof(size_t) * GGML_MAX_DIMS); -+ } else if (ggml_backend_buffer_is_vk(src0->buffer)) { -+ ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; -+ vk_buffer& buffer_gpu = buf_ctx->dev_buffer; -+ uint64_t offset = vk_tensor_offset(src0) + src0->view_offs; -+ if (!ggml_is_contiguous(src0) && ggml_vk_dim01_contiguous(src0)) { -+ for (int i3 = 0; i3 < src0->ne[3]; i3++) { -+ for (int i2 = 0; i2 < src0->ne[2]; i2++) { -+ const int idx = i3*src0->ne[2] + i2; -+ ggml_vk_buffer_read(buffer_gpu, offset + idx * src0->nb[2], ((char *)src0_clone->data + idx * src0_clone->nb[2]), src0->ne[1] * src0->nb[1]); -+ } -+ } -+ -+ src0_clone->nb[0] = src0->nb[0]; -+ src0_clone->nb[1] = src0->nb[1]; -+ for (int i = 2; i < GGML_MAX_DIMS; i++) { -+ src0_clone->nb[i] = src0_clone->nb[i - 1]*src0_clone->ne[i - 1]; -+ } -+ } else { -+ if (offset + src0_size >= buffer_gpu->size) { -+ src0_size = buffer_gpu->size - offset; -+ } -+ ggml_vk_buffer_read(buffer_gpu, offset, src0_clone->data, src0_size); -+ memcpy(src0_clone->nb, src0->nb, sizeof(size_t) * GGML_MAX_DIMS); -+ } -+ } else { -+ GGML_ABORT("fatal error"); -+ } -+ -+ if (vk_output_tensor > 0 && vk_output_tensor == check_counter) { -+ ggml_vk_print_tensor(src0, "src0"); -+ } -+ } -+ if (src1 != nullptr) { -+ src1_clone = ggml_dup_tensor(ggml_ctx, src1); -+ -+ src1_size = ggml_nbytes(src1); -+ -+ src1_buffer = malloc(src1_size); -+ src1_clone->data = src1_buffer; -+ if (ggml_backend_buffer_is_host(src1->buffer)) { -+ memcpy(src1_clone->data, src1->data, src1_size); -+ memcpy(src1_clone->nb, src1->nb, sizeof(size_t) * GGML_MAX_DIMS); -+ } else if (ggml_backend_buffer_is_vk(src1->buffer)) { -+ ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; -+ vk_buffer& buffer_gpu = buf_ctx->dev_buffer; -+ uint64_t offset = vk_tensor_offset(src1) + src1->view_offs; -+ if (!ggml_is_contiguous(src1) && ggml_vk_dim01_contiguous(src1)) { -+ for (int i3 = 0; i3 < src1->ne[3]; i3++) { -+ for (int i2 = 0; i2 < src1->ne[2]; i2++) { -+ const int idx = i3*src1->ne[2] + i2; -+ ggml_vk_buffer_read(buffer_gpu, offset + idx * src1->nb[2], ((char *)src1_clone->data + idx * src1_clone->nb[2]), src1->ne[1] * src1->nb[1]); -+ } -+ } -+ -+ src1_clone->nb[0] = src1->nb[0]; -+ src1_clone->nb[1] = src1->nb[1]; -+ for (int i = 2; i < GGML_MAX_DIMS; i++) { -+ src1_clone->nb[i] = src1_clone->nb[i - 1]*src1_clone->ne[i - 1]; -+ } -+ } else { -+ if (offset + src1_size >= buffer_gpu->size) { -+ src1_size = buffer_gpu->size - offset; -+ } -+ ggml_vk_buffer_read(buffer_gpu, offset, src1_clone->data, src1_size); -+ memcpy(src1_clone->nb, src1->nb, sizeof(size_t) * GGML_MAX_DIMS); -+ } -+ } else { -+ GGML_ABORT("fatal error"); -+ } -+ -+ if (vk_output_tensor > 0 && vk_output_tensor == check_counter) { -+ ggml_vk_print_tensor(src1, "src1"); -+ } -+ } -+ if (src2 != nullptr) { -+ src2_clone = ggml_dup_tensor(ggml_ctx, src2); -+ -+ src2_size = ggml_nbytes(src2); -+ -+ src2_buffer = malloc(src2_size); -+ src2_clone->data = src2_buffer; -+ if (ggml_backend_buffer_is_host(src2->buffer)) { -+ memcpy(src2_clone->data, src2->data, src2_size); -+ memcpy(src2_clone->nb, src2->nb, sizeof(size_t) * GGML_MAX_DIMS); -+ } else if (ggml_backend_buffer_is_vk(src2->buffer)) { -+ ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src2->buffer->context; -+ vk_buffer& buffer_gpu = buf_ctx->dev_buffer; -+ uint64_t offset = vk_tensor_offset(src2) + src2->view_offs; -+ if (!ggml_is_contiguous(src2) && ggml_vk_dim01_contiguous(src2)) { -+ for (int i3 = 0; i3 < src2->ne[3]; i3++) { -+ for (int i2 = 0; i2 < src2->ne[2]; i2++) { -+ const int idx = i3*src2->ne[2] + i2; -+ ggml_vk_buffer_read(buffer_gpu, offset + idx * src2->nb[2], ((char *)src2_clone->data + idx * src2_clone->nb[2]), src2->ne[1] * src2->nb[1]); -+ } -+ } -+ -+ src2_clone->nb[0] = src2->nb[0]; -+ src2_clone->nb[1] = src2->nb[1]; -+ for (int i = 2; i < GGML_MAX_DIMS; i++) { -+ src2_clone->nb[i] = src2_clone->nb[i - 1]*src2_clone->ne[i - 1]; -+ } -+ } else { -+ if (offset + src2_size >= buffer_gpu->size) { -+ src2_size = buffer_gpu->size - offset; -+ } -+ ggml_vk_buffer_read(buffer_gpu, offset, src2_clone->data, src2_size); -+ memcpy(src2_clone->nb, src2->nb, sizeof(size_t) * GGML_MAX_DIMS); -+ } -+ } else { -+ GGML_ABORT("fatal error"); -+ } -+ -+ if (vk_output_tensor > 0 && vk_output_tensor == check_counter) { -+ ggml_vk_print_tensor(src2, "src2"); -+ } -+ } -+ if (src3 != nullptr) { -+ src3_clone = ggml_dup_tensor(ggml_ctx, src3); -+ -+ src3_size = ggml_nbytes(src3); -+ -+ src3_buffer = malloc(src3_size); -+ src3_clone->data = src3_buffer; -+ if (ggml_backend_buffer_is_host(src3->buffer)) { -+ memcpy(src3_clone->data, src3->data, src3_size); -+ memcpy(src3_clone->nb, src3->nb, sizeof(size_t) * GGML_MAX_DIMS); -+ } else if (ggml_backend_buffer_is_vk(src3->buffer)) { -+ ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src3->buffer->context; -+ vk_buffer& buffer_gpu = buf_ctx->dev_buffer; -+ uint64_t offset = vk_tensor_offset(src3) + src3->view_offs; -+ if (!ggml_is_contiguous(src3) && ggml_vk_dim01_contiguous(src3)) { -+ for (int i3 = 0; i3 < src3->ne[3]; i3++) { -+ for (int i2 = 0; i2 < src3->ne[2]; i2++) { -+ const int idx = i3*src3->ne[2] + i2; -+ ggml_vk_buffer_read(buffer_gpu, offset + idx * src3->nb[2], ((char *)src3_clone->data + idx * src3_clone->nb[2]), src3->ne[1] * src3->nb[1]); -+ } -+ } -+ -+ src3_clone->nb[0] = src3->nb[0]; -+ src3_clone->nb[1] = src3->nb[1]; -+ for (int i = 2; i < GGML_MAX_DIMS; i++) { -+ src3_clone->nb[i] = src3_clone->nb[i - 1]*src3_clone->ne[i - 1]; -+ } -+ } else { -+ if (offset + src3_size >= buffer_gpu->size) { -+ src3_size = buffer_gpu->size - offset; -+ } -+ ggml_vk_buffer_read(buffer_gpu, offset, src3_clone->data, src3_size); -+ memcpy(src3_clone->nb, src3->nb, sizeof(size_t) * GGML_MAX_DIMS); -+ } -+ } else { -+ GGML_ABORT("fatal error"); -+ } -+ -+ if (vk_output_tensor > 0 && vk_output_tensor == check_counter) { -+ ggml_vk_print_tensor(src3, "src3"); -+ } -+ } -+ -+ if (tensor->op == GGML_OP_FLASH_ATTN_EXT) { -+ const float *params = (const float *)tensor->op_params; -+ tensor_clone = ggml_flash_attn_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, src3_clone, params[0], params[1], params[2]); -+ } else if (tensor->op == GGML_OP_MUL_MAT) { -+ tensor_clone = ggml_mul_mat(ggml_ctx, src0_clone, src1_clone); -+ } else if (tensor->op == GGML_OP_MUL_MAT_ID) { -+ tensor_clone = ggml_mul_mat_id(ggml_ctx, src0_clone, src1_clone, src2_clone); -+ } else if (tensor->op == GGML_OP_MUL) { -+ tensor_clone = ggml_mul(ggml_ctx, src0_clone, src1_clone); -+ } else if (tensor->op == GGML_OP_DIV) { -+ tensor_clone = ggml_div(ggml_ctx, src0_clone, src1_clone); -+ } else if (tensor->op == GGML_OP_CONCAT) { -+ tensor_clone = ggml_concat(ggml_ctx, src0_clone, src1_clone, *(int *)tensor->op_params); -+ } else if (tensor->op == GGML_OP_UPSCALE) { -+ tensor_clone = ggml_upscale_ext(ggml_ctx, src0_clone, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); -+ } else if (tensor->op == GGML_OP_SCALE) { -+ tensor_clone = ggml_scale(ggml_ctx, src0_clone, ((float *)tensor->op_params)[0]); -+ } else if (tensor->op == GGML_OP_SQR) { -+ tensor_clone = ggml_sqr(ggml_ctx, src0_clone); -+ } else if (tensor->op == GGML_OP_SIN) { -+ tensor_clone = ggml_sin(ggml_ctx, src0_clone); -+ } else if (tensor->op == GGML_OP_COS) { -+ tensor_clone = ggml_cos(ggml_ctx, src0_clone); -+ } else if (tensor->op == GGML_OP_CLAMP) { -+ tensor_clone = ggml_clamp(ggml_ctx, src0_clone, ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]); -+ } else if (tensor->op == GGML_OP_PAD) { -+ tensor_clone = ggml_pad(ggml_ctx, src0_clone, tensor->ne[0] - src0_clone->ne[0], tensor->ne[1] - src0_clone->ne[1], tensor->ne[2] - src0_clone->ne[2], tensor->ne[3] - src0_clone->ne[3]); -+ } else if (tensor->op == GGML_OP_REPEAT) { -+ tensor_clone = ggml_repeat(ggml_ctx, src0_clone, tensor); -+ } else if (tensor->op == GGML_OP_ADD) { -+ tensor_clone = ggml_add(ggml_ctx, src0_clone, src1_clone); -+ } else if (tensor->op == GGML_OP_ACC) { -+ tensor_clone = ggml_acc(ggml_ctx, src0_clone, src1_clone, tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]); -+ } else if (tensor->op == GGML_OP_NORM) { -+ tensor_clone = ggml_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params); -+ } else if (tensor->op == GGML_OP_GROUP_NORM) { -+ tensor_clone = ggml_group_norm(ggml_ctx, src0_clone, *(int *)tensor->op_params, ((float *)tensor->op_params)[1]); -+ } else if (tensor->op == GGML_OP_RMS_NORM) { -+ tensor_clone = ggml_rms_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params); -+ } else if (tensor->op == GGML_OP_SOFT_MAX) { -+ if (src1 != nullptr) { -+ tensor_clone = ggml_soft_max_ext(ggml_ctx, src0_clone, src1_clone, ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]); -+ } else { -+ tensor_clone = ggml_soft_max(ggml_ctx, src0_clone); -+ } -+ } else if (tensor->op == GGML_OP_DIAG_MASK_INF) { -+ tensor_clone = ggml_diag_mask_inf(ggml_ctx, src0_clone, *(int *)tensor->op_params); -+ } else if (tensor->op == GGML_OP_ROPE) { -+ const int n_dims = ((int32_t *) tensor->op_params)[1]; -+ const int mode = ((int32_t *) tensor->op_params)[2]; -+ //const int n_ctx_ggml = ((int32_t *) tensor->op_params)[3]; -+ const int n_ctx_orig_ggml = ((int32_t *) tensor->op_params)[4]; -+ const float freq_base = ((float *) tensor->op_params)[5]; -+ const float freq_scale = ((float *) tensor->op_params)[6]; -+ const float ext_factor = ((float *) tensor->op_params)[7]; -+ const float attn_factor = ((float *) tensor->op_params)[8]; -+ const float beta_fast = ((float *) tensor->op_params)[9]; -+ const float beta_slow = ((float *) tensor->op_params)[10]; -+ tensor_clone = ggml_rope_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); -+ } else if (tensor->op == GGML_OP_UNARY) { -+ switch (ggml_get_unary_op(tensor)) { -+ case GGML_UNARY_OP_SILU: -+ tensor_clone = ggml_silu(ggml_ctx, src0_clone); -+ break; -+ case GGML_UNARY_OP_GELU: -+ tensor_clone = ggml_gelu(ggml_ctx, src0_clone); -+ break; -+ case GGML_UNARY_OP_GELU_QUICK: -+ tensor_clone = ggml_gelu_quick(ggml_ctx, src0_clone); -+ break; -+ case GGML_UNARY_OP_RELU: -+ tensor_clone = ggml_relu(ggml_ctx, src0_clone); -+ break; -+ case GGML_UNARY_OP_TANH: -+ tensor_clone = ggml_tanh(ggml_ctx, src0_clone); -+ break; -+ default: -+ std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl; -+ GGML_ABORT("fatal error"); -+ } -+ } else if (tensor->op == GGML_OP_CPY || tensor->op == GGML_OP_DUP) { -+ if (src1 == nullptr) { -+ tensor_clone = ggml_dup(ggml_ctx, src0_clone); -+ tensor_clone->type = tensor->type; -+ } else { -+ tensor_clone = ggml_cpy(ggml_ctx, src0_clone, src1_clone); -+ } -+ } else if (tensor->op == GGML_OP_CONT) { -+ tensor_clone = ggml_cont_4d(ggml_ctx, src0_clone, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); -+ } else if (tensor->op == GGML_OP_RESHAPE) { -+ tensor_clone = ggml_reshape_4d(ggml_ctx, src0_clone, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); -+ } else if (tensor->op == GGML_OP_VIEW) { -+ tensor_clone = ggml_view_4d(ggml_ctx, src0_clone, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], tensor->nb[1], tensor->nb[2], tensor->nb[3], ((int32_t *) tensor->op_params)[0]); -+ } else if (tensor->op == GGML_OP_PERMUTE) { -+ int32_t * params = (int32_t *)tensor->op_params; -+ tensor_clone = ggml_permute(ggml_ctx, src0_clone, params[0], params[1], params[2], params[3]); -+ } else if (tensor->op == GGML_OP_TRANSPOSE) { -+ tensor_clone = ggml_transpose(ggml_ctx, src0_clone); -+ } else if (tensor->op == GGML_OP_GET_ROWS) { -+ tensor_clone = ggml_get_rows(ggml_ctx, src0_clone, src1_clone); -+ } else if (tensor->op == GGML_OP_ARGSORT) { -+ tensor_clone = ggml_argsort(ggml_ctx, src0_clone, (ggml_sort_order) *(int *)tensor->op_params); -+ } else if (tensor->op == GGML_OP_SUM_ROWS) { -+ tensor_clone = ggml_sum_rows(ggml_ctx, src0_clone); -+ } else if (tensor->op == GGML_OP_IM2COL) { -+ const int32_t s0 = tensor->op_params[0]; -+ const int32_t s1 = tensor->op_params[1]; -+ const int32_t p0 = tensor->op_params[2]; -+ const int32_t p1 = tensor->op_params[3]; -+ const int32_t d0 = tensor->op_params[4]; -+ const int32_t d1 = tensor->op_params[5]; -+ -+ const bool is_2D = tensor->op_params[6] == 1; -+ tensor_clone = ggml_im2col(ggml_ctx, src0_clone, src1_clone, s0, s1, p0, p1, d0, d1, is_2D, tensor->type); -+ } else if (tensor->op == GGML_OP_TIMESTEP_EMBEDDING) { -+ const int32_t dim = tensor->op_params[0]; -+ const int32_t max_period = tensor->op_params[1]; -+ tensor_clone = ggml_timestep_embedding(ggml_ctx, src0_clone, dim, max_period); -+ } else if (tensor->op == GGML_OP_POOL_2D) { -+ enum ggml_op_pool op = static_cast(tensor->op_params[0]); -+ const int32_t k0 = tensor->op_params[1]; -+ const int32_t k1 = tensor->op_params[2]; -+ const int32_t s0 = tensor->op_params[3]; -+ const int32_t s1 = tensor->op_params[4]; -+ const int32_t p0 = tensor->op_params[5]; -+ const int32_t p1 = tensor->op_params[6]; -+ -+ tensor_clone = ggml_pool_2d(ggml_ctx, src0_clone, op, k0, k1, s0, s1, p0, p1); -+ } else if (tensor->op == GGML_OP_LEAKY_RELU) { -+ const float * op_params = (const float *)tensor->op_params; -+ tensor_clone = ggml_leaky_relu(ggml_ctx, src0_clone, op_params[0], false); -+ } else if (tensor->op == GGML_OP_RWKV_WKV6) { -+ tensor_clone = ggml_rwkv_wkv6(ggml_ctx, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], -+ tensor->src[4], tensor->src[5]); -+ } -+ else { -+ std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl; -+ GGML_ABORT("fatal error"); -+ } -+ -+ ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx); -+ ggml_build_forward_expand(cgraph, tensor_clone); -+ -+ ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 8); -+ -+ if (vk_output_tensor > 0 && vk_output_tensor == check_counter) { -+ ggml_vk_print_tensor(tensor_clone, "tensor_clone"); -+ } -+ -+ comp_size = ggml_nbytes(tensor_clone); -+ -+ comp_result = malloc(comp_size); -+ memcpy(comp_result, tensor_clone->data, comp_size); -+ memcpy(comp_nb, tensor_clone->nb, sizeof(size_t) * GGML_MAX_DIMS); -+ -+ if (src0 != nullptr) { -+ free(src0_buffer); -+ } -+ if (src1 != nullptr) { -+ free(src1_buffer); -+ } -+ -+ ggml_free(ggml_ctx); -+ -+ VK_LOG_DEBUG("END ggml_vk_check_results_0(" << tensor->name << ")"); -+} -+ -+static void ggml_vk_check_results_1(ggml_tensor * tensor) { -+ if (tensor->op == GGML_OP_TRANSPOSE) { -+ return; -+ } -+ if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) { -+ return; -+ } -+ -+ VK_LOG_DEBUG("ggml_vk_check_results_1(" << tensor->name << ")"); -+ -+ ggml_tensor * src0 = tensor->src[0]; -+ ggml_tensor * src1 = tensor->src[1]; -+ ggml_tensor * src2 = tensor->src[2]; -+ -+ void * tensor_data = tensor->data; -+ -+ if (ggml_backend_buffer_is_vk(tensor->buffer)) { -+ size_t tensor_size = ggml_nbytes(tensor); -+ tensor_data = malloc(tensor_size); -+ -+ ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context; -+ -+ vk_buffer& buffer_gpu = buf_ctx->dev_buffer; -+ uint64_t offset = vk_tensor_offset(tensor) + tensor->view_offs; -+ if (offset + tensor_size >= buffer_gpu->size) { -+ tensor_size = buffer_gpu->size - offset; -+ } -+ -+ ggml_vk_buffer_read(buffer_gpu, offset, tensor_data, tensor_size); -+ } -+ -+ float first_error_result = -1.0f; -+ float first_error_correct = -1.0f; -+ std::array first_error = { -1, -1, -1, -1 }; -+ double avg_err = 0.0; -+ size_t counter = 0; -+ -+ for (int i3 = 0; i3 < tensor->ne[3]; i3++) { -+ for (int i2 = 0; i2 < tensor->ne[2]; i2++) { -+ for (int i1 = 0; i1 < tensor->ne[1]; i1++) { -+ for (int i0 = 0; i0 < tensor->ne[0]; i0++) { -+ const bool buffer_size_fit = i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0] < comp_size; -+ float correct = 0.0f; -+ float result = 0.0f; -+ -+ if (buffer_size_fit) { -+ if (tensor->type == GGML_TYPE_F32) { -+ correct = *(float *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]); -+ result = *(float *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]); -+ } else if (tensor->type == GGML_TYPE_F16) { -+ correct = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0])); -+ result = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0])); -+ } else if (tensor->type == GGML_TYPE_I32) { -+ correct = *(int32_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]); -+ result = *(int32_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]); -+ } else { -+ std::cerr << "Results check not implemented for type " << ggml_type_name(tensor->type) << std::endl; -+ } -+ } else { -+ std::cerr << "Missing debug code for type " << ggml_type_name(tensor->type) << std::endl; -+ GGML_ABORT("fatal error"); -+ } -+ -+ if ((std::isnan(correct) != std::isnan(result)) || (std::isinf(correct) != std::isinf(result)) || !buffer_size_fit) { -+ std::cerr << "ERROR: Invalid value in " << ggml_op_name(tensor->op) << " i3=" << i3 << " i2=" << i2 << " i1=" << i1 << " i0=" << i0 << " result=" << result << " correct=" << correct << " avg_err=" << (avg_err / counter) << std::endl; -+ std::cerr << "tensor=" << tensor << " tensor->name=" << tensor->name << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << " offset=" << tensor->view_offs << std::endl; -+ if (src0 != nullptr) { -+ std::cerr << "src0=" << src0 << " src0->name=" << src0->name << " op=" << ggml_op_name(src0->op) << " type=" << ggml_type_name(src0->type) << " ne0=" << src0->ne[0] << " nb0=" << src0->nb[0] << " ne1=" << src0->ne[1] << " nb1=" << src0->nb[1] << " ne2=" << src0->ne[2] << " nb2=" << src0->nb[2] << " ne3=" << src0->ne[3] << " nb3=" << src0->nb[3] << " offset=" << src0->view_offs << std::endl; -+ } -+ if (src1 != nullptr) { -+ std::cerr << "src1=" << src1 << " src1->name=" << src1->name << " op=" << ggml_op_name(src1->op) << " type=" << ggml_type_name(src1->type) << " ne0=" << src1->ne[0] << " nb0=" << src1->nb[0] << " ne1=" << src1->ne[1] << " nb1=" << src1->nb[1] << " ne2=" << src1->ne[2] << " nb2=" << src1->nb[2] << " ne3=" << src1->ne[3] << " nb3=" << src1->nb[3] << " offset=" << src1->view_offs << std::endl; -+ } -+ if (src2 != nullptr) { -+ std::cerr << "src2=" << src2 << " src2->name=" << src2->name << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl; -+ } -+ std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl; -+ std::cerr << std::endl << "Result:" << std::endl; -+ ggml_vk_print_tensor_area(tensor, tensor_data, i0, i1, i2, i3); -+ std::cerr << std::endl << "Correct:" << std::endl; -+ ggml_vk_print_tensor_area(tensor, comp_result, i0, i1, i2, i3); -+ std::cerr << std::endl; -+ std::vector done; -+ ggml_vk_print_graph_origin(tensor, done); -+ GGML_ABORT("fatal error"); -+ } -+ if (first_error[0] == -1 && std::fabs(correct - result) > 0.1f) { -+ first_error[0] = i0; -+ first_error[1] = i1; -+ first_error[2] = i2; -+ first_error[3] = i3; -+ first_error_result = result; -+ first_error_correct = correct; -+ } -+ -+ // Special case, value is infinite, avoid NaN result in avg_err -+ // NaN also appears in results, if both are nan error is 0 -+ if (!std::isinf(correct) && !std::isinf(result) && !std::isnan(correct) && !std::isnan(result)) { -+ avg_err += std::fabs(correct - result); -+ } -+ counter++; -+ } -+ } -+ } -+ } -+ -+ avg_err /= counter; -+ -+ if (vk_output_tensor > 0 && vk_output_tensor == check_counter) { -+ std::cerr << "TENSOR CHECK: avg_err=" << avg_err << " in " << ggml_op_name(tensor->op) << " (check " << check_counter << ")" << std::endl; -+ std::cerr << "tensor=" << tensor << " tensor->name=" << tensor->name << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << " offset=" << tensor->view_offs << std::endl; -+ if (src0 != nullptr) { -+ std::cerr << "src0=" << src0 << " op=" << ggml_op_name(src0->op) << " type=" << ggml_type_name(src0->type) << " ne0=" << src0->ne[0] << " nb0=" << src0->nb[0] << " ne1=" << src0->ne[1] << " nb1=" << src0->nb[1] << " ne2=" << src0->ne[2] << " nb2=" << src0->nb[2] << " ne3=" << src0->ne[3] << " nb3=" << src0->nb[3] << " offset=" << src0->view_offs << std::endl; -+ } -+ if (src1 != nullptr) { -+ std::cerr << "src1=" << src1 << " op=" << ggml_op_name(src1->op) << " type=" << ggml_type_name(src1->type) << " ne0=" << src1->ne[0] << " nb0=" << src1->nb[0] << " ne1=" << src1->ne[1] << " nb1=" << src1->nb[1] << " ne2=" << src1->ne[2] << " nb2=" << src1->nb[2] << " ne3=" << src1->ne[3] << " nb3=" << src1->nb[3] << " offset=" << src1->view_offs << std::endl; -+ } -+ if (src2 != nullptr) { -+ std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl; -+ } -+ std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl; -+ std::cerr << std::endl << "Result:" << std::endl; -+ ggml_vk_print_tensor_area(tensor, tensor_data, 5, 5, 0, 0); -+ std::cerr << std::endl << "Correct:" << std::endl; -+ ggml_vk_print_tensor_area(tensor, comp_result, 5, 5, 0, 0); -+ std::cerr << std::endl; -+ std::vector done; -+ ggml_vk_print_graph_origin(tensor, done); -+ } -+ -+ if (avg_err > 0.05 || std::isnan(avg_err)) { -+ std::cerr << "ERROR: avg_err=" << avg_err << " in " << ggml_op_name(tensor->op) << " (check " << check_counter << ")" << std::endl; -+ std::cerr << "tensor=" << tensor << " tensor->name=" << tensor->name << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << " offset=" << tensor->view_offs << std::endl; -+ if (src0 != nullptr) { -+ std::cerr << "src0=" << src0 << " op=" << ggml_op_name(src0->op) << " type=" << ggml_type_name(src0->type) << " ne0=" << src0->ne[0] << " nb0=" << src0->nb[0] << " ne1=" << src0->ne[1] << " nb1=" << src0->nb[1] << " ne2=" << src0->ne[2] << " nb2=" << src0->nb[2] << " ne3=" << src0->ne[3] << " nb3=" << src0->nb[3] << " offset=" << src0->view_offs << std::endl; -+ } -+ if (src1 != nullptr) { -+ std::cerr << "src1=" << src1 << " op=" << ggml_op_name(src1->op) << " type=" << ggml_type_name(src1->type) << " ne0=" << src1->ne[0] << " nb0=" << src1->nb[0] << " ne1=" << src1->ne[1] << " nb1=" << src1->nb[1] << " ne2=" << src1->ne[2] << " nb2=" << src1->nb[2] << " ne3=" << src1->ne[3] << " nb3=" << src1->nb[3] << " offset=" << src1->view_offs << std::endl; -+ } -+ if (src2 != nullptr) { -+ std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl; -+ } -+ std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl; -+ std::cerr << std::endl << "Result:" << std::endl; -+ ggml_vk_print_tensor_area(tensor, tensor_data, first_error[0], first_error[1], first_error[2], first_error[3]); -+ std::cerr << std::endl << "Correct:" << std::endl; -+ ggml_vk_print_tensor_area(tensor, comp_result, first_error[0], first_error[1], first_error[2], first_error[3]); -+ std::cerr << std::endl; -+ std::vector done; -+ ggml_vk_print_graph_origin(tensor, done); -+ GGML_ABORT("fatal error"); -+ } else { -+ std::cerr << check_counter << " " << tensor->name << " op=" << ggml_op_name(tensor->op) << " avg_err=" << avg_err << std::endl; -+ } -+ -+ free(comp_result); -+ comp_result = nullptr; -+ comp_size = 0; -+ -+ if (ggml_backend_buffer_is_vk(tensor->buffer)) { -+ free(tensor_data); -+ } -+ -+ VK_LOG_DEBUG("END ggml_vk_check_results_1(" << tensor->name << ")"); -+} -+#endif -+ -+GGML_BACKEND_DL_IMPL(ggml_backend_vk_reg) -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt -new file mode 100644 -index 00000000..bd0c74cb ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt -@@ -0,0 +1,9 @@ -+find_package (Threads REQUIRED) -+find_package(Vulkan COMPONENTS glslc REQUIRED) -+ -+set(TARGET vulkan-shaders-gen) -+add_executable(${TARGET} vulkan-shaders-gen.cpp) -+install(TARGETS ${TARGET} RUNTIME) -+target_compile_features(${TARGET} PRIVATE cxx_std_17) -+target_link_libraries(vulkan-shaders-gen PUBLIC Threads::Threads) -+target_link_libraries(vulkan-shaders-gen PRIVATE Vulkan::Vulkan) -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp -new file mode 100644 -index 00000000..d896f1ef ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp -@@ -0,0 +1,29 @@ -+#version 450 -+ -+#include "types.comp" -+#include "generic_binary_head.comp" -+ -+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; -+ -+void main() { -+ const uint idx = gl_GlobalInvocationID.x; -+ if (idx >= p.ne) { -+ return; -+ } -+ -+ const uint offset = p.param3; -+ const uint src1_i = idx - offset; -+ const uint oz = src1_i / p.nb02; -+ const uint oy = (src1_i - (oz * p.nb02)) / p.nb01; -+ const uint ox = src1_i % p.nb01; -+ -+ uint i00, i01, i02, i03; -+ get_indices(idx, i00, i01, i02, i03); -+ -+ if (ox < p.ne10 && oy < p.ne11 && oz < p.ne12) { -+ data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + ox + oy * p.ne10 + oz * p.ne10 * p.ne11])); -+ } else { -+ data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)])); -+ } -+} -+ -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/add.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/add.comp -new file mode 100644 -index 00000000..2b4085c4 ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/add.comp -@@ -0,0 +1,29 @@ -+#version 450 -+ -+#extension GL_EXT_shader_16bit_storage : require -+ -+#include "types.comp" -+#include "generic_binary_head.comp" -+ -+const uint num_threads = 256; -+ -+layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; -+ -+void main() { -+ uint idx = get_idx(); -+ -+ // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation -+ const uint num_iter = 2; -+ -+ [[unroll]] for (uint i = 0; i < num_iter; ++i) { -+ if (idx >= p.ne) { -+ continue; -+ } -+ uint i00, i01, i02, i03; -+ get_indices(idx, i00, i01, i02, i03); -+ -+ data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)])); -+ -+ idx += num_threads; -+ } -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp -new file mode 100644 -index 00000000..d4fa45b1 ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp -@@ -0,0 +1,69 @@ -+#version 450 -+ -+#include "types.comp" -+ -+#define BLOCK_SIZE 1024 -+#define ASC 0 -+ -+layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; -+ -+layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; -+layout (binding = 1) buffer D {int data_d[];}; -+ -+layout (push_constant) uniform parameter { -+ uint ncols; -+ uint ncols_pad; -+ uint order; -+} p; -+ -+shared int dst_row[BLOCK_SIZE]; -+ -+void swap(uint idx0, uint idx1) { -+ int tmp = dst_row[idx0]; -+ dst_row[idx0] = dst_row[idx1]; -+ dst_row[idx1] = tmp; -+} -+ -+void main() { -+ // bitonic sort -+ const int col = int(gl_LocalInvocationID.x); -+ const uint row = gl_WorkGroupID.y; -+ -+ const uint row_offset = row * p.ncols; -+ -+ // initialize indices -+ if (col < p.ncols_pad) { -+ dst_row[col] = col; -+ } -+ barrier(); -+ -+ for (uint k = 2; k <= p.ncols_pad; k *= 2) { -+ for (uint j = k / 2; j > 0; j /= 2) { -+ const uint ixj = col ^ j; -+ if (col < p.ncols_pad && ixj > col) { -+ if ((col & k) == 0) { -+ if (dst_row[col] >= p.ncols || -+ (dst_row[ixj] < p.ncols && (p.order == ASC ? -+ data_a[row_offset + dst_row[col]] > data_a[row_offset + dst_row[ixj]] : -+ data_a[row_offset + dst_row[col]] < data_a[row_offset + dst_row[ixj]])) -+ ) { -+ swap(col, ixj); -+ } -+ } else { -+ if (dst_row[ixj] >= p.ncols || -+ (dst_row[col] < p.ncols && (p.order == ASC ? -+ data_a[row_offset + dst_row[col]] < data_a[row_offset + dst_row[ixj]] : -+ data_a[row_offset + dst_row[col]] > data_a[row_offset + dst_row[ixj]])) -+ ) { -+ swap(col, ixj); -+ } -+ } -+ } -+ barrier(); -+ } -+ } -+ -+ if (col < p.ncols) { -+ data_d[row_offset + col] = dst_row[col]; -+ } -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp -new file mode 100644 -index 00000000..1e5cb8da ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp -@@ -0,0 +1,17 @@ -+#version 450 -+ -+#include "types.comp" -+#include "generic_unary_head.comp" -+ -+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; -+ -+void main() { -+ const uint idx = get_idx(); -+ -+ if (idx >= p.ne) { -+ return; -+ } -+ -+ const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]); -+ data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val < p.param1 ? p.param1 : (val > p.param2 ? p.param2 : val)); -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp -new file mode 100644 -index 00000000..9ee2f1fa ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp -@@ -0,0 +1,41 @@ -+#version 450 -+ -+#include "types.comp" -+#include "generic_binary_head.comp" -+ -+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; -+ -+void main() { -+ const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; -+ const int dim = p.param3; -+ -+ if (idx >= p.ne) { -+ return; -+ } -+ -+ const uint i3 = idx / (p.ne22*p.ne21*p.ne20); -+ const uint i3_offset = i3 * p.ne22*p.ne21*p.ne20; -+ const uint i2 = (idx - i3_offset) / (p.ne21*p.ne20); -+ const uint i2_offset = i2*p.ne21*p.ne20; -+ const uint i1 = (idx - i3_offset - i2_offset) / p.ne20; -+ const uint i0 = idx - i3_offset - i2_offset - i1*p.ne20; -+ -+ uint o[4] = {0, 0, 0, 0}; -+ o[dim] = dim == 0 ? p.ne00 : (dim == 1 ? p.ne01 : (dim == 2 ? p.ne02 : p.ne03)); -+ -+ const uint src0_idx = i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0*p.nb00; -+ const uint src1_idx = (i3 - o[3])*p.nb13 + (i2 - o[2])*p.nb12 + (i1 - o[1])*p.nb11 + (i0 - o[0])*p.nb10; -+ const uint dst_idx = i3*p.nb23 + i2*p.nb22 + i1*p.nb21 + i0*p.nb20; -+ -+ const bool is_src0 = i0 < p.ne00 && i1 < p.ne01 && i2 < p.ne02 && i3 < p.ne03; -+ -+#ifndef OPTIMIZATION_ERROR_WORKAROUND -+ data_d[get_doffset() + dst_idx] = D_TYPE(is_src0 ? data_a[get_aoffset() + src0_idx] : data_b[get_boffset() + src1_idx]); -+#else -+ if (is_src0) { -+ data_d[get_doffset() + dst_idx] = data_a[get_aoffset() + src0_idx]; -+ } else { -+ data_d[get_doffset() + dst_idx] = data_b[get_boffset() + src1_idx]; -+ } -+#endif -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp -new file mode 100644 -index 00000000..dd828c23 ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp -@@ -0,0 +1,42 @@ -+#version 450 -+ -+#include "types.comp" -+#include "generic_unary_head.comp" -+ -+#extension GL_EXT_control_flow_attributes : require -+ -+const uint num_threads = 128; -+ -+layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; -+ -+void main() { -+ uint idx = get_idx(); -+ -+ // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation -+ const uint num_iter = 4; -+ -+ // fast path for when all four iterations are in-bounds -+ if (idx + (num_iter-1)*num_threads < p.ne) { -+ [[unroll]] for (uint i = 0; i < num_iter; ++i) { -+#ifndef OPTIMIZATION_ERROR_WORKAROUND -+ data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]); -+#else -+ data_d[get_doffset() + idx] = data_a[get_aoffset() + idx]; -+#endif -+ idx += num_threads; -+ } -+ } else { -+ [[unroll]] for (uint i = 0; i < num_iter; ++i) { -+ if (idx >= p.ne) { -+ continue; -+ } -+ -+#ifndef OPTIMIZATION_ERROR_WORKAROUND -+ data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]); -+#else -+ data_d[get_doffset() + idx] = data_a[get_aoffset() + idx]; -+#endif -+ idx += num_threads; -+ } -+ } -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp -new file mode 100644 -index 00000000..29c90649 ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp -@@ -0,0 +1,20 @@ -+#version 450 -+ -+#include "types.comp" -+#include "generic_unary_head.comp" -+ -+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; -+ -+void main() { -+ const uint idx = get_idx(); -+ -+ if (idx >= p.ne) { -+ return; -+ } -+ -+#ifndef OPTIMIZATION_ERROR_WORKAROUND -+ data_d[get_doffset() + dst_idx(idx)] = D_TYPE(data_a[get_aoffset() + src0_idx(idx)]); -+#else -+ data_d[get_doffset() + dst_idx(idx)] = data_a[get_aoffset() + src0_idx(idx)]; -+#endif -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp -new file mode 100644 -index 00000000..0b8d02f5 ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp -@@ -0,0 +1,17 @@ -+#version 450 -+ -+#include "types.comp" -+#include "generic_unary_head.comp" -+ -+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; -+ -+void main() { -+ const uint idx = get_idx(); -+ -+ if (idx >= p.ne) { -+ return; -+ } -+ -+ const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]); -+ data_d[get_doffset() + dst_idx(idx)] = D_TYPE(cos(val)); -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp -new file mode 100644 -index 00000000..a4d3fca5 ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp -@@ -0,0 +1,20 @@ -+#version 450 -+ -+#include "dequant_head.comp" -+ -+layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; -+ -+layout (binding = 0) readonly buffer A {float data_a[];}; -+layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; -+ -+void main() { -+ const uint i = gl_GlobalInvocationID.x * 16; -+ -+ if (i >= p.nel) { -+ return; -+ } -+ -+ [[unroll]] for (uint l = 0; l < 16; l++) { -+ data_b[i + l] = D_TYPE(data_a[i + l]); -+ } -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp -new file mode 100644 -index 00000000..91bb8f8d ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp -@@ -0,0 +1,118 @@ -+#if !defined(DATA_A_F32) && !defined(DATA_A_F16) -+#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require -+#endif -+ -+#include "types.comp" -+ -+#if defined(A_TYPE_PACKED16) -+layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];}; -+#endif -+#if defined(A_TYPE_PACKED32) -+layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];}; -+#endif -+ -+#if defined(DATA_A_F32) -+vec2 dequantize(uint ib, uint iqs, uint a_offset) { -+ return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]); -+} -+#endif -+ -+#if defined(DATA_A_F16) -+vec2 dequantize(uint ib, uint iqs, uint a_offset) { -+ return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]); -+} -+#endif -+ -+#if defined(DATA_A_Q4_0) -+vec2 dequantize(uint ib, uint iqs, uint a_offset) { -+ const uint vui = uint(data_a[a_offset + ib].qs[iqs]); -+ return (vec2(vui & 0xF, vui >> 4) - 8.0f); -+} -+vec4 dequantize4(uint ib, uint iqs, uint a_offset) { -+ const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]); -+ return (vec4(vui & 0xF, (vui >> 4) & 0xF, (vui >> 8) & 0xF, vui >> 12) - 8.0f); -+} -+#endif -+ -+#if defined(DATA_A_Q4_1) -+vec2 dequantize(uint ib, uint iqs, uint a_offset) { -+ const uint vui = uint(data_a[a_offset + ib].qs[iqs]); -+ return vec2(vui & 0xF, vui >> 4); -+} -+vec4 dequantize4(uint ib, uint iqs, uint a_offset) { -+ const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]); -+ return vec4(vui & 0xF, (vui >> 4) & 0xF, (vui >> 8) & 0xF, vui >> 12); -+} -+#endif -+ -+#if defined(DATA_A_Q5_0) -+vec2 dequantize(uint ib, uint iqs, uint a_offset) { -+ const uint uint_qh = uint(data_a[a_offset + ib].qh[1]) << 16 | data_a[a_offset + ib].qh[0]; -+ const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); -+ const uint vui = uint(data_a[a_offset + ib].qs[iqs]); -+ return (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f); -+} -+vec4 dequantize4(uint ib, uint iqs, uint a_offset) { -+ const uint uint_qh = uint(data_a_packed16[a_offset + ib].qh[1]) << 16 | data_a_packed16[a_offset + ib].qh[0]; -+ const ivec2 qh0 = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); -+ const ivec2 qh1 = ivec2(((uint_qh >> (iqs + 1)) << 4) & 0x10, (uint_qh >> (iqs + 13)) & 0x10); -+ const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]); -+ return (vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) - 16.0f); -+} -+#endif -+ -+#if defined(DATA_A_Q5_1) -+vec2 dequantize(uint ib, uint iqs, uint a_offset) { -+ const uint uint_qh = data_a[a_offset + ib].qh; -+ const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); -+ const uint vui = uint(data_a[a_offset + ib].qs[iqs]); -+ return vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y); -+} -+vec4 dequantize4(uint ib, uint iqs, uint a_offset) { -+ const uint uint_qh = data_a_packed16[a_offset + ib].qh; -+ const ivec2 qh0 = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); -+ const ivec2 qh1 = ivec2(((uint_qh >> (iqs + 1)) << 4) & 0x10, (uint_qh >> (iqs + 13)) & 0x10); -+ const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]); -+ return vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y); -+} -+#endif -+ -+#if defined(DATA_A_Q8_0) -+vec2 dequantize(uint ib, uint iqs, uint a_offset) { -+ return vec2(int(data_a[a_offset + ib].qs[iqs]), int(data_a[a_offset + ib].qs[iqs + 1])); -+} -+vec4 dequantize4(uint ib, uint iqs, uint a_offset) { -+ uint32_t v0 = data_a_packed16[a_offset + ib].qs[iqs/2]; -+ uint32_t v1 = data_a_packed16[a_offset + ib].qs[iqs/2 + 1]; -+ return vec4(int8_t(v0 & 0xFF), int8_t(v0 >> 8), int8_t(v1 & 0xFF), int8_t(v1 >> 8)); -+} -+#endif -+ -+#if defined(DATA_A_IQ4_NL) -+vec2 dequantize(uint ib, uint iqs, uint a_offset) { -+ const uint vui = uint(data_a[a_offset + ib].qs[iqs]); -+ return vec2(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[vui >> 4]); -+} -+vec4 dequantize4(uint ib, uint iqs, uint a_offset) { -+ const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]); -+ return vec4(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[(vui >> 4) & 0xF], kvalues_iq4nl[(vui >> 8) & 0xF], kvalues_iq4nl[vui >> 12]); -+} -+#endif -+ -+#if defined(DATA_A_F32) || defined(DATA_A_F16) -+vec2 get_dm(uint ib, uint a_offset) { -+ return vec2(0, 0); -+} -+#endif -+ -+#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ4_NL) -+vec2 get_dm(uint ib, uint a_offset) { -+ return vec2(float(data_a[a_offset + ib].d), 0); -+} -+#endif -+ -+#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1) -+vec2 get_dm(uint ib, uint a_offset) { -+ return vec2(float(data_a[a_offset + ib].d), float(data_a[a_offset + ib].m)); -+} -+#endif -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp -new file mode 100644 -index 00000000..94b78598 ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp -@@ -0,0 +1,325 @@ -+ -+#include "types.comp" -+ -+layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ4_0 { -+ block_q4_0_packed16 block; -+}; -+ -+float16_t dequantFuncQ4_0(const in decodeBufQ4_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) -+{ -+ const float16_t d = bl.block.d; -+ const uint idx = coordInBlock[1]; -+ const uint shift = (idx & 0x10) >> 2; -+ uint32_t qs = uint32_t(bl.block.qs[(idx & 0xE) >> 1]); -+ qs >>= shift; -+ qs &= 0x0F0F; -+ qs = unpack8(qs)[idx & 1]; -+ float16_t ret = (float16_t(qs) - float16_t(8)) * d; -+ return ret; -+} -+ -+layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ4_1 { -+ block_q4_1 block; -+}; -+ -+float16_t dequantFuncQ4_1(const in decodeBufQ4_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) -+{ -+ const float16_t d = bl.block.d; -+ const float16_t m = bl.block.m; -+ const uint idx = coordInBlock[1]; -+ const uint iqs = idx & 0xF; -+ const uint shift = (idx & 0x10) >> 2; -+ uint32_t qs = bl.block.qs[iqs]; -+ qs >>= shift; -+ qs &= 0xF; -+ float16_t ret = float16_t(qs) * d + m; -+ return ret; -+} -+ -+layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ5_0 { -+ block_q5_0 block; -+}; -+ -+float16_t dequantFuncQ5_0(const in decodeBufQ5_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) -+{ -+ const float16_t d = bl.block.d; -+ const uint idx = coordInBlock[1]; -+ const uint iqs = idx & 0xF; -+ -+ const uint uint_qh = uint(bl.block.qh[1]) << 16 | bl.block.qh[0]; -+ const uint qh = ((uint_qh >> idx) << 4) & 0x10; -+ -+ const uint shift = (idx & 0x10) >> 2; -+ uint32_t qs = bl.block.qs[iqs]; -+ qs >>= shift; -+ qs &= 0xF; -+ -+ float16_t ret = (float16_t(qs | qh) - float16_t(16)) * d; -+ return ret; -+} -+ -+layout(buffer_reference, std430, buffer_reference_align = 8) buffer decodeBufQ5_1 { -+ block_q5_1 block; -+}; -+ -+float16_t dequantFuncQ5_1(const in decodeBufQ5_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) -+{ -+ const float16_t d = bl.block.d; -+ const float16_t m = bl.block.m; -+ const uint idx = coordInBlock[1]; -+ const uint iqs = idx & 0xF; -+ -+ const uint uint_qh = bl.block.qh; -+ const uint qh = ((uint_qh >> idx) << 4) & 0x10; -+ -+ const uint shift = (idx & 0x10) >> 2; -+ uint32_t qs = bl.block.qs[iqs]; -+ qs >>= shift; -+ qs &= 0xF; -+ -+ float16_t ret = float16_t(qs | qh) * d + m; -+ return ret; -+} -+ -+layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ8_0 { -+ block_q8_0_packed16 block; -+}; -+ -+float16_t dequantFuncQ8_0(const in decodeBufQ8_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) -+{ -+ const float16_t d = bl.block.d; -+ const uint idx = coordInBlock[1]; -+ const uint iqs = idx; -+ -+ // Load 16b and select the byte for this element -+ int32_t qs = unpack8(int32_t(bl.block.qs[(iqs & 0x1E) >> 1]))[iqs & 1]; -+ float16_t ret = float16_t(qs) * d; -+ return ret; -+} -+ -+layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ2_K { -+ block_q2_K block; -+}; -+ -+float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) -+{ -+ const f16vec2 d = bl.block.d; -+ const uint idx = coordInBlock[1]; -+ const uint iqs = idx; -+ -+ const uint qsi = (iqs / 128) * 32 + (iqs % 32); // 0..31 -+ const uint scalesi = iqs / 16; // 0..15 -+ const uint qsshift = ((iqs % 128) / 32) * 2; // 0,2,4,6 -+ -+ uint32_t qs = bl.block.qs[qsi]; -+ const uint scales = bl.block.scales[scalesi]; -+ float16_t ret = d.x * float16_t(scales & 0xF) * float16_t((qs >> qsshift) & 3) - d.y * float16_t(scales >> 4); -+ return ret; -+} -+ -+layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ3_K { -+ block_q3_K block; -+}; -+ -+float16_t dequantFuncQ3_K(const in decodeBufQ3_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) -+{ -+ const uint idx = coordInBlock[1]; -+ const uint iqs = idx; -+ -+ const uint n = iqs / 128; // 0,1 -+ const uint qsi = n * 32 + (iqs % 32); // 0..63 -+ const uint hmi = (iqs % 32); // 0..31 -+ const uint j = (iqs % 128) / 8; // 0..15 -+ const uint is = iqs / 16; // 0..15 -+ const uint halfsplit = ((iqs % 128) / 32); // 0,1,2,3 -+ const uint qsshift = halfsplit * 2; // 0,2,4,6 -+ const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128 -+ -+ uint32_t scaleidx0 = (is < 8) ? is : (is-8); -+ uint32_t scaleidx0shift = (is < 8) ? 0 : 4; -+ uint32_t scaleidx1 = is + 8 - (is/4)*4; -+ uint32_t scaleidx1shift = (is/4)*2; -+ -+ const int8_t us = int8_t(((bl.block.scales[scaleidx0] >> scaleidx0shift) & 0xF) | (((bl.block.scales[scaleidx1] >> scaleidx1shift) & 3) << 4)); -+ -+ const float16_t dl = bl.block.d * float16_t(us - 32); -+ -+ float16_t ret = dl * float16_t(int8_t((bl.block.qs[qsi ] >> qsshift) & 3) - (((bl.block.hmask[hmi ] & m) != 0) ? 0 : 4)); -+ -+ return ret; -+} -+ -+layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K { -+ block_q4_K block; -+}; -+ -+layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K_packed16 { -+ block_q4_K_packed16 block; -+}; -+ -+float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) -+{ -+ decodeBufQ4_K_packed16 bl16 = decodeBufQ4_K_packed16(bl); -+ const uint idx = coordInBlock[1]; -+ -+ const uint b = (idx & 0x20) >> 5; // 0,1 -+ const uint is = (idx & 0xE0) >> 5; // 0..7 -+ -+ const f16vec2 loadd = bl.block.d; -+ -+ uint32_t sc; -+ uint32_t mbyte; -+ -+ uint32_t scidx0 = (is < 4) ? is : (is + 4); -+ uint32_t scidx1 = (is < 4) ? is : (is - 4); -+ uint32_t scidxmask1 = (is < 4) ? 0x30 : 0xC0; -+ uint32_t scidxshift1 = (is < 4) ? 0 : 2; -+ uint32_t mbidx0 = is + 4; -+ uint32_t mbidx1 = (is < 4) ? is + 4 : is; -+ uint32_t mbidxmask0 = (is < 4) ? 0xF : 0xF0; -+ uint32_t mbidxshift0 = (is < 4) ? 0 : 4; -+ uint32_t mbidxmask1 = (is < 4) ? 0x30 : 0xC0; -+ uint32_t mbidxshift1 = (is < 4) ? 0 : 2; -+ -+ sc = uint8_t((bl.block.scales[scidx0] & 0xF) | ((bl.block.scales[scidx1] & scidxmask1) >> scidxshift1)); -+ mbyte = uint8_t(((bl.block.scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((bl.block.scales[mbidx1] & mbidxmask1) >> mbidxshift1)); -+ -+ const float16_t d = loadd.x * float16_t(sc); -+ const float16_t m = loadd.y * float16_t(mbyte); -+ -+ uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]); -+ qs = (qs >> (b * 4)) & 0x0F0F; -+ qs = unpack8(qs)[idx & 1]; -+ -+ float16_t ret = d * float16_t(qs) - m; -+ -+ return ret; -+} -+ -+layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K { -+ block_q5_K block; -+}; -+ -+layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K_packed16 { -+ block_q5_K_packed16 block; -+}; -+ -+float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) -+{ -+ decodeBufQ5_K_packed16 bl16 = decodeBufQ5_K_packed16(bl); -+ const uint idx = coordInBlock[1]; -+ -+ const uint b = (idx & 0x20) >> 5; // 0,1 -+ const uint is = (idx & 0xE0) >> 5; // 0..7 -+ -+ const uint32_t hm = 0x0101 << is; -+ -+ const f16vec2 loadd = bl.block.d; -+ -+ uint32_t sc; -+ uint32_t mbyte; -+ -+ uint32_t scidx0 = (is < 4) ? is : (is + 4); -+ uint32_t scidx1 = (is < 4) ? is : (is - 4); -+ uint32_t scidxmask1 = (is < 4) ? 0x30 : 0xC0; -+ uint32_t scidxshift1 = (is < 4) ? 0 : 2; -+ uint32_t mbidx0 = is + 4; -+ uint32_t mbidx1 = (is < 4) ? is + 4 : is; -+ uint32_t mbidxmask0 = (is < 4) ? 0xF : 0xF0; -+ uint32_t mbidxshift0 = (is < 4) ? 0 : 4; -+ uint32_t mbidxmask1 = (is < 4) ? 0x30 : 0xC0; -+ uint32_t mbidxshift1 = (is < 4) ? 0 : 2; -+ -+ sc = uint8_t((bl.block.scales[scidx0] & 0xF) | ((bl.block.scales[scidx1] & scidxmask1) >> scidxshift1)); -+ mbyte = uint8_t(((bl.block.scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((bl.block.scales[mbidx1] & mbidxmask1) >> mbidxshift1)); -+ -+ const float16_t d = loadd.x * float16_t(sc); -+ const float16_t m = loadd.y * float16_t(mbyte); -+ -+ uint qh = uint32_t(bl16.block.qh[(idx & 0x1E) >> 1]); -+ qh = qh & hm; -+ qh = unpack8(qh)[idx & 1]; -+ -+ uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]); -+ qs = (qs >> (b * 4)) & 0x0F0F; -+ qs = unpack8(qs)[idx & 1]; -+ -+ float16_t ret = d * (float16_t(qs) + (qh != 0 ? float16_t(16) : float16_t(0))) - m; -+ -+ return ret; -+} -+ -+layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ6_K { -+ block_q6_K block; -+}; -+ -+layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ6_K_packed16 { -+ block_q6_K_packed16 block; -+}; -+ -+float16_t dequantFuncQ6_K(const in decodeBufQ6_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) -+{ -+ decodeBufQ6_K_packed16 bl16 = decodeBufQ6_K_packed16(bl); -+ const uint idx = coordInBlock[1]; -+ -+ const uint b = (idx & 0x40) >> 6; // 0,1 -+ const uint qhshift = (idx & 0x60) >> 4; // 0,2,4,6 -+ const uint is = (idx & 0xF0) >> 4; // 0..15 -+ -+ const float16_t dscale = bl.block.d * float16_t(bl.block.scales[is]); -+ -+ uint ql = uint32_t(bl16.block.ql[((idx & 0x80) >> 2) + ((idx & 0x3E) >> 1)]); -+ ql = (ql >> (b * 4)) & 0x0F0F; -+ -+ uint qh = uint32_t(bl16.block.qh[((idx & 0x80) >> 3) + ((idx & 0x1E) >> 1)]); -+ qh = ((qh >> qhshift) & 0x0303) << 4; -+ -+ int q = unpack8(ql | qh)[idx & 1]; -+ -+ float16_t ret = dscale * float16_t(q - 32); -+ -+ return ret; -+} -+ -+#if defined(DATA_A_IQ4_NL) -+layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4_NL { -+ block_iq4_nl block; -+}; -+ -+float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoords[2], const in uint coordInBlock[2]) -+{ -+ const float16_t d = bl.block.d; -+ const uint idx = coordInBlock[1]; -+ const uint iqs = idx & 0xF; -+ const uint shift = (idx & 0x10) >> 2; -+ uint32_t qs = bl.block.qs[iqs]; -+ qs >>= shift; -+ qs &= 0xF; -+ float16_t ret = float16_t(kvalues_iq4nl[qs]) * d; -+ return ret; -+} -+#endif -+ -+#if defined(DATA_A_Q4_0) -+#define dequantFuncA dequantFuncQ4_0 -+#elif defined(DATA_A_Q4_1) -+#define dequantFuncA dequantFuncQ4_1 -+#elif defined(DATA_A_Q5_0) -+#define dequantFuncA dequantFuncQ5_0 -+#elif defined(DATA_A_Q5_1) -+#define dequantFuncA dequantFuncQ5_1 -+#elif defined(DATA_A_Q8_0) -+#define dequantFuncA dequantFuncQ8_0 -+#elif defined(DATA_A_Q2_K) -+#define dequantFuncA dequantFuncQ2_K -+#elif defined(DATA_A_Q3_K) -+#define dequantFuncA dequantFuncQ3_K -+#elif defined(DATA_A_Q4_K) -+#define dequantFuncA dequantFuncQ4_K -+#elif defined(DATA_A_Q5_K) -+#define dequantFuncA dequantFuncQ5_K -+#elif defined(DATA_A_Q6_K) -+#define dequantFuncA dequantFuncQ6_K -+#elif defined(DATA_A_IQ4_NL) -+#define dequantFuncA dequantFuncIQ4_NL -+#endif -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.comp -new file mode 100644 -index 00000000..8d806435 ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.comp -@@ -0,0 +1,13 @@ -+#extension GL_EXT_control_flow_attributes : require -+#extension GL_EXT_shader_16bit_storage : require -+ -+layout (push_constant) uniform parameter -+{ -+ uint M; -+ uint K; -+ uint stride_a; -+ uint stride_b; -+ uint nel; -+} p; -+ -+#include "types.comp" -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp -new file mode 100644 -index 00000000..8de14fc0 ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp -@@ -0,0 +1,32 @@ -+#version 450 -+ -+#include "dequant_head.comp" -+ -+layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; -+ -+layout (binding = 0) readonly buffer A {block_iq4_nl data_a[];}; -+layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; -+ -+void main() { -+ const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; -+ -+ init_iq4nl_shmem(); -+ -+ const uint tid = gl_LocalInvocationID.x % 64; -+ const uint il = tid/32; -+ const uint ir = tid%32; -+ const uint ib = 32*i + ir; -+ if (ib >= p.nel / 32) { -+ return; -+ } -+ -+ const uint q_idx = 8*il; -+ const uint b_idx = 1024*i + 32*ir + q_idx; -+ -+ const float d = float(data_a[ib].d); -+ -+ [[unroll]] for (uint l = 0; l < 8; ++l) { -+ data_b[b_idx + l + 0] = D_TYPE(d * kvalues_iq4nl[data_a[ib].qs[q_idx + l] & 0xF]); -+ data_b[b_idx + l + 16] = D_TYPE(d * kvalues_iq4nl[data_a[ib].qs[q_idx + l] >> 4]); -+ } -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp -new file mode 100644 -index 00000000..157154af ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp -@@ -0,0 +1,34 @@ -+#version 450 -+ -+#include "dequant_head.comp" -+ -+layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; -+ -+layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; -+layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; -+ -+void main() { -+ [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { -+ const uint i = gl_WorkGroupID.x * 256 + wgy; -+ if (i >= p.M * p.K / QUANT_K) { -+ return; -+ } -+ -+ const uint tid = gl_LocalInvocationID.x; -+ const uint ip = tid / 32; -+ const uint il = tid - 32 * ip; -+ const uint is = 8 * ip + il / 16; -+ -+ const uint y_idx = i * QUANT_K + 128 * ip + il; -+ -+ const uint ql_idx = 32 * ip + il; -+ const uint8_t qs = data_a[i].qs[32 * ip + il]; -+ -+ FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x); -+ FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y); -+ data_b[y_idx + 0] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+0] & 0xF) * ((qs >> 0) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+0] >> 4)); -+ data_b[y_idx + 32] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+2] & 0xF) * ((qs >> 2) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+2] >> 4)); -+ data_b[y_idx + 64] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+4] & 0xF) * ((qs >> 4) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+4] >> 4)); -+ data_b[y_idx + 96] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+6] & 0xF) * ((qs >> 6) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+6] >> 4)); -+ } -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp -new file mode 100644 -index 00000000..c17dd0d9 ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp -@@ -0,0 +1,42 @@ -+#version 450 -+ -+#include "dequant_head.comp" -+ -+layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; -+ -+layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; -+layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; -+ -+void main() { -+ [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { -+ const uint i = uint(gl_WorkGroupID.x * 256 + wgy); -+ if (i >= p.M * p.K / QUANT_K) { -+ return; -+ } -+ -+ const uint r = gl_LocalInvocationID.x / 4; -+ const uint tid = r / 2; -+ const uint is0 = r % 2; -+ const uint l0 = 16 * is0 + 4 * (gl_LocalInvocationID.x % 4); -+ const uint n = tid / 4; -+ const uint j = tid - 4*n; -+ -+ const uint8_t m = uint8_t(1 << (4*n + j)); -+ const uint is = 8*n + 2*j + is0; -+ const uint shift = 2*j; -+ -+ const int8_t us = int8_t(is < 4 ? (data_a[i].scales[is-0] & 0xF) | (((data_a[i].scales[is+8] >> 0) & 3) << 4) : -+ is < 8 ? (data_a[i].scales[is-0] & 0xF) | (((data_a[i].scales[is+4] >> 2) & 3) << 4) : -+ is < 12 ? (data_a[i].scales[is-8] >> 4) | (((data_a[i].scales[is+0] >> 4) & 3) << 4) : -+ (data_a[i].scales[is-8] >> 4) | (((data_a[i].scales[is-4] >> 6) & 3) << 4)); -+ const FLOAT_TYPE d_all = FLOAT_TYPE(data_a[i].d); -+ const FLOAT_TYPE dl = d_all * FLOAT_TYPE(us - 32); -+ -+ const uint y_idx = i * QUANT_K + 128 * n + 32 * j; -+ const uint qs_idx = 32*n; -+ -+ for (uint l = l0; l < l0 + 4; ++l) { -+ data_b[y_idx + l] = D_TYPE(dl * FLOAT_TYPE(int8_t((data_a[i].qs[qs_idx + l] >> shift) & 3) - (((data_a[i].hmask[l] & m) != 0) ? 0 : 4))); -+ } -+ } -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp -new file mode 100644 -index 00000000..40818532 ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp -@@ -0,0 +1,30 @@ -+#version 450 -+ -+#include "dequant_head.comp" -+ -+layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; -+ -+layout (binding = 0) readonly buffer A {block_q4_0 data_a[];}; -+layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; -+ -+void main() { -+ const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; -+ -+ const uint tid = gl_LocalInvocationID.x % 64; -+ const uint il = tid/32; -+ const uint ir = tid%32; -+ const uint ib = 32*i + ir; -+ if (ib >= p.nel / 32) { -+ return; -+ } -+ -+ const uint q_idx = 8*il; -+ const uint b_idx = 1024*i + 32*ir + q_idx; -+ -+ const float d = float(data_a[ib].d); -+ -+ [[unroll]] for (uint l = 0; l < 8; ++l) { -+ data_b[b_idx + l + 0] = D_TYPE(d * ((data_a[ib].qs[q_idx + l] & 0xF) - 8.0f)); -+ data_b[b_idx + l + 16] = D_TYPE(d * ((data_a[ib].qs[q_idx + l] >> 4) - 8.0f)); -+ } -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp -new file mode 100644 -index 00000000..2f27eee6 ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp -@@ -0,0 +1,32 @@ -+#version 450 -+ -+#include "dequant_head.comp" -+ -+layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; -+ -+layout (binding = 0) readonly buffer A {block_q4_1 data_a[];}; -+layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; -+ -+void main() { -+ const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; -+ -+ const uint tid = gl_LocalInvocationID.x % 64; -+ const uint il = tid/32; -+ const uint ir = tid%32; -+ const uint ib = 32*i + ir; -+ if (ib >= p.nel / 32) { -+ return; -+ } -+ -+ const uint b_idx = 1024*i + 32*ir + 8*il; -+ -+ const float d = float(data_a[ib].d); -+ const float m = float(data_a[ib].m); -+ -+ const uint q_idx = 8*il; -+ -+ [[unroll]] for (uint l = 0; l < 8; ++l) { -+ data_b[b_idx + l + 0] = D_TYPE(d * (data_a[ib].qs[q_idx + l] & 0xF) + m); -+ data_b[b_idx + l + 16] = D_TYPE(d * (data_a[ib].qs[q_idx + l] >> 4) + m); -+ } -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp -new file mode 100644 -index 00000000..987f113a ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp -@@ -0,0 +1,68 @@ -+#version 450 -+ -+#include "dequant_head.comp" -+ -+layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; -+ -+layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; -+layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; -+ -+void main() { -+ [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { -+ const uint ib = gl_WorkGroupID.x * 256 + wgy; -+ if (ib >= p.M * p.K / QUANT_K) { -+ return; -+ } -+ -+ const uint tid = gl_LocalInvocationID.x; -+ const uint il = tid / 8; -+ const uint ir = tid % 8; -+ const uint is = 2 * il; -+ const uint n = 4; -+ -+ const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].d.x); -+ const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].d.y); -+ -+ const uint y_idx = ib * QUANT_K + 64 * il + n * ir; -+ const uint qs_idx = 32*il + n * ir; -+ -+ uint scidx0 = (is < 4) ? is : (is + 4); -+ uint scidx1 = (is < 4) ? is : (is - 4); -+ uint scidxmask1 = (is < 4) ? 0x30 : 0xC0; -+ uint scidxshift1 = (is < 4) ? 0 : 2; -+ uint mbidx0 = is + 4; -+ uint mbidx1 = (is < 4) ? is + 4 : is; -+ uint mbidxmask0 = (is < 4) ? 0xF : 0xF0; -+ uint mbidxshift0 = (is < 4) ? 0 : 4; -+ uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0; -+ uint mbidxshift1 = (is < 4) ? 0 : 2; -+ -+ uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); -+ uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); -+ -+ const FLOAT_TYPE d1 = dall * sc; -+ const FLOAT_TYPE m1 = dmin * mbyte; -+ -+ scidx0 = (is < 4) ? is + 1 : (is + 5); -+ scidx1 = (is < 4) ? is + 1 : (is - 3); -+ scidxmask1 = (is < 4) ? 0x30 : 0xC0; -+ scidxshift1 = (is < 4) ? 0 : 2; -+ mbidx0 = is + 5; -+ mbidx1 = (is < 4) ? is + 5 : is + 1; -+ mbidxmask0 = (is < 4) ? 0xF : 0xF0; -+ mbidxshift0 = (is < 4) ? 0 : 4; -+ mbidxmask1 = (is < 4) ? 0x30 : 0xC0; -+ mbidxshift1 = (is < 4) ? 0 : 2; -+ -+ sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); -+ mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); -+ -+ const FLOAT_TYPE d2 = dall * sc; -+ const FLOAT_TYPE m2 = dmin * mbyte; -+ -+ [[unroll]] for (uint l = 0; l < n; ++l) { -+ data_b[y_idx + l ] = D_TYPE(d1 * FLOAT_TYPE(data_a[ib].qs[qs_idx + l] & 0xF) - m1); -+ data_b[y_idx + l + 32] = D_TYPE(d2 * FLOAT_TYPE(data_a[ib].qs[qs_idx + l] >> 4) - m2); -+ } -+ } -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp -new file mode 100644 -index 00000000..b20b8052 ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp -@@ -0,0 +1,34 @@ -+#version 450 -+ -+#include "dequant_head.comp" -+ -+layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; -+ -+layout (binding = 0) readonly buffer A {block_q5_0 data_a[];}; -+layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; -+ -+void main() { -+ const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; -+ -+ const uint tid = gl_LocalInvocationID.x % 64; -+ const uint il = tid/32; -+ const uint ir = tid%32; -+ const uint ib = 32*i + ir; -+ if (ib >= p.nel / 32) { -+ return; -+ } -+ -+ const uint b_idx = 1024*i + 32*ir + 8*il; -+ -+ const float d = float(data_a[ib].d); -+ const uint qh = uint(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]; -+ -+ const uint q_idx = 8*il; -+ -+ [[unroll]] for (uint l = 0; l < 8; ++l) { -+ const uint iqs = q_idx + l; -+ const uint vui = uint(data_a[ib].qs[iqs]); -+ data_b[b_idx + l + 0] = D_TYPE(d * (((vui & 0xF) | (((qh >> iqs) << 4) & 0x10)) - 16.0f)); -+ data_b[b_idx + l + 16] = D_TYPE(d * (((vui >> 4) | ((qh >> (iqs + 12)) & 0x10)) - 16.0f)); -+ } -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp -new file mode 100644 -index 00000000..dc59fe3b ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp -@@ -0,0 +1,35 @@ -+#version 450 -+ -+#include "dequant_head.comp" -+ -+layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; -+ -+layout (binding = 0) readonly buffer A {block_q5_1 data_a[];}; -+layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; -+ -+void main() { -+ const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; -+ -+ const uint tid = gl_LocalInvocationID.x % 64; -+ const uint il = tid/32; -+ const uint ir = tid%32; -+ const uint ib = 32*i + ir; -+ if (ib >= p.nel / 32) { -+ return; -+ } -+ -+ const uint b_idx = 1024*i + 32*ir + 8*il; -+ -+ const float d = float(data_a[ib].d); -+ const float m = float(data_a[ib].m); -+ const uint qh = data_a[ib].qh; -+ -+ const uint q_idx = 8*il; -+ -+ [[unroll]] for (uint l = 0; l < 8; ++l) { -+ const uint iqs = q_idx + l; -+ const uint vui = uint(data_a[ib].qs[iqs]); -+ data_b[b_idx + l + 0] = D_TYPE(d * (((vui & 0xF) | (((qh >> iqs) << 4) & 0x10))) + m); -+ data_b[b_idx + l + 16] = D_TYPE(d * (((vui >> 4) | ((qh >> (iqs + 12)) & 0x10))) + m); -+ } -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp -new file mode 100644 -index 00000000..6db5403b ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp -@@ -0,0 +1,70 @@ -+#version 450 -+ -+#include "dequant_head.comp" -+ -+layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; -+ -+layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; -+layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; -+ -+void main() { -+ [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { -+ const uint ib = gl_WorkGroupID.x * 256 + wgy; -+ if (ib >= p.M * p.K / QUANT_K) { -+ return; -+ } -+ -+ const uint tid = gl_LocalInvocationID.x; -+ const uint il = tid / 16; -+ const uint ir = tid % 16; -+ const uint is = 2 * il; -+ -+ const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].d.x); -+ const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].d.y); -+ -+ const uint y_idx = ib * QUANT_K + 64 * il + 2 * ir; -+ const uint qs_idx = 32*il + 2 * ir; -+ const uint qh_idx = 2 * ir; -+ -+ uint scidx0 = (is < 4) ? is : (is + 4); -+ uint scidx1 = (is < 4) ? is : (is - 4); -+ uint scidxmask1 = (is < 4) ? 0x30 : 0xC0; -+ uint scidxshift1 = (is < 4) ? 0 : 2; -+ uint mbidx0 = is + 4; -+ uint mbidx1 = (is < 4) ? is + 4 : is; -+ uint mbidxmask0 = (is < 4) ? 0xF : 0xF0; -+ uint mbidxshift0 = (is < 4) ? 0 : 4; -+ uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0; -+ uint mbidxshift1 = (is < 4) ? 0 : 2; -+ -+ uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); -+ uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); -+ -+ const FLOAT_TYPE d1 = dall * sc; -+ const FLOAT_TYPE m1 = dmin * mbyte; -+ -+ scidx0 = (is < 4) ? is + 1 : (is + 5); -+ scidx1 = (is < 4) ? is + 1 : (is - 3); -+ scidxmask1 = (is < 4) ? 0x30 : 0xC0; -+ scidxshift1 = (is < 4) ? 0 : 2; -+ mbidx0 = is + 5; -+ mbidx1 = (is < 4) ? is + 5 : is + 1; -+ mbidxmask0 = (is < 4) ? 0xF : 0xF0; -+ mbidxshift0 = (is < 4) ? 0 : 4; -+ mbidxmask1 = (is < 4) ? 0x30 : 0xC0; -+ mbidxshift1 = (is < 4) ? 0 : 2; -+ -+ sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); -+ mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); -+ -+ const FLOAT_TYPE d2 = dall * sc; -+ const FLOAT_TYPE m2 = dmin * mbyte; -+ -+ const uint8_t hm1 = uint8_t(1 << (2 * il )); -+ const uint8_t hm2 = uint8_t(1 << (2 * il + 1)); -+ data_b[y_idx ] = D_TYPE(d1 * FLOAT_TYPE((data_a[ib].qs[qs_idx ] & 0xF) + (((data_a[ib].qh[qh_idx ] & hm1) != 0) ? 16 : 0)) - m1); -+ data_b[y_idx + 1] = D_TYPE(d1 * FLOAT_TYPE((data_a[ib].qs[qs_idx + 1] & 0xF) + (((data_a[ib].qh[qh_idx + 1] & hm1) != 0) ? 16 : 0)) - m1); -+ data_b[y_idx + 32] = D_TYPE(d2 * FLOAT_TYPE((data_a[ib].qs[qs_idx ] >> 4) + (((data_a[ib].qh[qh_idx ] & hm2) != 0) ? 16 : 0)) - m2); -+ data_b[y_idx + 33] = D_TYPE(d2 * FLOAT_TYPE((data_a[ib].qs[qs_idx + 1] >> 4) + (((data_a[ib].qh[qh_idx + 1] & hm2) != 0) ? 16 : 0)) - m2); -+ } -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp -new file mode 100644 -index 00000000..0b913175 ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp -@@ -0,0 +1,33 @@ -+#version 450 -+ -+#include "dequant_head.comp" -+ -+layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; -+ -+layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; -+layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; -+ -+void main() { -+ [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { -+ const uint i = gl_WorkGroupID.x * 256 + wgy; -+ if (i >= p.M * p.K / QUANT_K) { -+ return; -+ } -+ const uint tid = gl_LocalInvocationID.x; -+ const uint ip = tid / 32; -+ const uint il = tid - 32 * ip; -+ const uint is = 8 * ip + il / 16; -+ -+ const uint y_idx = i * QUANT_K + 128 * ip + il; -+ -+ const uint ql_idx = 64 * ip + il; -+ const uint8_t qh = data_a[i].qh[32 * ip + il]; -+ -+ const FLOAT_TYPE d = FLOAT_TYPE(data_a[i].d); -+ -+ data_b[y_idx + 0] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 0] * (int8_t((data_a[i].ql[ql_idx + 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32))); -+ data_b[y_idx + 32] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 2] * (int8_t((data_a[i].ql[ql_idx + 32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32))); -+ data_b[y_idx + 64] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 4] * (int8_t((data_a[i].ql[ql_idx + 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32))); -+ data_b[y_idx + 96] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 6] * (int8_t((data_a[i].ql[ql_idx + 32] >> 4) | (((qh >> 6) & 3) << 4)) - 32))); -+ } -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp -new file mode 100644 -index 00000000..bd1344a8 ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp -@@ -0,0 +1,31 @@ -+#version 450 -+ -+#include "dequant_head.comp" -+ -+layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; -+ -+layout (binding = 0) readonly buffer A {block_q8_0 data_a[];}; -+layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; -+ -+void main() { -+ const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; -+ -+ const uint tid = gl_LocalInvocationID.x % 64; -+ const uint il = tid/32; -+ const uint ir = tid%32; -+ const uint ib = 32*i + ir; -+ if (ib >= p.nel / 32) { -+ return; -+ } -+ -+ const uint b_idx = 1024*i + 32*ir + 16*il; -+ -+ const float d = float(data_a[ib].d); -+ -+ const uint q_idx = 16*il; -+ -+ [[unroll]] for (uint l = 0; l < 16; l += 2) { -+ data_b[b_idx + l ] = D_TYPE(d * data_a[ib].qs[q_idx + l ]); -+ data_b[b_idx + l + 1] = D_TYPE(d * data_a[ib].qs[q_idx + l + 1]); -+ } -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp -new file mode 100644 -index 00000000..4e68742b ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp -@@ -0,0 +1,34 @@ -+#version 450 -+ -+#extension GL_EXT_shader_16bit_storage : require -+#extension GL_EXT_control_flow_attributes : enable -+ -+layout (push_constant) uniform parameter -+{ -+ uint ncols; -+ uint rows_per_channel; -+ uint n_past; -+} p; -+ -+#include "types.comp" -+ -+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; -+ -+layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; -+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; -+ -+void main() { -+ const uint col = gl_GlobalInvocationID.y; -+ const uint row = gl_GlobalInvocationID.x; -+ -+ if (col >= p.ncols) { -+ return; -+ } -+ -+ const uint i = row*p.ncols + col; -+ if (col > p.n_past + row % p.rows_per_channel) { -+ data_d[i] = D_TYPE(uintBitsToFloat(0xFF800000)); -+ } else { -+ data_d[i] = D_TYPE(data_a[i]); -+ } -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/div.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/div.comp -new file mode 100644 -index 00000000..9fb69c6c ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/div.comp -@@ -0,0 +1,27 @@ -+#version 450 -+ -+#include "types.comp" -+#include "generic_binary_head.comp" -+ -+const uint num_threads = 256; -+ -+layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; -+ -+void main() { -+ uint idx = get_idx(); -+ -+ // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation -+ const uint num_iter = 2; -+ -+ [[unroll]] for (uint i = 0; i < num_iter; ++i) { -+ if (idx >= p.ne) { -+ continue; -+ } -+ uint i00, i01, i02, i03; -+ get_indices(idx, i00, i01, i02, i03); -+ -+ data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) / FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)])); -+ -+ idx += num_threads; -+ } -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp -new file mode 100644 -index 00000000..c5be8131 ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp -@@ -0,0 +1,289 @@ -+#version 450 -+ -+#extension GL_EXT_control_flow_attributes : enable -+#extension GL_EXT_shader_16bit_storage : require -+ -+#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require -+#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require -+#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require -+#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require -+ -+#extension GL_KHR_memory_scope_semantics : enable -+#extension GL_KHR_cooperative_matrix : enable -+#extension GL_NV_cooperative_matrix2 : enable -+#extension GL_EXT_buffer_reference : enable -+#extension GL_KHR_shader_subgroup_ballot : enable -+#extension GL_KHR_shader_subgroup_vote : enable -+#extension GL_EXT_null_initializer : enable -+ -+#include "types.comp" -+#include "dequant_funcs_cm2.comp" -+ -+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; -+ -+layout (constant_id = 1) const uint32_t Br = 32; -+layout (constant_id = 2) const uint32_t Bc = 32; -+layout (constant_id = 3) const uint32_t D = 32; -+layout (constant_id = 4) const uint32_t Clamp = gl_CooperativeMatrixClampModeConstantNV; -+ -+layout (push_constant) uniform parameter { -+ uint32_t N; -+ uint32_t KV; -+ -+ uint32_t ne1; -+ uint32_t ne2; -+ uint32_t ne3; -+ -+ uint32_t neq2; -+ uint32_t neq3; -+ uint32_t nek2; -+ uint32_t nek3; -+ uint32_t nev2; -+ uint32_t nev3; -+ uint32_t nem1; -+ -+ uint32_t nb02; -+ uint32_t nb03; -+ uint32_t nb12; -+ uint32_t nb13; -+ uint32_t nb22; -+ uint32_t nb23; -+ uint32_t nb31; -+ -+ float scale; -+ float max_bias; -+ float logit_softcap; -+ -+ uint32_t mask; -+ uint32_t n_head_log2; -+ float m0; -+ float m1; -+} p; -+ -+layout (binding = 0) readonly buffer Q {uint8_t data_q[];}; -+layout (binding = 1) readonly buffer K {uint8_t data_k[];}; -+layout (binding = 2) readonly buffer V {uint8_t data_v[];}; -+layout (binding = 3) readonly buffer M {uint8_t data_m[];}; -+layout (binding = 4) writeonly buffer O {D_TYPE data_o[];}; -+ -+#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) -+ -+ACC_TYPE maxReduce(const in ACC_TYPE x, const in ACC_TYPE y) { -+ return max(x, y); -+} -+ -+ACC_TYPE smearReduce(const in ACC_TYPE x, const in ACC_TYPE y) { -+ return x; -+} -+ -+// Replace matrix elements >= numRows or numCols with 'replace' -+ACC_TYPE replacePadding(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem, const in ACC_TYPE replace, const in uint32_t numRows, const in uint32_t numCols) { -+ if (row >= numRows || col >= numCols) { -+ return replace; -+ } -+ return elem; -+} -+ -+ACC_TYPE Exp(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem) -+{ -+ return exp(elem); -+} -+ -+ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem0, const in ACC_TYPE elem1) -+{ -+ return max(elem0, elem1); -+} -+ -+#if defined(BLOCK_SIZE) -+#define DECODEFUNC , DEQUANTFUNC -+#else -+#define DECODEFUNC -+#endif -+ -+void main() { -+#if defined(DATA_A_IQ4_NL) -+ init_iq4nl_shmem(); -+#endif -+ -+ const uint32_t N = p.N; -+ const uint32_t KV = p.KV; -+ -+ const uint32_t Tr = CEIL_DIV(N, Br); -+ const uint32_t Tc = CEIL_DIV(KV, Bc); -+ -+ const uint32_t i = gl_WorkGroupID.x; -+ -+ const uint32_t iq2 = gl_WorkGroupID.y; -+ const uint32_t iq3 = gl_WorkGroupID.z; -+ -+ // broadcast factors -+ const uint32_t rk2 = p.neq2/p.nek2; -+ const uint32_t rk3 = p.neq3/p.nek3; -+ -+ const uint32_t rv2 = p.neq2/p.nev2; -+ const uint32_t rv3 = p.neq3/p.nev3; -+ -+ // k indices -+ const uint32_t ik3 = iq3 / rk3; -+ const uint32_t ik2 = iq2 / rk2; -+ -+ // v indices -+ const uint32_t iv3 = iq3 / rv3; -+ const uint32_t iv2 = iq2 / rv2; -+ -+ tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutQ = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); -+ tensorLayoutNV<2, Clamp> tensorLayoutK = createTensorLayoutNV(2, Clamp); -+ tensorLayoutNV<2, Clamp> tensorLayoutV = createTensorLayoutNV(2, Clamp); -+ -+ tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0); -+ -+#if defined(BLOCK_SIZE) -+ tensorLayoutK = setTensorLayoutBlockSizeNV(tensorLayoutK, 1, BLOCK_SIZE); -+ tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, BLOCK_SIZE); -+#endif -+ -+ tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N, D); -+ tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D); -+ tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D); -+ -+ coopmat Q; -+ coopmat Qf16; -+ -+ uint32_t q_offset = iq2*p.nb02+iq3*p.nb03; -+ coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, D)); -+ -+ Qf16 = coopmat(Q); -+ Qf16 *= float16_t(p.scale); -+ -+ coopmat O = coopmat(0); -+ -+ coopmat L, M; -+ -+ L = coopmat(0); -+ M = coopmat(-1.0/0.0); -+ -+ ACC_TYPE slope = ACC_TYPE(1.0); -+ -+ // ALiBi -+ if (p.max_bias > 0.0f) { -+ const uint32_t h = iq2; -+ -+ const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1); -+ const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1); -+ -+ slope = pow(base, ACC_TYPE(exph)); -+ } -+ -+ [[dont_unroll]] -+ for (uint32_t j = 0; j < Tc; ++j) { -+ -+ coopmat S = coopmat(0); -+ -+ coopmat K_T; -+ -+ uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13; -+ coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, D), tensorViewTranspose DECODEFUNC); -+ S = coopMatMulAdd(Qf16, K_T, S); -+ -+ if (p.logit_softcap != 0.0f) { -+ [[unroll]] -+ for (int k = 0; k < S.length(); ++k) { -+ S[k] = ACC_TYPE(p.logit_softcap)*tanh(S[k]); -+ } -+ } -+ -+ if (p.mask != 0) { -+ tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); -+ tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV); -+ -+ coopmat mv; -+ -+ coopMatLoadTensorNV(mv, data_m, 0, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); -+ -+ S += slope*coopmat(mv); -+ } -+ -+ // Clear padding elements to -inf, so they don't contribute to rowmax -+ if (Clamp != 0 && -+ ((j + 1) * Bc > KV || -+ (i + 1) * Br > N)) { -+ -+ uint R = ((i + 1) * Br > N) ? (N % Br) : Br; -+ uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc; -+ -+ coopMatPerElementNV(S, S, replacePadding, ACC_TYPE(-1.0/0.0), R, C); -+ } -+ -+ coopmat rowmax, P, rowsum, eM; -+ -+ coopMatReduceNV(rowmax, S, gl_CooperativeMatrixReduceRowNV, maxReduce); -+ -+ coopmat Mold = M; -+ -+ // M = max(rowmax, Mold) -+ // P = e^(S - M) -+ // eM = e^(Mold - M) -+ coopMatPerElementNV(M, rowmax, Max, Mold); -+ coopMatPerElementNV(P, S - M, Exp); -+ coopMatPerElementNV(eM, Mold - M, Exp); -+ -+ // Clear padding elements to 0, so they don't contribute to rowsum -+ if (Clamp != 0 && -+ ((j + 1) * Bc > KV || -+ (i + 1) * Br > N)) { -+ -+ uint R = ((i + 1) * Br > N) ? (N % Br) : Br; -+ uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc; -+ -+ coopMatPerElementNV(P, P, replacePadding, ACC_TYPE(0.0), R, C); -+ } -+ -+ coopmat P_A = coopmat(P); -+ -+ // compute rowsum by multiplying by matrix of all ones. -+ coopmat One = coopmat(1.0); -+ -+ rowsum = coopmat(0.0); -+ rowsum = coopMatMulAdd(P_A, One, rowsum); -+ -+ coopmat V; -+ uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23; -+ coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, D) DECODEFUNC); -+ -+ L = eM*L + rowsum; -+ -+ // This is the "diagonal" matrix in the paper, but since we do componentwise -+ // multiply rather than matrix multiply it has the diagonal element smeared -+ // across the row -+ coopmat eMdiag; -+ -+ // resize eM by using smear/reduce -+ coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce); -+ -+ O = eMdiag * O; -+ -+ O = coopMatMulAdd(P_A, V, O); -+ } -+ -+ coopmat Ldiag; -+ -+ // resize L by using smear/reduce -+ coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce); -+ -+ [[unroll]] -+ for (int k = 0; k < Ldiag.length(); ++k) { -+ Ldiag[k] = ACC_TYPE(1.0) / Ldiag[k]; -+ } -+ -+ O = Ldiag*O; -+ -+ tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV); -+ tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, D); -+ -+ // permute dimensions -+ tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2); -+ uint32_t o_offset = iq3*p.ne2*p.ne1; -+ -+ coopmat O_D = coopmat(O); -+ coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, 1, 0, D), tensorViewPermute); -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp -new file mode 100644 -index 00000000..4cc7a68c ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp -@@ -0,0 +1,25 @@ -+#version 450 -+ -+#include "generic_head.comp" -+#include "types.comp" -+ -+#extension GL_EXT_control_flow_attributes : enable -+ -+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; -+ -+layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; -+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; -+ -+void main() { -+ const float GELU_COEF_A = 0.044715f; -+ const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; -+ const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; -+ -+ if (i >= p.KX) { -+ return; -+ } -+ -+ const float xi = float(data_a[i]); -+ const float val = SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi); -+ data_d[i] = D_TYPE(0.5f*xi*(2.0f - 2.0f / (exp(2 * val) + 1))); -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp -new file mode 100644 -index 00000000..e6e6fcfd ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp -@@ -0,0 +1,23 @@ -+#version 450 -+ -+#include "generic_head.comp" -+#include "types.comp" -+ -+#extension GL_EXT_control_flow_attributes : enable -+ -+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; -+ -+layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; -+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; -+ -+void main() { -+ const float GELU_QUICK_COEF = -1.702f; -+ const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; -+ -+ if (i >= p.KX) { -+ return; -+ } -+ -+ const float x = float(data_a[i]); -+ data_d[i] = D_TYPE(x * (1.0f / (1.0f + exp(GELU_QUICK_COEF * x)))); -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp -new file mode 100644 -index 00000000..062e2a4c ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp -@@ -0,0 +1,64 @@ -+#extension GL_EXT_shader_16bit_storage : require -+#extension GL_EXT_control_flow_attributes : require -+ -+layout (push_constant) uniform parameter -+{ -+ uint ne; -+ uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03; -+ uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13; -+ uint ne20; uint ne21; uint ne22; uint ne23; uint nb20; uint nb21; uint nb22; uint nb23; -+ uint misalign_offsets; -+ float param1; float param2; int param3; -+} p; -+ -+layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; -+layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; -+layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; -+ -+// true if src0/src1 are the same shape and the indices can be reused without additional modulus -+layout(constant_id = 0) const bool norepeat = false; -+ -+uint get_idx() { -+ return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; -+} -+ -+uint get_aoffset() { return p.misalign_offsets >> 16; } -+uint get_boffset() { return (p.misalign_offsets >> 8) & 0xFF; } -+uint get_doffset() { return p.misalign_offsets & 0xFF; } -+ -+// mod and div are expensive and coordinates/dimensions are often power of 2 or equal to 1 -+uint fastmod(uint a, uint b) { -+ if ((b & (b-1)) == 0) { -+ return a & (b-1); -+ } -+ return a % b; -+} -+ -+uint fastdiv(uint a, uint b) { -+ return (a < b) ? 0 : (a / b); -+} -+ -+void get_indices(uint idx, out uint i00, out uint i01, out uint i02, out uint i03) { -+ i03 = fastdiv(idx, (p.ne02*p.ne01*p.ne00)); -+ const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00; -+ i02 = fastdiv((idx - i03_offset), (p.ne01*p.ne00)); -+ const uint i02_offset = i02*p.ne01*p.ne00; -+ i01 = (idx - i03_offset - i02_offset) / p.ne00; -+ i00 = idx - i03_offset - i02_offset - i01*p.ne00; -+} -+ -+uint src0_idx(uint i00, uint i01, uint i02, uint i03) { -+ return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00; -+} -+ -+uint src1_idx(uint i00, uint i01, uint i02, uint i03) { -+ if (norepeat) { -+ return i03*p.nb13 + i02*p.nb12 + i01*p.nb11 + i00*p.nb10; -+ } else { -+ return fastmod(i03, p.ne13)*p.nb13 + fastmod(i02, p.ne12)*p.nb12 + fastmod(i01, p.ne11)*p.nb11 + fastmod(i00, p.ne10)*p.nb10; -+ } -+} -+ -+uint dst_idx(uint i00, uint i01, uint i02, uint i03) { -+ return i03*p.nb23 + i02*p.nb22 + i01*p.nb21 + i00*p.nb20; -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.comp -new file mode 100644 -index 00000000..66e46ae6 ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.comp -@@ -0,0 +1,9 @@ -+#extension GL_EXT_shader_16bit_storage : require -+ -+layout (push_constant) uniform parameter -+{ -+ uint KX; -+ uint KY; -+ float param1; -+ float param2; -+} p; -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp -new file mode 100644 -index 00000000..68d1bc9f ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp -@@ -0,0 +1,56 @@ -+#extension GL_EXT_shader_16bit_storage : require -+#extension GL_EXT_control_flow_attributes : require -+ -+layout (push_constant) uniform parameter -+{ -+ uint ne; -+ uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03; -+ uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13; -+ uint misalign_offsets; -+ float param1; float param2; -+ -+ uint ne0_012mp; uint ne0_012L; -+ uint ne0_01mp; uint ne0_01L; -+ uint ne0_0mp; uint ne0_0L; -+ uint ne1_012mp; uint ne1_012L; -+ uint ne1_01mp; uint ne1_01L; -+ uint ne1_0mp; uint ne1_0L; -+} p; -+ -+layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; -+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; -+ -+uint get_idx() { -+ return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; -+} -+ -+uint get_aoffset() { return p.misalign_offsets >> 16; } -+uint get_doffset() { return p.misalign_offsets & 0xFFFF; } -+ -+// see init_fastdiv_values in ggml-vulkan.cpp -+uint fastdiv(uint n, uint mp, uint L) { -+ uint msbs, lsbs; -+ // msbs = mulhi(n, mp) -+ umulExtended(n, mp, msbs, lsbs); -+ return (msbs + n) >> L; -+} -+ -+uint src0_idx(uint idx) { -+ const uint i03 = fastdiv(idx, p.ne0_012mp, p.ne0_012L); -+ const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00; -+ const uint i02 = fastdiv(idx - i03_offset, p.ne0_01mp, p.ne0_01L); -+ const uint i02_offset = i02*p.ne01*p.ne00; -+ const uint i01 = fastdiv(idx - i03_offset - i02_offset, p.ne0_0mp, p.ne0_0L); -+ const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00; -+ return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00; -+} -+ -+uint dst_idx(uint idx) { -+ const uint i13 = fastdiv(idx, p.ne1_012mp, p.ne1_012L); -+ const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10; -+ const uint i12 = fastdiv(idx - i13_offset, p.ne1_01mp, p.ne1_01L); -+ const uint i12_offset = i12*p.ne11*p.ne10; -+ const uint i11 = fastdiv(idx - i13_offset - i12_offset, p.ne1_0mp, p.ne1_0L); -+ const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10; -+ return i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + i10*p.nb10; -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp -new file mode 100644 -index 00000000..e877ed77 ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp -@@ -0,0 +1,28 @@ -+#version 450 -+ -+#include "types.comp" -+#include "generic_binary_head.comp" -+ -+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; -+ -+void main() { -+ const uint i00 = gl_GlobalInvocationID.x; -+ const uint i10 = gl_GlobalInvocationID.y; -+ const uint i11 = (gl_GlobalInvocationID.z)/p.ne12; -+ const uint i12 = (gl_GlobalInvocationID.z)%p.ne12; -+ -+ if (i00 >= p.ne00) { -+ return; -+ } -+ -+ const uint i01 = data_b[get_boffset() + i10*p.nb10 + i11*p.nb11 + i12*p.nb12]; -+ -+ const uint a_offset = get_aoffset() + i01*p.nb01 + i11*p.nb02 + i12*p.nb03; -+ const uint d_offset = get_doffset() + i10*p.nb21 + i11*p.nb22 + i12*p.nb23; -+ -+#ifndef OPTIMIZATION_ERROR_WORKAROUND -+ data_d[d_offset + i00] = D_TYPE(data_a[a_offset + i00]); -+#else -+ data_d[d_offset + i00] = data_a[a_offset + i00]; -+#endif -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp -new file mode 100644 -index 00000000..1426fde6 ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp -@@ -0,0 +1,39 @@ -+#version 450 -+ -+#include "types.comp" -+#include "generic_binary_head.comp" -+#include "dequant_funcs.comp" -+ -+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; -+ -+void main() { -+ const uint i00 = (gl_GlobalInvocationID.x)*2; -+ const uint i10 = gl_GlobalInvocationID.y; -+ const uint i11 = (gl_GlobalInvocationID.z)/p.ne12; -+ const uint i12 = (gl_GlobalInvocationID.z)%p.ne12; -+ -+#if defined(DATA_A_IQ4_NL) -+ init_iq4nl_shmem(); -+#endif -+ -+ if (i00 >= p.ne00) { -+ return; -+ } -+ -+ const uint i01 = data_b[i10*p.nb10 + i11*p.nb11 + i12*p.nb12]; -+ -+ const uint a_offset = i01*p.nb01 + i11*p.nb02 + i12*p.nb03; -+ const uint d_offset = i10*p.nb21 + i11*p.nb22 + i12*p.nb23; -+ -+ const uint ib = a_offset + i00/QUANT_K; // block index -+ const uint iqs = (i00%QUANT_K)/QUANT_R; // quant index -+ const uint iybs = i00 - i00%QUANT_K; // dst block start index -+ const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2; -+ -+ vec2 v = dequantize(ib, iqs, 0); -+ const vec2 dm = get_dm(ib, 0); -+ v = v * dm.x + dm.y; -+ -+ data_d[d_offset + iybs + iqs ] = D_TYPE(v.x); -+ data_d[d_offset + iybs + iqs + y_offset] = D_TYPE(v.y); -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp -new file mode 100644 -index 00000000..b6a0d564 ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp -@@ -0,0 +1,66 @@ -+#version 450 -+ -+#include "generic_head.comp" -+#include "types.comp" -+ -+#extension GL_EXT_control_flow_attributes : enable -+#define BLOCK_SIZE 512 -+ -+layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; -+ -+layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; -+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; -+ -+shared float tmp[BLOCK_SIZE]; -+ -+void main() { -+ const uint group_size = p.KX; -+ const float eps = p.param1; -+ -+ const uint tid = gl_LocalInvocationID.x; -+ const uint start = gl_WorkGroupID.x * group_size + tid; -+ const uint end = (gl_WorkGroupID.x + 1) * group_size; -+ -+ tmp[tid] = 0.0f; -+ -+ // Calculate mean -+ [[unroll]] for (uint col = start; col < end; col += BLOCK_SIZE) { -+ tmp[tid] += float(data_a[col]); -+ } -+ -+ // tmp up partial tmps and write back result -+ barrier(); -+ [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { -+ if (tid < s) { -+ tmp[tid] += tmp[tid + s]; -+ } -+ barrier(); -+ } -+ -+ const float mean = tmp[0] / group_size; -+ barrier(); -+ tmp[tid] = 0.0f; -+ -+ // Calculate variance -+ [[unroll]] for (uint col = start; col < end; col += BLOCK_SIZE) { -+ const float xi = float(data_a[col]) - mean; -+ data_d[col] = D_TYPE(xi); -+ tmp[tid] += xi * xi; -+ } -+ -+ // sum up partial sums and write back result -+ barrier(); -+ [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { -+ if (tid < s) { -+ tmp[tid] += tmp[tid + s]; -+ } -+ barrier(); -+ } -+ -+ const float variance = tmp[0] / group_size; -+ const float scale = inversesqrt(variance + eps); -+ -+ [[unroll]] for (uint col = start; col < end; col += BLOCK_SIZE) { -+ data_d[col] *= D_TYPE(scale); -+ } -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp -new file mode 100644 -index 00000000..122b1e93 ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp -@@ -0,0 +1,87 @@ -+#version 450 -+ -+#extension GL_EXT_shader_16bit_storage : require -+#extension GL_EXT_spirv_intrinsics: enable -+#extension GL_EXT_control_flow_attributes : require -+ -+#if RTE16 -+spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits -+#endif -+ -+layout (push_constant) uniform parameter -+{ -+ uint batch_offset; uint offset_delta; -+ uint IC; -+ uint IW; uint IH; -+ uint OW; uint OH; -+ uint KW; uint KH; -+ uint pelements; -+ uint CHW; -+ int s0; int s1; -+ int p0; int p1; -+ int d0; int d1; -+} p; -+ -+#include "types.comp" -+ -+layout(constant_id = 0) const uint BLOCK_SIZE = 32; -+ -+const uint NUM_ITER = 512 / BLOCK_SIZE; -+ -+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; -+ -+layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; -+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; -+ -+void main() { -+ const uint gidx = gl_GlobalInvocationID.x; -+ -+ const uint oh = gl_GlobalInvocationID.y; -+ const uint batch = gl_GlobalInvocationID.z / p.IC; -+ const uint ic = gl_GlobalInvocationID.z % p.IC; -+ -+ A_TYPE values[NUM_ITER]; -+ uint offset_dst[NUM_ITER]; -+ [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { -+ values[idx] = A_TYPE(0); -+ } -+ -+ [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { -+ -+ const uint i = gidx * NUM_ITER + idx; -+ -+ const uint ksize = p.OW * (p.KH > 1 ? p.KW : 1); -+ const uint kx = i / ksize; -+ const uint kd = kx * ksize; -+ const uint ky = (i - kd) / p.OW; -+ const uint ix = i % p.OW; -+ -+ const uint iiw = ix * p.s0 + kx * p.d0 - p.p0; -+ const uint iih = oh * p.s1 + ky * p.d1 - p.p1; -+ -+ offset_dst[idx] = -+ ((batch * p.OH + oh) * p.OW + ix) * p.CHW + -+ (ic * (p.KW * p.KH) + ky * p.KW + kx); -+ -+ if (i >= p.pelements) { -+ continue; -+ } -+ -+ if (iih < p.IH && iiw < p.IW) { -+ const uint offset_src = ic * p.offset_delta + batch * p.batch_offset; -+ values[idx] = data_a[offset_src + iih * p.IW + iiw]; -+ } -+ } -+ -+ [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { -+ -+ const uint i = gidx * NUM_ITER + idx; -+ -+ if (i >= p.pelements) { -+ continue; -+ } -+ -+ data_d[offset_dst[idx]] = D_TYPE(values[idx]); -+ } -+ -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp -new file mode 100644 -index 00000000..d90a99ae ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp -@@ -0,0 +1,22 @@ -+#version 450 -+ -+#include "generic_head.comp" -+#include "types.comp" -+ -+#extension GL_EXT_control_flow_attributes : enable -+ -+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; -+ -+layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; -+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; -+ -+void main() { -+ const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; -+ -+ if (i >= p.KX) { -+ return; -+ } -+ -+ const float val = float(data_a[i]); -+ data_d[i] = D_TYPE(max(val, 0.0f) + min(val, 0.0f) * p.param1); -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp -new file mode 100644 -index 00000000..43de19df ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp -@@ -0,0 +1,27 @@ -+#version 450 -+ -+#include "types.comp" -+#include "generic_binary_head.comp" -+ -+const uint num_threads = 256; -+ -+layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; -+ -+void main() { -+ uint idx = get_idx(); -+ -+ // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation -+ const uint num_iter = 2; -+ -+ [[unroll]] for (uint i = 0; i < num_iter; ++i) { -+ if (idx >= p.ne) { -+ continue; -+ } -+ uint i00, i01, i02, i03; -+ get_indices(idx, i00, i01, i02, i03); -+ -+ data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) * FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)])); -+ -+ idx += num_threads; -+ } -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_split_k_reduce.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_split_k_reduce.comp -new file mode 100644 -index 00000000..4c64fd47 ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_split_k_reduce.comp -@@ -0,0 +1,48 @@ -+#version 450 -+ -+#extension GL_EXT_control_flow_attributes : enable -+ -+layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; -+ -+layout (binding = 0) readonly buffer A {float data_a[];}; -+layout (binding = 0) readonly buffer A4 {vec4 data_a4[];}; -+layout (binding = 1) writeonly buffer D {float data_d[];}; -+layout (binding = 1) writeonly buffer D4 {vec4 data_d4[];}; -+ -+layout (push_constant) uniform parameter { -+ uint ne; -+ uint k_num; -+} p; -+ -+void main() { -+ // Each invocation handles four consecutive components -+ const uint idx = gl_GlobalInvocationID.x * 4; -+ -+ if (idx >= p.ne) { -+ return; -+ } -+ -+ // Check if all four components are in bounds and aligned, -+ // then use vector loads -+ if (idx + 3 < p.ne && (p.ne % 4) == 0) { -+ vec4 result = vec4(0.0f); -+ -+ [[unroll]] for (uint i = 0; i < p.k_num; i++) { -+ result += data_a4[(i * p.ne + idx) / 4]; -+ } -+ -+ data_d4[idx / 4] = result; -+ } else { -+ [[unroll]] for (uint j = 0; j < 4; ++j) { -+ if (idx + j < p.ne) { -+ float result = 0.0f; -+ -+ [[unroll]] for (uint i = 0; i < p.k_num; i++) { -+ result += data_a[i * p.ne + idx + j]; -+ } -+ -+ data_d[idx + j] = result; -+ } -+ } -+ } -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp -new file mode 100644 -index 00000000..24875cdc ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp -@@ -0,0 +1,152 @@ -+#version 450 -+ -+#ifdef FLOAT16 -+#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require -+#endif -+#extension GL_EXT_shader_explicit_arithmetic_types : require -+ -+#include "mul_mat_vec_base.comp" -+ -+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; -+ -+#if !defined(DATA_A_F32) && !defined(DATA_A_F16) -+#define K_PER_ITER 8 -+#else -+#define K_PER_ITER 2 -+#endif -+ -+ -+uint a_offset, b_offset, d_offset, y_offset; -+ -+void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i, bool lastiter) -+{ -+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { -+ const uint col = i*BLOCK_SIZE + K_PER_ITER*tid; -+ const uint iqs = (col%QUANT_K)/QUANT_R; // quant index -+ const uint iybs = col - col%QUANT_K; // y block start index -+ -+#if K_PER_ITER == 8 -+#if QUANT_R == 2 -+ const B_TYPE_VEC4 bv02 = data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]; -+ const B_TYPE_VEC4 bv13 = data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs + y_offset) / 4]; -+ const vec4 bv0 = vec4(bv02.x, bv13.x, bv02.y, bv13.y); -+ const vec4 bv1 = vec4(bv02.z, bv13.z, bv02.w, bv13.w); -+#else -+ const vec4 bv0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]); -+ const vec4 bv1 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4 + 1]); -+#endif -+#else -+ // Check if the second of the pair of elements is OOB, and don't fetch B or -+ // accumulate it. We still fetch a pair of elements for A, which is fine for -+ // quantized formats since they'll be within the same block. We should -+ // probably skip fetching the second element for F16/F32, but as of now we -+ // still do. -+ const bool OOB = lastiter && (iybs + iqs + y_offset >= p.ncols); -+ -+ FLOAT_TYPE b0 = 0, b1 = 0; -+ b0 = FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]); -+ if (!OOB) { -+ b1 = FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset]); -+ } -+#endif -+ uint ibi = first_row*p.ncols; -+ [[unroll]] for (uint n = 0; n < num_rows; ++n) { -+ const uint ib = (ibi + col)/QUANT_K; // block index -+ ibi += p.ncols; -+ -+#if K_PER_ITER == 8 -+ vec4 v = dequantize4(ib, iqs, a_offset); -+ vec4 v2 = dequantize4(ib, iqs+(4/QUANT_R), a_offset); -+ -+ const vec2 dm = get_dm(ib, a_offset); -+ if (dm.y != 0) { // quant has min component -+ v = v * dm.x + dm.y; -+ v2 = v2 * dm.x + dm.y; -+ } -+ -+ // matrix multiplication -+ FLOAT_TYPE rowtmp = dot(bv0, v); -+ rowtmp += dot(bv1, v2); -+ -+ if (dm.y == 0) -+ rowtmp *= dm.x; -+ -+ temp[j][n] += rowtmp; -+#else -+ const vec2 v = dequantize(ib, iqs, a_offset); -+ -+ // matrix multiplication -+ temp[j][n] = fma(FLOAT_TYPE(v.x), b0, temp[j][n]); -+ if (!OOB) { -+ temp[j][n] = fma(FLOAT_TYPE(v.y), b1, temp[j][n]); -+ } -+#endif -+ } -+ } -+} -+ -+void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { -+ const uint tid = gl_LocalInvocationID.x; -+ -+ get_offsets(a_offset, b_offset, d_offset); -+ a_offset /= QUANT_K; -+ -+ y_offset = QUANT_R == 1 ? 1 : QUANT_K/2; -+ -+ FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; -+ -+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { -+ [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { -+ temp[j][i] = FLOAT_TYPE(0); -+ } -+ } -+ -+ uint num_iters = p.ncols / (K_PER_ITER * BLOCK_SIZE); -+ if (num_iters * K_PER_ITER * BLOCK_SIZE + K_PER_ITER*tid < p.ncols) { -+ num_iters++; -+ } -+ int unroll_count = 4; -+ uint unrolled_iters = num_iters & ~(unroll_count - 1); -+ -+ uint i = 0; -+ while (i < unrolled_iters) { -+ // Manually partially unroll the loop -+ [[unroll]] for (uint k = 0; k < unroll_count; ++k) { -+ iter(temp, first_row, num_rows, tid, i*K_PER_ITER, false); -+ i++; -+ } -+ } -+ unroll_count = 2; -+ unrolled_iters = num_iters & ~(unroll_count - 1); -+ while (i < unrolled_iters) { -+ // Manually partially unroll the loop -+ [[unroll]] for (uint k = 0; k < unroll_count; ++k) { -+ iter(temp, first_row, num_rows, tid, i*K_PER_ITER, false); -+ i++; -+ } -+ } -+ while (i < num_iters) { -+ iter(temp, first_row, num_rows, tid, i*K_PER_ITER, true); -+ i++; -+ } -+ -+ reduce_result(temp, d_offset, first_row, num_rows, tid); -+} -+ -+void main() { -+ const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); -+ -+#if defined(DATA_A_IQ4_NL) -+ init_iq4nl_shmem(); -+#endif -+ -+ // do NUM_ROWS at a time, unless there aren't enough remaining rows -+ if (first_row + NUM_ROWS <= p.stride_d) { -+ compute_outputs(first_row, NUM_ROWS); -+ } else { -+ if (first_row >= p.stride_d) { -+ return; -+ } -+ compute_outputs(first_row, p.stride_d - first_row); -+ } -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp -new file mode 100644 -index 00000000..903753c7 ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp -@@ -0,0 +1,118 @@ -+#extension GL_EXT_control_flow_attributes : enable -+#extension GL_EXT_shader_16bit_storage : require -+#extension GL_EXT_shader_8bit_storage : require -+ -+#ifdef MUL_MAT_ID -+#define EXPERT_COUNT 8 -+#endif -+ -+#include "types.comp" -+ -+layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; -+layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; -+layout (binding = 1) readonly buffer BV2 {B_TYPE_VEC2 data_b_v2[];}; -+layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];}; -+ -+layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; -+#ifdef MUL_MAT_ID -+layout (binding = 3) readonly buffer IDS {int data_ids[];}; -+#endif -+ -+#include "dequant_funcs.comp" -+ -+layout (push_constant) uniform parameter -+{ -+ uint ncols; -+ uint stride_a; -+ uint stride_b; -+ uint stride_d; -+ -+ uint batch_stride_a; -+ uint batch_stride_b; -+ uint batch_stride_d; -+ -+#ifdef MUL_MAT_ID -+ uint nei0; -+ uint ne11; -+#else -+ uint ne02; -+ uint ne12; -+ uint broadcast2; -+ uint broadcast3; -+#endif -+} p; -+ -+void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) { -+#ifdef MUL_MAT_ID -+ const uint expert_idx = gl_GlobalInvocationID.y; -+#else -+ const uint batch_idx = gl_GlobalInvocationID.y; -+#endif -+ -+#ifndef MUL_MAT_ID -+ uint batch_idx_a = 0; -+ if (batch_idx != 0) { -+ const uint i13 = batch_idx / p.ne12; -+ const uint i12 = batch_idx % p.ne12; -+ -+ const uint i03 = i13 / p.broadcast3; -+ const uint i02 = i12 / p.broadcast2; -+ -+ batch_idx_a = i03 * p.ne02 + i02; -+ } -+#else -+ const uint expert_id = data_ids[expert_idx]; -+#endif -+ -+ a_offset = -+#ifdef MUL_MAT_ID -+ expert_id * p.batch_stride_a; -+#else -+ batch_idx_a * p.batch_stride_a; -+#endif -+ b_offset = -+#ifdef MUL_MAT_ID -+ (expert_idx % p.ne11) * p.stride_b; -+#else -+ batch_idx * p.batch_stride_b; -+#endif -+ d_offset = -+#ifdef MUL_MAT_ID -+ expert_idx * p.stride_d; -+#else -+ batch_idx * p.batch_stride_d; -+#endif -+} -+ -+layout (constant_id = 0) const uint BLOCK_SIZE = 32; -+layout (constant_id = 1) const uint NUM_ROWS = 1; -+layout (constant_id = 2) const uint NUM_COLS = 1; -+ -+shared FLOAT_TYPE tmpsh[NUM_COLS][NUM_ROWS][BLOCK_SIZE]; -+ -+void reduce_result(const in FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid) { -+ // sum up partial sums and write back result -+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { -+ [[unroll]] for (uint n = 0; n < num_rows; ++n) { -+ tmpsh[j][n][tid] = temp[j][n]; -+ } -+ } -+ barrier(); -+ [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) { -+ if (tid < s) { -+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { -+ [[unroll]] for (uint n = 0; n < num_rows; ++n) { -+ tmpsh[j][n][tid] += tmpsh[j][n][tid + s]; -+ } -+ } -+ } -+ barrier(); -+ } -+ if (tid == 0) { -+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { -+ [[unroll]] for (uint n = 0; n < num_rows; ++n) { -+ data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(tmpsh[j][n][0]); -+ } -+ } -+ } -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp -new file mode 100644 -index 00000000..1cc4996d ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp -@@ -0,0 +1,71 @@ -+#version 450 -+ -+#extension GL_EXT_control_flow_attributes : enable -+#extension GL_EXT_shader_16bit_storage : require -+ -+#define BLOCK_SIZE 32 -+#define FLOAT_TYPE float -+ -+layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; -+ -+layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; -+layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; -+layout (binding = 2) writeonly buffer D {D_TYPE dst[];}; -+ -+layout (push_constant) uniform parameter -+{ -+ uint ncols_x; -+ uint nrows_x; -+ uint row_stride_x; -+ uint channel_stride_x; -+ uint channel_x_divisor; -+ uint b_offset; -+ uint d_offset; -+} p; -+ -+shared FLOAT_TYPE tmp[BLOCK_SIZE]; -+ -+void main() { -+ const uint tid = gl_LocalInvocationID.x; -+ const uint row_x = gl_GlobalInvocationID.y; -+ const uint channel = gl_GlobalInvocationID.z; -+ const uint channel_x = channel / p.channel_x_divisor; -+ -+ const uint nrows_y = p.ncols_x; -+ const uint nrows_dst = p.nrows_x; -+ const uint row_dst = row_x; -+ -+ const uint idst = channel*nrows_dst + row_dst; -+ -+ tmp[tid] = 0.0f; -+ -+ for (uint col_x0 = 0; col_x0 < p.ncols_x; col_x0 += BLOCK_SIZE) { -+ const uint col_x = col_x0 + tid; -+ -+ if (col_x >= p.ncols_x) { -+ break; -+ } -+ -+ const uint row_y = col_x; -+ -+ const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x; -+ const uint iy = channel*nrows_y + row_y; -+ -+ const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]); -+ -+ tmp[tid] = fma(xi, FLOAT_TYPE(data_b[iy]), tmp[tid]); -+ } -+ -+ // sum up partial sums and write back result -+ barrier(); -+ [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { -+ if (tid < s) { -+ tmp[tid] += tmp[tid + s]; -+ } -+ barrier(); -+ } -+ -+ if (tid == 0) { -+ dst[idst] = tmp[0]; -+ } -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp -new file mode 100644 -index 00000000..9b443807 ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp -@@ -0,0 +1,73 @@ -+#version 450 -+ -+#extension GL_EXT_control_flow_attributes : enable -+#extension GL_EXT_shader_16bit_storage : require -+ -+#define BLOCK_SIZE 32 -+#define FLOAT_TYPE float -+ -+layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; -+ -+layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; -+layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; -+layout (binding = 2) writeonly buffer D {D_TYPE dst[];}; -+ -+layout (push_constant) uniform parameter -+{ -+ uint ncols_x; -+ uint nrows_x; -+ uint nchannels_x; -+ uint nchannels_y; -+ uint b_offset; -+ uint d_offset; -+} p; -+ -+shared FLOAT_TYPE tmp[BLOCK_SIZE]; -+ -+void main() { -+ const uint tid = gl_LocalInvocationID.x; -+ const uint row_x = gl_GlobalInvocationID.y; -+ const uint channel = gl_GlobalInvocationID.z; -+ const uint channel_x = channel / (p.nchannels_y / p.nchannels_x); -+ -+ const uint nrows_y = p.ncols_x; -+ const uint nrows_dst = p.nrows_x; -+ const uint row_dst = row_x; -+ -+ tmp[tid] = FLOAT_TYPE(0.0f); -+ -+ for (uint col_x0 = 0; col_x0 < p.ncols_x; col_x0 += BLOCK_SIZE) { -+ const uint col_x = col_x0 + tid; -+ -+ if (col_x >= p.ncols_x) { -+ break; -+ } -+ -+ // x is transposed and permuted -+ const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x; -+ const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]); -+ -+ const uint row_y = col_x; -+ -+ // y is not transposed but permuted -+ const uint iy = channel*nrows_y + row_y; -+ -+ tmp[tid] = fma(xi, FLOAT_TYPE(data_b[iy]), tmp[tid]); -+ } -+ -+ // dst is not transposed and not permuted -+ const uint idst = channel*nrows_dst + row_dst; -+ -+ // sum up partial sums and write back result -+ barrier(); -+ [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { -+ if (tid < s) { -+ tmp[tid] += tmp[tid + s]; -+ } -+ barrier(); -+ } -+ -+ if (tid == 0) { -+ dst[idst] = tmp[0]; -+ } -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp -new file mode 100644 -index 00000000..93421344 ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp -@@ -0,0 +1,115 @@ -+#version 450 -+#extension GL_EXT_shader_explicit_arithmetic_types : require -+ -+#include "mul_mat_vec_base.comp" -+ -+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; -+ -+void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { -+ uint a_offset, b_offset, d_offset; -+ get_offsets(a_offset, b_offset, d_offset); -+ -+ const uint num_blocks_per_row = p.ncols / QUANT_K; -+ -+ // 16 threads are used to process each block -+ const uint it_size = gl_WorkGroupSize.x/16; -+ const uint tid = gl_LocalInvocationID.x; -+ const uint itid = tid%16; // 0...16 -+ const uint ix = tid/16; -+ -+ const uint step = 8; -+ -+ const uint v_im = itid/step; // 0 or 1. 0 computes 0..., 1 computes 128... -+ const uint v_in = itid - step*v_im; // 0...15 or 0...7 -+ -+ const uint l0 = 2*v_in; // 0...15 -+ const uint q_offset = 32*v_im + l0; -+ const uint s_offset = 8*v_im; -+ const uint y_offset = 128*v_im + l0; -+ -+ FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; -+ -+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { -+ [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { -+ temp[j][i] = FLOAT_TYPE(0); -+ } -+ } -+ -+ [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { -+ const uint y_idx = i * QUANT_K + y_offset; -+ -+ [[unroll]] for (uint n = 0; n < num_rows; ++n) { -+ const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; -+ f16vec2 d = data_a[ib0 + i].d; -+ const FLOAT_TYPE dall = d.x; -+ const FLOAT_TYPE dmin = d.y; -+ -+ uint32_t s0_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 0]; -+ uint32_t s4_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 1]; -+ -+ uint32_t s0_lo4_u32 = s0_u32 & 0x0F0F0F0F; -+ uint32_t s0_hi4_u32 = (s0_u32 >> 4) & 0x0F0F0F0F; -+ uint32_t s4_lo4_u32 = s4_u32 & 0x0F0F0F0F; -+ uint32_t s4_hi4_u32 = (s4_u32 >> 4) & 0x0F0F0F0F; -+ -+ uvec4 s0_lo4 = uvec4(unpack8(s0_lo4_u32)); -+ uvec4 s4_lo4 = uvec4(unpack8(s4_lo4_u32)); -+ uvec4 s0_hi4 = uvec4(unpack8(s0_hi4_u32)); -+ uvec4 s4_hi4 = uvec4(unpack8(s4_hi4_u32)); -+ -+ uint16_t qs0_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2 + 0]; -+ uint16_t qs16_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]; -+ uvec2 qs0 = uvec2(unpack8(qs0_u16)); -+ uvec2 qs16 = uvec2(unpack8(qs16_u16)); -+ -+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { -+ B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]; -+ B_TYPE_VEC2 b16 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]; -+ B_TYPE_VEC2 b32 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]; -+ B_TYPE_VEC2 b48 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]; -+ B_TYPE_VEC2 b64 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]; -+ B_TYPE_VEC2 b80 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]; -+ B_TYPE_VEC2 b96 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]; -+ B_TYPE_VEC2 b112 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]; -+ -+ FLOAT_TYPE sum1 = FLOAT_TYPE(0.0); -+ FLOAT_TYPE sum2 = FLOAT_TYPE(0.0); -+ [[unroll]] for (int l = 0; l < 2; ++l) { -+ sum1 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_lo4[0]) * FLOAT_TYPE((qs0[l] >> 0) & 3), -+ fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_lo4[1]) * FLOAT_TYPE((qs16[l] >> 0) & 3), -+ fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_lo4[2]) * FLOAT_TYPE((qs0[l] >> 2) & 3), -+ fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_lo4[3]) * FLOAT_TYPE((qs16[l] >> 2) & 3), -+ fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_lo4[0]) * FLOAT_TYPE((qs0[l] >> 4) & 3), -+ fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_lo4[1]) * FLOAT_TYPE((qs16[l] >> 4) & 3), -+ fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_lo4[2]) * FLOAT_TYPE((qs0[l] >> 6) & 3), -+ fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_lo4[3]) * FLOAT_TYPE((qs16[l] >> 6) & 3), sum1)))))))); -+ sum2 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_hi4[0]), -+ fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_hi4[1]), -+ fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_hi4[2]), -+ fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_hi4[3]), -+ fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_hi4[0]), -+ fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_hi4[1]), -+ fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_hi4[2]), -+ fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_hi4[3]), sum2)))))))); -+ } -+ temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n])); -+ } -+ } -+ } -+ -+ reduce_result(temp, d_offset, first_row, num_rows, tid); -+} -+ -+void main() { -+ const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); -+ -+ // do NUM_ROWS at a time, unless there aren't enough remaining rows -+ if (first_row + NUM_ROWS <= p.stride_d) { -+ compute_outputs(first_row, NUM_ROWS); -+ } else { -+ if (first_row >= p.stride_d) { -+ return; -+ } -+ compute_outputs(first_row, p.stride_d - first_row); -+ } -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp -new file mode 100644 -index 00000000..86b0159d ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp -@@ -0,0 +1,103 @@ -+#version 450 -+#extension GL_EXT_shader_explicit_arithmetic_types : require -+ -+#include "mul_mat_vec_base.comp" -+ -+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; -+ -+void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { -+ uint a_offset, b_offset, d_offset; -+ get_offsets(a_offset, b_offset, d_offset); -+ -+ const uint num_blocks_per_row = p.ncols / QUANT_K; -+ -+ // 16 threads are used to process each block -+ const uint it_size = gl_WorkGroupSize.x/16; -+ const uint tid = gl_LocalInvocationID.x; -+ const uint itid = tid%16; // 0...16 -+ const uint ix = tid/16; -+ -+ const uint step = 8; -+ -+ const uint v_im = itid/step; // 0 or 1. 0 computes 0..., 1 computes 128... -+ const uint v_in = itid - step*v_im; // 0...15 or 0...7 -+ -+ const uint8_t m = uint8_t(1 << (4 * v_im)); -+ -+ const uint l0 = 2*v_in; // 0...15 -+ const uint q_offset = 32*v_im + l0; -+ const uint y_offset = 128*v_im + l0; -+ -+ FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; -+ -+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { -+ [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { -+ temp[j][i] = FLOAT_TYPE(0); -+ } -+ } -+ -+ const uint s_shift = 4 * v_im; -+ -+ [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { -+ const uint y_idx = i * QUANT_K + y_offset; -+ -+ [[unroll]] for (uint n = 0; n < num_rows; ++n) { -+ const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; -+ const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); -+ -+ uint16_t s0_16 = data_a_packed16[ib0 + i].scales[0]; -+ uint16_t s2_16 = data_a_packed16[ib0 + i].scales[1]; -+ uint16_t s4_16 = data_a_packed16[ib0 + i].scales[2]; -+ uint16_t s6_16 = data_a_packed16[ib0 + i].scales[3]; -+ uint16_t s8_16 = data_a_packed16[ib0 + i].scales[4]; -+ uint16_t s10_16 = data_a_packed16[ib0 + i].scales[5]; -+ u8vec2 s0 = unpack8(s0_16); -+ u8vec2 s2 = unpack8(s2_16); -+ u8vec2 s4 = unpack8(s4_16); -+ u8vec2 s6 = unpack8(s6_16); -+ u8vec2 s8 = unpack8(s8_16); -+ u8vec2 s10 = unpack8(s10_16); -+ -+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { -+ -+ B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]; -+ B_TYPE_VEC2 b16 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]; -+ B_TYPE_VEC2 b32 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]; -+ B_TYPE_VEC2 b48 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]; -+ B_TYPE_VEC2 b64 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]; -+ B_TYPE_VEC2 b80 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]; -+ B_TYPE_VEC2 b96 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]; -+ B_TYPE_VEC2 b112 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]; -+ -+ FLOAT_TYPE sum = FLOAT_TYPE(0.0); -+ [[unroll]] for (int l = 0; l < 2; ++l) { -+ sum = fma(FLOAT_TYPE(b0[l]) * FLOAT_TYPE(int8_t(((s0[0] >> s_shift) & 0xF) | ((s8[0] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 0)) != 0) ? 0 : 4)), -+ fma(FLOAT_TYPE(b32[l]) * FLOAT_TYPE(int8_t(((s2[0] >> s_shift) & 0xF) | ((s10[0] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 1)) != 0) ? 0 : 4)), -+ fma(FLOAT_TYPE(b64[l]) * FLOAT_TYPE(int8_t(((s4[0] >> s_shift) & 0xF) | ((s8[0] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 2)) != 0) ? 0 : 4)), -+ fma(FLOAT_TYPE(b96[l]) * FLOAT_TYPE(int8_t(((s6[0] >> s_shift) & 0xF) | ((s10[0] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 3)) != 0) ? 0 : 4)), -+ fma(FLOAT_TYPE(b16[l]) * FLOAT_TYPE(int8_t(((s0[1] >> s_shift) & 0xF) | ((s8[1] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 0)) != 0) ? 0 : 4)), -+ fma(FLOAT_TYPE(b48[l]) * FLOAT_TYPE(int8_t(((s2[1] >> s_shift) & 0xF) | ((s10[1] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 1)) != 0) ? 0 : 4)), -+ fma(FLOAT_TYPE(b80[l]) * FLOAT_TYPE(int8_t(((s4[1] >> s_shift) & 0xF) | ((s8[1] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 2)) != 0) ? 0 : 4)), -+ fma(FLOAT_TYPE(b112[l]) * FLOAT_TYPE(int8_t(((s6[1] >> s_shift) & 0xF) | ((s10[1] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 3)) != 0) ? 0 : 4)), sum)))))))); -+ } -+ temp[j][n] = fma(d, sum, temp[j][n]); -+ } -+ } -+ } -+ -+ reduce_result(temp, d_offset, first_row, num_rows, tid); -+} -+ -+void main() { -+ const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); -+ -+ // do NUM_ROWS at a time, unless there aren't enough remaining rows -+ if (first_row + NUM_ROWS <= p.stride_d) { -+ compute_outputs(first_row, NUM_ROWS); -+ } else { -+ if (first_row >= p.stride_d) { -+ return; -+ } -+ compute_outputs(first_row, p.stride_d - first_row); -+ } -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp -new file mode 100644 -index 00000000..cd1dd8e8 ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp -@@ -0,0 +1,133 @@ -+#version 450 -+ -+#extension GL_EXT_shader_explicit_arithmetic_types : require -+ -+#include "mul_mat_vec_base.comp" -+ -+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; -+ -+void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { -+ uint a_offset, b_offset, d_offset; -+ get_offsets(a_offset, b_offset, d_offset); -+ -+ const uint num_blocks_per_row = p.ncols / QUANT_K; -+ -+ // 16 threads are used to process each block -+ const uint it_size = gl_WorkGroupSize.x/16; -+ const uint tid = gl_LocalInvocationID.x; -+ const uint itid = tid%16; // 0...16 -+ const uint ix = tid/16; -+ -+ const uint step = 4; -+ -+ const uint il = itid/step; // 0...3 -+ const uint ir = itid - step*il; // 0...7 or 0...3 -+ const uint n = 4; -+ -+ const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 -+ const uint v_in = il % 2; -+ -+ const uint l0 = n * (2 * ir + v_in); // 0...15 -+ const uint q_offset = 32*v_im + l0; -+ const uint y_offset = 64*v_im + l0; -+ -+ FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; -+ -+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { -+ [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { -+ temp[j][i] = FLOAT_TYPE(0); -+ } -+ } -+ -+ [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { -+ const uint y1_idx = i * QUANT_K + y_offset; -+ const uint y2_idx = y1_idx + 128; -+ -+ [[unroll]] for (uint n = 0; n < num_rows; ++n) { -+ const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; -+ f16vec2 d = data_a[ib0 + i].d; -+ const FLOAT_TYPE dall = FLOAT_TYPE(d.x); -+ const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); -+ -+ uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; -+ uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; -+ uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4]; -+ uvec4 scale0 = uvec4(unpack8(scale0_u32)); -+ uvec4 scale4 = uvec4(unpack8(scale4_u32)); -+ uvec4 scale8 = uvec4(unpack8(scale8_u32)); -+ -+ const uint32_t sc0 = ( scale0.x & 0x3f); -+ const uint32_t sc1 = ( scale0.y & 0x3f); -+ const uint32_t sc2 = ( scale4.x & 0x3f); -+ const uint32_t sc3 = ( scale4.y & 0x3f); -+ const uint32_t sc4 = (( scale8.x & 0x0f) | ((scale0.x & 0xc0) >> 2)); -+ const uint32_t sc5 = (( scale8.y & 0x0f) | ((scale0.y & 0xc0) >> 2)); -+ const uint32_t sc6 = (((scale8.x >> 4) & 0x0f) | ((scale4.x & 0xc0) >> 2)); -+ const uint32_t sc7 = (((scale8.y >> 4) & 0x0f) | ((scale4.y & 0xc0) >> 2)); -+ -+ uint32_t qs0_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4]; -+ uint32_t qs64_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4 + 16]; -+ -+ uint32_t qs0_u32_lo4 = qs0_u32 & 0x0F0F0F0F; -+ uint32_t qs0_u32_hi4 = (qs0_u32 >> 4) & 0x0F0F0F0F; -+ uint32_t qs64_u32_lo4 = qs64_u32 & 0x0F0F0F0F; -+ uint32_t qs64_u32_hi4 = (qs64_u32 >> 4) & 0x0F0F0F0F; -+ -+ uvec4 qs0_lo4 = uvec4(unpack8(qs0_u32_lo4)); -+ uvec4 qs64_lo4 = uvec4(unpack8(qs64_u32_lo4)); -+ uvec4 qs0_hi4 = uvec4(unpack8(qs0_u32_hi4)); -+ uvec4 qs64_hi4 = uvec4(unpack8(qs64_u32_hi4)); -+ -+ const uint32_t q4_0 = qs0_lo4.x; -+ const uint32_t q4_1 = qs0_lo4.y; -+ const uint32_t q4_2 = qs0_lo4.z; -+ const uint32_t q4_3 = qs0_lo4.w; -+ const uint32_t q4_4 = qs0_hi4.x; -+ const uint32_t q4_5 = qs0_hi4.y; -+ const uint32_t q4_6 = qs0_hi4.z; -+ const uint32_t q4_7 = qs0_hi4.w; -+ const uint32_t q4_8 = qs64_lo4.x; -+ const uint32_t q4_9 = qs64_lo4.y; -+ const uint32_t q4_10 = qs64_lo4.z; -+ const uint32_t q4_11 = qs64_lo4.w; -+ const uint32_t q4_12 = qs64_hi4.x; -+ const uint32_t q4_13 = qs64_hi4.y; -+ const uint32_t q4_14 = qs64_hi4.z; -+ const uint32_t q4_15 = qs64_hi4.w; -+ -+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { -+ B_TYPE_VEC4 by10 = data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4]; -+ B_TYPE_VEC4 by132 = data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 + 8]; -+ B_TYPE_VEC4 by20 = data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4]; -+ B_TYPE_VEC4 by232 = data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 + 8]; -+ -+ const FLOAT_TYPE sx = fma(FLOAT_TYPE(by10.x), q4_0, fma(FLOAT_TYPE(by10.y), q4_1, fma(FLOAT_TYPE(by10.z), q4_2, FLOAT_TYPE(by10.w) * q4_3))); -+ const FLOAT_TYPE sy = fma(FLOAT_TYPE(by132.x), q4_4, fma(FLOAT_TYPE(by132.y), q4_5, fma(FLOAT_TYPE(by132.z), q4_6, FLOAT_TYPE(by132.w) * q4_7))); -+ const FLOAT_TYPE sz = fma(FLOAT_TYPE(by20.x), q4_8, fma(FLOAT_TYPE(by20.y), q4_9, fma(FLOAT_TYPE(by20.z), q4_10, FLOAT_TYPE(by20.w) * q4_11))); -+ const FLOAT_TYPE sw = fma(FLOAT_TYPE(by232.x), q4_12, fma(FLOAT_TYPE(by232.y), q4_13, fma(FLOAT_TYPE(by232.z), q4_14, FLOAT_TYPE(by232.w) * q4_15))); -+ const FLOAT_TYPE smin = -+ fma(FLOAT_TYPE(by10.x), sc2, fma(FLOAT_TYPE(by132.x), sc3, fma(FLOAT_TYPE(by20.x), sc6, fma(FLOAT_TYPE(by232.x), sc7, -+ fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7, -+ fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7, -+ fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6, FLOAT_TYPE(by232.w) * sc7))))))))))))))); -+ temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n])); -+ } -+ } -+ } -+ -+ reduce_result(temp, d_offset, first_row, num_rows, tid); -+} -+ -+void main() { -+ const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); -+ -+ // do NUM_ROWS at a time, unless there aren't enough remaining rows -+ if (first_row + NUM_ROWS <= p.stride_d) { -+ compute_outputs(first_row, NUM_ROWS); -+ } else { -+ if (first_row >= p.stride_d) { -+ return; -+ } -+ compute_outputs(first_row, p.stride_d - first_row); -+ } -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp -new file mode 100644 -index 00000000..0a68891c ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp -@@ -0,0 +1,162 @@ -+#version 450 -+ -+#extension GL_EXT_shader_explicit_arithmetic_types : require -+ -+#include "mul_mat_vec_base.comp" -+ -+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; -+ -+void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { -+ uint a_offset, b_offset, d_offset; -+ get_offsets(a_offset, b_offset, d_offset); -+ -+ const uint num_blocks_per_row = p.ncols / QUANT_K; -+ -+ // 16 threads are used to process each block -+ const uint it_size = gl_WorkGroupSize.x/16; -+ const uint tid = gl_LocalInvocationID.x; -+ const uint itid = tid%16; // 0...16 -+ const uint ix = tid/16; -+ -+ const uint il = itid/4; // 0...3 -+ const uint ir = itid - 4*il; // 0...7 or 0...3 -+ -+ const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 -+ const uint v_in = il % 2; -+ -+ const uint l0 = 4*ir + 2*v_in; // 0...15 -+ const uint q_offset = 32*v_im + l0; -+ const uint y_offset = 64*v_im + l0; -+ -+ FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; -+ -+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { -+ [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { -+ temp[j][i] = FLOAT_TYPE(0); -+ } -+ } -+ -+ [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { -+ const uint y1_idx = i * QUANT_K + y_offset; -+ const uint y2_idx = y1_idx + 128; -+ -+ [[unroll]] for (uint n = 0; n < num_rows; ++n) { -+ const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; -+ f16vec2 d = data_a[ib0 + i].d; -+ const FLOAT_TYPE dall = FLOAT_TYPE(d.x); -+ const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); -+ -+ uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; -+ uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; -+ uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4]; -+ uvec4 scale0 = uvec4(unpack8(scale0_u32)); -+ uvec4 scale4 = uvec4(unpack8(scale4_u32)); -+ uvec4 scale8 = uvec4(unpack8(scale8_u32)); -+ -+ const uint32_t sc0 = ( scale0.x & 0x3f); -+ const uint32_t sc1 = ( scale0.y & 0x3f); -+ const uint32_t sc2 = ( scale4.x & 0x3f); -+ const uint32_t sc3 = ( scale4.y & 0x3f); -+ const uint32_t sc4 = (( scale8.x & 0x0f) | ((scale0.x & 0xc0) >> 2)); -+ const uint32_t sc5 = (( scale8.y & 0x0f) | ((scale0.y & 0xc0) >> 2)); -+ const uint32_t sc6 = (((scale8.x >> 4) & 0x0f) | ((scale4.x & 0xc0) >> 2)); -+ const uint32_t sc7 = (((scale8.y >> 4) & 0x0f) | ((scale4.y & 0xc0) >> 2)); -+ -+ uint32_t qs0_16_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16); -+ uint32_t qs64_80_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 32]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 40]) << 16); -+ -+ uint32_t qs0_16_u32_lo4 = qs0_16_u32 & 0x0F0F0F0F; -+ uint32_t qs0_16_u32_hi4 = (qs0_16_u32 >> 4) & 0x0F0F0F0F; -+ uint32_t qs64_80_u32_lo4 = qs64_80_u32 & 0x0F0F0F0F; -+ uint32_t qs64_80_u32_hi4 = (qs64_80_u32 >> 4) & 0x0F0F0F0F; -+ -+ uint32_t qh = pack32(u16vec2(data_a_packed16[ib0 + i].qh[l0 / 2], data_a_packed16[ib0 + i].qh[l0 / 2 + 8])); -+ -+ uint32_t qs0_16_lo4_offset16 = ((qh >> (2*v_im)) & 0x01010101) << 4; -+ uint32_t qs0_16_hi4_offset16 = ((qh >> (2*v_im)) & 0x02020202) << 3; -+ uint32_t qs64_80_lo4_offset16 = ((qh >> (2*v_im)) & 0x10101010) << 0; -+ uint32_t qs64_80_hi4_offset16 = ((qh >> (2*v_im)) & 0x20202020) >> 1; -+ -+ qs0_16_u32_lo4 += qs0_16_lo4_offset16; -+ qs0_16_u32_hi4 += qs0_16_hi4_offset16; -+ qs64_80_u32_lo4 += qs64_80_lo4_offset16; -+ qs64_80_u32_hi4 += qs64_80_hi4_offset16; -+ -+ uvec4 qs0_16_lo4 = uvec4(unpack8(qs0_16_u32_lo4)); -+ uvec4 qs64_80_lo4 = uvec4(unpack8(qs64_80_u32_lo4)); -+ uvec4 qs0_16_hi4 = uvec4(unpack8(qs0_16_u32_hi4)); -+ uvec4 qs64_80_hi4 = uvec4(unpack8(qs64_80_u32_hi4)); -+ -+ const uint32_t q4_0 = qs0_16_lo4.x; -+ const uint32_t q4_1 = qs0_16_lo4.y; -+ const uint32_t q4_2 = qs0_16_lo4.z; -+ const uint32_t q4_3 = qs0_16_lo4.w; -+ const uint32_t q4_4 = qs0_16_hi4.x; -+ const uint32_t q4_5 = qs0_16_hi4.y; -+ const uint32_t q4_6 = qs0_16_hi4.z; -+ const uint32_t q4_7 = qs0_16_hi4.w; -+ const uint32_t q4_8 = qs64_80_lo4.x; -+ const uint32_t q4_9 = qs64_80_lo4.y; -+ const uint32_t q4_10 = qs64_80_lo4.z; -+ const uint32_t q4_11 = qs64_80_lo4.w; -+ const uint32_t q4_12 = qs64_80_hi4.x; -+ const uint32_t q4_13 = qs64_80_hi4.y; -+ const uint32_t q4_14 = qs64_80_hi4.z; -+ const uint32_t q4_15 = qs64_80_hi4.w; -+ -+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { -+ B_TYPE_VEC2 by10 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2]; -+ B_TYPE_VEC2 by116 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 8]; -+ B_TYPE_VEC2 by132 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 16]; -+ B_TYPE_VEC2 by148 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 24]; -+ B_TYPE_VEC2 by20 = data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2]; -+ B_TYPE_VEC2 by216 = data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 8]; -+ B_TYPE_VEC2 by232 = data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 16]; -+ B_TYPE_VEC2 by248 = data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 24]; -+ -+ const FLOAT_TYPE sx = -+ fma(FLOAT_TYPE(by10.x), q4_0, -+ fma(FLOAT_TYPE(by10.y), q4_1, -+ fma(FLOAT_TYPE(by116.x), q4_2, -+ FLOAT_TYPE(by116.y) * q4_3))); -+ const FLOAT_TYPE sy = -+ fma(FLOAT_TYPE(by132.x), q4_4, -+ fma(FLOAT_TYPE(by132.y), q4_5, -+ fma(FLOAT_TYPE(by148.x), q4_6, -+ FLOAT_TYPE(by148.y) * q4_7))); -+ const FLOAT_TYPE sz = -+ fma(FLOAT_TYPE(by20.x), q4_8, -+ fma(FLOAT_TYPE(by20.y), q4_9, -+ fma(FLOAT_TYPE(by216.x), q4_10, -+ FLOAT_TYPE(by216.y) * q4_11))); -+ const FLOAT_TYPE sw = -+ fma(FLOAT_TYPE(by232.x), q4_12, -+ fma(FLOAT_TYPE(by232.y), q4_13, -+ fma(FLOAT_TYPE(by248.x), q4_14, -+ FLOAT_TYPE(by248.y) * q4_15))); -+ const FLOAT_TYPE smin = -+ fma(FLOAT_TYPE(by10.x) + FLOAT_TYPE(by10.y) + FLOAT_TYPE(by116.x) + FLOAT_TYPE(by116.y), sc2, -+ fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3, -+ fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6, -+ (FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7))); -+ temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n])); -+ } -+ } -+ } -+ -+ reduce_result(temp, d_offset, first_row, num_rows, tid); -+} -+ -+void main() { -+ const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); -+ -+ // do NUM_ROWS at a time, unless there aren't enough remaining rows -+ if (first_row + NUM_ROWS <= p.stride_d) { -+ compute_outputs(first_row, NUM_ROWS); -+ } else { -+ if (first_row >= p.stride_d) { -+ return; -+ } -+ compute_outputs(first_row, p.stride_d - first_row); -+ } -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp -new file mode 100644 -index 00000000..70e13a56 ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp -@@ -0,0 +1,112 @@ -+#version 450 -+ -+#extension GL_EXT_shader_explicit_arithmetic_types : require -+ -+#include "mul_mat_vec_base.comp" -+ -+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; -+ -+void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { -+ uint a_offset, b_offset, d_offset; -+ get_offsets(a_offset, b_offset, d_offset); -+ -+ const uint num_blocks_per_row = p.ncols / QUANT_K; -+ -+ // 16 threads are used to process each block -+ const uint it_size = gl_WorkGroupSize.x/16; -+ const uint tid = gl_LocalInvocationID.x; -+ const uint itid = tid%16; // 0...16 -+ const uint ix = tid/16; -+ -+ const uint step = 8; -+ -+ const uint v_im = itid/step; // 0 or 1. 0 computes 0..., 1 computes 128... -+ const uint v_in = itid - step*v_im; // 0...15 or 0...7 -+ -+ const uint l0 = 4 * v_in; // 0, 4, 8, ..., 28 -+ const uint is = v_in / 4; -+ -+ const uint ql_offset = 64*v_im + l0; -+ const uint qh_offset = 32*v_im + l0; -+ const uint s_offset = 8*v_im + is; -+ const uint y_offset = 128*v_im + l0; -+ -+ FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; -+ -+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { -+ [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { -+ temp[j][i] = FLOAT_TYPE(0); -+ } -+ } -+ -+ [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { -+ const uint y_idx = i * QUANT_K + y_offset; -+ -+ [[unroll]] for (uint n = 0; n < num_rows; ++n) { -+ const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; -+ const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); -+ -+ FLOAT_TYPE scales[4]; -+ scales[0] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]); -+ scales[1] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]); -+ scales[2] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]); -+ scales[3] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]); -+ -+ uint32_t ql0_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 1]) << 16); -+ uint32_t ql32_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 16]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 17]) << 16); -+ -+ uint32_t ql0_u32_lo4 = ql0_u32 & 0x0F0F0F0F; -+ uint32_t ql0_u32_hi4 = (ql0_u32 >> 4) & 0x0F0F0F0F; -+ uint32_t ql32_u32_lo4 = ql32_u32 & 0x0F0F0F0F; -+ uint32_t ql32_u32_hi4 = (ql32_u32 >> 4) & 0x0F0F0F0F; -+ -+ uint32_t qh_u32 = uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2 + 1]) << 16); -+ uint32_t qh0_u32 = (qh_u32 & 0x03030303) << 4; -+ uint32_t qh2_u32 = (qh_u32 & 0x0C0C0C0C) << 2; -+ uint32_t qh4_u32 = (qh_u32 & 0x30303030) << 0; -+ uint32_t qh6_u32 = (qh_u32 & 0xC0C0C0C0) >> 2; -+ -+ uint32_t q0_u32 = ql0_u32_lo4 | qh0_u32; -+ uint32_t q1_u32 = ql32_u32_lo4 | qh2_u32; -+ uint32_t q2_u32 = ql0_u32_hi4 | qh4_u32; -+ uint32_t q3_u32 = ql32_u32_hi4 | qh6_u32; -+ -+ uvec4 q0 = uvec4(unpack8(q0_u32)); -+ uvec4 q1 = uvec4(unpack8(q1_u32)); -+ uvec4 q2 = uvec4(unpack8(q2_u32)); -+ uvec4 q3 = uvec4(unpack8(q3_u32)); -+ -+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { -+ B_TYPE_VEC4 by0 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4]; -+ B_TYPE_VEC4 by32 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 8]; -+ B_TYPE_VEC4 by64 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 16]; -+ B_TYPE_VEC4 by96 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 24]; -+ -+ FLOAT_TYPE sum = FLOAT_TYPE(0.0); -+ [[unroll]] for (int l = 0; l < 4; ++l) { -+ sum = fma(FLOAT_TYPE(by0[l]) * scales[0], FLOAT_TYPE(int8_t(q0[l]) - 32), -+ fma(FLOAT_TYPE(by32[l]) * scales[1], FLOAT_TYPE(int8_t(q1[l]) - 32), -+ fma(FLOAT_TYPE(by64[l]) * scales[2], FLOAT_TYPE(int8_t(q2[l]) - 32), -+ fma(FLOAT_TYPE(by96[l]) * scales[3], FLOAT_TYPE(int8_t(q3[l]) - 32), sum)))); -+ } -+ temp[j][n] += sum * d; -+ } -+ } -+ } -+ -+ reduce_result(temp, d_offset, first_row, num_rows, tid); -+} -+ -+void main() { -+ const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); -+ -+ // do NUM_ROWS at a time, unless there aren't enough remaining rows -+ if (first_row + NUM_ROWS <= p.stride_d) { -+ compute_outputs(first_row, NUM_ROWS); -+ } else { -+ if (first_row >= p.stride_d) { -+ return; -+ } -+ compute_outputs(first_row, p.stride_d - first_row); -+ } -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp -new file mode 100644 -index 00000000..48122cbe ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp -@@ -0,0 +1,631 @@ -+#version 450 -+ -+#extension GL_EXT_control_flow_attributes : enable -+#extension GL_EXT_shader_16bit_storage : require -+ -+#ifdef FLOAT16 -+#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require -+#endif -+ -+#ifdef COOPMAT -+#extension GL_KHR_cooperative_matrix : enable -+#extension GL_KHR_memory_scope_semantics : enable -+#extension GL_KHR_shader_subgroup_basic : enable -+#endif -+ -+#ifdef MUL_MAT_ID -+#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require -+#endif -+ -+#include "types.comp" -+ -+#ifndef LOAD_VEC_A -+#define LOAD_VEC_A 1 -+#endif -+#ifndef LOAD_VEC_B -+#define LOAD_VEC_B 1 -+#endif -+ -+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; -+ -+layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; -+layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; -+layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; -+ -+#ifdef MUL_MAT_ID -+layout (binding = 3) readonly buffer IDS {int data_ids[];}; -+#endif -+ -+layout (push_constant) uniform parameter -+{ -+ uint M; -+ uint N; -+ uint K; -+ uint stride_a; -+ uint stride_b; -+ uint stride_d; -+ -+ uint batch_stride_a; -+ uint batch_stride_b; -+ uint batch_stride_d; -+ -+#ifdef MUL_MAT_ID -+ uint nei0; -+ uint nei1; -+ uint nbi1; -+ uint ne11; -+#else -+ uint k_split; -+ uint ne02; -+ uint ne12; -+ uint broadcast2; -+ uint broadcast3; -+#endif -+} p; -+ -+layout (constant_id = 0) const uint BLOCK_SIZE = 64; -+layout (constant_id = 1) const uint BM = 64; -+layout (constant_id = 2) const uint BN = 64; -+layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant -+layout (constant_id = 4) const uint WM = 32; -+layout (constant_id = 5) const uint WN = 32; -+layout (constant_id = 6) const uint WMITER = 2; -+layout (constant_id = 7) const uint TM = 4; -+layout (constant_id = 8) const uint TN = 2; -+layout (constant_id = 9) const uint TK = 1; // Only needed for coopmat -+layout (constant_id = 10) const uint WARP = 32; -+ -+#ifdef COOPMAT -+#define SHMEM_STRIDE (BK + 8) -+#else -+#define SHMEM_STRIDE (BK + 1) -+#endif -+ -+shared FLOAT_TYPE buf_a[BM * SHMEM_STRIDE]; -+shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE]; -+ -+#ifdef MUL_MAT_ID -+shared u16vec2 row_ids[3072]; -+#endif // MUL_MAT_ID -+ -+#define NUM_WARPS (BLOCK_SIZE / WARP) -+ -+#ifdef COOPMAT -+shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS]; -+#endif -+ -+void main() { -+#if defined(DATA_A_IQ4_NL) -+ init_iq4nl_shmem(); -+#endif -+ -+#ifdef MUL_MAT_ID -+ const uint expert_idx = gl_GlobalInvocationID.z; -+#else -+ const uint batch_idx = gl_GlobalInvocationID.z; -+ -+ const uint i13 = batch_idx / p.ne12; -+ const uint i12 = batch_idx % p.ne12; -+ -+ const uint i03 = i13 / p.broadcast3; -+ const uint i02 = i12 / p.broadcast2; -+ -+ const uint batch_idx_a = i03 * p.ne02 + i02; -+#endif -+ -+ const uint blocks_m = (p.M + BM - 1) / BM; -+ const uint ir = gl_WorkGroupID.x % blocks_m; -+ const uint ik = gl_WorkGroupID.x / blocks_m; -+ const uint ic = gl_WorkGroupID.y; -+ -+ const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER); -+ const uint WSUBM = WM / WMITER; -+ const uint WSUBN = WN / WNITER; -+ -+#ifdef COOPMAT -+ const uint warp_i = gl_SubgroupID; -+ -+ const uint tiw = gl_SubgroupInvocationID; -+ -+ const uint cms_per_row = WM / TM; -+ const uint cms_per_col = WN / TN; -+ -+ const uint storestride = WARP / TM; -+ const uint store_r = tiw % TM; -+ const uint store_c = tiw / TM; -+#else -+ const uint warp_i = gl_LocalInvocationID.x / WARP; -+ -+ const uint tiw = gl_LocalInvocationID.x % WARP; -+ -+ const uint tiwr = tiw % (WSUBM / TM); -+ const uint tiwc = tiw / (WSUBM / TM); -+#endif -+ -+ const uint warp_r = warp_i % (BM / WM); -+ const uint warp_c = warp_i / (BM / WM); -+ -+ const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A); -+ const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A); -+ const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B); -+ const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B); -+ -+ const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A / BK; -+ const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B / BK; -+ -+#ifdef MUL_MAT_ID -+ uint _ne1 = 0; -+ for (uint ii1 = 0; ii1 < p.nei1; ii1++) { -+ for (uint ii0 = 0; ii0 < p.nei0; ii0++) { -+ if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) { -+ row_ids[_ne1] = u16vec2(ii0, ii1); -+ _ne1++; -+ } -+ } -+ } -+ -+ barrier(); -+ -+ // Workgroup has no work -+ if (ic * BN >= _ne1) return; -+#endif -+ -+#ifdef MUL_MAT_ID -+ const uint start_k = 0; -+ const uint end_k = p.K; -+#else -+ const uint start_k = ik * p.k_split; -+ const uint end_k = min(p.K, (ik + 1) * p.k_split); -+#endif -+ -+ uint pos_a = ( -+#ifdef MUL_MAT_ID -+ expert_idx * p.batch_stride_a + -+#else -+ batch_idx_a * p.batch_stride_a + -+#endif -+ ir * BM * p.stride_a + start_k) / LOAD_VEC_A; -+#ifdef MUL_MAT_ID -+ uint pos_b = 0; -+#else -+ uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B; -+#endif -+ -+#ifdef COOPMAT -+ coopmat cache_a; -+ coopmat cache_b; -+ coopmat sums[cms_per_row * cms_per_col]; -+ -+ [[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) { -+ sums[i] = coopmat(0.0f); -+ } -+#else -+ ACC_TYPE sums[WMITER * TM * WNITER * TN]; -+ FLOAT_TYPE cache_a[WMITER * TM]; -+ FLOAT_TYPE cache_b[WNITER * TN]; -+ -+ [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) { -+ sums[i] = ACC_TYPE(0.0f); -+ } -+#endif -+ -+ for (uint block = start_k; block < end_k; block += BK) { -+ [[unroll]] for (uint l = 0; l < BM; l += loadstride_a) { -+ -+#if defined(DATA_A_F32) || defined(DATA_A_F16) -+#if LOAD_VEC_A == 8 -+ const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; -+ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; -+ buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx][0].x); -+ buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx][0].y); -+ buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx][0].z); -+ buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx][0].w); -+ buf_a[buf_idx + 4] = FLOAT_TYPE(data_a[idx][1].x); -+ buf_a[buf_idx + 5] = FLOAT_TYPE(data_a[idx][1].y); -+ buf_a[buf_idx + 6] = FLOAT_TYPE(data_a[idx][1].z); -+ buf_a[buf_idx + 7] = FLOAT_TYPE(data_a[idx][1].w); -+#elif LOAD_VEC_A == 4 -+ const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; -+ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; -+ buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx].x); -+ buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx].y); -+ buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx].z); -+ buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx].w); -+#else -+ if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) { -+ buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]); -+ } else { -+ buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(0.0f); -+ } -+#endif -+#elif defined(DATA_A_Q4_0) -+ const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; -+ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a; -+ -+ const uint ib = idx / 16; -+ const uint iqs = idx & 0xF; -+ -+ const float d = float(data_a[ib].d); -+ const uint vui = uint(data_a[ib].qs[iqs]); -+ const vec2 v = (vec2(vui & 0xF, vui >> 4) - 8.0f) * d; -+ -+ buf_a[buf_idx ] = FLOAT_TYPE(v.x); -+ buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); -+#elif defined(DATA_A_Q4_1) -+ const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; -+ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a; -+ -+ const uint ib = idx / 16; -+ const uint iqs = idx & 0xF; -+ -+ const float d = float(data_a[ib].d); -+ const float m = float(data_a[ib].m); -+ const uint vui = uint(data_a[ib].qs[iqs]); -+ const vec2 v = vec2(vui & 0xF, vui >> 4) * d + m; -+ -+ buf_a[buf_idx ] = FLOAT_TYPE(v.x); -+ buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); -+#elif defined(DATA_A_Q5_0) -+ const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; -+ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a; -+ -+ const uint ib = idx / 16; -+ const uint iqs = idx & 0xF; -+ -+ const float d = float(data_a[ib].d); -+ const uint uint_qh = uint(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]; -+ const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); -+ const uint vui = uint(data_a[ib].qs[iqs]); -+ const vec2 v = (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f) * d; -+ -+ buf_a[buf_idx ] = FLOAT_TYPE(v.x); -+ buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); -+#elif defined(DATA_A_Q5_1) -+ const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; -+ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a; -+ -+ const uint ib = idx / 16; -+ const uint iqs = idx & 0xF; -+ -+ const float d = float(data_a[ib].d); -+ const float m = float(data_a[ib].m); -+ const uint uint_qh = data_a[ib].qh; -+ const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); -+ const uint vui = uint(data_a[ib].qs[iqs]); -+ const vec2 v = vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) * d + m; -+ -+ buf_a[buf_idx ] = FLOAT_TYPE(v.x); -+ buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); -+#elif defined(DATA_A_Q8_0) -+ const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; -+ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; -+ -+ const uint ib = idx / 16; -+ const uint iqs = (idx & 0xF) * 2; -+ -+ const float d = float(data_a[ib].d); -+ const vec2 v = vec2(int(data_a[ib].qs[iqs]), int(data_a[ib].qs[iqs + 1])) * d; -+ -+ buf_a[buf_idx ] = FLOAT_TYPE(v.x); -+ buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); -+#elif defined(DATA_A_Q2_K) -+ const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; -+ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; -+ -+ const uint ib = idx / 128; // 2 values per idx -+ const uint iqs = idx % 128; // 0..127 -+ -+ const uint qsi = (iqs / 64) * 32 + (iqs % 16) * 2; // 0,2,4..30 -+ const uint scalesi = iqs / 8; // 0..15 -+ const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6 -+ -+ const uvec2 qs = uvec2(data_a[ib].qs[qsi], data_a[ib].qs[qsi + 1]); -+ const uint scales = data_a[ib].scales[scalesi]; -+ const vec2 d = vec2(data_a[ib].d); -+ -+ const vec2 v = d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4); -+ -+ buf_a[buf_idx ] = FLOAT_TYPE(v.x); -+ buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); -+#elif defined(DATA_A_Q3_K) -+ const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; -+ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; -+ -+ const uint ib = idx / 128; // 2 values per idx -+ const uint iqs = idx % 128; // 0..127 -+ -+ const uint n = iqs / 64; // 0,1 -+ const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..62 -+ const uint hmi = (iqs % 16) * 2; // 0,2,4..30 -+ const uint j = (iqs % 64) / 4; // 0..3 -+ const uint is = iqs / 8; // 0..15 -+ const uint halfsplit = ((iqs % 64) / 16); // 0,1,2,3 -+ const uint qsshift = halfsplit * 2; // 0,2,4,6 -+ const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128 -+ -+ const int8_t us = int8_t(is < 4 ? (data_a[ib].scales[is-0] & 0xF) | (((data_a[ib].scales[is+8] >> 0) & 3) << 4) : -+ is < 8 ? (data_a[ib].scales[is-0] & 0xF) | (((data_a[ib].scales[is+4] >> 2) & 3) << 4) : -+ is < 12 ? (data_a[ib].scales[is-8] >> 4) | (((data_a[ib].scales[is+0] >> 4) & 3) << 4) : -+ (data_a[ib].scales[is-8] >> 4) | (((data_a[ib].scales[is-4] >> 6) & 3) << 4)); -+ const float dl = float(data_a[ib].d) * float(us - 32); -+ -+ buf_a[buf_idx ] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi ] >> qsshift) & 3) - (((data_a[ib].hmask[hmi ] & m) != 0) ? 0 : 4))); -+ buf_a[buf_idx + 1] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4))); -+#elif defined(DATA_A_Q4_K) -+ const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; -+ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; -+ -+ const uint ib = idx / 128; // 2 values per idx -+ const uint iqs = idx % 128; // 0..127 -+ -+ const uint n = iqs / 32; // 0,1,2,3 -+ const uint b = (iqs % 32) / 16; // 0,1 -+ const uint is = 2 * n + b; // 0..7 -+ const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126 -+ -+ const vec2 loadd = vec2(data_a[ib].d); -+ -+ const uint scidx0 = (is < 4) ? is : (is + 4); -+ const uint scidx1 = (is < 4) ? is : (is - 4); -+ const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0; -+ const uint scidxshift1 = (is < 4) ? 0 : 2; -+ const uint mbidx0 = is + 4; -+ const uint mbidx1 = (is < 4) ? is + 4 : is; -+ const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0; -+ const uint mbidxshift0 = (is < 4) ? 0 : 4; -+ const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0; -+ const uint mbidxshift1 = (is < 4) ? 0 : 2; -+ -+ const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); -+ const uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); -+ -+ const float d = loadd.x * sc; -+ const float m = -loadd.y * mbyte; -+ -+ buf_a[buf_idx ] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF), m)); -+ buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF), m)); -+#elif defined(DATA_A_Q5_K) -+ const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; -+ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; -+ -+ const uint ib = idx / 128; // 2 values per idx -+ const uint iqs = idx % 128; // 0..127 -+ -+ const uint n = iqs / 32; // 0,1,2,3 -+ const uint b = (iqs % 32) / 16; // 0,1 -+ const uint is = 2 * n + b; // 0..7 -+ const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126 -+ const uint qhi = (iqs % 16) * 2; // 0,2,4..30 -+ -+ const uint8_t hm = uint8_t(1 << (iqs / 16)); -+ -+ const vec2 loadd = vec2(data_a[ib].d); -+ -+ const uint scidx0 = (is < 4) ? is : (is + 4); -+ const uint scidx1 = (is < 4) ? is : (is - 4); -+ const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0; -+ const uint scidxshift1 = (is < 4) ? 0 : 2; -+ const uint mbidx0 = is + 4; -+ const uint mbidx1 = (is < 4) ? is + 4 : is; -+ const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0; -+ const uint mbidxshift0 = (is < 4) ? 0 : 4; -+ const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0; -+ const uint mbidxshift1 = (is < 4) ? 0 : 2; -+ -+ const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); -+ const uint8_t mbyte = uint8_t(((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); -+ -+ const float d = loadd.x * sc; -+ const float m = -loadd.y * mbyte; -+ -+ buf_a[buf_idx ] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi ] & hm) != 0 ? 16 : 0), m)); -+ buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m)); -+#elif defined(DATA_A_Q6_K) -+ const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; -+ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; -+ -+ const uint ib = idx / 128; // 2 values per idx -+ const uint iqs = idx % 128; // 0..127 -+ -+ const uint n = iqs / 64; // 0,1 -+ const uint b = (iqs % 64) / 32; // 0,1 -+ const uint is_b = (iqs % 16) / 8; // 0,1 -+ const uint qhshift = ((iqs % 64) / 16) * 2; // 0,2,4,6 -+ const uint is = 8 * n + qhshift + is_b; // 0..15 -+ const uint qsi = n * 64 + (iqs % 32) * 2; // 0,2,4..126 -+ const uint qhi = n * 32 + (iqs % 16) * 2; // 0,2,4..62 -+ -+ const float dscale = float(data_a[ib].d) * float(data_a[ib].scales[is]); -+ -+ buf_a[buf_idx ] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi ] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi ] >> qhshift) & 3) << 4)) - 32)); -+ buf_a[buf_idx + 1] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32)); -+#elif defined(DATA_A_IQ4_NL) -+ const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; -+ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a; -+ -+ const uint ib = idx / 16; -+ const uint iqs = idx & 0xF; -+ -+ const float d = float(data_a[ib].d); -+ const uint vui = uint(data_a[ib].qs[iqs]); -+ const vec2 v = vec2(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[vui >> 4]) * d; -+ -+ buf_a[buf_idx ] = FLOAT_TYPE(v.x); -+ buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); -+#endif -+ } -+ [[unroll]] for (uint l = 0; l < BN; l += loadstride_b) { -+#if LOAD_VEC_B == 8 -+#ifdef MUL_MAT_ID -+ const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l]; -+ const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b; -+#else -+ const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b; -+#endif -+ const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B; -+ buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx][0].x); -+ buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx][0].y); -+ buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx][0].z); -+ buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx][0].w); -+ buf_b[buf_idx + 4] = FLOAT_TYPE(data_b[idx][1].x); -+ buf_b[buf_idx + 5] = FLOAT_TYPE(data_b[idx][1].y); -+ buf_b[buf_idx + 6] = FLOAT_TYPE(data_b[idx][1].z); -+ buf_b[buf_idx + 7] = FLOAT_TYPE(data_b[idx][1].w); -+#elif LOAD_VEC_B == 4 -+#ifdef MUL_MAT_ID -+ const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l]; -+ const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b; -+#else -+ const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b; -+#endif -+ const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B; -+ buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx].x); -+ buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx].y); -+ buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx].z); -+ buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx].w); -+#elif !MUL_MAT_ID -+ if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) { -+ buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]); -+ } else { -+ buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f); -+ } -+#else -+ const uint row_i = ic * BN + loadc_b + l; -+ if (row_i < _ne1) { -+ const u16vec2 row_idx = row_ids[row_i]; -+ buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]); -+ } else { -+ buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f); -+ } -+#endif -+ } -+ -+ barrier(); -+ -+ pos_a += BK / LOAD_VEC_A; -+ pos_b += BK / LOAD_VEC_B; -+ -+#ifdef COOPMAT -+ [[unroll]] for (uint i = 0; i < BK; i += TK) { -+ [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { -+ // Load from shared into cache -+ coopMatLoad(cache_a, buf_a, (warp_r * WM + cm_row * TM) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor); -+ -+ [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { -+ coopMatLoad(cache_b, buf_b, (warp_c * WN + cm_col * TN) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor); -+ -+ sums[cm_col * cms_per_row + cm_row] = coopMatMulAdd(cache_a, cache_b, sums[cm_col * cms_per_row + cm_row]); -+ } -+ } -+ } -+#else -+ [[unroll]] for (uint i = 0; i < BK; i++) { -+ // Load from shared into cache -+ [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { -+ [[unroll]] for (uint j = 0; j < TM; j++) { -+ cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + i]; -+ } -+ } -+ [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { -+ [[unroll]] for (uint j = 0; j < TN; j++) { -+ cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * SHMEM_STRIDE + i]; -+ } -+ } -+ -+ [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { -+ [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { -+ [[unroll]] for (uint cc = 0; cc < TN; cc++) { -+ [[unroll]] for (uint cr = 0; cr < TM; cr++) { -+ const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr; -+ sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr]), ACC_TYPE(cache_b[wsic * TN + cc]), sums[sums_idx]); -+ } -+ } -+ } -+ } -+ } -+#endif -+ -+ barrier(); -+ } -+ -+ const uint dr = ir * BM + warp_r * WM; -+ const uint dc = ic * BN + warp_c * WN; -+ -+#ifndef MUL_MAT_ID -+ const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; -+#endif -+ -+#ifdef COOPMAT -+#ifdef MUL_MAT_ID -+ [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { -+ [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { -+ coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); -+ -+ [[unroll]] for (uint col = 0; col < BN; col += storestride) { -+ const uint row_i = dc + cm_col * TN + col + store_c; -+ if (row_i >= _ne1) break; -+ -+ const u16vec2 row_idx = row_ids[row_i]; -+ -+ data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); -+ } -+ } -+ } -+#else -+ const bool is_aligned = p.stride_d % 4 == 0; // Assumption: D_TYPE == float -+ -+ [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { -+ [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { -+ const bool is_in_bounds = dr + (cm_row + 1) * TM <= p.M && dc + (cm_col + 1) * TN <= p.N; -+ -+ if (is_aligned && is_in_bounds) { -+ // Full coopMat is within bounds and stride_d is aligned with 16B -+ coopmat cm_dtype = coopmat(sums[cm_col * cms_per_row + cm_row]); -+ coopMatStore(cm_dtype, data_d, offsets + (dc + cm_col * TN) * p.stride_d + dr + cm_row * TM, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor); -+ } else if (is_in_bounds) { -+ // Full coopMat is within bounds, but stride_d is not aligned -+ coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); -+ -+ [[unroll]] for (uint col = 0; col < TN; col += storestride) { -+ data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); -+ } -+ } else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) { -+ // Partial coopMat is within bounds -+ coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); -+ -+ [[unroll]] for (uint col = 0; col < TN; col += storestride) { -+ if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) { -+ data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); -+ } -+ } -+ } -+ } -+ } -+#endif // MUL_MAT_ID -+#else -+ [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { -+ [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { -+ -+ const uint dr_warp = dr + wsir * WSUBM + tiwr * TM; -+ const uint dc_warp = dc + wsic * WSUBN + tiwc * TN; -+ [[unroll]] for (uint cc = 0; cc < TN; cc++) { -+#ifdef MUL_MAT_ID -+ const uint row_i = dc_warp + cc; -+ if (row_i >= _ne1) break; -+ -+ const u16vec2 row_idx = row_ids[row_i]; -+#endif // MUL_MAT_ID -+ [[unroll]] for (uint cr = 0; cr < TM; cr++) { -+#ifdef MUL_MAT_ID -+ data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); -+#else -+ if (dr_warp + cr < p.M && dc_warp + cc < p.N) { -+ data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); -+ } -+#endif // MUL_MAT_ID -+ } -+ } -+ } -+ } -+#endif // COOPMAT -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp -new file mode 100644 -index 00000000..cbfa5dce ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp -@@ -0,0 +1,328 @@ -+#version 450 -+ -+#extension GL_EXT_control_flow_attributes : enable -+#extension GL_EXT_shader_16bit_storage : require -+ -+#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require -+#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require -+#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require -+#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require -+ -+#extension GL_KHR_memory_scope_semantics : enable -+#extension GL_KHR_cooperative_matrix : enable -+#extension GL_NV_cooperative_matrix2 : enable -+#extension GL_EXT_buffer_reference : enable -+#extension GL_KHR_shader_subgroup_ballot : enable -+#extension GL_KHR_shader_subgroup_vote : enable -+ -+#include "types.comp" -+ -+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; -+ -+layout (constant_id = 1) const uint BM = 64; -+layout (constant_id = 2) const uint BN = 64; -+layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant -+ -+layout (push_constant) uniform parameter -+{ -+ uint M; -+ uint N; -+ uint K; -+ uint stride_a; -+ uint stride_b; -+ uint stride_d; -+ -+ uint batch_stride_a; -+ uint batch_stride_b; -+ uint batch_stride_d; -+ -+#ifdef MUL_MAT_ID -+ uint nei0; -+ uint nei1; -+ uint nbi1; -+ uint ne11; -+#else -+ uint k_split; -+ uint ne02; -+ uint ne12; -+ uint broadcast2; -+ uint broadcast3; -+#endif -+} p; -+ -+ -+layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; -+layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; -+layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; -+ -+#if QUANT_K > 1 -+#define DECODEFUNCA , dequantFuncA -+#define MAT_A_TYPE float16_t -+ -+#include "dequant_funcs_cm2.comp" -+ -+#else -+#define DECODEFUNCA -+#define MAT_A_TYPE A_TYPE -+#endif -+ -+#define MAT_B_TYPE B_TYPE -+ -+#ifdef MUL_MAT_ID -+layout (binding = 3) readonly buffer IDS {int data_ids[];}; -+ -+shared u16vec4 row_ids[3072]; -+ -+layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufB { -+ B_TYPE b[]; -+}; -+ -+uint _ne1; -+shared uint _ne1_sh; -+ -+B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const in uint coordInBlock[2]) -+{ -+ const uint row_i = blockCoords[0]; -+ -+ if (row_i >= _ne1) { -+ return B_TYPE(0.0); -+ } -+ -+ const u16vec4 row_idx = row_ids[row_i]; -+ B_TYPE ret = data_b[row_idx.y * p.batch_stride_b + row_idx.x * p.stride_b + blockCoords[1]]; -+ -+ return ret; -+} -+ -+D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t ir, const in uint32_t ic) -+{ -+ uint dr = ir * BM + r; -+ uint dc = ic * BN + c; -+ -+ if (dr < p.M && dc < _ne1) { -+ uint row_i = dc; -+ const u16vec4 row_idx = row_ids[row_i]; -+ data_d[row_idx.y * p.batch_stride_d + row_idx.z * p.stride_d + dr] = elem; -+ } -+ return elem; -+} -+ -+#endif -+ -+void main() { -+#if defined(DATA_A_IQ4_NL) -+ init_iq4nl_shmem(); -+#endif -+ -+#ifdef MUL_MAT_ID -+ const uint expert_idx = gl_GlobalInvocationID.z; -+#else -+ const uint batch_idx = gl_GlobalInvocationID.z; -+ -+ const uint i13 = batch_idx / p.ne12; -+ const uint i12 = batch_idx % p.ne12; -+ -+ const uint i03 = i13 / p.broadcast3; -+ const uint i02 = i12 / p.broadcast2; -+ -+ const uint batch_idx_a = i03 * p.ne02 + i02; -+#endif -+ -+ const uint blocks_m = (p.M + BM - 1) / BM; -+ const uint ir = gl_WorkGroupID.x % blocks_m; -+ const uint ik = gl_WorkGroupID.x / blocks_m; -+ const uint ic = gl_WorkGroupID.y; -+ -+#ifdef MUL_MAT_ID -+ // Spread the search across all elements in the first subgroup -+ if (gl_SubgroupID == 0) { -+ _ne1 = 0; -+ uint num_elements = p.nei1 * p.nei0; -+ -+ for (uint i = gl_SubgroupInvocationID; subgroupAny(i < num_elements); i += gl_SubgroupSize) { -+ bool in_range = i < num_elements; -+ uint ii0 = i % p.nei0; -+ uint ii1 = i / p.nei0; -+ uint id = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0; -+ uvec4 ballot = subgroupBallot(in_range && id == expert_idx); -+ uint idx = subgroupBallotExclusiveBitCount(ballot); -+ if (in_range && id == expert_idx) { -+ row_ids[_ne1 + idx] = u16vec4(ii0 % p.ne11, ii1, ii0, 0); -+ } -+ _ne1 += subgroupBallotBitCount(ballot); -+ } -+ _ne1_sh = _ne1; -+ } -+ -+ barrier(); -+ -+ _ne1 = _ne1_sh; -+ -+ // Workgroup has no work -+ if (ic * BN >= _ne1) return; -+#endif -+ -+#ifdef MUL_MAT_ID -+ uint start_k = 0; -+ const uint end_k = p.K; -+#else -+ uint start_k = ik * p.k_split; -+ const uint end_k = min(p.K, (ik + 1) * p.k_split); -+#endif -+ -+ coopmat sum; -+ sum = coopmat(0.0); -+ -+#ifdef MUL_MAT_ID -+ uint pos_a = (expert_idx * p.batch_stride_a) / QUANT_K; -+ uint pos_b = 0; -+#else -+ uint pos_a = (batch_idx_a * p.batch_stride_a) / QUANT_K; -+ uint pos_b = batch_idx * p.batch_stride_b; -+#endif -+ -+ uint stride_a = p.stride_a / QUANT_K; -+ uint stride_b = p.stride_b; -+ -+ // Hint to the compiler that values are aligned (want 16B alignment). -+ // Quants are always block-aligned, no alignment needed. -+#if ALIGNED -+#if QUANT_K == 1 -+ stride_a &= ~7; -+#endif -+ stride_b &= ~7; -+#endif -+ -+ // Create layouts for both clamped and unclamped accesses -+ tensorLayoutNV<2> tensorLayoutA = createTensorLayoutNV(2); -+ tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutAClamp = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); -+ tensorLayoutNV<2> tensorLayoutB = createTensorLayoutNV(2); -+ tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutBClamp = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); -+ tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); -+ -+#if QUANT_K > 1 -+ tensorLayoutA = setTensorLayoutBlockSizeNV(tensorLayoutA, 1, QUANT_K); -+ tensorLayoutAClamp = setTensorLayoutBlockSizeNV(tensorLayoutAClamp, 1, QUANT_K); -+#endif -+ -+ // Use end_k rather than p.K as the dimension because that's what -+ // we need to bound check against when using split_k -+ tensorLayoutA = setTensorLayoutDimensionNV(tensorLayoutA, p.M, end_k); -+ tensorLayoutB = setTensorLayoutDimensionNV(tensorLayoutB, p.N, end_k); -+ tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.N, p.M); -+ tensorLayoutAClamp = setTensorLayoutDimensionNV(tensorLayoutAClamp, p.M, end_k); -+ tensorLayoutBClamp = setTensorLayoutDimensionNV(tensorLayoutBClamp, p.N, end_k); -+ -+ tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0); -+ -+#if !defined(MUL_MAT_ID) -+ // Detect a fast path where all loads are entirely in bounds and no clamping is required -+ if ((ir + 1) * BM <= p.M && (ic + 1) * BN <= p.N && (start_k % BK) == 0 && (end_k % BK) == 0 && -+#if QUANT_K == 1 -+ (stride_a % 8) == 0 && -+#endif -+ (stride_b % 8) == 0 && (start_k % 8) == 0) { -+ // Hint to the compiler that values are aligned (want 16B alignment) -+ start_k &= ~7; -+ stride_b &= ~7; -+#if QUANT_K == 1 -+ stride_a &= ~7; -+#endif -+ -+ tensorLayoutA = setTensorLayoutStrideNV(tensorLayoutA, stride_a, 1); -+ tensorLayoutB = setTensorLayoutStrideNV(tensorLayoutB, stride_b, 1); -+ -+ uint k_iters = (end_k - start_k + BK - 1) / BK; -+ -+ for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) { -+ -+ coopmat mat_a; -+ coopmat mat_b; -+ -+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); -+ coopmat mat_a_ft = coopmat(mat_a); -+ -+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose); -+ coopmat mat_b_ft = coopmat(mat_b); -+ -+ sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum); -+ } -+ } else -+#endif // !defined(MUL_MAT_ID) -+ { -+ tensorLayoutA = setTensorLayoutStrideNV(tensorLayoutA, stride_a, 1); -+ -+ tensorLayoutAClamp = setTensorLayoutStrideNV(tensorLayoutAClamp, stride_a, 1); -+ -+ tensorLayoutB = setTensorLayoutStrideNV(tensorLayoutB, stride_b, 1); -+ -+ tensorLayoutBClamp = setTensorLayoutStrideNV(tensorLayoutBClamp, stride_b, 1); -+ -+ [[dont_unroll]] -+ for (uint block_k = start_k; block_k < end_k; block_k += BK) { -+ -+ coopmat mat_a; -+ coopmat mat_b; -+ coopmat mat_a_ft; -+ coopmat mat_b_ft; -+ -+ // Clamping is expensive, so detect different code paths for each combination -+ // of A and B needing clamping. -+ bool unclampedA = (ir + 1) * BM <= p.M && block_k + BK <= end_k && (block_k % 8) == 0; -+#ifdef MUL_MAT_ID -+ bool unclampedB = true; -+#else -+ bool unclampedB = (ic + 1) * BN <= p.N && block_k + BK <= end_k && (block_k % 8) == 0; -+#endif -+ if (unclampedA && unclampedB) { -+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, (block_k & ~7), BK) DECODEFUNCA); -+#ifdef MUL_MAT_ID -+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB); -+#else -+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, (block_k & ~7), BK), tensorViewTranspose); -+#endif -+ mat_a_ft = coopmat(mat_a); -+ mat_b_ft = coopmat(mat_b); -+ sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum); -+ } else if (unclampedA && !unclampedB) { -+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, (block_k & ~7), BK) DECODEFUNCA); -+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose); -+ -+ mat_a_ft = coopmat(mat_a); -+ mat_b_ft = coopmat(mat_b); -+ sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum); -+ } else if (!unclampedA && unclampedB) { -+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); -+#ifdef MUL_MAT_ID -+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB); -+#else -+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, (block_k & ~7), BK), tensorViewTranspose); -+#endif -+ mat_a_ft = coopmat(mat_a); -+ mat_b_ft = coopmat(mat_b); -+ sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum); -+ } else if (!unclampedA && !unclampedB) { -+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); -+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose); -+ -+ mat_a_ft = coopmat(mat_a); -+ mat_b_ft = coopmat(mat_b); -+ sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum); -+ } -+ } -+ } -+ -+ // Convert from ACC_TYPE to D_TYPE -+ coopmat mat_d; -+ mat_d = coopmat(sum); -+ -+#ifdef MUL_MAT_ID -+ // Call callback to store each element, remapping row through shared memory -+ coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic); -+#else -+ tensorLayoutD = setTensorLayoutStrideNV(tensorLayoutD, p.stride_d, 1); -+ -+ uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; -+ coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose); -+#endif -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp -new file mode 100644 -index 00000000..6627a50b ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp -@@ -0,0 +1,44 @@ -+#version 450 -+ -+#include "generic_head.comp" -+#include "types.comp" -+ -+#extension GL_EXT_control_flow_attributes : enable -+#define BLOCK_SIZE 512 -+ -+layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; -+ -+layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; -+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; -+ -+shared vec2 sum[BLOCK_SIZE]; -+ -+void main() { -+ const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; -+ const uint tid = gl_LocalInvocationID.x; -+ -+ sum[tid] = vec2(0.0f, 0.0f); -+ -+ [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { -+ const float xi = float(data_a[row*p.KX + col]); -+ sum[tid].x += xi; -+ sum[tid].y += xi * xi; -+ } -+ -+ // sum up partial sums and write back result -+ barrier(); -+ [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { -+ if (tid < s) { -+ sum[tid] += sum[tid + s]; -+ } -+ barrier(); -+ } -+ -+ const float mean = sum[0].x / p.KX; -+ const float var = sum[0].y / p.KX - mean * mean; -+ const float inv_std = inversesqrt(var + p.param1); -+ -+ [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { -+ data_d[row*p.KX + col] = D_TYPE((float(data_a[row*p.KX + col]) - mean) * inv_std); -+ } -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp -new file mode 100644 -index 00000000..450b67fc ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp -@@ -0,0 +1,28 @@ -+#version 450 -+ -+#include "types.comp" -+#include "generic_unary_head.comp" -+ -+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; -+ -+void main() { -+ const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; -+ -+ if (idx >= p.ne) { -+ return; -+ } -+ -+ const uint i3 = idx / (p.ne12*p.ne11*p.ne10); -+ const uint i3_offset = i3 * p.ne12*p.ne11*p.ne10; -+ const uint i2 = (idx - i3_offset) / (p.ne11*p.ne10); -+ const uint i2_offset = i2*p.ne11*p.ne10; -+ const uint i1 = (idx - i3_offset - i2_offset) / p.ne10; -+ const uint i0 = idx - i3_offset - i2_offset - i1*p.ne10; -+ -+ const uint src0_idx = i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0*p.nb00; -+ const uint dst_idx = i3*p.nb13 + i2*p.nb12 + i1*p.nb11 + i0*p.nb10; -+ -+ const bool is_src0 = i0 < p.ne00 && i1 < p.ne01 && i2 < p.ne02 && i3 < p.ne03; -+ -+ data_d[get_doffset() + dst_idx] = D_TYPE(is_src0 ? data_a[get_aoffset() + src0_idx] : 0.0f); -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp -new file mode 100644 -index 00000000..b6124411 ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp -@@ -0,0 +1,74 @@ -+#version 450 -+ -+#include "types.comp" -+ -+#extension GL_EXT_shader_16bit_storage : require -+ -+layout(push_constant) uniform parameter { -+ uint IW; uint IH; -+ uint OW; uint OH; -+ uint OC; -+ uint pelements; -+ uint op; -+ int k0; int k1; -+ int s0; int s1; -+ int p0; int p1; -+} p; -+ -+#define BLOCK_SIZE 512 -+#define FLT_MAX 3.402823466e+38F -+#define OP_POOL_MAX 0u -+#define OP_POOL_AVG 1u -+ -+layout (local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; -+ -+layout(binding = 0) readonly buffer X {A_TYPE data_a[];}; -+layout(binding = 1) writeonly buffer D {D_TYPE data_d[];}; -+ -+void main() { -+ const uint idx = gl_GlobalInvocationID.x; -+ if (idx >= p.pelements) { -+ return; -+ } -+ -+ const uint O_HW = p.OW * p.OH; -+ -+ const uint nc = idx / O_HW; -+ const uint cur_oh = (idx % O_HW) / p.OW; -+ const uint cur_ow = (idx % O_HW) % p.OW; -+ -+ const int start_h = int(cur_oh) * p.s0 - p.p0; -+ const uint bh = max(start_h, 0); -+ const uint eh = min(start_h + p.k0, p.IH); -+ -+ const int start_w = int(cur_ow) * p.s1 - p.p1; -+ const uint bw = max(start_w, 0); -+ const uint ew = min(start_w + p.k1, p.IW); -+ -+ const float scale = 1.0 / float(p.k0 * p.k1); -+ float res; -+ -+ if (p.op == OP_POOL_AVG) { -+ res = 0.0; -+ } else if (p.op == OP_POOL_MAX) { -+ res = -FLT_MAX; -+ } else { -+ return; -+ } -+ -+ #pragma unroll -+ for (uint i = bh; i < eh; i++) { -+ #pragma unroll -+ for (uint j = bw; j < ew; j++) { -+ const float cur = D_TYPE(data_a[nc * p.IH * p.IW + i * p.IW + j]); -+ -+ if (p.op == OP_POOL_AVG) { -+ res += cur * scale; -+ } else if (p.op == OP_POOL_MAX) { -+ res = max(res, cur); -+ } -+ } -+ } -+ -+ data_d[nc * O_HW + cur_oh * p.OW + cur_ow] = res; -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp -new file mode 100644 -index 00000000..52a19b62 ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp -@@ -0,0 +1,21 @@ -+#version 450 -+ -+#include "generic_head.comp" -+#include "types.comp" -+ -+#extension GL_EXT_control_flow_attributes : enable -+ -+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; -+ -+layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; -+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; -+ -+void main() { -+ const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; -+ -+ if (i >= p.KX) { -+ return; -+ } -+ -+ data_d[i] = max(float(data_a[i]), 0); -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp -new file mode 100644 -index 00000000..1568b141 ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp -@@ -0,0 +1,26 @@ -+#version 450 -+ -+#include "types.comp" -+#include "generic_unary_head.comp" -+ -+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; -+ -+uint src0_idx_mod(uint idx) { -+ const uint i13 = idx / (p.ne12*p.ne11*p.ne10); -+ const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10; -+ const uint i12 = (idx - i13_offset) / (p.ne11*p.ne10); -+ const uint i12_offset = i12*p.ne11*p.ne10; -+ const uint i11 = (idx - i13_offset - i12_offset) / p.ne10; -+ const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10; -+ return (i13 % p.ne03)*p.nb03 + (i12 % p.ne02)*p.nb02 + (i11 % p.ne01)*p.nb01 + (i10 % p.ne00)*p.nb00; -+} -+ -+void main() { -+ const uint idx = get_idx(); -+ -+ if (idx >= p.ne) { -+ return; -+ } -+ -+ data_d[get_doffset() + dst_idx(idx)] = D_TYPE(data_a[get_aoffset() + src0_idx_mod(idx)]); -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp -new file mode 100644 -index 00000000..b554400b ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp -@@ -0,0 +1,42 @@ -+#version 450 -+ -+#include "generic_head.comp" -+#include "types.comp" -+ -+#extension GL_EXT_control_flow_attributes : enable -+#define BLOCK_SIZE 512 -+ -+layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; -+ -+layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; -+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; -+ -+shared FLOAT_TYPE sum[BLOCK_SIZE]; -+ -+void main() { -+ const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; -+ const uint tid = gl_LocalInvocationID.x; -+ -+ sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp -+ -+ [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { -+ const FLOAT_TYPE xi = FLOAT_TYPE(data_a[row*p.KX + col]); -+ sum[tid] += xi * xi; -+ } -+ -+ // sum up partial sums and write back result -+ barrier(); -+ [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { -+ if (tid < s) { -+ sum[tid] += sum[tid + s]; -+ } -+ barrier(); -+ } -+ -+ const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(p.KX); -+ const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1)); -+ -+ [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { -+ data_d[row*p.KX + col] = D_TYPE(scale * FLOAT_TYPE(data_a[row*p.KX + col])); -+ } -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp -new file mode 100644 -index 00000000..574b51ca ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp -@@ -0,0 +1,49 @@ -+#include "types.comp" -+ -+#extension GL_EXT_shader_16bit_storage : require -+#extension GL_EXT_spirv_intrinsics: enable -+ -+#if RTE16 -+spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits -+#endif -+ -+layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in; -+ -+layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; -+layout (binding = 1) readonly buffer Y {int data_pos[];}; -+layout (binding = 2) readonly buffer Z {float data_ff[];}; -+layout (binding = 3) writeonly buffer D {D_TYPE data_d[];}; -+ -+layout (push_constant) uniform parameter { -+ uint ncols; -+ uint n_dims; -+ float freq_scale; -+ uint p_delta_rows; -+ float freq_base; -+ float ext_factor; -+ float attn_factor; -+ float corr_dims[2]; -+ float theta_scale; -+ uint has_ff; -+} p; -+ -+float rope_yarn_ramp(const float low, const float high, const uint i0) { -+ const float y = (i0 / 2 - low) / max(0.001f, high - low); -+ return 1.0f - min(1.0f, max(0.0f, y)); -+} -+ -+void rope_yarn(const float theta_extrap, const uint i0, out float cos_theta, out float sin_theta) { -+ float mscale = p.attn_factor; -+ // Get n-d rotational scaling corrected for extrapolation -+ float theta_interp = p.freq_scale * theta_extrap; -+ float theta = theta_interp; -+ if (p.ext_factor != 0.0f) { -+ float ramp_mix = rope_yarn_ramp(p.corr_dims[0], p.corr_dims[1], i0) * p.ext_factor; -+ theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; -+ -+ // Get n-d magnitude scaling corrected for interpolation -+ mscale *= 1.0f + 0.1f * log(1.0f / p.freq_scale); -+ } -+ cos_theta = cos(theta) * mscale; -+ sin_theta = sin(theta) * mscale; -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp -new file mode 100644 -index 00000000..83b46b69 ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp -@@ -0,0 +1,37 @@ -+#version 450 -+ -+#include "rope_head.comp" -+ -+void main() { -+ const uint col = gl_GlobalInvocationID.y * 2; -+ const uint row = gl_GlobalInvocationID.x; -+ -+ if (col >= p.ncols) { -+ return; -+ } -+ -+ if (col >= p.n_dims) { -+ const uint i = row*p.ncols + col; -+ -+ data_d[i + 0] = data_a[i + 0]; -+ data_d[i + 1] = data_a[i + 1]; -+ -+ return; -+ } -+ -+ const uint i = row*p.ncols + col/2; -+ const uint i2 = row/p.p_delta_rows; -+ -+ const float theta_base = data_pos[i2] * pow(p.theta_scale, col/2.0f); -+ -+ const float freq_factor = p.has_ff != 0 ? data_ff[col/2] : 1.0f; -+ -+ float cos_theta, sin_theta; -+ rope_yarn(theta_base / freq_factor, col, cos_theta, sin_theta); -+ -+ const float x0 = float(data_a[i + 0]); -+ const float x1 = float(data_a[i + p.n_dims/2]); -+ -+ data_d[i + 0] = D_TYPE(x0*cos_theta - x1*sin_theta); -+ data_d[i + p.n_dims/2] = D_TYPE(x0*sin_theta + x1*cos_theta); -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp -new file mode 100644 -index 00000000..e416ad93 ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp -@@ -0,0 +1,37 @@ -+#version 450 -+ -+#include "rope_head.comp" -+ -+void main() { -+ const uint col = gl_GlobalInvocationID.y * 2; -+ const uint row = gl_GlobalInvocationID.x; -+ -+ if (col >= p.ncols) { -+ return; -+ } -+ -+ if (col >= p.n_dims) { -+ const uint i = row*p.ncols + col; -+ -+ data_d[i + 0] = data_a[i + 0]; -+ data_d[i + 1] = data_a[i + 1]; -+ -+ return; -+ } -+ -+ const uint i = row*p.ncols + col; -+ const uint i2 = row/p.p_delta_rows; -+ -+ const float theta_base = data_pos[i2] * pow(p.theta_scale, col/2.0f); -+ -+ const float freq_factor = p.has_ff != 0 ? data_ff[col/2] : 1.0f; -+ -+ float cos_theta, sin_theta; -+ rope_yarn(theta_base / freq_factor, col, cos_theta, sin_theta); -+ -+ const float x0 = float(data_a[i + 0]); -+ const float x1 = float(data_a[i + 1]); -+ -+ data_d[i + 0] = D_TYPE(x0*cos_theta - x1*sin_theta); -+ data_d[i + 1] = D_TYPE(x0*sin_theta + x1*cos_theta); -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp -new file mode 100644 -index 00000000..4663428d ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp -@@ -0,0 +1,24 @@ -+#version 450 -+ -+#include "types.comp" -+#include "generic_unary_head.comp" -+ -+const uint num_threads = 128; -+ -+layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; -+ -+void main() { -+ uint idx = get_idx(); -+ -+ // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation -+ const uint num_iter = 4; -+ -+ [[unroll]] for (uint i = 0; i < num_iter; ++i) { -+ if (idx >= p.ne) { -+ continue; -+ } -+ -+ data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) * FLOAT_TYPE(p.param1)); -+ idx += num_threads; -+ } -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp -new file mode 100644 -index 00000000..4d36f88e ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp -@@ -0,0 +1,22 @@ -+#version 450 -+ -+#include "generic_head.comp" -+#include "types.comp" -+ -+#extension GL_EXT_control_flow_attributes : enable -+ -+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; -+ -+layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; -+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; -+ -+void main() { -+ const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; -+ -+ if (i >= p.KX) { -+ return; -+ } -+ -+ const float xi = float(data_a[i]); -+ data_d[i] = D_TYPE(xi / (1.0f + exp(-xi))); -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp -new file mode 100644 -index 00000000..d7c15a16 ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp -@@ -0,0 +1,17 @@ -+#version 450 -+ -+#include "types.comp" -+#include "generic_unary_head.comp" -+ -+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; -+ -+void main() { -+ const uint idx = get_idx(); -+ -+ if (idx >= p.ne) { -+ return; -+ } -+ -+ const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]); -+ data_d[get_doffset() + dst_idx(idx)] = D_TYPE(sin(val)); -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp -new file mode 100644 -index 00000000..a25808e1 ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp -@@ -0,0 +1,174 @@ -+#version 450 -+ -+#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require -+#extension GL_EXT_control_flow_attributes : enable -+ -+layout (push_constant) uniform parameter -+{ -+ uint KX; -+ uint KY; -+ float scale; -+ float max_bias; -+ float m0; -+ float m1; -+ uint n_head_log2; -+ uint nrows_x; -+} p; -+ -+#include "types.comp" -+ -+layout(constant_id = 0) const uint BLOCK_SIZE = 32; -+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; -+ -+layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; -+layout (binding = 1) readonly buffer Y {B_TYPE data_b[];}; -+layout (binding = 2) buffer D {D_TYPE data_d[];}; -+ -+shared FLOAT_TYPE vals[BLOCK_SIZE]; -+ -+// num_iters is the number of BLOCK_SIZE loop iterations we need to iterate -+// over all the columns. The main function tries to pass a constant here, -+// as if it were a template function, to allow unrolling. -+void soft_max(uint num_iters) { -+ const uint tid = gl_LocalInvocationID.x; -+ const uint rowx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; -+ const uint rowy = (p.KY > 0) ? (rowx % p.KY) : 0; -+ -+ if (rowx >= p.nrows_x) { -+ return; -+ } -+ -+ float slope = 1.0f; -+ -+ // ALiBi -+ if (p.max_bias > 0.0f) { -+ const uint h = rowx/p.KY; // head index -+ -+ const float base = h < p.n_head_log2 ? p.m0 : p.m1; -+ const uint exp = h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1; -+ -+ slope = pow(base, exp); -+ } -+ -+ // Find max -+ FLOAT_TYPE max_val = uintBitsToFloat(0xFF800000); -+ -+ // Cache values while we compute the max, so we don't need to read them -+ // again when we're ready to compute exp(x-max). -+ const uint DATA_CACHE_SIZE = 16; -+ FLOAT_TYPE data_cache[DATA_CACHE_SIZE]; -+ -+ [[unroll]] for (uint col0 = 0, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) { -+ const uint col = col0 + tid; -+ -+ FLOAT_TYPE a = FLOAT_TYPE(0); -+ if (col < p.KX) { -+ a = data_a[rowx * p.KX + col]; -+ } -+ -+ FLOAT_TYPE b = FLOAT_TYPE(0); -+ if (p.KY > 0 && col < p.KX) { -+ b = data_b[rowy * p.KX + col]; -+ } -+ -+ FLOAT_TYPE v = a * p.scale + slope * b; -+ -+ if (col < p.KX) { -+ max_val = max(max_val, v); -+ } -+ -+ if (idx < DATA_CACHE_SIZE) { -+ data_cache[idx] = v; -+ } -+ } -+ -+ // reduce across the workgroup -+ vals[tid] = max_val; -+ barrier(); -+ [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) { -+ if (tid < s) { -+ vals[tid] = max(vals[tid], vals[tid + s]); -+ } -+ barrier(); -+ } -+ -+ max_val = vals[0]; -+ barrier(); -+ -+ FLOAT_TYPE sum = FLOAT_TYPE(0.0f); -+ -+ // Compute sum{exp(x - max)} -+ [[unroll]] for (uint col0 = 0, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) { -+ const uint col = col0 + tid; -+ -+ if (col >= p.KX) { -+ break; -+ } -+ -+ // compute exp(a*scale+b*slope), add it to sum, and cache the new value -+ // in data_cache if possible. -+ const uint i = rowx * p.KX + col; -+ FLOAT_TYPE val; -+ if (idx < DATA_CACHE_SIZE) { -+ val = exp(data_cache[idx] - max_val); -+ } else { -+ val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)) - max_val); -+ } -+ sum += val; -+ if (idx < DATA_CACHE_SIZE) { -+ data_cache[idx] = val; -+ } else { -+ data_d[i] = D_TYPE(val); -+ } -+ } -+ -+ // reduce across the workgroup -+ vals[tid] = sum; -+ barrier(); -+ [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) { -+ if (tid < s) { -+ vals[tid] += vals[tid + s]; -+ } -+ barrier(); -+ } -+ sum = vals[0]; -+ -+ FLOAT_TYPE rcpdivisor = 1.0/sum; -+ -+ [[unroll]] for (uint col0 = 0, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) { -+ const uint col = col0 + tid; -+ -+ if (col >= p.KX) { -+ continue; -+ } -+ -+ if (idx < DATA_CACHE_SIZE) { -+ data_d[rowx*p.KX + col] = D_TYPE(data_cache[idx] * rcpdivisor); -+ } else { -+ data_d[rowx*p.KX + col] *= D_TYPE(rcpdivisor); -+ } -+ } -+} -+ -+void main() { -+ // instantiate the soft_max function for several different -+ // dimensions, to allow loop unrolling -+ uint num_blocks = (p.KX + BLOCK_SIZE - 1) / BLOCK_SIZE; -+ if (num_blocks > 32) { -+ soft_max(num_blocks); -+ } else if (num_blocks > 16) { -+ soft_max(32); -+ } else if (num_blocks > 8) { -+ soft_max(16); -+ } else if (num_blocks > 4) { -+ soft_max(8); -+ } else if (num_blocks == 4) { -+ soft_max(4); -+ } else if (num_blocks == 3) { -+ soft_max(3); -+ } else if (num_blocks == 2) { -+ soft_max(2); -+ } else if (num_blocks == 1) { -+ soft_max(1); -+ } -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/square.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/square.comp -new file mode 100644 -index 00000000..ef43598b ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/square.comp -@@ -0,0 +1,17 @@ -+#version 450 -+ -+#include "types.comp" -+#include "generic_unary_head.comp" -+ -+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; -+ -+void main() { -+ const uint idx = get_idx(); -+ -+ if (idx >= p.ne) { -+ return; -+ } -+ -+ const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]); -+ data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val * val); -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp -new file mode 100644 -index 00000000..961e5ffa ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp -@@ -0,0 +1,37 @@ -+#version 450 -+ -+#include "generic_head.comp" -+#include "types.comp" -+ -+#extension GL_EXT_control_flow_attributes : enable -+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; -+ -+layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; -+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; -+ -+layout (constant_id = 0) const uint BLOCK_SIZE = 32; -+ -+shared FLOAT_TYPE tmp[BLOCK_SIZE]; -+ -+void main() { -+ const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; -+ const uint col = gl_LocalInvocationID.x; -+ -+ tmp[col] = FLOAT_TYPE(0.0f); -+ -+ for (uint i = col; i < p.KX; i += BLOCK_SIZE) { -+ tmp[col] += FLOAT_TYPE(data_a[row*p.KX + i]); -+ } -+ -+ barrier(); -+ [[unroll]] for (int s = int(BLOCK_SIZE) / 2; s > 0; s >>= 1) { -+ if (col < s) { -+ tmp[col] += tmp[col + s]; -+ } -+ barrier(); -+ } -+ -+ if (col == 0) { -+ data_d[row] = D_TYPE(tmp[0]); -+ } -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp -new file mode 100644 -index 00000000..495f966b ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp -@@ -0,0 +1,20 @@ -+#version 450 -+ -+#include "generic_head.comp" -+#include "types.comp" -+ -+#extension GL_EXT_control_flow_attributes : enable -+ -+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; -+ -+layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; -+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; -+ -+void main() { -+ const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; -+ -+ if (i >= p.KX) { -+ return; -+ } -+ data_d[i] = D_TYPE(1. - 2. / (exp(2.*data_a[i]) + 1.)); -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat2_support.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat2_support.comp -new file mode 100644 -index 00000000..28eb24e1 ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat2_support.comp -@@ -0,0 +1,7 @@ -+#version 460 -+ -+#extension GL_NV_cooperative_matrix2 : require -+ -+void main() -+{ -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp -new file mode 100644 -index 00000000..79e065a9 ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp -@@ -0,0 +1,41 @@ -+#version 450 -+ -+#extension GL_EXT_shader_16bit_storage : require -+ -+layout (push_constant) uniform parameter -+{ -+ uint nb1; -+ uint dim; -+ uint max_period; -+} p; -+ -+#include "types.comp" -+ -+#extension GL_EXT_control_flow_attributes : enable -+#define BLOCK_SIZE 256 -+ -+layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; -+ -+layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; -+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; -+ -+void main() { -+ const uint i = gl_WorkGroupID.y; -+ const uint j = gl_GlobalInvocationID.x; -+ const uint d_offset = i * p.nb1; -+ -+ if (p.dim % 2 != 0 && j == ((p.dim + 1) / 2)) { -+ data_d[d_offset + p.dim] = 0.f; -+ } -+ -+ const uint half_dim = p.dim / 2; -+ if (j >= half_dim) { -+ return; -+ } -+ -+ const float timestep = float(data_a[i]); -+ const float freq = float(exp(-log(p.max_period) * j / half_dim)); -+ const float arg = timestep * freq; -+ data_d[d_offset + j] = D_TYPE(cos(arg)); -+ data_d[d_offset + j + half_dim] = D_TYPE(sin(arg)); -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/types.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/types.comp -new file mode 100644 -index 00000000..eecc47f3 ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/types.comp -@@ -0,0 +1,323 @@ -+ -+#if !defined(GGML_TYPES_COMP) -+#define GGML_TYPES_COMP -+ -+#extension GL_EXT_shader_explicit_arithmetic_types : require -+ -+#if defined(DATA_A_F32) -+#define QUANT_K 1 -+#define QUANT_R 1 -+ -+#if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1 -+#define A_TYPE float -+#elif LOAD_VEC_A == 4 -+#define A_TYPE vec4 -+#elif LOAD_VEC_A == 8 -+#define A_TYPE mat2x4 -+#endif -+#endif -+ -+#if defined(DATA_A_F16) -+#define QUANT_K 1 -+#define QUANT_R 1 -+ -+#if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1 -+#define A_TYPE float16_t -+#elif LOAD_VEC_A == 4 -+#define A_TYPE f16vec4 -+#elif LOAD_VEC_A == 8 -+#define A_TYPE f16mat2x4 -+#endif -+#endif -+ -+#define QUANT_K_Q4_0 32 -+#define QUANT_R_Q4_0 2 -+ -+struct block_q4_0 -+{ -+ float16_t d; -+ uint8_t qs[16]; -+}; -+struct block_q4_0_packed16 -+{ -+ float16_t d; -+ uint16_t qs[16/2]; -+}; -+ -+#if defined(DATA_A_Q4_0) -+#define QUANT_K QUANT_K_Q4_0 -+#define QUANT_R QUANT_R_Q4_0 -+#define A_TYPE block_q4_0 -+#define A_TYPE_PACKED16 block_q4_0_packed16 -+#endif -+ -+#define QUANT_K_Q4_1 32 -+#define QUANT_R_Q4_1 2 -+ -+struct block_q4_1 -+{ -+ float16_t d; -+ float16_t m; -+ uint8_t qs[16]; -+}; -+ -+struct block_q4_1_packed16 -+{ -+ float16_t d; -+ float16_t m; -+ uint16_t qs[16/2]; -+}; -+ -+#if defined(DATA_A_Q4_1) -+#define QUANT_K QUANT_K_Q4_1 -+#define QUANT_R QUANT_R_Q4_1 -+#define A_TYPE block_q4_1 -+#define A_TYPE_PACKED16 block_q4_1_packed16 -+#endif -+ -+#define QUANT_K_Q5_0 32 -+#define QUANT_R_Q5_0 2 -+ -+struct block_q5_0 -+{ -+ float16_t d; -+ uint16_t qh[2]; -+ uint8_t qs[16]; -+}; -+ -+struct block_q5_0_packed16 -+{ -+ float16_t d; -+ uint16_t qh[2]; -+ uint16_t qs[16/2]; -+}; -+ -+#if defined(DATA_A_Q5_0) -+#define QUANT_K QUANT_K_Q5_0 -+#define QUANT_R QUANT_R_Q5_0 -+#define A_TYPE block_q5_0 -+#define A_TYPE_PACKED16 block_q5_0_packed16 -+#endif -+ -+#define QUANT_K_Q5_1 32 -+#define QUANT_R_Q5_1 2 -+ -+struct block_q5_1 -+{ -+ float16_t d; -+ float16_t m; -+ uint qh; -+ uint8_t qs[16]; -+}; -+ -+struct block_q5_1_packed16 -+{ -+ float16_t d; -+ float16_t m; -+ uint qh; -+ uint16_t qs[16/2]; -+}; -+ -+#if defined(DATA_A_Q5_1) -+#define QUANT_K QUANT_K_Q5_1 -+#define QUANT_R QUANT_R_Q5_1 -+#define A_TYPE block_q5_1 -+#define A_TYPE_PACKED16 block_q5_1_packed16 -+#endif -+ -+#define QUANT_K_Q8_0 32 -+#define QUANT_R_Q8_0 1 -+ -+struct block_q8_0 -+{ -+ float16_t d; -+ int8_t qs[32]; -+}; -+struct block_q8_0_packed16 -+{ -+ float16_t d; -+ uint16_t qs[32/2]; -+}; -+ -+#if defined(DATA_A_Q8_0) -+#define QUANT_K QUANT_K_Q8_0 -+#define QUANT_R QUANT_R_Q8_0 -+#define A_TYPE block_q8_0 -+#define A_TYPE_PACKED16 block_q8_0_packed16 -+#endif -+ -+// K-quants -+#define QUANT_K_Q2_K 256 -+ -+struct block_q2_K -+{ -+ uint8_t scales[QUANT_K_Q2_K/16]; -+ uint8_t qs[QUANT_K_Q2_K/4]; -+ f16vec2 d; -+}; -+ -+struct block_q2_K_packed16 -+{ -+ uint16_t scales[QUANT_K_Q2_K/16/2]; -+ uint16_t qs[QUANT_K_Q2_K/4/2]; -+ f16vec2 d; -+}; -+ -+struct block_q2_K_packed32 -+{ -+ uint32_t scales[QUANT_K_Q2_K/16/4]; -+ uint32_t qs[QUANT_K_Q2_K/4/4]; -+ f16vec2 d; -+}; -+ -+#if defined(DATA_A_Q2_K) -+#define QUANT_K QUANT_K_Q2_K -+#define A_TYPE block_q2_K -+#define A_TYPE_PACKED16 block_q2_K_packed16 -+#define A_TYPE_PACKED32 block_q2_K_packed32 -+#endif -+ -+#define QUANT_K_Q3_K 256 -+ -+struct block_q3_K -+{ -+ uint8_t hmask[QUANT_K_Q3_K/8]; -+ uint8_t qs[QUANT_K_Q3_K/4]; -+ uint8_t scales[12]; -+ float16_t d; -+}; -+ -+struct block_q3_K_packed16 -+{ -+ uint16_t hmask[QUANT_K_Q3_K/8/2]; -+ uint16_t qs[QUANT_K_Q3_K/4/2]; -+ uint16_t scales[12/2]; -+ float16_t d; -+}; -+ -+#if defined(DATA_A_Q3_K) -+#define QUANT_K QUANT_K_Q3_K -+#define A_TYPE block_q3_K -+#define A_TYPE_PACKED16 block_q3_K_packed16 -+#endif -+ -+#define QUANT_K_Q4_K 256 -+ -+struct block_q4_K -+{ -+ f16vec2 d; -+ uint8_t scales[3*QUANT_K_Q4_K/64]; -+ uint8_t qs[QUANT_K_Q4_K/2]; -+}; -+ -+struct block_q4_K_packed16 -+{ -+ f16vec2 d; -+ uint16_t scales[3*QUANT_K_Q4_K/64/2]; -+ uint16_t qs[QUANT_K_Q4_K/2/2]; -+}; -+ -+struct block_q4_K_packed32 -+{ -+ f16vec2 d; -+ uint32_t scales[3*QUANT_K_Q4_K/64/4]; -+ uint32_t qs[QUANT_K_Q4_K/2/4]; -+}; -+ -+#if defined(DATA_A_Q4_K) -+#define QUANT_K QUANT_K_Q4_K -+#define A_TYPE block_q4_K -+#define A_TYPE_PACKED16 block_q4_K_packed16 -+#define A_TYPE_PACKED32 block_q4_K_packed32 -+#endif -+ -+#define QUANT_K_Q5_K 256 -+ -+struct block_q5_K -+{ -+ f16vec2 d; -+ uint8_t scales[12]; -+ uint8_t qh[QUANT_K_Q5_K/8]; -+ uint8_t qs[QUANT_K_Q5_K/2]; -+}; -+ -+struct block_q5_K_packed16 -+{ -+ f16vec2 d; -+ uint16_t scales[12/2]; -+ uint16_t qh[QUANT_K_Q5_K/8/2]; -+ uint16_t qs[QUANT_K_Q5_K/2/2]; -+}; -+ -+#if defined(DATA_A_Q5_K) -+#define QUANT_K QUANT_K_Q5_K -+#define A_TYPE block_q5_K -+#define A_TYPE_PACKED16 block_q5_K_packed16 -+#endif -+ -+#define QUANT_K_Q6_K 256 -+ -+struct block_q6_K -+{ -+ uint8_t ql[QUANT_K_Q6_K/2]; -+ uint8_t qh[QUANT_K_Q6_K/4]; -+ int8_t scales[QUANT_K_Q6_K/16]; -+ float16_t d; -+}; -+ -+struct block_q6_K_packed16 -+{ -+ uint16_t ql[QUANT_K_Q6_K/2/2]; -+ uint16_t qh[QUANT_K_Q6_K/4/2]; -+ int8_t scales[QUANT_K_Q6_K/16]; -+ float16_t d; -+}; -+ -+#if defined(DATA_A_Q6_K) -+#define QUANT_K QUANT_K_Q6_K -+#define A_TYPE block_q6_K -+#define A_TYPE_PACKED16 block_q6_K_packed16 -+#endif -+ -+// IQuants -+ -+#define QUANT_K_IQ4_NL 32 -+#define QUANT_R_IQ4_NL 2 -+ -+struct block_iq4_nl -+{ -+ float16_t d; -+ uint8_t qs[QUANT_K_IQ4_NL/2]; -+}; -+ -+struct block_iq4_nl_packed16 -+{ -+ float16_t d; -+ uint16_t qs[QUANT_K_IQ4_NL/2/2]; -+}; -+ -+#if defined(DATA_A_IQ4_NL) -+ -+const int8_t kvalues_iq4nl_const[16] = { -+ int8_t(-127), int8_t(-104), int8_t(-83), int8_t(-65), int8_t(-49), int8_t(-35), int8_t(-22), int8_t(-10), -+ int8_t(1), int8_t(13), int8_t(25), int8_t(38), int8_t(53), int8_t(69), int8_t(89), int8_t(113) -+}; -+ -+shared FLOAT_TYPE kvalues_iq4nl[16]; -+ -+void init_iq4nl_shmem() -+{ -+ // copy the table into shared memory and sync -+ if (gl_LocalInvocationIndex.x < 16) { -+ kvalues_iq4nl[gl_LocalInvocationIndex.x] = FLOAT_TYPE(kvalues_iq4nl_const[gl_LocalInvocationIndex.x]); -+ } -+ barrier(); -+} -+ -+#define QUANT_K QUANT_K_IQ4_NL -+#define QUANT_R QUANT_R_IQ4_NL -+#define A_TYPE block_iq4_nl -+#define A_TYPE_PACKED16 block_iq4_nl_packed16 -+#endif -+ -+#endif // !defined(GGML_TYPES_COMP) -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp -new file mode 100644 -index 00000000..6f607380 ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp -@@ -0,0 +1,36 @@ -+#version 450 -+ -+layout (push_constant) uniform parameter -+{ -+ uint ne; uint a_offset; uint d_offset; -+ uint nb00; uint nb01; uint nb02; uint nb03; -+ uint ne10; uint ne11; uint ne12; uint ne13; -+ float sf0; float sf1; float sf2; float sf3; -+} p; -+ -+#include "types.comp" -+ -+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; -+ -+layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; -+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; -+ -+void main() { -+ const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; -+ -+ if (idx >= p.ne) { -+ return; -+ } -+ -+ const uint i10 = idx % p.ne10; -+ const uint i11 = (idx / p.ne10) % p.ne11; -+ const uint i12 = (idx / (p.ne10 * p.ne11)) % p.ne12; -+ const uint i13 = (idx / (p.ne10 * p.ne11 * p.ne12)) % p.ne13; -+ -+ const uint i00 = uint(i10 / p.sf0); -+ const uint i01 = uint(i11 / p.sf1); -+ const uint i02 = uint(i12 / p.sf2); -+ const uint i03 = uint(i13 / p.sf3); -+ -+ data_d[p.d_offset + idx] = D_TYPE(data_a[p.a_offset + i03 * p.nb03 + i02 * p.nb02 + i01 * p.nb01 + i00 * p.nb00]); -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp -new file mode 100644 -index 00000000..8111c063 ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp -@@ -0,0 +1,594 @@ -+ -+ -+#include -+#include -+#include -+#include -+#include -+#include -+#include -+#include -+#include -+#include -+#include -+#include -+#include -+#include -+#include -+#include -+#include -+#include -+#include -+ -+#ifdef _WIN32 -+ #include -+ #include // For _mkdir on Windows -+ #include // For std::replace on w64devkit -+#else -+ #include -+ #include -+ #include -+#endif -+ -+#include -+ -+#define ASYNCIO_CONCURRENCY 64 -+ -+std::mutex lock; -+std::vector> shader_fnames; -+ -+std::string GLSLC = "glslc"; -+std::string input_dir = "vulkan-shaders"; -+std::string output_dir = "/tmp"; -+std::string target_hpp = "ggml-vulkan-shaders.hpp"; -+std::string target_cpp = "ggml-vulkan-shaders.cpp"; -+bool no_clean = false; -+ -+const std::vector type_names = { -+ "f32", -+ "f16", -+ "q4_0", -+ "q4_1", -+ "q5_0", -+ "q5_1", -+ "q8_0", -+ "q2_k", -+ "q3_k", -+ "q4_k", -+ "q5_k", -+ "q6_k", -+ "iq4_nl" -+}; -+ -+namespace { -+void execute_command(const std::string& command, std::string& stdout_str, std::string& stderr_str) { -+#ifdef _WIN32 -+ HANDLE stdout_read, stdout_write; -+ HANDLE stderr_read, stderr_write; -+ SECURITY_ATTRIBUTES sa = { sizeof(SECURITY_ATTRIBUTES), NULL, TRUE }; -+ -+ if (!CreatePipe(&stdout_read, &stdout_write, &sa, 0) || -+ !SetHandleInformation(stdout_read, HANDLE_FLAG_INHERIT, 0)) { -+ throw std::runtime_error("Failed to create stdout pipe"); -+ } -+ -+ if (!CreatePipe(&stderr_read, &stderr_write, &sa, 0) || -+ !SetHandleInformation(stderr_read, HANDLE_FLAG_INHERIT, 0)) { -+ throw std::runtime_error("Failed to create stderr pipe"); -+ } -+ -+ PROCESS_INFORMATION pi; -+ STARTUPINFOA si = {}; -+ si.cb = sizeof(STARTUPINFOA); -+ si.dwFlags = STARTF_USESTDHANDLES; -+ si.hStdOutput = stdout_write; -+ si.hStdError = stderr_write; -+ -+ std::vector cmd(command.begin(), command.end()); -+ cmd.push_back('\0'); -+ -+ if (!CreateProcessA(NULL, cmd.data(), NULL, NULL, TRUE, 0, NULL, NULL, &si, &pi)) { -+ throw std::runtime_error("Failed to create process"); -+ } -+ -+ CloseHandle(stdout_write); -+ CloseHandle(stderr_write); -+ -+ std::array buffer; -+ DWORD bytes_read; -+ -+ while (ReadFile(stdout_read, buffer.data(), (DWORD)buffer.size(), &bytes_read, NULL) && bytes_read > 0) { -+ stdout_str.append(buffer.data(), bytes_read); -+ } -+ -+ while (ReadFile(stderr_read, buffer.data(), (DWORD)buffer.size(), &bytes_read, NULL) && bytes_read > 0) { -+ stderr_str.append(buffer.data(), bytes_read); -+ } -+ -+ CloseHandle(stdout_read); -+ CloseHandle(stderr_read); -+ WaitForSingleObject(pi.hProcess, INFINITE); -+ CloseHandle(pi.hProcess); -+ CloseHandle(pi.hThread); -+#else -+int stdout_pipe[2]; -+ int stderr_pipe[2]; -+ -+ if (pipe(stdout_pipe) != 0 || pipe(stderr_pipe) != 0) { -+ throw std::runtime_error("Failed to create pipes"); -+ } -+ -+ pid_t pid = fork(); -+ if (pid < 0) { -+ throw std::runtime_error("Failed to fork process"); -+ } -+ -+ if (pid == 0) { -+ close(stdout_pipe[0]); -+ close(stderr_pipe[0]); -+ dup2(stdout_pipe[1], STDOUT_FILENO); -+ dup2(stderr_pipe[1], STDERR_FILENO); -+ close(stdout_pipe[1]); -+ close(stderr_pipe[1]); -+ execl("/bin/sh", "sh", "-c", command.c_str(), (char*) nullptr); -+ _exit(EXIT_FAILURE); -+ } else { -+ close(stdout_pipe[1]); -+ close(stderr_pipe[1]); -+ -+ std::array buffer; -+ ssize_t bytes_read; -+ -+ while ((bytes_read = read(stdout_pipe[0], buffer.data(), buffer.size())) > 0) { -+ stdout_str.append(buffer.data(), bytes_read); -+ } -+ -+ while ((bytes_read = read(stderr_pipe[0], buffer.data(), buffer.size())) > 0) { -+ stderr_str.append(buffer.data(), bytes_read); -+ } -+ -+ close(stdout_pipe[0]); -+ close(stderr_pipe[0]); -+ waitpid(pid, nullptr, 0); -+ } -+#endif -+} -+ -+bool directory_exists(const std::string& path) { -+ struct stat info; -+ if (stat(path.c_str(), &info) != 0) { -+ return false; // Path doesn't exist or can't be accessed -+ } -+ return (info.st_mode & S_IFDIR) != 0; // Check if it is a directory -+} -+ -+bool create_directory(const std::string& path) { -+#ifdef _WIN32 -+ return _mkdir(path.c_str()) == 0 || errno == EEXIST; // EEXIST means the directory already exists -+#else -+ return mkdir(path.c_str(), 0755) == 0 || errno == EEXIST; // 0755 is the directory permissions -+#endif -+} -+ -+std::string to_uppercase(const std::string& input) { -+ std::string result = input; -+ for (char& c : result) { -+ c = std::toupper(c); -+ } -+ return result; -+} -+ -+bool string_ends_with(const std::string& str, const std::string& suffix) { -+ if (suffix.size() > str.size()) { -+ return false; -+ } -+ return std::equal(suffix.rbegin(), suffix.rend(), str.rbegin()); -+} -+ -+static const char path_separator = '/'; -+ -+std::string join_paths(const std::string& path1, const std::string& path2) { -+ return path1 + path_separator + path2; -+} -+ -+std::string basename(const std::string &path) { -+ return path.substr(path.find_last_of("/\\") + 1); -+} -+ -+// variables to track number of compiles in progress -+static uint32_t compile_count = 0; -+static std::mutex compile_count_mutex; -+static std::condition_variable compile_count_cond; -+ -+void string_to_spv_func(const std::string& _name, const std::string& in_fname, const std::map& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) { -+ std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat ? "_coopmat" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32")); -+ std::string out_fname = join_paths(output_dir, name + ".spv"); -+ std::string in_path = join_paths(input_dir, in_fname); -+ -+ std::string target_env = (name.find("_cm2") != std::string::npos) ? "--target-env=vulkan1.3" : "--target-env=vulkan1.2"; -+ -+ // disable spirv-opt for coopmat shaders for https://github.com/ggerganov/llama.cpp/issues/10734 -+ std::string opt_level = coopmat ? "" : "-O"; -+ -+ #ifdef _WIN32 -+ std::vector cmd = {GLSLC, "-fshader-stage=compute", target_env, opt_level, "\"" + in_path + "\"", "-o", "\"" + out_fname + "\""}; -+ #else -+ std::vector cmd = {GLSLC, "-fshader-stage=compute", target_env, opt_level, in_path, "-o", out_fname}; -+ #endif -+ -+ #ifdef GGML_VULKAN_SHADER_DEBUG_INFO -+ cmd.push_back("-g"); -+ #endif -+ -+ for (const auto& define : defines) { -+ cmd.push_back("-D" + define.first + "=" + define.second); -+ } -+ -+ std::string command; -+ for (const auto& part : cmd) { -+ command += part + " "; -+ } -+ -+ std::string stdout_str, stderr_str; -+ try { -+ // std::cout << "Executing command: "; -+ // for (const auto& part : cmd) { -+ // std::cout << part << " "; -+ // } -+ // std::cout << std::endl; -+ -+ execute_command(command, stdout_str, stderr_str); -+ if (!stderr_str.empty()) { -+ std::cerr << "cannot compile " << name << "\n\n" << command << "\n\n" << stderr_str << std::endl; -+ return; -+ } -+ -+ std::lock_guard guard(lock); -+ shader_fnames.push_back(std::make_pair(name, out_fname)); -+ } catch (const std::exception& e) { -+ std::cerr << "Error executing command for " << name << ": " << e.what() << std::endl; -+ } -+ { -+ std::lock_guard guard(compile_count_mutex); -+ assert(compile_count > 0); -+ compile_count--; -+ } -+ compile_count_cond.notify_all(); -+} -+ -+std::map merge_maps(const std::map& a, const std::map& b) { -+ std::map result = a; -+ result.insert(b.begin(), b.end()); -+ return result; -+} -+ -+static std::vector> compiles; -+void string_to_spv(const std::string& _name, const std::string& in_fname, const std::map& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) { -+ { -+ // wait until fewer than N compiles are in progress. -+ // 16 is an arbitrary limit, the goal is to avoid "failed to create pipe" errors. -+ uint32_t N = 16; -+ std::unique_lock guard(compile_count_mutex); -+ while (compile_count >= N) { -+ compile_count_cond.wait(guard); -+ } -+ compile_count++; -+ } -+ compiles.push_back(std::async(string_to_spv_func, _name, in_fname, defines, fp16, coopmat, coopmat2, f16acc)); -+} -+ -+void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool f16acc) { -+ std::string load_vec = coopmat2 ? "1" : fp16 ? "8" : "4"; -+ std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4"; -+ std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4"; -+ -+ std::map base_dict = {{"FLOAT_TYPE", (coopmat2 || fp16) ? "float16_t" : "float"}}; -+ std::string shader_name = "matmul"; -+ -+ if (matmul_id) { -+ base_dict["MUL_MAT_ID"] = "1"; -+ shader_name = "matmul_id"; -+ } -+ -+ if (fp16) { -+ base_dict["FLOAT16"] = "1"; -+ } -+ -+ base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float"; -+ -+ if (coopmat) { -+ base_dict["COOPMAT"] = "1"; -+ } -+ -+ base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float"; -+ -+ std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp"; -+ -+ // Shaders with f16 B_TYPE -+ string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc); -+ string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); -+ -+ string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); -+ string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); -+ -+ for (const auto& tname : type_names) { -+ std::string data_a_key = "DATA_A_" + to_uppercase(tname); -+ // For unaligned, load one at a time for f32/f16, or two at a time for quants -+ std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16") ? "1" : "2"; -+ // For aligned matmul loads -+ std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : "2"; -+ -+ string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc); -+ string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); -+ -+ if (tname != "f16" && tname != "f32") { -+ string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc); -+ string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); -+ } -+ } -+} -+ -+void process_shaders() { -+ std::cout << "ggml_vulkan: Generating and compiling shaders to SPIR-V" << std::endl; -+ std::map base_dict = {{"FLOAT_TYPE", "float"}}; -+ -+ // matmul -+ for (const auto& matmul_id : {false, true}) { -+ // No coopmats -+ // fp32 -+ matmul_shaders(false, matmul_id, false, false, false); -+ -+ // fp16, fp32acc and fp16acc -+ matmul_shaders(true, matmul_id, false, false, false); -+ matmul_shaders(true, matmul_id, false, false, true); -+ -+ // Coopmat, fp32acc and fp16acc -+ matmul_shaders(true, matmul_id, true, false, false); -+ matmul_shaders(true, matmul_id, true, false, true); -+ -+#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) -+ // Coopmat2, fp32acc and fp16acc -+ matmul_shaders(true, matmul_id, false, true, false); -+ matmul_shaders(true, matmul_id, false, true, true); -+#endif -+ } -+ -+#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) -+ // flash attention -+ for (const auto& f16acc : {false, true}) { -+ std::string acctype = f16acc ? "float16_t" : "float"; -+ -+ for (const auto& tname : type_names) { -+ if (tname == "f32") { -+ continue; -+ } -+ -+ if (tname == "f16") { -+ string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", -+ merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, true, f16acc); -+ } else { -+ std::string data_a_key = "DATA_A_" + to_uppercase(tname); -+ string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", -+ merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc); -+ } -+ } -+ } -+#endif -+ -+ for (const auto& tname : type_names) { -+ // mul mat vec -+ std::string data_a_key = "DATA_A_" + to_uppercase(tname); -+ std::string shader = (string_ends_with(tname, "_k")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp"; -+ -+ string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}})); -+ string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}})); -+ -+ string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}})); -+ -+ // Dequant shaders -+ if (tname != "f16") { -+ string_to_spv("dequant_" + tname, "dequant_" + tname + ".comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float16_t"}})); -+ } -+ -+ if (!string_ends_with(tname, "_k")) { -+ shader = (tname == "f32" || tname == "f16") ? "get_rows.comp" : "get_rows_quant.comp"; -+ -+ if (tname == "f16") { -+ string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}})); -+ } else { -+ string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}})); -+ } -+ string_to_spv("get_rows_" + tname + "_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}})); -+ } -+ } -+ -+ string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); -+ string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); -+ -+ // Norms -+ string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); -+ string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); -+ string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); -+ -+ string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); -+ string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}); -+ string_to_spv("cpy_f16_f16", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); -+ string_to_spv("contig_cpy_f32_f32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); -+ string_to_spv("contig_cpy_f32_f16", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}); -+ string_to_spv("contig_cpy_f16_f16", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); -+ -+ string_to_spv("add_f32", "add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); -+ string_to_spv("add_f16_f32_f16", "add.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}}); -+ -+ string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); -+ -+ string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {}); -+ -+ string_to_spv("mul_f32", "mul.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); -+ -+ string_to_spv("div_f32", "div.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); -+ -+ string_to_spv("repeat_f32", "repeat.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); -+ -+ string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); -+ -+ string_to_spv("sqr_f32", "square.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); -+ -+ string_to_spv("sin_f32", "sin.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); -+ -+ string_to_spv("cos_f32", "cos.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); -+ -+ string_to_spv("clamp_f32", "clamp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); -+ -+ string_to_spv("pad_f32", "pad.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); -+ -+ string_to_spv("concat_f32", "concat.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); -+ string_to_spv("concat_f16", "concat.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); -+ string_to_spv("concat_i32", "concat.comp", {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}}); -+ -+ string_to_spv("upscale_f32", "upscale.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); -+ -+ string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); -+ string_to_spv("gelu_quick_f32", "gelu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); -+ string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); -+ string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); -+ string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); -+ string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); -+ -+ string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); -+ -+ string_to_spv("soft_max_f32", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); -+ string_to_spv("soft_max_f32_f16", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}})); -+ -+ string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); -+ string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); -+ string_to_spv("rope_norm_f16_rte", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}); -+ -+ string_to_spv("rope_neox_f32", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); -+ string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); -+ string_to_spv("rope_neox_f16_rte", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}); -+ -+ string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}}); -+ -+ string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); -+ -+ string_to_spv("im2col_f32", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); -+ string_to_spv("im2col_f32_f16", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}})); -+ string_to_spv("im2col_f32_f16_rte", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}})); -+ -+ string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); -+ -+ string_to_spv("pool2d_f32", "pool2d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); -+ -+ string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); -+ -+ for (auto &c : compiles) { -+ c.wait(); -+ } -+} -+ -+void write_output_files() { -+ FILE* hdr = fopen(target_hpp.c_str(), "w"); -+ FILE* src = fopen(target_cpp.c_str(), "w"); -+ -+ fprintf(hdr, "#include \n\n"); -+ fprintf(src, "#include \"%s\"\n\n", basename(target_hpp).c_str()); -+ -+ for (const auto& pair : shader_fnames) { -+ const std::string& name = pair.first; -+ #ifdef _WIN32 -+ std::string path = pair.second; -+ std::replace(path.begin(), path.end(), '/', '\\' ); -+ #else -+ const std::string& path = pair.second; -+ #endif -+ -+ FILE* spv = fopen(path.c_str(), "rb"); -+ if (!spv) { -+ std::cerr << "Error opening SPIR-V file: " << path << " (" << strerror(errno) << ")\n"; -+ continue; -+ } -+ -+ fseek(spv, 0, SEEK_END); -+ size_t size = ftell(spv); -+ fseek(spv, 0, SEEK_SET); -+ -+ std::vector data(size); -+ size_t read_size = fread(data.data(), 1, size, spv); -+ fclose(spv); -+ if (read_size != size) { -+ std::cerr << "Error reading SPIR-V file: " << path << " (" << strerror(errno) << ")\n"; -+ continue; -+ } -+ -+ fprintf(hdr, "extern unsigned char %s_data[%zu];\n", name.c_str(), size); -+ fprintf(hdr, "const uint64_t %s_len = %zu;\n\n", name.c_str(), size); -+ -+ fprintf(src, "unsigned char %s_data[%zu] = {\n", name.c_str(), size); -+ for (size_t i = 0; i < size; ++i) { -+ fprintf(src, "0x%02x,", data[i]); -+ if ((i + 1) % 12 == 0) fprintf(src, "\n"); -+ } -+ fprintf(src, "\n};\n\n"); -+ -+ if (!no_clean) { -+ std::remove(path.c_str()); -+ } -+ } -+ -+ fclose(hdr); -+ fclose(src); -+} -+} -+ -+int main(int argc, char** argv) { -+ std::map args; -+ for (int i = 1; i < argc; ++i) { -+ std::string arg = argv[i]; -+ if (arg.rfind("--", 0) == 0) { -+ if (i + 1 < argc && argv[i + 1][0] != '-') { -+ args[arg] = argv[i + 1]; -+ ++i; -+ } else { -+ args[arg] = ""; -+ } -+ } -+ } -+ -+ if (args.find("--glslc") != args.end()) { -+ GLSLC = args["--glslc"]; // Path to glslc -+ } -+ if (args.find("--input-dir") != args.end()) { -+ input_dir = args["--input-dir"]; // Directory containing shader sources -+ } -+ if (args.find("--output-dir") != args.end()) { -+ output_dir = args["--output-dir"]; // Directory for containing SPIR-V output -+ } -+ if (args.find("--target-hpp") != args.end()) { -+ target_hpp = args["--target-hpp"]; // Path to generated header file -+ } -+ if (args.find("--target-cpp") != args.end()) { -+ target_cpp = args["--target-cpp"]; // Path to generated cpp file -+ } -+ if (args.find("--no-clean") != args.end()) { -+ no_clean = true; // Keep temporary SPIR-V files in output-dir after build -+ } -+ -+ if (!directory_exists(input_dir)) { -+ std::cerr << "\"" << input_dir << "\" must be a valid directory containing shader sources" << std::endl; -+ return EXIT_FAILURE; -+ } -+ -+ if (!directory_exists(output_dir)) { -+ if (!create_directory(output_dir)) { -+ std::cerr << "Error creating output directory: " << output_dir << "\n"; -+ return EXIT_FAILURE; -+ } -+ } -+ -+ process_shaders(); -+ -+ write_output_files(); -+ -+ return EXIT_SUCCESS; -+} -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp -new file mode 100644 -index 00000000..35cc6c45 ---- /dev/null -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp -@@ -0,0 +1,87 @@ -+#version 450 -+ -+#extension GL_EXT_control_flow_attributes : require -+ -+#define BLOCK_SIZE 64 -+layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; -+ -+layout(push_constant) uniform Parameters { -+ uint B; -+ uint T; -+ uint C; -+ uint H; -+}; -+ -+layout(binding = 0) readonly buffer KBuf { A_TYPE k[]; }; -+layout(binding = 1) readonly buffer VBuf { A_TYPE v[]; }; -+layout(binding = 2) readonly buffer RBuf { A_TYPE r[]; }; -+layout(binding = 3) readonly buffer TimeFBuf { A_TYPE tf[]; }; -+layout(binding = 4) readonly buffer TimeDBuf { A_TYPE td[]; }; -+layout(binding = 5) readonly buffer StateBuf { A_TYPE state_in[]; }; -+layout(binding = 6) buffer DstBuf { A_TYPE dst[]; }; -+ -+shared A_TYPE _k[BLOCK_SIZE], _r[BLOCK_SIZE], _tf[BLOCK_SIZE], _td[BLOCK_SIZE]; -+ -+void main() { -+ const uint head_size = BLOCK_SIZE; -+ const uint batch_id = gl_WorkGroupID.x / H; -+ const uint head_id = gl_WorkGroupID.x % H; -+ const uint tid = gl_LocalInvocationID.x; -+ -+ const uint state_size = C * head_size; -+ const uint n_seq_tokens = T / B; -+ -+ if (batch_id >= B || head_id >= H) { -+ return; -+ } -+ -+ A_TYPE state[BLOCK_SIZE]; -+ [[unroll]] for (uint i = 0; i < head_size; i++) { -+ state[i] = state_in[batch_id * state_size + head_id * head_size * head_size -+ + i * head_size + tid]; -+ } -+ -+ barrier(); -+ _tf[tid] = tf[head_id * head_size + tid]; -+ barrier(); -+ -+ const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid; -+ const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid; -+ -+ for (uint t = start_t; t < end_t; t += C) { -+ barrier(); -+ _k[tid] = k[t]; -+ _r[tid] = r[t]; -+ _td[tid] = td[t]; -+ barrier(); -+ -+ const A_TYPE v_val = v[t]; -+ A_TYPE y = 0.0; -+ -+ [[unroll]] for (uint j = 0; j < head_size; j += 4) { -+ vec4 k_vec = vec4(_k[j], _k[j+1], _k[j+2], _k[j+3]); -+ vec4 r_vec = vec4(_r[j], _r[j+1], _r[j+2], _r[j+3]); -+ vec4 tf_vec = vec4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]); -+ vec4 td_vec = vec4(_td[j], _td[j+1], _td[j+2], _td[j+3]); -+ vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]); -+ -+ vec4 kv = k_vec * v_val; -+ -+ vec4 temp = tf_vec * kv + s_vec; -+ y += dot(r_vec, temp); -+ -+ s_vec = s_vec * td_vec + kv; -+ state[j] = s_vec.x; -+ state[j+1] = s_vec.y; -+ state[j+2] = s_vec.z; -+ state[j+3] = s_vec.w; -+ } -+ -+ dst[t] = y; -+ } -+ -+ [[unroll]] for (uint i = 0; i < head_size; i++) { -+ dst[T * C + batch_id * state_size + head_id * head_size * head_size -+ + i * head_size + tid] = state[i]; -+ } -+} --- -2.43.0 - From d1939aa1c64ed4816cba2239f5678b721ab3c310 Mon Sep 17 00:00:00 2001 From: Vadim Grinco Date: Sat, 15 Mar 2025 20:28:57 +0100 Subject: [PATCH 030/172] Fixes SIGSEGV: segmentation violation running gemma3 models on ollama 0.6.0 #21 Patch provided by McBane87 on https://github.com/whyvl/ollama-vulkan/issues/21 Signed-off-by: Vadim Grinco --- ml/backend/ggml/ggml.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 03b9acb32..2237e7f51 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -20,6 +20,7 @@ import ( "strings" "unicode" "unsafe" + "sync" "github.com/ollama/ollama/format" fs "github.com/ollama/ollama/fs/ggml" @@ -299,6 +300,7 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) { // concurrently read in tensor data. uses a section reader which is safe for concurrent reads sr := io.NewSectionReader(r, int64(meta.Tensors().Offset), n-int64(meta.Tensors().Offset)) + var tensorSetMutex sync.Mutex var g errgroup.Group for _, t := range meta.Tensors().Items() { for _, target := range targets[t.Name] { @@ -322,7 +324,9 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) { return errors.New("short read") } + tensorSetMutex.Lock() C.ggml_backend_tensor_set(tt, unsafe.Pointer(&bts[0]), 0, C.size_t(t.Size())) + tensorSetMutex.Unlock() return nil }) } From c2e440879aea4b3000cd250165d536c0c56f9936 Mon Sep 17 00:00:00 2001 From: Vadim Grinco Date: Sun, 16 Mar 2025 10:52:49 +0100 Subject: [PATCH 031/172] Applied 04-disable-mmap-vulkan.patch From: https://github.com/whyvl/ollama-vulkan/issues/7#issuecomment-2660836871 Signed-off-by: Vadim Grinco --- llm/server.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/llm/server.go b/llm/server.go index c6f117125..537cc1e1a 100644 --- a/llm/server.go +++ b/llm/server.go @@ -207,10 +207,12 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a } // Windows CUDA should not use mmap for best performance + // Vulkan should not use mmap because of double allocation (VRAM + RAM) // Linux with a model larger than free space, mmap leads to thrashing // For CPU loads we want the memory to be allocated, not FS cache if (runtime.GOOS == "windows" && gpus[0].Library == "cuda" && opts.UseMMap == nil) || (runtime.GOOS == "linux" && systemFreeMemory < estimate.TotalSize && opts.UseMMap == nil) || + (gpus[0].Library == "vulkan" && opts.UseMMap == nil) || (gpus[0].Library == "cpu" && opts.UseMMap == nil) || (opts.UseMMap != nil && !*opts.UseMMap) { params = append(params, "--no-mmap") From 640f0bb250f1bf9f53451166f1f04a4ebaecb8d3 Mon Sep 17 00:00:00 2001 From: Vadim Grinco Date: Sun, 16 Mar 2025 12:22:22 +0100 Subject: [PATCH 032/172] Pulled new upstream code for ggml-bulkan backend Signed-off-by: Vadim Grinco --- CMakePresets.json | 4 +- .../ggml/ggml/src/ggml-vulkan/CMakeLists.txt | 78 +- .../ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp | 1844 +++++++++++------ .../ggml-vulkan/vulkan-shaders/CMakeLists.txt | 6 +- .../ggml-vulkan/vulkan-shaders/argmax.comp | 51 + .../vulkan-shaders/copy_from_quant.comp | 51 + .../vulkan-shaders/copy_to_quant.comp | 237 +++ .../vulkan-shaders/count_equal.comp | 31 + .../vulkan-shaders/dequant_funcs.comp | 340 ++- .../vulkan-shaders/dequant_funcs_cm2.comp | 336 ++- .../vulkan-shaders/dequant_iq1_m.comp | 42 + .../vulkan-shaders/dequant_iq1_s.comp | 35 + .../vulkan-shaders/dequant_iq2_s.comp | 44 + .../vulkan-shaders/dequant_iq2_xs.comp | 43 + .../vulkan-shaders/dequant_iq2_xxs.comp | 48 + .../vulkan-shaders/dequant_iq3_s.comp | 39 + .../vulkan-shaders/dequant_iq3_xxs.comp | 49 + .../vulkan-shaders/dequant_iq4_nl.comp | 2 +- .../vulkan-shaders/dequant_iq4_xs.comp | 34 + .../vulkan-shaders/diag_mask_inf.comp | 2 +- .../vulkan-shaders/flash_attn_cm2.comp | 26 +- .../vulkan-shaders/generic_unary_head.comp | 20 + .../vulkan-shaders/get_rows_quant.comp | 4 +- .../vulkan-shaders/mul_mat_vec.comp | 13 +- .../vulkan-shaders/mul_mat_vec_iq1_m.comp | 82 + .../vulkan-shaders/mul_mat_vec_iq1_s.comp | 79 + .../vulkan-shaders/mul_mat_vec_q2_k.comp | 156 +- .../vulkan-shaders/mul_mat_vec_q3_k.comp | 143 +- .../vulkan-shaders/mul_mat_vec_q4_k.comp | 171 +- .../vulkan-shaders/mul_mat_vec_q5_k.comp | 231 ++- .../vulkan-shaders/mul_mat_vec_q6_k.comp | 146 +- .../ggml-vulkan/vulkan-shaders/mul_mm.comp | 194 +- .../vulkan-shaders/mul_mm_cm2.comp | 39 +- .../vulkan-shaders/opt_step_adamw.comp | 42 + .../vulkan-shaders/repeat_back.comp | 37 + .../vulkan-shaders/rms_norm_back.comp | 55 + .../ggml-vulkan/vulkan-shaders/rope_head.comp | 9 + .../vulkan-shaders/rope_multi.comp | 60 + .../ggml-vulkan/vulkan-shaders/rope_neox.comp | 34 +- .../ggml-vulkan/vulkan-shaders/rope_norm.comp | 34 +- .../vulkan-shaders/rope_vision.comp | 47 + .../ggml-vulkan/vulkan-shaders/sigmoid.comp | 20 + .../ggml-vulkan/vulkan-shaders/silu_back.comp | 26 + .../ggml-vulkan/vulkan-shaders/soft_max.comp | 1 - .../vulkan-shaders/soft_max_back.comp | 50 + .../src/ggml-vulkan/vulkan-shaders/sub.comp | 29 + .../vulkan-shaders/test_coopmat_support.comp | 7 + .../src/ggml-vulkan/vulkan-shaders/types.comp | 968 ++++++++- .../vulkan-shaders/vulkan-shaders-gen.cpp | 55 +- 49 files changed, 4970 insertions(+), 1124 deletions(-) create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat_support.comp diff --git a/CMakePresets.json b/CMakePresets.json index 09e924011..6181eb732 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -62,7 +62,7 @@ { "name": "Vulkan", "inherits": [ "Default" ] - } + } ], "buildPresets": [ { @@ -114,6 +114,6 @@ "name": "Vulkan", "targets": [ "ggml-vulkan" ], "configurePreset": "Vulkan" - } + } ] } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/CMakeLists.txt b/ml/backend/ggml/ggml/src/ggml-vulkan/CMakeLists.txt index 9501de736..d970f7e20 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/CMakeLists.txt +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/CMakeLists.txt @@ -1,5 +1,20 @@ +cmake_minimum_required(VERSION 3.19) +cmake_policy(SET CMP0114 NEW) + find_package(Vulkan COMPONENTS glslc REQUIRED) +function(detect_host_compiler) + if (CMAKE_HOST_SYSTEM_NAME STREQUAL "Windows") + find_program(HOST_C_COMPILER NAMES cl gcc clang NO_CMAKE_FIND_ROOT_PATH) + find_program(HOST_CXX_COMPILER NAMES cl g++ clang++ NO_CMAKE_FIND_ROOT_PATH) + else() + find_program(HOST_C_COMPILER NAMES gcc clang NO_CMAKE_FIND_ROOT_PATH) + find_program(HOST_CXX_COMPILER NAMES g++ clang++ NO_CMAKE_FIND_ROOT_PATH) + endif() + set(HOST_C_COMPILER "${HOST_C_COMPILER}" PARENT_SCOPE) + set(HOST_CXX_COMPILER "${HOST_CXX_COMPILER}" PARENT_SCOPE) +endfunction() + if (Vulkan_FOUND) message(STATUS "Vulkan found") @@ -8,6 +23,20 @@ if (Vulkan_FOUND) ../../include/ggml-vulkan.h ) + # Compile a test shader to determine whether GL_KHR_cooperative_matrix is supported. + # If it's not, there will be an error to stderr. + # If it's supported, set a define to indicate that we should compile those shaders + execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat_support.comp" + OUTPUT_VARIABLE glslc_output + ERROR_VARIABLE glslc_error) + + if (${glslc_error} MATCHES ".*extension not supported: GL_KHR_cooperative_matrix.*") + message(STATUS "GL_KHR_cooperative_matrix not supported by glslc") + else() + message(STATUS "GL_KHR_cooperative_matrix supported by glslc") + add_compile_definitions(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + endif() + # Compile a test shader to determine whether GL_NV_cooperative_matrix2 is supported. # If it's not, there will be an error to stderr. # If it's supported, set a define to indicate that we should compile those shaders @@ -59,21 +88,62 @@ if (Vulkan_FOUND) add_compile_definitions(GGML_VULKAN_RUN_TESTS) endif() - add_subdirectory(vulkan-shaders) + if (NOT CMAKE_CROSSCOMPILING) + add_subdirectory(vulkan-shaders) + if (MSVC) + foreach(CONFIG ${CMAKE_CONFIGURATION_TYPES}) + string(TOUPPER ${CONFIG} CONFIG) + set_target_properties(vulkan-shaders-gen PROPERTIES + RUNTIME_OUTPUT_DIRECTORY_${CONFIG} ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}) + endforeach() + endif() + else() + if (GGML_VULKAN_SHADERS_GEN_TOOLCHAIN) + set(HOST_CMAKE_TOOLCHAIN_FILE ${GGML_VULKAN_SHADERS_GEN_TOOLCHAIN}) + else() + detect_host_compiler() + if (NOT HOST_C_COMPILER OR NOT HOST_CXX_COMPILER) + message(FATAL_ERROR "Host compiler not found") + else() + message(STATUS "Host compiler: ${HOST_C_COMPILER} ${HOST_CXX_COMPILER}") + endif() + configure_file(${CMAKE_CURRENT_SOURCE_DIR}/cmake/host-toolchain.cmake.in ${CMAKE_BINARY_DIR}/host-toolchain.cmake @ONLY) + set(HOST_CMAKE_TOOLCHAIN_FILE ${CMAKE_BINARY_DIR}/host-toolchain.cmake) + endif() + message(STATUS "vulkan-shaders-gen toolchain file: ${HOST_CMAKE_TOOLCHAIN_FILE}") - set (_ggml_vk_genshaders_cmd vulkan-shaders-gen) + include(ExternalProject) + # Native build through ExternalProject_Add + ExternalProject_Add( + vulkan-shaders-gen + SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders + CMAKE_ARGS -DCMAKE_TOOLCHAIN_FILE=${HOST_CMAKE_TOOLCHAIN_FILE} + -DCMAKE_INSTALL_PREFIX=${CMAKE_BINARY_DIR} + BUILD_COMMAND ${CMAKE_COMMAND} --build . + INSTALL_COMMAND ${CMAKE_COMMAND} --install . + INSTALL_DIR ${CMAKE_BINARY_DIR} + ) + ExternalProject_Add_StepTargets(vulkan-shaders-gen build install) + endif() + set (_ggml_vk_host_suffix $,.exe,>) + set (_ggml_vk_genshaders_cmd ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/vulkan-shaders-gen${_ggml_vk_host_suffix}) set (_ggml_vk_header ${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.hpp) set (_ggml_vk_source ${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.cpp) set (_ggml_vk_input_dir ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders) set (_ggml_vk_output_dir ${CMAKE_CURRENT_BINARY_DIR}/vulkan-shaders.spv) file(GLOB _ggml_vk_shader_deps "${_ggml_vk_input_dir}/*.comp") + set (_ggml_vk_shader_deps ${_ggml_vk_shader_deps} vulkan-shaders-gen) + + if (CMAKE_CROSSCOMPILING) + set(_ggml_vk_shader_deps ${_ggml_vk_shader_deps} vulkan-shaders-gen-build vulkan-shaders-gen-install) + endif() add_custom_command( OUTPUT ${_ggml_vk_header} ${_ggml_vk_source} - COMMAND "$/${_ggml_vk_genshaders_cmd}" + COMMAND ${_ggml_vk_genshaders_cmd} --glslc ${Vulkan_GLSLC_EXECUTABLE} --input-dir ${_ggml_vk_input_dir} --output-dir ${_ggml_vk_output_dir} @@ -81,7 +151,7 @@ if (Vulkan_FOUND) --target-cpp ${_ggml_vk_source} --no-clean - DEPENDS ${_ggml_vk_shader_deps} ${_ggml_vk_genshaders_cmd} + DEPENDS ${_ggml_vk_shader_deps} COMMENT "Generate vulkan shaders" ) diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp index d75cd6d61..abe3e7908 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -29,8 +29,6 @@ #include "ggml-vulkan-shaders.hpp" -#define VK_API_VERSION VK_API_VERSION_1_2 - #define CEIL_DIV(M, N) (((M) + (N)-1) / (N)) #define VK_VENDOR_ID_AMD 0x1002 @@ -87,6 +85,10 @@ struct vk_pipeline_struct { uint32_t parameter_count; std::array wg_denoms; uint32_t align; + // set to true to request the pipeline is compiled after the dryrun + bool needed {}; + // set to true when the shader has been compiled + bool compiled {}; }; typedef std::shared_ptr vk_pipeline; @@ -154,6 +156,7 @@ struct vk_device_struct { vk::PhysicalDeviceProperties properties; std::string name; uint64_t max_memory_allocation_size; + uint64_t suballocation_block_size; bool fp16; bool pipeline_robustness; vk::Device device; @@ -164,6 +167,7 @@ struct vk_device_struct { uint32_t subgroup_size; uint32_t shader_core_count; bool uma; + bool prefer_host_memory; bool float_controls_rte_fp16; bool subgroup_size_control; @@ -181,15 +185,18 @@ struct vk_device_struct { size_t idx; - bool mul_mat_l; - bool mul_mat_m; - bool mul_mat_s; - bool mul_mat_id_l; - bool mul_mat_id_m; - bool mul_mat_id_s; + bool mul_mat_l[GGML_TYPE_COUNT]; + bool mul_mat_m[GGML_TYPE_COUNT]; + bool mul_mat_s[GGML_TYPE_COUNT]; + bool mul_mat_id_l[GGML_TYPE_COUNT]; + bool mul_mat_id_m[GGML_TYPE_COUNT]; + bool mul_mat_id_s[GGML_TYPE_COUNT]; - vk_matmul_pipeline pipeline_matmul_f32; - vk_matmul_pipeline pipeline_matmul_f32_f16; + // set to true to indicate that some shaders need to be compiled after the dryrun + bool need_compiles {}; + + vk_matmul_pipeline pipeline_matmul_f32 {}; + vk_matmul_pipeline pipeline_matmul_f32_f16 {}; vk_matmul_pipeline2 pipeline_matmul_f16; vk_matmul_pipeline2 pipeline_matmul_f16_f32; vk_pipeline pipeline_matmul_split_k_reduce; @@ -197,7 +204,7 @@ struct vk_device_struct { vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_COUNT]; vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat[GGML_TYPE_COUNT]; - vk_matmul_pipeline pipeline_matmul_id_f32; + vk_matmul_pipeline pipeline_matmul_id_f32 {}; vk_matmul_pipeline2 pipeline_matmul_id_f16; vk_matmul_pipeline2 pipeline_matmul_id_f16_f32; @@ -215,6 +222,7 @@ struct vk_device_struct { vk_pipeline pipeline_acc_f32; vk_pipeline pipeline_add_f32, pipeline_add_f32_norepeat; vk_pipeline pipeline_add_f16_f32_f16, pipeline_add_f16_f32_f16_norepeat; + vk_pipeline pipeline_sub_f32, pipeline_sub_f32_norepeat; vk_pipeline pipeline_mul_f32, pipeline_mul_f32_norepeat; vk_pipeline pipeline_div_f32, pipeline_div_f32_norepeat; vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32; @@ -225,29 +233,40 @@ struct vk_device_struct { vk_pipeline pipeline_cos_f32; vk_pipeline pipeline_clamp_f32; vk_pipeline pipeline_pad_f32; - vk_pipeline pipeline_repeat_f32; + vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32; vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16; vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16; + vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT]; + vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT]; vk_pipeline pipeline_norm_f32; vk_pipeline pipeline_group_norm_f32; vk_pipeline pipeline_rms_norm_f32; + vk_pipeline pipeline_rms_norm_back_f32; vk_pipeline pipeline_gelu_f32; vk_pipeline pipeline_gelu_quick_f32; vk_pipeline pipeline_silu_f32; + vk_pipeline pipeline_silu_back_f32; vk_pipeline pipeline_relu_f32; vk_pipeline pipeline_leaky_relu_f32; vk_pipeline pipeline_tanh_f32; + vk_pipeline pipeline_sigmoid_f32; vk_pipeline pipeline_diag_mask_inf_f32; vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16; vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512; + vk_pipeline pipeline_soft_max_back_f32; vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16; vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16; + vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16; + vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16; vk_pipeline pipeline_argsort_f32; vk_pipeline pipeline_sum_rows_f32; + vk_pipeline pipeline_argmax_f32; + vk_pipeline pipeline_count_equal_i32; vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16; vk_pipeline pipeline_timestep_embedding_f32; vk_pipeline pipeline_pool2d_f32; vk_pipeline pipeline_rwkv_wkv6_f32; + vk_pipeline pipeline_opt_step_adamw_f32; // [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned} vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2]; @@ -384,10 +403,13 @@ struct vk_flash_attn_push_constants { uint32_t nev3; uint32_t nem1; + uint32_t nb01; uint32_t nb02; uint32_t nb03; + uint32_t nb11; uint32_t nb12; uint32_t nb13; + uint32_t nb21; uint32_t nb22; uint32_t nb23; uint32_t nb31; @@ -482,6 +504,11 @@ struct vk_op_rope_push_constants { float corr_dims[2]; float theta_scale; uint32_t has_ff; + uint32_t ne02; + uint32_t s1; + uint32_t s2; + int32_t sections[4]; + uint32_t is_back; }; struct vk_op_soft_max_push_constants { @@ -764,22 +791,15 @@ static uint32_t compile_count = 0; static std::mutex compile_count_mutex; static std::condition_variable compile_count_cond; -static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, const std::string name, size_t spv_size, const void* spv_data, const std::string entrypoint, - uint32_t parameter_count, uint32_t push_constant_size, std::array wg_denoms, std::vector specialization_constants, - uint32_t align, bool disable_robustness, bool require_full_subgroups, uint32_t required_subgroup_size) { - VK_LOG_DEBUG("ggml_vk_create_pipeline(" << device->name << ", " << name << ", " << entrypoint << ", " << parameter_count << ", " << push_constant_size << - ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << align << - ", " << disable_robustness << ", " << require_full_subgroups << ", " << required_subgroup_size << ")"); +static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, size_t spv_size, const void* spv_data, const std::string entrypoint, + uint32_t parameter_count, std::array wg_denoms, std::vector specialization_constants, + bool disable_robustness, bool require_full_subgroups, uint32_t required_subgroup_size) { + VK_LOG_DEBUG("ggml_vk_create_pipeline(" << device->name << ", " << pipeline->name << ", " << entrypoint << ", " << parameter_count << + ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << + disable_robustness << ", " << require_full_subgroups << ", " << required_subgroup_size << ")"); GGML_ASSERT(parameter_count > 0); GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT - pipeline = std::make_shared(); - pipeline->name = name; - pipeline->parameter_count = parameter_count; - pipeline->push_constant_size = push_constant_size; - pipeline->wg_denoms = wg_denoms; - pipeline->align = align; - vk::ShaderModuleCreateInfo shader_module_create_info({}, spv_size, reinterpret_cast(spv_data)); pipeline->shader_module = device->device.createShaderModule(shader_module_create_info); @@ -861,7 +881,14 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin compute_pipeline_create_info.setPNext(&rci); } - pipeline->pipeline = device->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value; + try { + pipeline->pipeline = device->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value; + } catch (const vk::SystemError& e) { + std::cerr << "ggml_vulkan: Compute pipeline creation failed for " << pipeline->name << std::endl; + std::cerr << "ggml_vulkan: " << e.what() << std::endl; + throw e; + } + pipeline->compiled = true; { std::lock_guard guard(device->mutex); @@ -872,12 +899,6 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin std::lock_guard guard(compile_count_mutex); assert(compile_count > 0); compile_count--; - - // "Progress bar" for shader compiles - static uint32_t total_compile_count = 0; - if ((total_compile_count++ % 10) == 0) { - std::cerr << "."; - } } compile_count_cond.notify_all(); } @@ -903,6 +924,10 @@ static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline) static void ggml_pipeline_request_descriptor_sets(vk_device& device, vk_pipeline& pipeline, uint32_t n) { VK_LOG_DEBUG("ggml_pipeline_request_descriptor_sets(" << pipeline->name << ", " << n << ")"); device->pipeline_descriptor_set_requirements[pipeline->name] += n; + if (!pipeline->compiled) { + pipeline->needed = true; + device->need_compiles = true; + } } static void ggml_pipeline_allocate_descriptor_sets(vk_device& device) { @@ -1285,7 +1310,9 @@ static vk_buffer ggml_vk_create_buffer_check(vk_device& device, size_t size, vk: static vk_buffer ggml_vk_create_buffer_device(vk_device& device, size_t size) { vk_buffer buf; try { - if (device->uma) { + if (device->prefer_host_memory) { + buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent, vk::MemoryPropertyFlagBits::eDeviceLocal); + } else if (device->uma) { // Fall back to host memory type buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent); } else { @@ -1369,7 +1396,37 @@ static std::array fa_rows_cols(uint32_t D, uint32_t clamp, ggml_typ return {64, 64}; }; -static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector& warptile, bool mul_mat_id) { +static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector& warptile, bool mul_mat_id, ggml_type src0_type) { + + uint32_t lut_size = 0; + switch (src0_type) { + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + lut_size = 2*2048; + break; + case GGML_TYPE_IQ2_XXS: + lut_size = 8*256; + break; + case GGML_TYPE_IQ2_XS: + lut_size = 8*512; + break; + case GGML_TYPE_IQ2_S: + lut_size = 8*1024; + break; + case GGML_TYPE_IQ3_XXS: + lut_size = 4*256; + break; + case GGML_TYPE_IQ3_S: + lut_size = 4*512; + break; + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: + lut_size = 4*16; + break; + default: + break; + } + // Needs to be kept up to date on shader changes const uint32_t bank_conflict_offset = device->coopmat_support ? 8 : 1; const uint32_t type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float); @@ -1379,15 +1436,20 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec const uint32_t mmid_row_ids = mul_mat_id ? 3072 * sizeof(uint32_t) : 0; const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0; - return (load_bufs + mmid_row_ids + coopmat_stage) <= device->properties.limits.maxComputeSharedMemorySize; + const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size; + const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; + + VK_LOG_DEBUG("ggml_vk_matmul_shmem_support(warptile=(" << warptile[0] << "," << warptile[1] << "," << warptile[2] << "), " + "mul_mat_id=" << mul_mat_id << ", src0_type=" << ggml_type_name(src0_type) << ", supported=" << supported); + + return supported; } static void ggml_vk_load_shaders(vk_device& device) { VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")"); - std::cerr << "ggml_vulkan: Compiling shaders"; - // some shaders have a minimum subgroup size + const uint32_t subgroup_size_8 = std::max(device->subgroup_size, 8u); const uint32_t subgroup_size_16 = std::max(device->subgroup_size, 16u); const uint32_t subgroup_size_32 = std::max(device->subgroup_size, 32u); @@ -1450,13 +1512,13 @@ static void ggml_vk_load_shaders(vk_device& device) { const uint32_t tk_m = device->coopmat_support ? device->coopmat_k : 1; const uint32_t tk_s = device->coopmat_support ? device->coopmat_k : 1; - l_warptile = { 128, 128, 128, 16, device->subgroup_size * 2, 64, 2, tm_l, tn_l, tk_l, device->subgroup_size }; - m_warptile = { 128, 64, 64, 16, device->subgroup_size, 32, 2, tm_m, tn_m, tk_m, device->subgroup_size }; - s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, device->subgroup_size }; + l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 }; + m_warptile = { 128, 64, 64, 16, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 }; + s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 }; - l_warptile_mmq = { 128, 128, 128, 32, device->subgroup_size * 2, 64, 2, tm_l, tn_l, tk_l, device->subgroup_size }; - m_warptile_mmq = { 128, 64, 64, 32, device->subgroup_size, 32, 2, tm_m, tn_m, tk_m, device->subgroup_size }; - s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, device->subgroup_size }; + l_warptile_mmq = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 }; + m_warptile_mmq = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 }; + s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 }; l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 }; m_mmq_wg_denoms = m_wg_denoms = { 64, 64, 1 }; @@ -1465,74 +1527,62 @@ static void ggml_vk_load_shaders(vk_device& device) { m_align = 64; s_align = 32; - // Fallback to smaller sizes if there's not enough shared memory. Given the current shaders - // and tile sizes, this should handle 16KB, 32KB, and 48KB+. - // This logic doesn't explicitly account for the 12KB row_ids in the mul_mat_mat_id shaders. - // But the numbers happen to work out for 32KB shared memory size that when using the medium - // size there's enough room for everything, and we assert for this. - uint32_t shmem_needed = (l_warptile[1] + l_warptile[2]) * (l_warptile[3] + 1) * sizeof(float); - if (shmem_needed > device->properties.limits.maxComputeSharedMemorySize) { - l_warptile = m_warptile; - l_wg_denoms = m_wg_denoms; - shmem_needed = (l_warptile[1] + l_warptile[2]) * (l_warptile[3] + 1) * sizeof(float); - GGML_ASSERT(shmem_needed <= device->properties.limits.maxComputeSharedMemorySize); - } - if (device->properties.limits.maxComputeSharedMemorySize >= 32768) { - // assert mul_mat_mat_id shaders will fit. - GGML_ASSERT(shmem_needed + 3072*4 <= device->properties.limits.maxComputeSharedMemorySize); - } - - shmem_needed = (l_warptile_mmq[1] + l_warptile_mmq[2]) * (l_warptile_mmq[3] + 1) * sizeof(float); - if (shmem_needed > device->properties.limits.maxComputeSharedMemorySize) { - if (device->properties.limits.maxComputeSharedMemorySize == 32768) { - l_warptile_mmq = m_warptile_mmq; - l_mmq_wg_denoms = m_mmq_wg_denoms; - } else { - l_warptile_mmq = s_warptile_mmq; - l_mmq_wg_denoms = s_mmq_wg_denoms; + for (uint32_t i = 0; i < GGML_TYPE_COUNT; ++i) { + ggml_type t = (ggml_type)i; + // Disable medium and large matrix multiplication if not enough shared memory is available + // Check mmq warptiles as the largest configuration + // Throw an error if not enough for any matrix multiplication is available + if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, false, t)) { + std::cerr << "ggml_vulkan: Error: Shared memory size too small for matrix multiplication." << std::endl; + throw std::runtime_error("Shared memory size too small for matrix multiplication."); + } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, false, t)) { + device->mul_mat_m[i] = false; + device->mul_mat_l[i] = false; + } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, false, t)) { + device->mul_mat_l[i] = false; } - shmem_needed = (l_warptile_mmq[1] + l_warptile_mmq[2]) * (l_warptile_mmq[3] + 1) * sizeof(float); - GGML_ASSERT(shmem_needed <= device->properties.limits.maxComputeSharedMemorySize); - } - if (device->properties.limits.maxComputeSharedMemorySize >= 32768) { - // assert mul_mat_mat_id shaders will fit. - GGML_ASSERT(shmem_needed + 3072*4 <= device->properties.limits.maxComputeSharedMemorySize); - } - // Disable medium and large matrix multiplication if not enough shared memory is available - // Check mmq warptiles as the largest configuration - // Throw an error if not enough for any matrix multiplication is available - if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, false)) { - std::cerr << "ggml_vulkan: Error: Shared memory size too small for matrix multiplication." << std::endl; - throw std::runtime_error("Shared memory size too small for matrix multiplication."); - } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, false)) { - device->mul_mat_m = false; - device->mul_mat_l = false; - } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, false)) { - device->mul_mat_l = false; - } - // Disable mul_mat_id if not enough shared memory is available - if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, true)) { - device->mul_mat_id_s = false; - device->mul_mat_id_m = false; - device->mul_mat_id_l = false; - } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, true)) { - device->mul_mat_id_m = false; - device->mul_mat_id_l = false; - } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, true)) { - device->mul_mat_id_l = false; + // Disable mul_mat_id if not enough shared memory is available + if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, true, t)) { + device->mul_mat_id_s[i] = false; + device->mul_mat_id_m[i] = false; + device->mul_mat_id_l[i] = false; + } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, true, t)) { + device->mul_mat_id_m[i] = false; + device->mul_mat_id_l[i] = false; + } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, true, t)) { + device->mul_mat_id_l[i] = false; + } } } - device->pipeline_matmul_f32 = std::make_shared(); - device->pipeline_matmul_f32_f16 = std::make_shared(); - - device->pipeline_matmul_id_f32 = std::make_shared(); + if (!device->pipeline_matmul_f32) { + device->pipeline_matmul_f32 = std::make_shared(); + } + if (!device->pipeline_matmul_f32_f16) { + device->pipeline_matmul_f32_f16 = std::make_shared(); + } + if (!device->pipeline_matmul_id_f32) { + device->pipeline_matmul_id_f32 = std::make_shared(); + } std::vector> compiles; auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array wg_denoms, const std::vector& specialization_constants, uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) { + + if (!pipeline) { + pipeline = std::make_shared(); + pipeline->name = name; + pipeline->parameter_count = parameter_count; + pipeline->push_constant_size = push_constant_size; + pipeline->wg_denoms = wg_denoms; + pipeline->align = align; + } + + if (!pipeline->needed || pipeline->compiled) { + return; + } { // wait until fewer than N compiles are in progress uint32_t N = std::max(1u, std::thread::hardware_concurrency()); @@ -1542,8 +1592,8 @@ static void ggml_vk_load_shaders(vk_device& device) { } compile_count++; } - compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), name, spv_size, spv_data, entrypoint, - parameter_count, push_constant_size, wg_denoms, specialization_constants, align, disable_robustness, require_full_subgroups, required_subgroup_size)); + compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), spv_size, spv_data, entrypoint, + parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size)); }; #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) @@ -1592,6 +1642,14 @@ static void ggml_vk_load_shaders(vk_device& device) { //CREATE_FA(GGML_TYPE_Q4_K, q4_k) //CREATE_FA(GGML_TYPE_Q5_K, q5_k) //CREATE_FA(GGML_TYPE_Q6_K, q6_k) + //CREATE_FA(GGML_TYPE_IQ1_S, iq1_s) + //CREATE_FA(GGML_TYPE_IQ1_M, iq1_m) + //CREATE_FA(GGML_TYPE_IQ2_XXS, iq2_xxs) + //CREATE_FA(GGML_TYPE_IQ2_XS, iq2_xs) + //CREATE_FA(GGML_TYPE_IQ2_S, iq2_s) + //CREATE_FA(GGML_TYPE_IQ3_XXS, iq3_xxs) + //CREATE_FA(GGML_TYPE_IQ3_S, iq3_s) + //CREATE_FA(GGML_TYPE_IQ4_XS, iq4_xs) CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl) #undef CREATE_FA @@ -1609,11 +1667,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ - CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3) - CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3) - CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3) - CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3) CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) @@ -1624,234 +1678,305 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) - CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ1_S].f16acc, matmul_iq1_s_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ1_M].f16acc, matmul_iq1_m_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) - CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4) CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4) - CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4) - - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f16acc, matmul_id_iq1_s_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f16acc, matmul_id_iq1_m_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) #undef CREATE_MM #undef CREATE_MM2 } else #endif // defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) +#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) if (device->coopmat_support) { // Create 6 variants, {s,m,l}x{unaligned,aligned} -#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ - if (device->mul_mat ## ID ## _l) \ +#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + if (device->mul_mat ## ID ## _l[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \ - if (device->mul_mat ## ID ## _m) \ + if (device->mul_mat ## ID ## _m[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \ - if (device->mul_mat ## ID ## _s) \ + if (device->mul_mat ## ID ## _s[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \ - if (device->mul_mat ## ID ## _l) \ + if (device->mul_mat ## ID ## _l[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \ - if (device->mul_mat ## ID ## _m) \ + if (device->mul_mat ## ID ## _m[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \ - if (device->mul_mat ## ID ## _s) \ + if (device->mul_mat ## ID ## _s[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \ // Create 2 variants, {f16,f32} accumulator -#define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ +#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ if (device->coopmat_acc_f16_support) { \ - CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ } \ if (device->coopmat_acc_f32_support) { \ - CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + CREATE_MM(TYPE, PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ } \ - CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); - CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); - CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); if (device->coopmat_acc_f16_support) { - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f16acc, matmul_iq1_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f16acc, matmul_iq1_m_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); } else { - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f16acc, matmul_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f16acc, matmul_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); } - // If there's not enough shared memory for row_ids and the result tile, don't create these pipelines. - if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) { - CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); - CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); - CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); - if (device->coopmat_acc_f16_support) { - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + if (device->coopmat_acc_f16_support) { + CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - } else { - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f16acc, matmul_id_iq1_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f16acc, matmul_id_iq1_m_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + } else { + CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - } + CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f16acc, matmul_id_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f16acc, matmul_id_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); } #undef CREATE_MM2 #undef CREATE_MM - } else if (device->fp16) { + } else +#endif // defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + if (device->fp16) { // Create 6 variants, {s,m,l}x{unaligned,aligned} -#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ - if (device->mul_mat ## ID ## _l) \ +#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + if (device->mul_mat ## ID ## _l[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ - if (device->mul_mat ## ID ## _m) \ + if (device->mul_mat ## ID ## _m[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ - if (device->mul_mat ## ID ## _s) \ + if (device->mul_mat ## ID ## _s[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ - if (device->mul_mat ## ID ## _l) \ + if (device->mul_mat ## ID ## _l[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \ - if (device->mul_mat ## ID ## _m) \ + if (device->mul_mat ## ID ## _m[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \ - if (device->mul_mat ## ID ## _s) \ + if (device->mul_mat ## ID ## _s[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \ // Create 2 variants, {f16,f32} accumulator -#define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ - CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ - CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ +#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + CREATE_MM(TYPE, PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ - CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); - CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); - CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f16acc, matmul_iq1_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f16acc, matmul_iq1_m_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - // If there's not enough shared memory for row_ids and the result tile, don't create these pipelines. - if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) { - CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); - CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); - CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - } + CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f16acc, matmul_id_iq1_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f16acc, matmul_id_iq1_m_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); #undef CREATE_MM2 #undef CREATE_MM } else { // Create 6 variants, {s,m,l}x{unaligned,aligned} -#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ - if (device->mul_mat ## ID ## _l) \ +#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + if (device->mul_mat ## ID ## _l[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ - if (device->mul_mat ## ID ## _m) \ + if (device->mul_mat ## ID ## _m[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ - if (device->mul_mat ## ID ## _s) \ + if (device->mul_mat ## ID ## _s[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ - if (device->mul_mat ## ID ## _l) \ + if (device->mul_mat ## ID ## _l[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \ - if (device->mul_mat ## ID ## _m) \ + if (device->mul_mat ## ID ## _m[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \ - if (device->mul_mat ## ID ## _s) \ + if (device->mul_mat ## ID ## _s[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \ - CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f32acc, matmul_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f32acc, matmul_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f32acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f32acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f32acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - // If there's not enough shared memory for row_ids and the result tile, don't create these pipelines. - if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) { - CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); - CREATE_MM(pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); - CREATE_MM(pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - } + CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc, matmul_id_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc, matmul_id_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); #undef CREATE_MM } @@ -1881,7 +2006,15 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq1_s_f32_f32_len, mul_mat_vec_iq1_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq1_m_f32_f32_len, mul_mat_vec_iq1_m_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xxs_f32_f32_len, mul_mat_vec_iq2_xxs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xs_f32_f32_len, mul_mat_vec_iq2_xs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_s_f32_f32_len, mul_mat_vec_iq2_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_xxs_f32_f32_len, mul_mat_vec_iq3_xxs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f32_f32_len, mul_mat_vec_iq3_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f32_f32_len, mul_mat_vec_iq4_xs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32_"+std::to_string(i+1), mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32_"+std::to_string(i+1), mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); @@ -1895,7 +2028,15 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq1_s_f16_f32_len, mul_mat_vec_iq1_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq1_m_f16_f32_len, mul_mat_vec_iq1_m_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xxs_f16_f32_len, mul_mat_vec_iq2_xxs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xs_f16_f32_len, mul_mat_vec_iq2_xs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_s_f16_f32_len, mul_mat_vec_iq2_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_xxs_f16_f32_len, mul_mat_vec_iq3_xxs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f16_f32_len, mul_mat_vec_iq3_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f16_f32_len, mul_mat_vec_iq4_xs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq, i+1}, 1, true); } ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); @@ -1910,7 +2051,15 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_S], "mul_mat_vec_id_iq1_s_f32", mul_mat_vec_id_iq1_s_f32_len, mul_mat_vec_id_iq1_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_M], "mul_mat_vec_id_iq1_m_f32", mul_mat_vec_id_iq1_m_f32_len, mul_mat_vec_id_iq1_m_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XXS], "mul_mat_vec_id_iq2_xxs_f32", mul_mat_vec_id_iq2_xxs_f32_len, mul_mat_vec_id_iq2_xxs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XS], "mul_mat_vec_id_iq2_xs_f32", mul_mat_vec_id_iq2_xs_f32_len, mul_mat_vec_id_iq2_xs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_S], "mul_mat_vec_id_iq2_s_f32", mul_mat_vec_id_iq2_s_f32_len, mul_mat_vec_id_iq2_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_XXS], "mul_mat_vec_id_iq3_xxs_f32", mul_mat_vec_id_iq3_xxs_f32_len, mul_mat_vec_id_iq3_xxs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_S], "mul_mat_vec_id_iq3_s_f32", mul_mat_vec_id_iq3_s_f32_len, mul_mat_vec_id_iq3_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", mul_mat_vec_id_iq4_xs_f32_len, mul_mat_vec_id_iq4_xs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq}, 1, true); // dequant shaders ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); @@ -1924,7 +2073,15 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_K], "dequant_q4_k", dequant_q4_k_len, dequant_q4_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_K], "dequant_q5_k", dequant_q5_k_len, dequant_q5_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q6_K], "dequant_q6_k", dequant_q6_k_len, dequant_q6_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_NL], "dequant_iq4_nl", dequant_iq4_nl_len, dequant_iq4_nl_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ1_S], "dequant_iq1_s", dequant_iq1_s_len, dequant_iq1_s_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ1_M], "dequant_iq1_m", dequant_iq1_m_len, dequant_iq1_m_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ2_XXS], "dequant_iq2_xxs", dequant_iq2_xxs_len, dequant_iq2_xxs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ2_XS], "dequant_iq2_xs", dequant_iq2_xs_len, dequant_iq2_xs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ2_S], "dequant_iq2_s", dequant_iq2_s_len, dequant_iq2_s_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ3_XXS], "dequant_iq3_xxs", dequant_iq3_xxs_len, dequant_iq3_xxs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ3_S], "dequant_iq3_s", dequant_iq3_s_len, dequant_iq3_s_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_XS], "dequant_iq4_xs", dequant_iq4_xs_len, dequant_iq4_xs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_NL], "dequant_iq4_nl", dequant_iq4_nl_len, dequant_iq4_nl_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); // get_rows ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); @@ -1934,7 +2091,15 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_0], "get_rows_q5_0", get_rows_q5_0_len, get_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_1], "get_rows_q5_1", get_rows_q5_1_len, get_rows_q5_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q8_0], "get_rows_q8_0", get_rows_q8_0_len, get_rows_q8_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl", get_rows_iq4_nl_len, get_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ1_S], "get_rows_iq1_s", get_rows_iq1_s_len, get_rows_iq1_s_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ1_M], "get_rows_iq1_m", get_rows_iq1_m_len, get_rows_iq1_m_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_XXS], "get_rows_iq2_xxs", get_rows_iq2_xxs_len, get_rows_iq2_xxs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_XS], "get_rows_iq2_xs", get_rows_iq2_xs_len, get_rows_iq2_xs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_S], "get_rows_iq2_s", get_rows_iq2_s_len, get_rows_iq2_s_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ3_XXS], "get_rows_iq3_xxs", get_rows_iq3_xxs_len, get_rows_iq3_xxs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ3_S], "get_rows_iq3_s", get_rows_iq3_s_len, get_rows_iq3_s_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs", get_rows_iq4_xs_len, get_rows_iq4_xs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl", get_rows_iq4_nl_len, get_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); @@ -1943,7 +2108,15 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_0], "get_rows_q5_0_f32", get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_1], "get_rows_q5_1_f32", get_rows_q5_1_f32_len, get_rows_q5_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q8_0], "get_rows_q8_0_f32", get_rows_q8_0_f32_len, get_rows_q8_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ1_S], "get_rows_iq1_s_f32", get_rows_iq1_s_f32_len, get_rows_iq1_s_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ1_M], "get_rows_iq1_m_f32", get_rows_iq1_m_f32_len, get_rows_iq1_m_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_XXS], "get_rows_iq2_xxs_f32", get_rows_iq2_xxs_f32_len, get_rows_iq2_xxs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_XS], "get_rows_iq2_xs_f32", get_rows_iq2_xs_f32_len, get_rows_iq2_xs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_S], "get_rows_iq2_s_f32", get_rows_iq2_s_f32_len, get_rows_iq2_s_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ3_XXS], "get_rows_iq3_xxs_f32", get_rows_iq3_xxs_f32_len, get_rows_iq3_xxs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ3_S], "get_rows_iq3_s_f32", get_rows_iq3_s_f32_len, get_rows_iq3_s_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs_f32", get_rows_iq4_xs_f32_len, get_rows_iq4_xs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1); @@ -1953,6 +2126,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); @@ -1962,6 +2136,20 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_0], "cpy_q4_0_f32", cpy_q4_0_f32_len, cpy_q4_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_1], "cpy_q4_1_f32", cpy_q4_1_f32_len, cpy_q4_1_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q5_0], "cpy_q5_0_f32", cpy_q5_0_f32_len, cpy_q5_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q5_1], "cpy_q5_1_f32", cpy_q5_1_f32_len, cpy_q5_1_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q8_0], "cpy_q8_0_f32", cpy_q8_0_f32_len, cpy_q8_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_IQ4_NL], "cpy_iq4_nl_f32", cpy_iq4_nl_f32_len, cpy_iq4_nl_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_add_f32, "add_f32", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1); ggml_vk_create_pipeline(device, device->pipeline_add_f32_norepeat, "add_f32_norepeat", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1); ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16, "add_f16_f32_f16", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1); @@ -1969,6 +2157,8 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_sub_f32, "sub_f32", sub_f32_len, sub_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1); + ggml_vk_create_pipeline(device, device->pipeline_sub_f32_norepeat, "sub_f32_norepeat", sub_f32_len, sub_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1); ggml_vk_create_pipeline(device, device->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1); ggml_vk_create_pipeline(device, device->pipeline_mul_f32_norepeat, "mul_f32_norepeat", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1); ggml_vk_create_pipeline(device, device->pipeline_div_f32, "div_f32", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1); @@ -1991,36 +2181,50 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32", repeat_f32_len, repeat_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_repeat_back_f32, "repeat_back_f32", repeat_back_f32_len, repeat_back_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_gelu_f32, "gelu_f32", gelu_f32_len, gelu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_gelu_quick_f32, "gelu_quick_f32", gelu_quick_f32_len, gelu_quick_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_silu_f32, "silu_f32", silu_f32_len, silu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_relu_f32, "relu_f32", relu_f32_len, relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_tanh_f32, "tanh_f32", tanh_f32_len, tanh_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_sigmoid_f32, "sigmoid_f32", sigmoid_f32_len, sigmoid_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1); ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32, "rope_multi_f32", rope_multi_f32_len, rope_multi_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f32, "rope_vision_f32", rope_vision_f32_len, rope_vision_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); if (device->float_controls_rte_fp16) { ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_rte_len, rope_norm_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_rte_len, rope_neox_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_rte_len, rope_multi_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_rte_len, rope_vision_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); } else { ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_len, rope_multi_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_len, rope_vision_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); } ggml_vk_create_pipeline(device, device->pipeline_argsort_f32, "argsort_f32", argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); if (device->float_controls_rte_fp16) { ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); @@ -2034,10 +2238,12 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); + ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + for (auto &c : compiles) { c.wait(); } - std::cerr << "Done!" << std::endl; + device->need_compiles = false; } static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props); @@ -2069,6 +2275,9 @@ static vk_device ggml_vk_get_device(size_t idx) { device->physical_device = physical_devices[dev_num]; const std::vector ext_props = device->physical_device.enumerateDeviceExtensionProperties(); + const char* GGML_VK_PREFER_HOST_MEMORY = getenv("GGML_VK_PREFER_HOST_MEMORY"); + device->prefer_host_memory = GGML_VK_PREFER_HOST_MEMORY != nullptr; + bool fp16_storage = false; bool fp16_compute = false; bool maintenance4_support = false; @@ -2150,6 +2359,7 @@ static vk_device ggml_vk_get_device(size_t idx) { device->physical_device.getProperties2(&props2); device->properties = props2.properties; + device->vendor_id = device->properties.vendorID; const char* GGML_VK_FORCE_MAX_ALLOCATION_SIZE = getenv("GGML_VK_FORCE_MAX_ALLOCATION_SIZE"); @@ -2161,7 +2371,20 @@ static vk_device ggml_vk_get_device(size_t idx) { device->max_memory_allocation_size = props3.maxMemoryAllocationSize; } - device->vendor_id = device->properties.vendorID; + const char* GGML_VK_SUBALLOCATION_BLOCK_SIZE = getenv("GGML_VK_SUBALLOCATION_BLOCK_SIZE"); + + if (GGML_VK_SUBALLOCATION_BLOCK_SIZE != nullptr) { + device->suballocation_block_size = std::stoul(GGML_VK_SUBALLOCATION_BLOCK_SIZE); +#if defined(_WIN32) + } else if (device->vendor_id == VK_VENDOR_ID_NVIDIA) { + // Limit batching of allocations to 1GB by default to avoid fragmentation issues + device->suballocation_block_size = 1024*1024*1024; +#endif + } else { + device->suballocation_block_size = device->max_memory_allocation_size; + } + device->suballocation_block_size = std::min(device->suballocation_block_size, device->max_memory_allocation_size); + device->subgroup_size = subgroup_props.subgroupSize; device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu; if (sm_builtins) { @@ -2242,6 +2465,7 @@ static vk_device ggml_vk_get_device(size_t idx) { last_struct = (VkBaseOutStructure *)&subgroup_size_control_features; } +#if defined(VK_KHR_cooperative_matrix) VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features; coopmat_features.pNext = nullptr; coopmat_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR; @@ -2251,6 +2475,7 @@ static vk_device ggml_vk_get_device(size_t idx) { last_struct->pNext = (VkBaseOutStructure *)&coopmat_features; last_struct = (VkBaseOutStructure *)&coopmat_features; } +#endif #if defined(VK_NV_cooperative_matrix2) VkPhysicalDeviceCooperativeMatrix2FeaturesNV coopmat2_features {}; @@ -2263,6 +2488,14 @@ static vk_device ggml_vk_get_device(size_t idx) { } #endif + VkPhysicalDeviceMaintenance4Features maint4_features {}; + maint4_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MAINTENANCE_4_FEATURES; + if (maintenance4_support) { + last_struct->pNext = (VkBaseOutStructure *)&maint4_features; + last_struct = (VkBaseOutStructure *)&maint4_features; + device_extensions.push_back("VK_KHR_maintenance4"); + } + vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2); device->fp16 = device->fp16 && vk12_features.shaderFloat16; @@ -2272,6 +2505,7 @@ static vk_device ggml_vk_get_device(size_t idx) { if (device->subgroup_size_control) { device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize; device->subgroup_max_size = subgroup_size_control_props.maxSubgroupSize; + device_extensions.push_back("VK_EXT_subgroup_size_control"); } device->subgroup_size_control = device->subgroup_size_control && @@ -2280,10 +2514,11 @@ static vk_device ggml_vk_get_device(size_t idx) { if (device->subgroup_size_control) { device->subgroup_require_full_support = subgroup_size_control_features.computeFullSubgroups; - device_extensions.push_back("VK_EXT_subgroup_size_control"); } +#if defined(VK_KHR_cooperative_matrix) device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix; +#endif if (coopmat2_support) { #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) @@ -2376,6 +2611,7 @@ static vk_device ggml_vk_get_device(size_t idx) { device_extensions.push_back("VK_KHR_shader_float16_int8"); } +#if defined(VK_KHR_cooperative_matrix) if (device->coopmat_support) { // Query supported shapes std::vector cm_props; @@ -2442,7 +2678,7 @@ static vk_device ggml_vk_get_device(size_t idx) { if (device->coopmat_support) { device_extensions.push_back("VK_KHR_cooperative_matrix"); } - +#endif device->name = GGML_VK_NAME + std::to_string(idx); device_create_info = { @@ -2459,34 +2695,36 @@ static vk_device ggml_vk_get_device(size_t idx) { // Shaders // Disable matmul tile sizes early if performance low or not supported - switch (device->vendor_id) { + for (uint32_t i = 0; i < GGML_TYPE_COUNT; ++i) { + switch (device->vendor_id) { #ifndef GGML_VULKAN_RUN_TESTS - case VK_VENDOR_ID_AMD: - case VK_VENDOR_ID_INTEL: - device->mul_mat_l = false; - device->mul_mat_m = true; - device->mul_mat_s = true; - device->mul_mat_id_l = false; - device->mul_mat_id_m = true; - device->mul_mat_id_s = true; - break; - case VK_VENDOR_ID_APPLE: - device->mul_mat_l = false; - device->mul_mat_m = true; - device->mul_mat_s = false; - device->mul_mat_id_l = false; - device->mul_mat_id_m = true; - device->mul_mat_id_s = false; - break; + case VK_VENDOR_ID_AMD: + case VK_VENDOR_ID_INTEL: + device->mul_mat_l[i] = false; + device->mul_mat_m[i] = true; + device->mul_mat_s[i] = true; + device->mul_mat_id_l[i] = false; + device->mul_mat_id_m[i] = true; + device->mul_mat_id_s[i] = true; + break; + case VK_VENDOR_ID_APPLE: + device->mul_mat_l[i] = false; + device->mul_mat_m[i] = true; + device->mul_mat_s[i] = false; + device->mul_mat_id_l[i] = false; + device->mul_mat_id_m[i] = true; + device->mul_mat_id_s[i] = false; + break; #endif - default: - device->mul_mat_l = true; - device->mul_mat_m = true; - device->mul_mat_s = true; - device->mul_mat_id_l = true; - device->mul_mat_id_m = true; - device->mul_mat_id_s = true; - break; + default: + device->mul_mat_l[i] = true; + device->mul_mat_m[i] = true; + device->mul_mat_s[i] = true; + device->mul_mat_id_l[i] = true; + device->mul_mat_id_m[i] = true; + device->mul_mat_id_s[i] = true; + break; + } } ggml_vk_load_shaders(device); @@ -2553,9 +2791,11 @@ static void ggml_vk_print_gpu_info(size_t idx) { fp16_storage = true; } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) { fp16_compute = true; - } else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0 && +#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + } else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0 && !getenv("GGML_VK_DISABLE_COOPMAT")) { coopmat_support = true; +#endif #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) } else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 && !getenv("GGML_VK_DISABLE_COOPMAT2")) { @@ -2593,6 +2833,7 @@ static void ggml_vk_print_gpu_info(size_t idx) { // Pointer to the last chain element VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_features; +#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features; coopmat_features.pNext = nullptr; coopmat_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR; @@ -2608,12 +2849,14 @@ static void ggml_vk_print_gpu_info(size_t idx) { fp16 = fp16 && vk12_features.shaderFloat16; coopmat_support = coopmat_support && coopmat_features.cooperativeMatrix; +#endif std::string matrix_cores = coopmat2_support ? "NV_coopmat2" : coopmat_support ? "KHR_coopmat" : "none"; std::string device_name = props2.properties.deviceName.data(); - GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | matrix cores: %s\n", - idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, subgroup_size, matrix_cores.c_str()); + GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | shared memory: %d | matrix cores: %s\n", + idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, subgroup_size, + props2.properties.limits.maxComputeSharedMemorySize, matrix_cores.c_str()); if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) { GGML_LOG_DEBUG("ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want.\n"); @@ -2623,15 +2866,20 @@ static void ggml_vk_print_gpu_info(size_t idx) { static bool ggml_vk_instance_validation_ext_available(const std::vector& instance_extensions); static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector& instance_extensions); -void ggml_vk_instance_init() { +static void ggml_vk_instance_init() { if (vk_instance_initialized) { return; } VK_LOG_DEBUG("ggml_vk_instance_init()"); - vk_instance_initialized = true; + uint32_t api_version = vk::enumerateInstanceVersion(); - vk::ApplicationInfo app_info{ "ggml-vulkan", 1, nullptr, 0, VK_API_VERSION }; + if (api_version < VK_API_VERSION_1_2) { + std::cerr << "ggml_vulkan: Error: Vulkan 1.2 required." << std::endl; + GGML_ABORT("fatal error"); + } + + vk::ApplicationInfo app_info{ "ggml-vulkan", 1, nullptr, 0, api_version }; const std::vector instance_extensions = vk::enumerateInstanceExtensionProperties(); const bool validation_ext = ggml_vk_instance_validation_ext_available(instance_extensions); @@ -2674,6 +2922,7 @@ void ggml_vk_instance_init() { GGML_LOG_DEBUG("ggml_vulkan: Validation layers enabled\n"); } vk_instance.instance = vk::createInstance(instance_create_info); + vk_instance_initialized = true; size_t num_available_devices = vk_instance.instance.enumeratePhysicalDevices().size(); @@ -2698,7 +2947,7 @@ void ggml_vk_instance_init() { // Make sure at least one device exists if (devices.empty()) { std::cerr << "ggml_vulkan: Error: No devices found." << std::endl; - GGML_ABORT("fatal error"); + return; } // Default to using all dedicated GPUs @@ -2832,6 +3081,14 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: break; default: @@ -2880,6 +3137,14 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: break; default: @@ -2911,6 +3176,14 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: break; default: @@ -2941,7 +3214,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co } } - GGML_ASSERT(src1_type == GGML_TYPE_F32); + GGML_ASSERT(src1_type == GGML_TYPE_F32 || (ctx->device->coopmat2 && src1_type == GGML_TYPE_F16)); switch (src0_type) { case GGML_TYPE_Q4_0: @@ -2954,6 +3227,14 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: break; default: @@ -2980,6 +3261,14 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: break; default: @@ -3516,6 +3805,12 @@ static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& sr } } +static void ggml_vk_buffer_memset_async(vk_context& ctx, vk_buffer& dst, size_t offset, uint32_t c, size_t size) { + VK_LOG_DEBUG("ggml_vk_buffer_memset_async(" << offset << ", " << c << ", " << size << ")"); + + ctx->s->buffer.fillBuffer(dst->buffer, offset, size, c); +} + static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, size_t size) { VK_LOG_DEBUG("ggml_vk_buffer_memset(" << offset << ", " << c << ", " << size << ")"); @@ -3550,31 +3845,31 @@ static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, int m, int return split_k; } -static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) { - VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ")"); +static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned, ggml_type src0_type) { + VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")"); if (ctx->device->coopmat2) { - if ((ctx->device->mul_mat_l && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_m && !ctx->device->mul_mat_s)) { + if ((ctx->device->mul_mat_l[src0_type] && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) { return aligned ? mmp->a_l : mmp->l; } - if ((ctx->device->mul_mat_m && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_s) { + if ((ctx->device->mul_mat_m[src0_type] && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_s[src0_type]) { return aligned ? mmp->a_m : mmp->m; } return aligned ? mmp->a_s : mmp->s; } - if ((ctx->device->mul_mat_s && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_m && !ctx->device->mul_mat_l)) { + if ((ctx->device->mul_mat_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_l[src0_type])) { return aligned ? mmp->a_s : mmp->s; } - if ((ctx->device->mul_mat_m && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_l) { + if ((ctx->device->mul_mat_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_l[src0_type]) { return aligned ? mmp->a_m : mmp->m; } return aligned ? mmp->a_l : mmp->l; } -static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n) { - VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ")"); - return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true)->align; +static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type) { + VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ")"); + return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true, src0_type)->align; } static void ggml_vk_matmul( @@ -3601,31 +3896,31 @@ static void ggml_vk_matmul( ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 }); } -static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) { - VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ")"); +static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned, ggml_type src0_type) { + VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")"); if (ctx->device->coopmat2) { - if ((ctx->device->mul_mat_id_l && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_id_m && !ctx->device->mul_mat_id_s)) { + if ((ctx->device->mul_mat_id_l[src0_type] && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_s[src0_type])) { return aligned ? mmp->a_l : mmp->l; } - if ((ctx->device->mul_mat_id_m && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_id_s) { + if ((ctx->device->mul_mat_id_m[src0_type] && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_id_s[src0_type]) { return aligned ? mmp->a_m : mmp->m; } return aligned ? mmp->a_s : mmp->s; } - if ((ctx->device->mul_mat_id_s && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_id_m && !ctx->device->mul_mat_id_l)) { + if ((ctx->device->mul_mat_id_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_l[src0_type])) { return aligned ? mmp->a_s : mmp->s; } - if ((ctx->device->mul_mat_id_m && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_id_l) { + if ((ctx->device->mul_mat_id_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_id_l[src0_type]) { return aligned ? mmp->a_m : mmp->m; } return aligned ? mmp->a_l : mmp->l; } -static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n) { - VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ")"); - return ggml_vk_guess_matmul_id_pipeline(ctx, mmp, m, n, true)->align; +static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type) { + VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ")"); + return ggml_vk_guess_matmul_id_pipeline(ctx, mmp, m, n, true, src0_type)->align; } static void ggml_vk_matmul_id( @@ -3677,6 +3972,33 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_cpy_f16_f16; } } + if (src->type == GGML_TYPE_F32) { + switch (to) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_IQ4_NL: + return ctx->device->pipeline_cpy_f32_quant[to]; + default: + break; + } + } + + if (to == GGML_TYPE_F32) { + switch (src->type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_IQ4_NL: + return ctx->device->pipeline_cpy_quant_f32[src->type]; + default: + break; + } + } std::cerr << "Missing CPY op for types: " << ggml_type_name(src->type) << " " << ggml_type_name(to) << std::endl; GGML_ABORT("fatal error"); @@ -3754,8 +4076,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub src1_uma = d_Qy != nullptr; } - const bool x_non_contig = !ggml_vk_dim01_contiguous(src0); - // Reformat and convert to fp16 if src1 is non-contiguous, or for coopmat2 for better perf + // Reformat and convert to fp16 if non-contiguous, or for coopmat2 for better perf + const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) || + !ggml_vk_dim01_contiguous(src0); const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) || !ggml_vk_dim01_contiguous(src1); @@ -3778,10 +4101,10 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub const int y_ne = ne11 * ne10; const int d_ne = ne11 * ne01; - const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11)); + const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? GGML_TYPE_F16 : src0->type)); const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8; - vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned); + vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type); const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, pipeline); @@ -3869,7 +4192,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub } if (qy_needs_dequant) { d_Y = ctx->prealloc_y; - GGML_ASSERT(d_Y->size >= y_sz * ne02 * ne03); + GGML_ASSERT(d_Y->size >= y_sz * ne12 * ne13); } else { d_Y = d_Qy; y_buf_offset = qy_buf_offset; @@ -4335,8 +4658,11 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& ids_uma = d_ids != nullptr; } - const bool x_non_contig = !ggml_vk_dim01_contiguous(src0); - const bool y_non_contig = !ggml_vk_dim01_contiguous(src1); + // Reformat and convert to fp16 if non-contiguous, or for coopmat2 for better perf + const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) || + !ggml_vk_dim01_contiguous(src0); + const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) || + !ggml_vk_dim01_contiguous(src1); const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig; @@ -4346,7 +4672,8 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig; if (qx_needs_dequant) { - GGML_ABORT("fatal error"); + // Fall back to dequant + f16 mulmat + mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16, (ggml_prec)dst->op_params[0]); } // Not implemented @@ -4356,10 +4683,10 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& const uint64_t y_ne = ne11 * ne10; const uint64_t d_ne = ne21 * ne20; - const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1)); + const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? GGML_TYPE_F16 : src0->type)); const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8; - vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned); + vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type); const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type); const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); @@ -4442,7 +4769,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& } if (qy_needs_dequant) { d_Y = ctx->prealloc_y; - GGML_ASSERT(d_Y->size >= y_sz * ne02 * ne03); + GGML_ASSERT(d_Y->size >= y_sz * ne12 * ne13); } else { d_Y = d_Qy; y_buf_offset = qy_buf_offset; @@ -4754,7 +5081,14 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx } assert(pipelines); - bool aligned = (KV % pipelines[1]->align) == 0; + const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type)); + const uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type)); + const uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type)); + + bool aligned = (KV % pipelines[1]->align) == 0 && + // the "aligned" shader variant will forcibly align strides, for performance + (q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0; + vk_pipeline pipeline = pipelines[aligned]; assert(pipeline); @@ -4790,15 +5124,15 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx if (ctx->device->uma) { ggml_vk_host_get(ctx->device, q->data, d_Q, q_buf_offset); - ggml_vk_host_get(ctx->device, k->data, d_K, q_buf_offset); - ggml_vk_host_get(ctx->device, v->data, d_V, q_buf_offset); - ggml_vk_host_get(ctx->device, dst->data, d_D, q_buf_offset); + ggml_vk_host_get(ctx->device, k->data, d_K, k_buf_offset); + ggml_vk_host_get(ctx->device, v->data, d_V, v_buf_offset); + ggml_vk_host_get(ctx->device, dst->data, d_D, d_buf_offset); Q_uma = d_Q != nullptr; K_uma = d_K != nullptr; V_uma = d_V != nullptr; D_uma = d_D != nullptr; if (mask) { - ggml_vk_host_get(ctx->device, mask->data, d_M, q_buf_offset); + ggml_vk_host_get(ctx->device, mask->data, d_M, m_buf_offset); M_uma = d_M != nullptr; } } @@ -4836,7 +5170,18 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx } } - const vk_flash_attn_push_constants pc = { N, KV, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, (uint32_t)neq2, (uint32_t)neq3, (uint32_t)nek2, (uint32_t)nek3, (uint32_t)nev2, (uint32_t)nev3, nem1, (uint32_t)nbq2, (uint32_t)nbq3, (uint32_t)nbk2, (uint32_t)nbk3, (uint32_t)nbv2, (uint32_t)nbv3, nbm1, scale, max_bias, logit_softcap, mask != nullptr, n_head_log2, m0, m1 }; + const vk_flash_attn_push_constants pc = { N, KV, + (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, + (uint32_t)neq2, (uint32_t)neq3, + (uint32_t)nek2, (uint32_t)nek3, + (uint32_t)nev2, (uint32_t)nev3, + nem1, + q_stride, (uint32_t)nbq2, (uint32_t)nbq3, + k_stride, (uint32_t)nbk2, (uint32_t)nbk3, + v_stride, (uint32_t)nbv2, (uint32_t)nbv3, + nbm1, + scale, max_bias, logit_softcap, + mask != nullptr, n_head_log2, m0, m1 }; ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE}, @@ -4872,6 +5217,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_f16_f32_f16_norepeat : ctx->device->pipeline_add_f16_f32_f16; } return nullptr; + case GGML_OP_SUB: + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_sub_f32_norepeat : ctx->device->pipeline_sub_f32; + } + return nullptr; case GGML_OP_MUL: if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_mul_f32_norepeat : ctx->device->pipeline_mul_f32; @@ -4933,10 +5283,20 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_repeat_f32; } return nullptr; + case GGML_OP_REPEAT_BACK: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_repeat_back_f32; + } + return nullptr; case GGML_OP_CPY: case GGML_OP_CONT: case GGML_OP_DUP: return ggml_vk_get_cpy_pipeline(ctx, src0, dst, dst->type); + case GGML_OP_SILU_BACK: + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_silu_back_f32; + } + return nullptr; case GGML_OP_NORM: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_norm_f32; @@ -4952,6 +5312,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_rms_norm_f32; } return nullptr; + case GGML_OP_RMS_NORM_BACK: + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_rms_norm_back_f32; + } + return nullptr; case GGML_OP_UNARY: switch (ggml_get_unary_op(dst)) { case GGML_UNARY_OP_SILU: @@ -4979,6 +5344,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_tanh_f32; } break; + case GGML_UNARY_OP_SIGMOID: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_sigmoid_f32; + } + break; default: break; } @@ -4998,10 +5368,18 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_f16_wg512 : ctx->device->pipeline_soft_max_f32_f16; } return nullptr; + case GGML_OP_SOFT_MAX_BACK: + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_soft_max_back_f32; + } + return nullptr; case GGML_OP_ROPE: + case GGML_OP_ROPE_BACK: { const int mode = ((const int32_t *) dst->op_params)[2]; const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; + const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; + const bool is_vision = mode == GGML_ROPE_TYPE_VISION; if (is_neox) { if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { @@ -5010,6 +5388,20 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { return ctx->device->pipeline_rope_neox_f16; } + } else if (is_mrope && !is_vision) { + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_rope_multi_f32; + } + if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_rope_multi_f16; + } + } else if (is_vision) { + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_rope_vision_f32; + } + if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_rope_vision_f16; + } } else { if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_rope_norm_f32; @@ -5025,11 +5417,22 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_argsort_f32; } return nullptr; + case GGML_OP_SUM: case GGML_OP_SUM_ROWS: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_sum_rows_f32; } return nullptr; + case GGML_OP_ARGMAX: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) { + return ctx->device->pipeline_argmax_f32; + } + return nullptr; + case GGML_OP_COUNT_EQUAL: + if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I64) { + return ctx->device->pipeline_count_equal_i32; + } + return nullptr; case GGML_OP_IM2COL: if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_im2col_f32; @@ -5053,6 +5456,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_rwkv_wkv6_f32; } return nullptr; + case GGML_OP_OPT_STEP_ADAMW: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_opt_step_adamw_f32; + } + return nullptr; case GGML_OP_LEAKY_RELU: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_leaky_relu_f32; @@ -5070,6 +5478,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) { case GGML_OP_CPY: case GGML_OP_GET_ROWS: case GGML_OP_ADD: + case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: case GGML_OP_CONCAT: @@ -5080,6 +5489,8 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) { case GGML_OP_CLAMP: case GGML_OP_PAD: case GGML_OP_REPEAT: + case GGML_OP_REPEAT_BACK: + case GGML_OP_ROPE: return true; default: return false; @@ -5148,7 +5559,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co } std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; std::cerr << "), " << ggml_op_name(op) << ", " << (dryrun ? "dryrun" : "") << ")"); - GGML_ASSERT(op == GGML_OP_GET_ROWS || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type)))); // NOLINT + GGML_ASSERT(op == GGML_OP_GET_ROWS || op == GGML_OP_CPY || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type)))); // NOLINT GGML_ASSERT(ggml_vk_op_supports_incontiguous(op) || ggml_vk_dim01_contiguous(src0)); // NOLINT GGML_ASSERT(dst->buffer != nullptr); const uint64_t ne00 = src0->ne[0]; @@ -5291,8 +5702,11 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co switch (op) { case GGML_OP_NORM: case GGML_OP_RMS_NORM: + case GGML_OP_RMS_NORM_BACK: case GGML_OP_SOFT_MAX: + case GGML_OP_SOFT_MAX_BACK: case GGML_OP_SUM_ROWS: + case GGML_OP_ARGMAX: { const uint32_t nr = ggml_nrows(src0); if (nr > 262144) { @@ -5303,6 +5717,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co elements = { nr, 1, 1 }; } } break; + case GGML_OP_SUM: + // We use GGML_OP_SUM_ROWS with 1 row. + elements = { 1, 1, 1 }; + break; case GGML_OP_GROUP_NORM: { const uint32_t num_groups = dst->op_params[0]; @@ -5310,6 +5728,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co } break; case GGML_OP_DIAG_MASK_INF: case GGML_OP_ROPE: + case GGML_OP_ROPE_BACK: elements = { (uint32_t)ggml_nrows(src0), (uint32_t)ne00, 1 }; break; case GGML_OP_GET_ROWS: @@ -5349,6 +5768,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co elements = { N * OC * OH * OW, 1, 1}; } break; case GGML_OP_ADD: + case GGML_OP_SUB: case GGML_OP_DIV: case GGML_OP_MUL: case GGML_OP_SCALE: @@ -5358,6 +5778,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co case GGML_OP_CLAMP: case GGML_OP_PAD: case GGML_OP_REPEAT: + case GGML_OP_REPEAT_BACK: case GGML_OP_CPY: case GGML_OP_CONCAT: case GGML_OP_UPSCALE: @@ -5403,7 +5824,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); - } else if (op == GGML_OP_ROPE) { + } else if (op == GGML_OP_ROPE || op == GGML_OP_ROPE_BACK) { // Empty src2 is possible in rope, but the shader needs a buffer vk_subbuffer subbuf_z; if (use_src2) { @@ -5418,6 +5839,12 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co // im2col uses only src1 and dst buffers ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); + } else if (op == GGML_OP_COUNT_EQUAL) { + ggml_vk_sync_buffers(subctx); + // count_equal assumes that destination buffer is initialized with zeroes + ggml_vk_buffer_memset_async(subctx, d_D, d_buf_offset, 0, d_sz); + ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); } else if (use_src2) { ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); @@ -5480,6 +5907,21 @@ static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const }, dryrun); } +static void ggml_vk_sub(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SUB, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, 0, + }, dryrun); +} + static void ggml_vk_mul(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { const uint32_t src0_type_size = ggml_type_size(src0->type); const uint32_t src1_type_size = ggml_type_size(src1->type); @@ -5621,9 +6063,9 @@ static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subc } static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) { - const size_t seq_length = dst->src[0]->ne[3]; + const size_t seq_length = dst->src[0]->ne[2]; const size_t n_embed = dst->ne[0]; - const size_t n_heads = dst->src[0]->ne[2]; + const size_t n_heads = dst->src[0]->ne[1]; const size_t n_seqs = dst->src[5]->ne[1]; ggml_vk_op_f32_rwkv6( @@ -5638,6 +6080,111 @@ static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ); } +static void ggml_vk_op_f32_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_push_constants&& pc, bool dryrun = false) { + const ggml_tensor * x = dst->src[0]; + const ggml_tensor * g = dst->src[1]; + const ggml_tensor * gm = dst->src[2]; + const ggml_tensor * gv = dst->src[3]; + const ggml_tensor * p = dst->src[4]; + + GGML_ASSERT(x->type == GGML_TYPE_F32); + GGML_ASSERT(g->type == GGML_TYPE_F32); + GGML_ASSERT(gm->type == GGML_TYPE_F32); + GGML_ASSERT(gv->type == GGML_TYPE_F32); + GGML_ASSERT(p->type == GGML_TYPE_F32); + GGML_ASSERT(dst->buffer != nullptr); + GGML_ASSERT(ggml_is_contiguous(x)); + GGML_ASSERT(ggml_is_contiguous(g)); + GGML_ASSERT(ggml_is_contiguous(gm)); + GGML_ASSERT(ggml_is_contiguous(gv)); + GGML_ASSERT(ggml_is_contiguous(p)); + GGML_ASSERT(ggml_are_same_shape(x, g)); + GGML_ASSERT(ggml_are_same_shape(x, gm)); + GGML_ASSERT(ggml_are_same_shape(x, gv)); + GGML_ASSERT(ggml_nelements(p) == 7); + + vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, g, gm, gv, dst, GGML_OP_OPT_STEP_ADAMW); + GGML_ASSERT(pipeline != nullptr); + + if (dryrun) { + ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); + return; + } + + ggml_backend_vk_buffer_context * x_buf_ctx = (ggml_backend_vk_buffer_context *)x->buffer->context; + ggml_backend_vk_buffer_context * g_buf_ctx = (ggml_backend_vk_buffer_context *)g->buffer->context; + ggml_backend_vk_buffer_context * gm_buf_ctx = (ggml_backend_vk_buffer_context *)gm->buffer->context; + ggml_backend_vk_buffer_context * gv_buf_ctx = (ggml_backend_vk_buffer_context *)gv->buffer->context; + ggml_backend_vk_buffer_context * p_buf_ctx = (ggml_backend_vk_buffer_context *)p->buffer->context; + + ggml_vk_sync_buffers(subctx); + + vk_buffer d_X = nullptr, d_G = nullptr, d_GM = nullptr, d_GV = nullptr, d_P = nullptr; + size_t x_offset = 0, g_offset = 0, gm_offset = 0, gv_offset = 0, p_offset = 0; + bool X_uma = false, G_uma = false, GM_uma = false, GV_uma = false, P_uma = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, x->data, d_X, x_offset); + ggml_vk_host_get(ctx->device, g->data, d_G, g_offset); + ggml_vk_host_get(ctx->device, gm->data, d_GM, gm_offset); + ggml_vk_host_get(ctx->device, gv->data, d_GV, gv_offset); + ggml_vk_host_get(ctx->device, p->data, d_P, p_offset); + + X_uma = d_X != nullptr; + G_uma = d_G != nullptr; + GM_uma = d_GM != nullptr; + GV_uma = d_GV != nullptr; + P_uma = d_P != nullptr; + } + + if (!X_uma) { + d_X = x_buf_ctx->dev_buffer; + x_offset = vk_tensor_offset(x) + x->view_offs; + } + if (!G_uma) { + d_G = g_buf_ctx->dev_buffer; + g_offset = vk_tensor_offset(g) + g->view_offs; + } + if (!GM_uma) { + d_GM = gm_buf_ctx->dev_buffer; + gm_offset = vk_tensor_offset(gm) + gm->view_offs; + } + if (!GV_uma) { + d_GV = gv_buf_ctx->dev_buffer; + gv_offset = vk_tensor_offset(gv) + gv->view_offs; + } + if (!P_uma) { + d_P = p_buf_ctx->dev_buffer; + p_offset = vk_tensor_offset(p) + p->view_offs; + } + + const uint64_t x_size = ggml_nbytes(x); + const uint64_t g_size = ggml_nbytes(g); + const uint64_t gm_size = ggml_nbytes(gm); + const uint64_t gv_size = ggml_nbytes(gv); + const uint64_t p_size = ggml_nbytes(p); + + std::array elements = { (uint32_t)ggml_nelements(x), 1, 1 }; + + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { + vk_subbuffer{ d_X, x_offset, x_size }, + vk_subbuffer{ d_G, g_offset, g_size }, + vk_subbuffer{ d_GM, gm_offset, gm_size }, + vk_subbuffer{ d_GV, gv_offset, gv_size }, + vk_subbuffer{ d_P, p_offset, p_size }, + }, sizeof(vk_op_push_constants), &pc, elements); +} + +static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) { + const size_t n = ggml_nelements(dst->src[0]); + + ggml_vk_op_f32_opt_step_adamw( + ctx, subctx, dst, + { (uint32_t)n, 0, 0.0f, 0.0f }, + dryrun + ); +} + static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { int * op_params = (int *)dst->op_params; @@ -5771,6 +6318,20 @@ static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, co }, dryrun); } +static void ggml_vk_repeat_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT_BACK, { + (uint32_t)ggml_nelements(dst), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }, dryrun); +} + static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { const uint32_t src0_type_size = ggml_type_size(src0->type); const uint32_t dst_type_size = ggml_type_size(dst->type); @@ -5785,6 +6346,10 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const }, dryrun); } +static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun); +} + static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { float * op_params = (float *)dst->op_params; @@ -5807,6 +6372,11 @@ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun); } +static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + float * op_params = (float *)dst->op_params; + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun); +} + static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun); } @@ -5842,9 +6412,14 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, }, dryrun); } -static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) { +static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + float * op_params = (float *)dst->op_params; + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], op_params[1] }, dryrun); +} + +static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool backprop, bool dryrun = false) { const int n_dims = ((int32_t *) dst->op_params)[1]; - // const int mode = ((int32_t *) dst->op_params)[2]; + const int mode = ((int32_t *) dst->op_params)[2]; // const int n_ctx = ((int32_t *) dst->op_params)[3]; const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; const float freq_base = ((float *) dst->op_params)[5]; @@ -5853,16 +6428,24 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons const float attn_factor = ((float *) dst->op_params)[8]; const float beta_fast = ((float *) dst->op_params)[9]; const float beta_slow = ((float *) dst->op_params)[10]; + int sections[4] {}; + if (mode & GGML_ROPE_TYPE_MROPE) { + memcpy(sections, (int32_t *) dst->op_params + 11, sizeof(int)*4); + } float corr_dims[2]; ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); const float theta_scale = powf(freq_base, -2.0f/n_dims); + uint32_t s1 = src0->nb[1] / ggml_type_size(src0->type); + uint32_t s2 = src0->nb[2] / ggml_type_size(src0->type); + ggml_vk_op_f32(ctx, subctx, src0, src1, src2, dst, GGML_OP_ROPE, { (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1], freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale, - src2 != nullptr, + src2 != nullptr, (uint32_t)src0->ne[2], s1, s2, + sections[0], sections[1], sections[2], sections[3], backprop }, dryrun); } @@ -5885,10 +6468,22 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c }, dryrun); } +static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun); +} + static void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, { (uint32_t)src0->ne[0], 0, 0.0f, 0.0f }, dryrun); } +static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], 0, 0.0f, 0.0f }, dryrun); +} + +static void ggml_vk_count_equal(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun); +} + static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { const int32_t s0 = dst->op_params[0]; const int32_t s1 = dst->op_params[1]; @@ -6747,15 +7342,18 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_TANH: + case GGML_UNARY_OP_SIGMOID: break; default: return false; } break; case GGML_OP_REPEAT: + case GGML_OP_REPEAT_BACK: case GGML_OP_GET_ROWS: case GGML_OP_ADD: case GGML_OP_ACC: + case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: case GGML_OP_CONCAT: @@ -6769,22 +7367,30 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_CPY: case GGML_OP_CONT: case GGML_OP_DUP: + case GGML_OP_SILU_BACK: case GGML_OP_NORM: case GGML_OP_GROUP_NORM: case GGML_OP_RMS_NORM: + case GGML_OP_RMS_NORM_BACK: case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: + case GGML_OP_SOFT_MAX_BACK: case GGML_OP_ROPE: + case GGML_OP_ROPE_BACK: case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: case GGML_OP_ARGSORT: + case GGML_OP_SUM: case GGML_OP_SUM_ROWS: + case GGML_OP_ARGMAX: + case GGML_OP_COUNT_EQUAL: case GGML_OP_IM2COL: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_POOL_2D: case GGML_OP_RWKV_WKV6: case GGML_OP_LEAKY_RELU: case GGML_OP_FLASH_ATTN_EXT: + case GGML_OP_OPT_STEP_ADAMW: break; default: std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl; @@ -6805,9 +7411,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod } else { switch (node->op) { case GGML_OP_REPEAT: + case GGML_OP_REPEAT_BACK: case GGML_OP_ACC: case GGML_OP_GET_ROWS: case GGML_OP_ADD: + case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: case GGML_OP_CONCAT: @@ -6821,15 +7429,22 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_CPY: case GGML_OP_CONT: case GGML_OP_DUP: + case GGML_OP_SILU_BACK: case GGML_OP_NORM: case GGML_OP_GROUP_NORM: case GGML_OP_RMS_NORM: + case GGML_OP_RMS_NORM_BACK: case GGML_OP_UNARY: case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: + case GGML_OP_SOFT_MAX_BACK: case GGML_OP_ROPE: + case GGML_OP_ROPE_BACK: case GGML_OP_ARGSORT: + case GGML_OP_SUM: case GGML_OP_SUM_ROWS: + case GGML_OP_ARGMAX: + case GGML_OP_COUNT_EQUAL: case GGML_OP_IM2COL: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_POOL_2D: @@ -6850,6 +7465,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_REPEAT: ggml_vk_repeat(ctx, compute_ctx, src0, node, dryrun); + break; + case GGML_OP_REPEAT_BACK: + ggml_vk_repeat_back(ctx, compute_ctx, src0, node, dryrun); + break; case GGML_OP_ACC: ggml_vk_acc(ctx, compute_ctx, src0, src1, node, dryrun); @@ -6862,6 +7481,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_ADD: ggml_vk_add(ctx, compute_ctx, src0, src1, node, dryrun); + break; + case GGML_OP_SUB: + ggml_vk_sub(ctx, compute_ctx, src0, src1, node, dryrun); + break; case GGML_OP_MUL: ggml_vk_mul(ctx, compute_ctx, src0, src1, node, dryrun); @@ -6908,6 +7531,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_DUP: ggml_vk_cpy(ctx, compute_ctx, src0, node, dryrun); + break; + case GGML_OP_SILU_BACK: + ggml_vk_silu_back(ctx, compute_ctx, src0, src1, node, dryrun); + break; case GGML_OP_NORM: ggml_vk_norm(ctx, compute_ctx, src0, node, dryrun); @@ -6920,6 +7547,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_RMS_NORM: ggml_vk_rms_norm(ctx, compute_ctx, src0, node, dryrun); + break; + case GGML_OP_RMS_NORM_BACK: + ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun); + break; case GGML_OP_UNARY: switch (ggml_get_unary_op(node)) { @@ -6928,6 +7559,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_TANH: + case GGML_UNARY_OP_SIGMOID: ggml_vk_unary(ctx, compute_ctx, src0, node, dryrun); break; default: @@ -6941,18 +7573,38 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_SOFT_MAX: ggml_vk_soft_max(ctx, compute_ctx, src0, src1, node, dryrun); + break; + case GGML_OP_SOFT_MAX_BACK: + ggml_vk_soft_max_back(ctx, compute_ctx, src0, src1, node, dryrun); + break; case GGML_OP_ROPE: - ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, dryrun); + ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, false, dryrun); + + break; + case GGML_OP_ROPE_BACK: + ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, true, dryrun); break; case GGML_OP_ARGSORT: ggml_vk_argsort(ctx, compute_ctx, src0, node, dryrun); + break; + case GGML_OP_SUM: + ggml_vk_sum(ctx, compute_ctx, src0, node, dryrun); + break; case GGML_OP_SUM_ROWS: ggml_vk_sum_rows(ctx, compute_ctx, src0, node, dryrun); + break; + case GGML_OP_ARGMAX: + ggml_vk_argmax(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_COUNT_EQUAL: + ggml_vk_count_equal(ctx, compute_ctx, src0, src1, node, dryrun); + break; case GGML_OP_IM2COL: ggml_vk_im2col(ctx, compute_ctx, src0, src1, node, dryrun); @@ -6987,6 +7639,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_RWKV_WKV6: ggml_vk_rwkv_wkv6(ctx, compute_ctx, node, dryrun); + break; + + case GGML_OP_OPT_STEP_ADAMW: + ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun); + break; default: return false; @@ -7038,6 +7695,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * case GGML_OP_ADD: case GGML_OP_ACC: case GGML_OP_GET_ROWS: + case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: case GGML_OP_CONCAT: @@ -7051,25 +7709,34 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * case GGML_OP_CPY: case GGML_OP_CONT: case GGML_OP_DUP: + case GGML_OP_SILU_BACK: case GGML_OP_NORM: case GGML_OP_GROUP_NORM: case GGML_OP_RMS_NORM: + case GGML_OP_RMS_NORM_BACK: case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: + case GGML_OP_SOFT_MAX_BACK: case GGML_OP_ROPE: + case GGML_OP_ROPE_BACK: case GGML_OP_RESHAPE: case GGML_OP_VIEW: case GGML_OP_PERMUTE: case GGML_OP_TRANSPOSE: case GGML_OP_NONE: case GGML_OP_ARGSORT: + case GGML_OP_SUM: case GGML_OP_SUM_ROWS: + case GGML_OP_ARGMAX: + case GGML_OP_COUNT_EQUAL: case GGML_OP_IM2COL: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_POOL_2D: case GGML_OP_RWKV_WKV6: case GGML_OP_LEAKY_RELU: case GGML_OP_REPEAT: + case GGML_OP_REPEAT_BACK: + case GGML_OP_OPT_STEP_ADAMW: buf = tensor->buffer; break; @@ -7080,6 +7747,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_TANH: + case GGML_UNARY_OP_SIGMOID: buf = tensor->buffer; break; default: @@ -7261,6 +7929,15 @@ static void ggml_backend_vk_buffer_init_tensor(ggml_backend_buffer_t buffer, ggm } } +static void ggml_backend_vk_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { + VK_LOG_DEBUG("ggml_backend_vk_buffer_memset_tensor(" << buffer << ", " << tensor << ", " << value << ", " << offset << ", " << size << ")"); + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context; + vk_buffer buf = buf_ctx->dev_buffer; + + uint32_t val32 = (uint32_t)value * 0x01010101; + ggml_vk_buffer_memset(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, val32, size); +} + static void ggml_backend_vk_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { VK_LOG_DEBUG("ggml_backend_vk_buffer_set_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")"); ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context; @@ -7305,7 +7982,7 @@ static ggml_backend_buffer_i ggml_backend_vk_buffer_interface = { /* .free_buffer = */ ggml_backend_vk_buffer_free_buffer, /* .get_base = */ ggml_backend_vk_buffer_get_base, /* .init_tensor = */ ggml_backend_vk_buffer_init_tensor, - /* .memset_tensor = */ NULL, + /* .memset_tensor = */ ggml_backend_vk_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_vk_buffer_set_tensor, /* .get_tensor = */ ggml_backend_vk_buffer_get_tensor, /* .cpy_tensor = */ ggml_backend_vk_buffer_cpy_tensor, @@ -7343,7 +8020,7 @@ static size_t ggml_backend_vk_buffer_type_get_alignment(ggml_backend_buffer_type static size_t ggml_backend_vk_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) { ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context; - return ctx->device->max_memory_allocation_size; + return ctx->device->suballocation_block_size; } static size_t ggml_backend_vk_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { @@ -7569,6 +8246,9 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg for (int i = 0; i < cgraph->n_nodes; i++) { ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false); } + if (ctx->device->need_compiles) { + ggml_vk_load_shaders(ctx->device); + } ggml_vk_preallocate_buffers(ctx); ggml_pipeline_allocate_descriptor_sets(ctx->device); @@ -7769,6 +8449,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_TANH: + case GGML_UNARY_OP_SIGMOID: return ggml_is_contiguous(op->src[0]); default: return false; @@ -7777,13 +8458,14 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: { + ggml_type src0_type = op->src[0]->type; ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; const vk_device& device = ggml_vk_get_device(ctx->device); - if (op->op == GGML_OP_MUL_MAT_ID && !device->mul_mat_id_s && !device->mul_mat_id_m && !device->mul_mat_id_l) { + if (op->op == GGML_OP_MUL_MAT_ID && !device->mul_mat_id_s[src0_type] && !device->mul_mat_id_m[src0_type] && !device->mul_mat_id_l[src0_type]) { // If there's not enough shared memory for row_ids and the result tile, fallback to CPU return false; } - switch (op->src[0]->type) { + switch (src0_type) { case GGML_TYPE_F32: case GGML_TYPE_F16: case GGML_TYPE_Q4_0: @@ -7796,6 +8478,14 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: break; default: @@ -7864,6 +8554,14 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm //case GGML_TYPE_Q4_K: //case GGML_TYPE_Q5_K: //case GGML_TYPE_Q6_K: + //case GGML_TYPE_IQ1_S: + //case GGML_TYPE_IQ1_M: + //case GGML_TYPE_IQ2_XXS: + //case GGML_TYPE_IQ2_XS: + //case GGML_TYPE_IQ2_S: + //case GGML_TYPE_IQ3_XXS: + //case GGML_TYPE_IQ3_S: + //case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: break; default: @@ -7881,6 +8579,14 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: return true; default: @@ -7893,12 +8599,36 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm { ggml_type src0_type = op->src[0]->type; ggml_type src1_type = op->src[1] != nullptr ? op->src[1]->type : src0_type; - if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) { - return true; + + if (src0_type == GGML_TYPE_F32) { + switch (src1_type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_IQ4_NL: + return true; + default: + break; + } } - if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) { - return true; + if (src1_type == GGML_TYPE_F32) { + switch (src0_type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_IQ4_NL: + return true; + default: + break; + } } + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { return true; } @@ -7906,30 +8636,28 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm } break; case GGML_OP_REPEAT: return ggml_type_size(op->type) == sizeof(float) && ggml_type_size(op->src[0]->type) == sizeof(float); + case GGML_OP_REPEAT_BACK: + return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_ROPE: - { - const int mode = ((const int32_t *) op->op_params)[2]; - if (mode & GGML_ROPE_TYPE_MROPE) { - return false; - } - if (mode & GGML_ROPE_TYPE_VISION) { - return false; - } - return ggml_is_contiguous(op->src[0]); - } + case GGML_OP_ROPE_BACK: case GGML_OP_NONE: case GGML_OP_RESHAPE: case GGML_OP_VIEW: case GGML_OP_PERMUTE: case GGML_OP_TRANSPOSE: + return true; case GGML_OP_NORM: case GGML_OP_GROUP_NORM: case GGML_OP_RMS_NORM: + return ggml_is_contiguous(op->src[0]); case GGML_OP_ADD: case GGML_OP_ACC: + case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: case GGML_OP_CONCAT: + case GGML_OP_SILU_BACK: + case GGML_OP_RMS_NORM_BACK: case GGML_OP_UPSCALE: case GGML_OP_SCALE: case GGML_OP_SQR: @@ -7939,13 +8667,18 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_PAD: case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: + case GGML_OP_SOFT_MAX_BACK: case GGML_OP_ARGSORT: + case GGML_OP_SUM: case GGML_OP_SUM_ROWS: + case GGML_OP_ARGMAX: + case GGML_OP_COUNT_EQUAL: case GGML_OP_IM2COL: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_POOL_2D: case GGML_OP_RWKV_WKV6: case GGML_OP_LEAKY_RELU: + case GGML_OP_OPT_STEP_ADAMW: return true; default: return false; @@ -8045,8 +8778,13 @@ ggml_backend_reg_t ggml_backend_vk_reg() { /* .iface = */ ggml_backend_vk_reg_i, /* .context = */ nullptr, }; - - return ® + try { + ggml_vk_instance_init(); + return ® + } catch (const vk::SystemError& e) { + VK_LOG_DEBUG("ggml_backend_vk_reg() -> Error: System error: " << e.what()); + return nullptr; + } } // Extension availability @@ -8213,8 +8951,6 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { ggml_tensor * src0 = tensor->src[0]; ggml_tensor * src1 = tensor->src[1]; - ggml_tensor * src2 = tensor->src[2]; - ggml_tensor * src3 = tensor->src[3]; struct ggml_init_params iparams = { /*.mem_size =*/ 2ul*1024ul*1024ul*1024ul, @@ -8224,239 +8960,121 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { struct ggml_context * ggml_ctx = ggml_init(iparams); - struct ggml_tensor * src0_clone = nullptr; - struct ggml_tensor * src1_clone = nullptr; - struct ggml_tensor * src2_clone = nullptr; - struct ggml_tensor * src3_clone = nullptr; + std::array src_clone = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}; + std::array src_size = {0, 0, 0, 0, 0, 0}; + std::array src_buffer = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}; + const char * srci_name[6] = {"src0", "src1", "src2", "src3", "src4", "src5"}; + struct ggml_tensor * tensor_clone = nullptr; - size_t src0_size; - size_t src1_size; - size_t src2_size; - size_t src3_size; + for (int i = 0; i < 6; i++) { + ggml_tensor * srci = tensor->src[i]; + if (srci == nullptr) { + continue; + } + ggml_tensor * srci_clone = ggml_dup_tensor(ggml_ctx, srci); + size_t srci_size = ggml_nbytes(srci); - void * src0_buffer = nullptr; - void * src1_buffer = nullptr; - void * src2_buffer = nullptr; - void * src3_buffer = nullptr; + src_clone[i] = srci_clone; + src_size[i] = ggml_nbytes(srci); + src_buffer[i] = malloc(srci_size); - if (src0 != nullptr) { - src0_clone = ggml_dup_tensor(ggml_ctx, src0); - - src0_size = ggml_nbytes(src0); - - src0_buffer = malloc(src0_size); - src0_clone->data = src0_buffer; - if (ggml_backend_buffer_is_host(src0->buffer)) { - memcpy(src0_clone->data, src0->data, src0_size); - memcpy(src0_clone->nb, src0->nb, sizeof(size_t) * GGML_MAX_DIMS); - } else if (ggml_backend_buffer_is_vk(src0->buffer)) { - ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; + srci_clone->data = src_buffer[i]; + if (ggml_backend_buffer_is_host(srci->buffer)) { + memcpy(srci_clone->data, srci->data, srci_size); + memcpy(srci_clone->nb, srci->nb, sizeof(size_t) * GGML_MAX_DIMS); + } else if (ggml_backend_buffer_is_vk(srci->buffer)) { + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)srci->buffer->context; vk_buffer& buffer_gpu = buf_ctx->dev_buffer; - uint64_t offset = vk_tensor_offset(src0) + src0->view_offs; - if (!ggml_is_contiguous(src0) && ggml_vk_dim01_contiguous(src0)) { - for (int i3 = 0; i3 < src0->ne[3]; i3++) { - for (int i2 = 0; i2 < src0->ne[2]; i2++) { - const int idx = i3*src0->ne[2] + i2; - ggml_vk_buffer_read(buffer_gpu, offset + idx * src0->nb[2], ((char *)src0_clone->data + idx * src0_clone->nb[2]), src0->ne[1] * src0->nb[1]); + uint64_t offset = vk_tensor_offset(srci) + srci->view_offs; + if (!ggml_is_contiguous(srci) && ggml_vk_dim01_contiguous(srci)) { + for (int i3 = 0; i3 < srci->ne[3]; i3++) { + for (int i2 = 0; i2 < srci->ne[2]; i2++) { + const int idx = i3*srci->ne[2] + i2; + ggml_vk_buffer_read(buffer_gpu, offset + idx * srci->nb[2], ((char *)srci_clone->data + idx * srci_clone->nb[2]), srci->ne[1] * srci->nb[1]); } } - src0_clone->nb[0] = src0->nb[0]; - src0_clone->nb[1] = src0->nb[1]; + srci_clone->nb[0] = srci->nb[0]; + srci_clone->nb[1] = srci->nb[1]; for (int i = 2; i < GGML_MAX_DIMS; i++) { - src0_clone->nb[i] = src0_clone->nb[i - 1]*src0_clone->ne[i - 1]; + srci_clone->nb[i] = srci_clone->nb[i - 1]*srci_clone->ne[i - 1]; } } else { - if (offset + src0_size >= buffer_gpu->size) { - src0_size = buffer_gpu->size - offset; + if (offset + srci_size >= buffer_gpu->size) { + srci_size = buffer_gpu->size - offset; } - ggml_vk_buffer_read(buffer_gpu, offset, src0_clone->data, src0_size); - memcpy(src0_clone->nb, src0->nb, sizeof(size_t) * GGML_MAX_DIMS); + ggml_vk_buffer_read(buffer_gpu, offset, srci_clone->data, srci_size); + memcpy(srci_clone->nb, srci->nb, sizeof(size_t) * GGML_MAX_DIMS); } } else { GGML_ABORT("fatal error"); } if (vk_output_tensor > 0 && vk_output_tensor == check_counter) { - ggml_vk_print_tensor(src0, "src0"); - } - } - if (src1 != nullptr) { - src1_clone = ggml_dup_tensor(ggml_ctx, src1); - - src1_size = ggml_nbytes(src1); - - src1_buffer = malloc(src1_size); - src1_clone->data = src1_buffer; - if (ggml_backend_buffer_is_host(src1->buffer)) { - memcpy(src1_clone->data, src1->data, src1_size); - memcpy(src1_clone->nb, src1->nb, sizeof(size_t) * GGML_MAX_DIMS); - } else if (ggml_backend_buffer_is_vk(src1->buffer)) { - ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; - vk_buffer& buffer_gpu = buf_ctx->dev_buffer; - uint64_t offset = vk_tensor_offset(src1) + src1->view_offs; - if (!ggml_is_contiguous(src1) && ggml_vk_dim01_contiguous(src1)) { - for (int i3 = 0; i3 < src1->ne[3]; i3++) { - for (int i2 = 0; i2 < src1->ne[2]; i2++) { - const int idx = i3*src1->ne[2] + i2; - ggml_vk_buffer_read(buffer_gpu, offset + idx * src1->nb[2], ((char *)src1_clone->data + idx * src1_clone->nb[2]), src1->ne[1] * src1->nb[1]); - } - } - - src1_clone->nb[0] = src1->nb[0]; - src1_clone->nb[1] = src1->nb[1]; - for (int i = 2; i < GGML_MAX_DIMS; i++) { - src1_clone->nb[i] = src1_clone->nb[i - 1]*src1_clone->ne[i - 1]; - } - } else { - if (offset + src1_size >= buffer_gpu->size) { - src1_size = buffer_gpu->size - offset; - } - ggml_vk_buffer_read(buffer_gpu, offset, src1_clone->data, src1_size); - memcpy(src1_clone->nb, src1->nb, sizeof(size_t) * GGML_MAX_DIMS); - } - } else { - GGML_ABORT("fatal error"); - } - - if (vk_output_tensor > 0 && vk_output_tensor == check_counter) { - ggml_vk_print_tensor(src1, "src1"); - } - } - if (src2 != nullptr) { - src2_clone = ggml_dup_tensor(ggml_ctx, src2); - - src2_size = ggml_nbytes(src2); - - src2_buffer = malloc(src2_size); - src2_clone->data = src2_buffer; - if (ggml_backend_buffer_is_host(src2->buffer)) { - memcpy(src2_clone->data, src2->data, src2_size); - memcpy(src2_clone->nb, src2->nb, sizeof(size_t) * GGML_MAX_DIMS); - } else if (ggml_backend_buffer_is_vk(src2->buffer)) { - ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src2->buffer->context; - vk_buffer& buffer_gpu = buf_ctx->dev_buffer; - uint64_t offset = vk_tensor_offset(src2) + src2->view_offs; - if (!ggml_is_contiguous(src2) && ggml_vk_dim01_contiguous(src2)) { - for (int i3 = 0; i3 < src2->ne[3]; i3++) { - for (int i2 = 0; i2 < src2->ne[2]; i2++) { - const int idx = i3*src2->ne[2] + i2; - ggml_vk_buffer_read(buffer_gpu, offset + idx * src2->nb[2], ((char *)src2_clone->data + idx * src2_clone->nb[2]), src2->ne[1] * src2->nb[1]); - } - } - - src2_clone->nb[0] = src2->nb[0]; - src2_clone->nb[1] = src2->nb[1]; - for (int i = 2; i < GGML_MAX_DIMS; i++) { - src2_clone->nb[i] = src2_clone->nb[i - 1]*src2_clone->ne[i - 1]; - } - } else { - if (offset + src2_size >= buffer_gpu->size) { - src2_size = buffer_gpu->size - offset; - } - ggml_vk_buffer_read(buffer_gpu, offset, src2_clone->data, src2_size); - memcpy(src2_clone->nb, src2->nb, sizeof(size_t) * GGML_MAX_DIMS); - } - } else { - GGML_ABORT("fatal error"); - } - - if (vk_output_tensor > 0 && vk_output_tensor == check_counter) { - ggml_vk_print_tensor(src2, "src2"); - } - } - if (src3 != nullptr) { - src3_clone = ggml_dup_tensor(ggml_ctx, src3); - - src3_size = ggml_nbytes(src3); - - src3_buffer = malloc(src3_size); - src3_clone->data = src3_buffer; - if (ggml_backend_buffer_is_host(src3->buffer)) { - memcpy(src3_clone->data, src3->data, src3_size); - memcpy(src3_clone->nb, src3->nb, sizeof(size_t) * GGML_MAX_DIMS); - } else if (ggml_backend_buffer_is_vk(src3->buffer)) { - ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src3->buffer->context; - vk_buffer& buffer_gpu = buf_ctx->dev_buffer; - uint64_t offset = vk_tensor_offset(src3) + src3->view_offs; - if (!ggml_is_contiguous(src3) && ggml_vk_dim01_contiguous(src3)) { - for (int i3 = 0; i3 < src3->ne[3]; i3++) { - for (int i2 = 0; i2 < src3->ne[2]; i2++) { - const int idx = i3*src3->ne[2] + i2; - ggml_vk_buffer_read(buffer_gpu, offset + idx * src3->nb[2], ((char *)src3_clone->data + idx * src3_clone->nb[2]), src3->ne[1] * src3->nb[1]); - } - } - - src3_clone->nb[0] = src3->nb[0]; - src3_clone->nb[1] = src3->nb[1]; - for (int i = 2; i < GGML_MAX_DIMS; i++) { - src3_clone->nb[i] = src3_clone->nb[i - 1]*src3_clone->ne[i - 1]; - } - } else { - if (offset + src3_size >= buffer_gpu->size) { - src3_size = buffer_gpu->size - offset; - } - ggml_vk_buffer_read(buffer_gpu, offset, src3_clone->data, src3_size); - memcpy(src3_clone->nb, src3->nb, sizeof(size_t) * GGML_MAX_DIMS); - } - } else { - GGML_ABORT("fatal error"); - } - - if (vk_output_tensor > 0 && vk_output_tensor == check_counter) { - ggml_vk_print_tensor(src3, "src3"); + ggml_vk_print_tensor(srci, srci_name[i]); } } if (tensor->op == GGML_OP_FLASH_ATTN_EXT) { const float *params = (const float *)tensor->op_params; - tensor_clone = ggml_flash_attn_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, src3_clone, params[0], params[1], params[2]); + tensor_clone = ggml_flash_attn_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], params[0], params[1], params[2]); } else if (tensor->op == GGML_OP_MUL_MAT) { - tensor_clone = ggml_mul_mat(ggml_ctx, src0_clone, src1_clone); + tensor_clone = ggml_mul_mat(ggml_ctx, src_clone[0], src_clone[1]); } else if (tensor->op == GGML_OP_MUL_MAT_ID) { - tensor_clone = ggml_mul_mat_id(ggml_ctx, src0_clone, src1_clone, src2_clone); + tensor_clone = ggml_mul_mat_id(ggml_ctx, src_clone[0], src_clone[1], src_clone[2]); + } else if (tensor->op == GGML_OP_SUB) { + tensor_clone = ggml_sub(ggml_ctx, src_clone[0], src_clone[1]); } else if (tensor->op == GGML_OP_MUL) { - tensor_clone = ggml_mul(ggml_ctx, src0_clone, src1_clone); + tensor_clone = ggml_mul(ggml_ctx, src_clone[0], src_clone[1]); } else if (tensor->op == GGML_OP_DIV) { - tensor_clone = ggml_div(ggml_ctx, src0_clone, src1_clone); + tensor_clone = ggml_div(ggml_ctx, src_clone[0], src_clone[1]); } else if (tensor->op == GGML_OP_CONCAT) { - tensor_clone = ggml_concat(ggml_ctx, src0_clone, src1_clone, *(int *)tensor->op_params); + tensor_clone = ggml_concat(ggml_ctx, src_clone[0], src_clone[1], *(int *)tensor->op_params); } else if (tensor->op == GGML_OP_UPSCALE) { - tensor_clone = ggml_upscale_ext(ggml_ctx, src0_clone, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); + tensor_clone = ggml_upscale_ext(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); } else if (tensor->op == GGML_OP_SCALE) { - tensor_clone = ggml_scale(ggml_ctx, src0_clone, ((float *)tensor->op_params)[0]); + tensor_clone = ggml_scale(ggml_ctx, src_clone[0], ((float *)tensor->op_params)[0]); } else if (tensor->op == GGML_OP_SQR) { - tensor_clone = ggml_sqr(ggml_ctx, src0_clone); + tensor_clone = ggml_sqr(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_SIN) { - tensor_clone = ggml_sin(ggml_ctx, src0_clone); + tensor_clone = ggml_sin(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_COS) { - tensor_clone = ggml_cos(ggml_ctx, src0_clone); + tensor_clone = ggml_cos(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_CLAMP) { - tensor_clone = ggml_clamp(ggml_ctx, src0_clone, ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]); + tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]); } else if (tensor->op == GGML_OP_PAD) { - tensor_clone = ggml_pad(ggml_ctx, src0_clone, tensor->ne[0] - src0_clone->ne[0], tensor->ne[1] - src0_clone->ne[1], tensor->ne[2] - src0_clone->ne[2], tensor->ne[3] - src0_clone->ne[3]); + tensor_clone = ggml_pad(ggml_ctx, src_clone[0], tensor->ne[0] - src_clone[0]->ne[0], tensor->ne[1] - src_clone[0]->ne[1], tensor->ne[2] - src_clone[0]->ne[2], tensor->ne[3] - src_clone[0]->ne[3]); } else if (tensor->op == GGML_OP_REPEAT) { - tensor_clone = ggml_repeat(ggml_ctx, src0_clone, tensor); + tensor_clone = ggml_repeat(ggml_ctx, src_clone[0], tensor); + } else if (tensor->op == GGML_OP_REPEAT_BACK) { + tensor_clone = ggml_repeat_back(ggml_ctx, src_clone[0], tensor); } else if (tensor->op == GGML_OP_ADD) { - tensor_clone = ggml_add(ggml_ctx, src0_clone, src1_clone); + tensor_clone = ggml_add(ggml_ctx, src_clone[0], src_clone[1]); } else if (tensor->op == GGML_OP_ACC) { - tensor_clone = ggml_acc(ggml_ctx, src0_clone, src1_clone, tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]); + tensor_clone = ggml_acc(ggml_ctx, src_clone[0], src_clone[1], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]); } else if (tensor->op == GGML_OP_NORM) { - tensor_clone = ggml_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params); + tensor_clone = ggml_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params); } else if (tensor->op == GGML_OP_GROUP_NORM) { - tensor_clone = ggml_group_norm(ggml_ctx, src0_clone, *(int *)tensor->op_params, ((float *)tensor->op_params)[1]); + tensor_clone = ggml_group_norm(ggml_ctx, src_clone[0], *(int *)tensor->op_params, ((float *)tensor->op_params)[1]); } else if (tensor->op == GGML_OP_RMS_NORM) { - tensor_clone = ggml_rms_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params); + tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params); + } else if (tensor->op == GGML_OP_RMS_NORM_BACK) { + const float eps = ((float *) tensor->op_params)[0]; + tensor_clone = ggml_rms_norm_back(ggml_ctx, src_clone[0], src_clone[1], eps); + } else if (tensor->op == GGML_OP_SILU_BACK) { + tensor_clone = ggml_silu_back(ggml_ctx, src_clone[0], src_clone[1]); } else if (tensor->op == GGML_OP_SOFT_MAX) { if (src1 != nullptr) { - tensor_clone = ggml_soft_max_ext(ggml_ctx, src0_clone, src1_clone, ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]); + tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]); } else { - tensor_clone = ggml_soft_max(ggml_ctx, src0_clone); + tensor_clone = ggml_soft_max(ggml_ctx, src_clone[0]); } + } else if (tensor->op == GGML_OP_SOFT_MAX_BACK) { + tensor_clone = ggml_soft_max_ext_back(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]); } else if (tensor->op == GGML_OP_DIAG_MASK_INF) { - tensor_clone = ggml_diag_mask_inf(ggml_ctx, src0_clone, *(int *)tensor->op_params); - } else if (tensor->op == GGML_OP_ROPE) { + tensor_clone = ggml_diag_mask_inf(ggml_ctx, src_clone[0], *(int *)tensor->op_params); + } else if (tensor->op == GGML_OP_ROPE || tensor->op == GGML_OP_ROPE_BACK) { const int n_dims = ((int32_t *) tensor->op_params)[1]; const int mode = ((int32_t *) tensor->op_params)[2]; //const int n_ctx_ggml = ((int32_t *) tensor->op_params)[3]; @@ -8467,23 +9085,39 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { const float attn_factor = ((float *) tensor->op_params)[8]; const float beta_fast = ((float *) tensor->op_params)[9]; const float beta_slow = ((float *) tensor->op_params)[10]; - tensor_clone = ggml_rope_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + if (mode & GGML_ROPE_TYPE_MROPE) { + int32_t *sections = ((int32_t *) tensor->op_params) + 11; + if (tensor->op == GGML_OP_ROPE) { + tensor_clone = ggml_rope_multi(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + } else { + tensor_clone = ggml_rope_multi_back(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + } + } else { + if (tensor->op == GGML_OP_ROPE) { + tensor_clone = ggml_rope_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + } else { + tensor_clone = ggml_rope_ext_back(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + } + } } else if (tensor->op == GGML_OP_UNARY) { switch (ggml_get_unary_op(tensor)) { case GGML_UNARY_OP_SILU: - tensor_clone = ggml_silu(ggml_ctx, src0_clone); + tensor_clone = ggml_silu(ggml_ctx, src_clone[0]); break; case GGML_UNARY_OP_GELU: - tensor_clone = ggml_gelu(ggml_ctx, src0_clone); + tensor_clone = ggml_gelu(ggml_ctx, src_clone[0]); break; case GGML_UNARY_OP_GELU_QUICK: - tensor_clone = ggml_gelu_quick(ggml_ctx, src0_clone); + tensor_clone = ggml_gelu_quick(ggml_ctx, src_clone[0]); break; case GGML_UNARY_OP_RELU: - tensor_clone = ggml_relu(ggml_ctx, src0_clone); + tensor_clone = ggml_relu(ggml_ctx, src_clone[0]); break; case GGML_UNARY_OP_TANH: - tensor_clone = ggml_tanh(ggml_ctx, src0_clone); + tensor_clone = ggml_tanh(ggml_ctx, src_clone[0]); + break; + case GGML_UNARY_OP_SIGMOID: + tensor_clone = ggml_sigmoid(ggml_ctx, src_clone[0]); break; default: std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl; @@ -8491,28 +9125,34 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { } } else if (tensor->op == GGML_OP_CPY || tensor->op == GGML_OP_DUP) { if (src1 == nullptr) { - tensor_clone = ggml_dup(ggml_ctx, src0_clone); + tensor_clone = ggml_dup(ggml_ctx, src_clone[0]); tensor_clone->type = tensor->type; } else { - tensor_clone = ggml_cpy(ggml_ctx, src0_clone, src1_clone); + tensor_clone = ggml_cpy(ggml_ctx, src_clone[0], src_clone[1]); } } else if (tensor->op == GGML_OP_CONT) { - tensor_clone = ggml_cont_4d(ggml_ctx, src0_clone, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); + tensor_clone = ggml_cont_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); } else if (tensor->op == GGML_OP_RESHAPE) { - tensor_clone = ggml_reshape_4d(ggml_ctx, src0_clone, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); + tensor_clone = ggml_reshape_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); } else if (tensor->op == GGML_OP_VIEW) { - tensor_clone = ggml_view_4d(ggml_ctx, src0_clone, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], tensor->nb[1], tensor->nb[2], tensor->nb[3], ((int32_t *) tensor->op_params)[0]); + tensor_clone = ggml_view_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], tensor->nb[1], tensor->nb[2], tensor->nb[3], ((int32_t *) tensor->op_params)[0]); } else if (tensor->op == GGML_OP_PERMUTE) { int32_t * params = (int32_t *)tensor->op_params; - tensor_clone = ggml_permute(ggml_ctx, src0_clone, params[0], params[1], params[2], params[3]); + tensor_clone = ggml_permute(ggml_ctx, src_clone[0], params[0], params[1], params[2], params[3]); } else if (tensor->op == GGML_OP_TRANSPOSE) { - tensor_clone = ggml_transpose(ggml_ctx, src0_clone); + tensor_clone = ggml_transpose(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_GET_ROWS) { - tensor_clone = ggml_get_rows(ggml_ctx, src0_clone, src1_clone); + tensor_clone = ggml_get_rows(ggml_ctx, src_clone[0], src_clone[1]); } else if (tensor->op == GGML_OP_ARGSORT) { - tensor_clone = ggml_argsort(ggml_ctx, src0_clone, (ggml_sort_order) *(int *)tensor->op_params); + tensor_clone = ggml_argsort(ggml_ctx, src_clone[0], (ggml_sort_order) *(int *)tensor->op_params); + } else if (tensor->op == GGML_OP_SUM) { + tensor_clone = ggml_sum(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_SUM_ROWS) { - tensor_clone = ggml_sum_rows(ggml_ctx, src0_clone); + tensor_clone = ggml_sum_rows(ggml_ctx, src_clone[0]); + } else if (tensor->op == GGML_OP_ARGMAX) { + tensor_clone = ggml_argmax(ggml_ctx, src_clone[0]); + } else if (tensor->op == GGML_OP_COUNT_EQUAL) { + tensor_clone = ggml_count_equal(ggml_ctx, src_clone[0], src_clone[1]); } else if (tensor->op == GGML_OP_IM2COL) { const int32_t s0 = tensor->op_params[0]; const int32_t s1 = tensor->op_params[1]; @@ -8522,11 +9162,11 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { const int32_t d1 = tensor->op_params[5]; const bool is_2D = tensor->op_params[6] == 1; - tensor_clone = ggml_im2col(ggml_ctx, src0_clone, src1_clone, s0, s1, p0, p1, d0, d1, is_2D, tensor->type); + tensor_clone = ggml_im2col(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1, is_2D, tensor->type); } else if (tensor->op == GGML_OP_TIMESTEP_EMBEDDING) { const int32_t dim = tensor->op_params[0]; const int32_t max_period = tensor->op_params[1]; - tensor_clone = ggml_timestep_embedding(ggml_ctx, src0_clone, dim, max_period); + tensor_clone = ggml_timestep_embedding(ggml_ctx, src_clone[0], dim, max_period); } else if (tensor->op == GGML_OP_POOL_2D) { enum ggml_op_pool op = static_cast(tensor->op_params[0]); const int32_t k0 = tensor->op_params[1]; @@ -8536,13 +9176,17 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { const int32_t p0 = tensor->op_params[5]; const int32_t p1 = tensor->op_params[6]; - tensor_clone = ggml_pool_2d(ggml_ctx, src0_clone, op, k0, k1, s0, s1, p0, p1); + tensor_clone = ggml_pool_2d(ggml_ctx, src_clone[0], op, k0, k1, s0, s1, p0, p1); } else if (tensor->op == GGML_OP_LEAKY_RELU) { const float * op_params = (const float *)tensor->op_params; - tensor_clone = ggml_leaky_relu(ggml_ctx, src0_clone, op_params[0], false); + tensor_clone = ggml_leaky_relu(ggml_ctx, src_clone[0], op_params[0], false); } else if (tensor->op == GGML_OP_RWKV_WKV6) { - tensor_clone = ggml_rwkv_wkv6(ggml_ctx, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], - tensor->src[4], tensor->src[5]); + tensor_clone = ggml_rwkv_wkv6(ggml_ctx, src_clone[0], src_clone[1], + src_clone[2], src_clone[3], src_clone[4], src_clone[5]); + } else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) { + src_clone[0]->flags = src0->flags; + tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1], + src_clone[2], src_clone[3], src_clone[4]); } else { std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl; @@ -8564,11 +9208,10 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { memcpy(comp_result, tensor_clone->data, comp_size); memcpy(comp_nb, tensor_clone->nb, sizeof(size_t) * GGML_MAX_DIMS); - if (src0 != nullptr) { - free(src0_buffer); - } - if (src1 != nullptr) { - free(src1_buffer); + for (int i = 0; i < 6; i++) { + if (src_buffer[i] != nullptr) { + free(src_buffer[i]); + } } ggml_free(ggml_ctx); @@ -8589,6 +9232,7 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) { ggml_tensor * src0 = tensor->src[0]; ggml_tensor * src1 = tensor->src[1]; ggml_tensor * src2 = tensor->src[2]; + ggml_tensor * src3 = tensor->src[3]; void * tensor_data = tensor->data; @@ -8631,6 +9275,9 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) { } else if (tensor->type == GGML_TYPE_I32) { correct = *(int32_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]); result = *(int32_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]); + } else if (tensor->type == GGML_TYPE_I64) { + correct = *(int64_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]); + result = *(int64_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]); } else { std::cerr << "Results check not implemented for type " << ggml_type_name(tensor->type) << std::endl; } @@ -8651,6 +9298,9 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) { if (src2 != nullptr) { std::cerr << "src2=" << src2 << " src2->name=" << src2->name << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl; } + if (src3 != nullptr) { + std::cerr << "src3=" << src3 << " src3->name=" << src3->name << " op=" << ggml_op_name(src3->op) << " type=" << ggml_type_name(src3->type) << " ne0=" << src3->ne[0] << " nb0=" << src3->nb[0] << " ne1=" << src3->ne[1] << " nb1=" << src3->nb[1] << " ne2=" << src3->ne[2] << " nb2=" << src3->nb[2] << " ne3=" << src3->ne[3] << " nb3=" << src3->nb[3] << " offset=" << src3->view_offs << std::endl; + } std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl; std::cerr << std::endl << "Result:" << std::endl; ggml_vk_print_tensor_area(tensor, tensor_data, i0, i1, i2, i3); @@ -8695,6 +9345,9 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) { if (src2 != nullptr) { std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl; } + if (src3 != nullptr) { + std::cerr << "src3=" << src3 << " op=" << ggml_op_name(src3->op) << " type=" << ggml_type_name(src3->type) << " ne0=" << src3->ne[0] << " nb0=" << src3->nb[0] << " ne1=" << src3->ne[1] << " nb1=" << src3->nb[1] << " ne2=" << src3->ne[2] << " nb2=" << src3->nb[2] << " ne3=" << src3->ne[3] << " nb3=" << src3->nb[3] << " offset=" << src3->view_offs << std::endl; + } std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl; std::cerr << std::endl << "Result:" << std::endl; ggml_vk_print_tensor_area(tensor, tensor_data, 5, 5, 0, 0); @@ -8717,6 +9370,9 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) { if (src2 != nullptr) { std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl; } + if (src3 != nullptr) { + std::cerr << "src3=" << src3 << " op=" << ggml_op_name(src3->op) << " type=" << ggml_type_name(src3->type) << " ne0=" << src3->ne[0] << " nb0=" << src3->nb[0] << " ne1=" << src3->ne[1] << " nb1=" << src3->nb[1] << " ne2=" << src3->ne[2] << " nb2=" << src3->nb[2] << " ne3=" << src3->ne[3] << " nb3=" << src3->nb[3] << " offset=" << src3->view_offs << std::endl; + } std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl; std::cerr << std::endl << "Result:" << std::endl; ggml_vk_print_tensor_area(tensor, tensor_data, first_error[0], first_error[1], first_error[2], first_error[3]); diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt index bd0c74cb1..074031087 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt @@ -1,9 +1,11 @@ find_package (Threads REQUIRED) -find_package(Vulkan COMPONENTS glslc REQUIRED) +find_program(GLSLC_EXECUTABLE glslc) +if(NOT GLSLC_EXECUTABLE) + message(FATAL_ERROR "glslc not found.") +endif() set(TARGET vulkan-shaders-gen) add_executable(${TARGET} vulkan-shaders-gen.cpp) install(TARGETS ${TARGET} RUNTIME) target_compile_features(${TARGET} PRIVATE cxx_std_17) target_link_libraries(vulkan-shaders-gen PUBLIC Threads::Threads) -target_link_libraries(vulkan-shaders-gen PRIVATE Vulkan::Vulkan) diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp new file mode 100644 index 000000000..eaf4da341 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp @@ -0,0 +1,51 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +layout (constant_id = 0) const uint BLOCK_SIZE = 32; + +shared FLOAT_TYPE tmpmax[BLOCK_SIZE]; +shared uint tmp[BLOCK_SIZE]; + +void main() { + const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; + const uint col = gl_LocalInvocationID.x; + + if (col >= p.KX) { + return; + } + A_TYPE amax = data_a[row*p.KX + col]; + tmp[col] = col; + + for (uint i = col + BLOCK_SIZE; i < p.KX; i += BLOCK_SIZE) { + A_TYPE val = data_a[row*p.KX + i]; + if (val > amax) { + amax = val; + tmp[col] = i; + } + } + tmpmax[col] = amax; + + barrier(); + [[unroll]] for (int s = int(BLOCK_SIZE) / 2; s > 0; s >>= 1) { + if (col < s && col + s < p.KX) { + if (tmpmax[col] < tmpmax[col + s]) { + tmpmax[col] = tmpmax[col + s]; + tmp[col] = tmp[col + s]; + } + } + barrier(); + } + + if (col == 0) { + data_d[row] = D_TYPE(tmp[0]); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp new file mode 100644 index 000000000..dbc7daa33 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp @@ -0,0 +1,51 @@ +#version 450 + +#include "types.comp" +#include "generic_unary_head.comp" +#include "dequant_funcs.comp" + +#if defined(DATA_A_IQ4_NL) +// 16 invocations needed for init_iq4nl_shmem +layout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in; +#else +layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; +#endif + +void main() { +#ifdef NEEDS_INIT_IQ_SHMEM + init_iq_shmem(gl_WorkGroupSize); + if (gl_LocalInvocationIndex.x != 0) { + return; + } +#endif + + const uint idx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * QUANT_K; + + if (idx >= p.ne) { + return; + } + + uint dst_idx = get_doffset() + dst_idx(idx); + uint src_idx = src0_idx_quant(idx, QUANT_K); + + const uint a_offset = 0; + const uint ib = src_idx; + const vec2 dm = get_dm(ib, a_offset); + + [[unroll]] for (int j = 0; j < QUANT_K; j += 4) { + vec4 v = dequantize4(ib, j / QUANT_R, a_offset); + v = v * dm.x + vec4(dm.y); + +#if QUANT_R == 2 + data_d[dst_idx + j/2 + 0] = v[0]; + data_d[dst_idx + j/2 + QUANT_K/2 + 0] = v[1]; + data_d[dst_idx + j/2 + 1] = v[2]; + data_d[dst_idx + j/2 + QUANT_K/2 + 1] = v[3]; +#else + data_d[dst_idx + j + 0] = v[0]; + data_d[dst_idx + j + 1] = v[1]; + data_d[dst_idx + j + 2] = v[2]; + data_d[dst_idx + j + 3] = v[3]; +#endif + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp new file mode 100644 index 000000000..c813f1404 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp @@ -0,0 +1,237 @@ +#version 450 + +#include "types.comp" +#include "generic_unary_head.comp" + +#if defined(DATA_A_IQ4_NL) +// 16 invocations needed for init_iq4nl_shmem +layout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in; +#else +layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; +#endif + +layout (binding = 0) readonly buffer S {float data_s[];}; +layout (binding = 1) writeonly buffer Q {A_TYPE data_q[];}; + +#if defined(DATA_A_Q4_0) +void quantize(uint dst_idx, uint src_idx) +{ + float amax = 0.0; + float vmax = 0.0; + + [[unroll]] for (int j = 0; j < QUANT_K_Q4_0; ++j) { + const float v = data_s[src_idx + j]; + if (amax < abs(v)) { + amax = abs(v); + vmax = v; + } + } + + const float d = vmax / -8; + const float id = (d != 0.0) ? 1.0/d : 0.0; + + data_q[dst_idx].d = float16_t(d); + + [[unroll]] for (int j = 0; j < QUANT_K_Q4_0/2; ++j) { + const float x0 = data_s[src_idx + 0 + j]*id; + const float x1 = data_s[src_idx + QUANT_K_Q4_0/2 + j]*id; + + const uint xi0 = min(15, int(x0 + 8.5)); + const uint xi1 = min(15, int(x1 + 8.5)); + + data_q[dst_idx].qs[j] = uint8_t(xi0 | (xi1 << 4)); + } +} +#endif + +#if defined(DATA_A_Q4_1) +void quantize(uint dst_idx, uint src_idx) +{ + float vmin = 1.0/0.0; + float vmax = -vmin; + + [[unroll]] for (int j = 0; j < QUANT_K_Q4_1; ++j) { + const float v = data_s[src_idx + j]; + + if (v < vmin) vmin = v; + if (v > vmax) vmax = v; + } + + const float d = (vmax - vmin) / ((1 << 4) - 1); + const float id = (d != 0.0) ? 1.0/d : 0.0; + + data_q[dst_idx].d = float16_t(d); + data_q[dst_idx].m = float16_t(vmin); + + [[unroll]] for (int j = 0; j < QUANT_K_Q4_1/2; ++j) { + const float x0 = (data_s[src_idx + 0 + j] - vmin)*id; + const float x1 = (data_s[src_idx + QUANT_K_Q4_1/2 + j] - vmin)*id; + + const uint xi0 = min(15, int(x0 + 0.5)); + const uint xi1 = min(15, int(x1 + 0.5)); + + data_q[dst_idx].qs[j] = uint8_t(xi0 | (xi1 << 4)); + } +} +#endif + +#if defined(DATA_A_Q5_0) +void quantize(uint dst_idx, uint src_idx) +{ + float amax = 0.0; + float vmax = 0.0; + + [[unroll]] for (int j = 0; j < QUANT_K_Q5_0; ++j) { + const float v = data_s[src_idx + j]; + if (amax < abs(v)) { + amax = abs(v); + vmax = v; + } + } + + const float d = vmax / -16; + const float id = (d != 0.0) ? 1.0/d : 0.0; + + data_q[dst_idx].d = float16_t(d); + + uint32_t qh = 0; + [[unroll]] for (int j = 0; j < QUANT_K_Q5_0/2; ++j) { + const float x0 = data_s[src_idx + 0 + j]*id; + const float x1 = data_s[src_idx + QUANT_K_Q5_0/2 + j]*id; + + const uint xi0 = min(31, int(x0 + 16.5)); + const uint xi1 = min(31, int(x1 + 16.5)); + + data_q[dst_idx].qs[j] = uint8_t((xi0 & 0xf) | ((xi1 & 0xf) << 4)); + qh |= ((xi0 & 0x10u) >> 4) << (j + 0); + qh |= ((xi1 & 0x10u) >> 4) << (j + QUANT_K_Q5_0/2); + } + data_q[dst_idx].qh[0] = uint16_t(qh & 0xFFFF); + data_q[dst_idx].qh[1] = uint16_t(qh >> 16); +} +#endif + +#if defined(DATA_A_Q5_1) +void quantize(uint dst_idx, uint src_idx) +{ + float min = data_s[src_idx + 0]; + float max = min; + + [[unroll]] for (int j = 1; j < QUANT_K_Q5_1; ++j) { + const float v = data_s[src_idx + j]; + min = v < min ? v : min; + max = v > max ? v : max; + } + + const float d = (max - min) / 31; + const float id = (d != 0) ? 1.0/d : 0.0; + + data_q[dst_idx].d = float16_t(d); + data_q[dst_idx].m = float16_t(min); + + uint32_t qh = 0; + [[unroll]] for (int j = 0; j < QUANT_K_Q5_1/2; ++j) { + const float x0 = (data_s[src_idx + 0 + j] - min)*id; + const float x1 = (data_s[src_idx + QUANT_K_Q5_1/2 + j] - min)*id; + + const uint xi0 = uint(x0 + 0.5); + const uint xi1 = uint(x1 + 0.5); + + data_q[dst_idx].qs[j] = uint8_t((xi0 & 0xf) | ((xi1 & 0xf) << 4)); + qh |= ((xi0 & 0x10u) >> 4) << (j + 0); + qh |= ((xi1 & 0x10u) >> 4) << (j + QUANT_K_Q5_1/2); + } + data_q[dst_idx].qh = qh; +} +#endif + +#if defined(DATA_A_Q8_0) +void quantize(uint dst_idx, uint src_idx) +{ + float amax = 0.0; // absolute max + + [[unroll]] for (int j = 0; j < QUANT_K_Q8_0; j++) { + const float v = data_s[src_idx + j]; + amax = max(amax, abs(v)); + } + + const float d = amax / ((1 << 7) - 1); + const float id = (d != 0.0) ? 1.0/d : 0.0; + + data_q[dst_idx].d = float16_t(d); + + [[unroll]] for (int j = 0; j < QUANT_K_Q8_0; ++j) { + const float x0 = data_s[src_idx + j]*id; + + data_q[dst_idx].qs[j] = int8_t(round(x0)); + } +} +#endif + +#if defined(DATA_A_IQ4_NL) +uint best_index(float x) { + if (x <= kvalues_iq4nl[0]) return 0; + if (x >= kvalues_iq4nl[15]) return 15; + int ml = 0, mu = 15; + while (mu-ml > 1) { + int mav = (ml+mu)/2; + if (x < kvalues_iq4nl[mav]) mu = mav; else ml = mav; + } + return x - kvalues_iq4nl[mu-1] < kvalues_iq4nl[mu] - x ? mu-1 : mu; +} + +void quantize(uint dst_idx, uint src_idx) +{ + float amax = 0.0; + float vmax = 0.0; + + [[unroll]] for (int j = 0; j < QUANT_K_IQ4_NL; ++j) { + const float v = data_s[src_idx + j]; + if (amax < abs(v)) { + amax = abs(v); + vmax = v; + } + } + + float d = vmax / kvalues_iq4nl[0]; + const float id = (d != 0.0) ? 1.0/d : 0.0; + + float sumqx = 0, sumq2 = 0; + [[unroll]] for (int j = 0; j < QUANT_K_IQ4_NL/2; ++j) { + const float x0 = data_s[src_idx + 0 + j]*id; + const float x1 = data_s[src_idx + QUANT_K_IQ4_NL/2 + j]*id; + const uint xi0 = best_index(x0); + const uint xi1 = best_index(x1); + data_q[dst_idx].qs[j] = uint8_t(xi0 | (xi1 << 4)); + const float v0 = kvalues_iq4nl[xi0]; + const float v1 = kvalues_iq4nl[xi1]; + const float w0 = data_s[src_idx + 0 + j]*data_s[src_idx + 0 + j]; + const float w1 = data_s[src_idx + QUANT_K_IQ4_NL/2 + j]*data_s[src_idx + QUANT_K_IQ4_NL/2 + j]; + sumqx += w0*v0*data_s[src_idx + j] + w1*v1*data_s[src_idx + QUANT_K_IQ4_NL/2 + j]; + sumq2 += w0*v0*v0 + w1*v1*v1; + } + + data_q[dst_idx].d = float16_t(sumq2 > 0 ? sumqx/sumq2 : d); + +} +#endif + +void main() { +#ifdef NEEDS_INIT_IQ_SHMEM + init_iq_shmem(gl_WorkGroupSize); + if (gl_LocalInvocationIndex.x != 0) { + return; + } +#endif + + const uint idx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * QUANT_K; + + if (idx >= p.ne) { + return; + } + + uint dst_idx = dst_idx_quant(idx, QUANT_K); + uint src_idx = get_aoffset() + src0_idx(idx); + + quantize(dst_idx, src_idx); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp new file mode 100644 index 000000000..d9345497c --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp @@ -0,0 +1,31 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable + +#include "types.comp" +#include "generic_head.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer Y {B_TYPE data_b[];}; +layout (binding = 2) buffer D {D_TYPE data_d[];}; + +const uint CHUNK_SIZE = 512; + +void main() { + const uint base = gl_WorkGroupID.x * CHUNK_SIZE; + const uint col = gl_LocalInvocationID.x; + + uint count = 0; + [[unroll]] + for (uint i = 0; i < CHUNK_SIZE; i += gl_WorkGroupSize.x) { + const uint idx = base + i + col; + if (idx >= p.KX) { + break; + } + count += uint(data_a[idx] == data_b[idx]); + } + + atomicAdd(data_d[0], D_TYPE(count)); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp index 91bb8f8db..10318e876 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp @@ -88,6 +88,335 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) { } #endif +#if defined(DATA_A_IQ1_S) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint ib32 = iqs / 32; + const uint ib8 = iqs / 8; + const int i8 = int(iqs % 8); + const uint qh = data_a[a_offset + ib].qh[ib32]; + const uint qs = data_a[a_offset + ib].qs[ib8]; + const float dl = float(2 * bitfieldExtract(qh, 12, 3) + 1); + const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA; + const uint idxhi = bitfieldExtract(qh, 3 * int(ib8 & 3), 3); + const int16_t grid = int16_t(iq1s_grid[qs | (idxhi << 8)]); + // Signed bitfield extract. + const ivec2 gvec = ivec2( + bitfieldExtract(grid, 2 * (i8), 2), + bitfieldExtract(grid, 2 * (i8 + 1), 2) + ); + return dl * (vec2(gvec) + delta); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint ib32 = iqs / 32; + const uint ib8 = iqs / 8; + const int i8 = int(iqs % 8); + const uint qh = data_a[a_offset + ib].qh[ib32]; + const uint qs = data_a[a_offset + ib].qs[ib8]; + const float dl = 2 * bitfieldExtract(qh, 12, 3) + 1; + const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA; + const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]); + // Signed bitfield extract. + const ivec4 gvec = ivec4( + bitfieldExtract(grid, 2 * (i8), 2), + bitfieldExtract(grid, 2 * (i8 + 1), 2), + bitfieldExtract(grid, 2 * (i8 + 2), 2), + bitfieldExtract(grid, 2 * (i8 + 3), 2) + ); + return dl * (vec4(gvec) + delta); +} +#endif + +#if defined(DATA_A_IQ1_M) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint ib8 = iqs / 8; + const uint ib16 = iqs / 16; + const int i8 = int(iqs % 8); + const uint sc = data_a[a_offset + ib].scales[iqs / 64]; + const uint qs = data_a[a_offset + ib].qs[ib8]; + const uint qh = data_a[a_offset + ib].qh[ib16] >> (4 * (ib8 & 1)); + const float dl = 2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1; + const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA; + const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]); + // Signed bitfield extract. + const ivec2 gvec = ivec2( + bitfieldExtract(grid, 2 * (i8), 2), + bitfieldExtract(grid, 2 * (i8 + 1), 2) + ); + return dl * (vec2(gvec) + delta); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint ib8 = iqs / 8; + const uint ib16 = iqs / 16; + const int i8 = int(iqs % 8); + const uint sc = data_a[a_offset + ib].scales[iqs / 64]; + const uint qs = data_a[a_offset + ib].qs[ib8]; + const uint qh = data_a[a_offset + ib].qh[ib16] >> (4 * (ib8 & 1)); + const float dl = 2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1; + const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA; + const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]); + // Signed bitfield extract. + const ivec4 gvec = ivec4( + bitfieldExtract(grid, 2 * (i8), 2), + bitfieldExtract(grid, 2 * (i8 + 1), 2), + bitfieldExtract(grid, 2 * (i8 + 2), 2), + bitfieldExtract(grid, 2 * (i8 + 3), 2) + ); + return dl * (vec4(gvec) + delta); +} +#endif + +#if defined(DATA_A_IQ2_XXS) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint ib32 = iqs / 32; + const uint ib8 = (iqs / 8) % 4; + const uint qs = data_a[a_offset + ib].qs[8 * ib32 + ib8]; + // Scales are stored as packed 7+7+7+7+4 bits (4 sign tuples and 1 int4 scale) + const uint signs = pack32(u16vec2(data_a_packed16[a_offset + ib].qs[4 * ib32 + 2], + data_a_packed16[a_offset + ib].qs[4 * ib32 + 3])); + const float db = 0.25 * (0.5 + (signs >> 28)); + const uint sign7 = bitfieldExtract(signs, 7 * int(ib8), 7); + // Add parity bit + const uint sign8 = sign7 | (bitCount(sign7) << 7); + const uint sign = sign8 >> (iqs % 8); + const u8vec4 grid = unpack8(iq2xxs_grid[qs][(iqs % 8) / 4] >> (8 * (iqs % 4))); + bool sign0 = (sign & 1) != 0; + bool sign1 = (sign & 2) != 0; + return db * vec2( + grid.x * (sign0 ? -1.0 : 1.0), + grid.y * (sign1 ? -1.0 : 1.0) + ); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint ib32 = iqs / 32; + const uint ib8 = (iqs / 8) % 4; + const uint qs = data_a[a_offset + ib].qs[8 * ib32 + ib8]; + // Scales are stored as packed 7+7+7+7+4 bits (4 sign tuples and 1 int4 scale) + const uint signs = pack32(u16vec2(data_a_packed16[a_offset + ib].qs[4 * ib32 + 2], + data_a_packed16[a_offset + ib].qs[4 * ib32 + 3])); + const float db = 0.25 * (0.5 + (signs >> 28)); + const uint sign7 = bitfieldExtract(signs, 7 * int(ib8), 7); + // Add parity bit + const uint sign8 = sign7 | (bitCount(sign7) << 7); + const uint sign = sign8 >> (iqs % 8); + const u8vec4 grid = unpack8(iq2xxs_grid[qs][(iqs % 8) / 4] >> (8 * (iqs % 4))); + bool sign0 = (sign & 1) != 0; + bool sign1 = (sign & 2) != 0; + bool sign2 = (sign & 4) != 0; + bool sign3 = (sign & 8) != 0; + return db * vec4( + grid.x * (sign0 ? -1.0 : 1.0), + grid.y * (sign1 ? -1.0 : 1.0), + grid.z * (sign2 ? -1.0 : 1.0), + grid.w * (sign3 ? -1.0 : 1.0) + ); +} +#endif + +#if defined(DATA_A_IQ2_XS) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint scale = (data_a[a_offset + ib].scales[iqs / 32] >> (4 * ((iqs / 16) & 1))) & 0xf; + const uint qs = data_a[a_offset + ib].qs[iqs / 8]; + const float db = 0.25 * (0.5 + scale); + const uint sign7 = qs >> 9; + // Add parity bit + const uint sign8 = sign7 | (bitCount(sign7) << 7); + const uint sign = sign8 >> (iqs % 8); + const u8vec4 grid = unpack8(iq2xs_grid[qs & 511][(iqs % 8) / 4] >> (8 * (iqs % 4))); + bool sign0 = (sign & 1) != 0; + bool sign1 = (sign & 2) != 0; + return db * vec2( + grid.x * (sign0 ? -1.0 : 1.0), + grid.y * (sign1 ? -1.0 : 1.0) + ); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint scale = (data_a[a_offset + ib].scales[iqs / 32] >> (4 * ((iqs / 16) & 1))) & 0xf; + const uint qs = data_a[a_offset + ib].qs[iqs / 8]; + const float db = 0.25 * (0.5 + scale); + const uint sign7 = qs >> 9; + // Add parity bit + const uint sign8 = sign7 | (bitCount(sign7) << 7); + const uint sign = sign8 >> (iqs % 8); + const u8vec4 grid = unpack8(iq2xs_grid[qs & 511][(iqs % 8) / 4] >> (8 * (iqs % 4))); + bool sign0 = (sign & 1) != 0; + bool sign1 = (sign & 2) != 0; + bool sign2 = (sign & 4) != 0; + bool sign3 = (sign & 8) != 0; + return db * vec4( + grid.x * (sign0 ? -1.0 : 1.0), + grid.y * (sign1 ? -1.0 : 1.0), + grid.z * (sign2 ? -1.0 : 1.0), + grid.w * (sign3 ? -1.0 : 1.0) + ); +} +#endif + +#if defined(DATA_A_IQ2_S) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint ib32 = iqs / 32; + const uint ib8 = iqs / 8; + + const uint scale = (data_a[a_offset + ib].scales[ib32] >> (4 * ((iqs / 16) & 1))) & 0xf; + const uint qs = data_a[a_offset + ib].qs[ib8]; + const uint qh = data_a[a_offset + ib].qh[ib32]; + const uint qhshift = 2 * (ib8 % 4); + const uint sign = data_a[a_offset + ib].qs[QUANT_K / 8 + ib8] >> (iqs % 8); + + const float db = 0.25 * (0.5 + scale); + const u8vec4 grid = unpack8(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(iqs % 8) / 4]); + bool sign0 = (sign & 1) != 0; + bool sign1 = (sign & 2) != 0; + return db * vec2( + grid[iqs % 4] * (sign0 ? -1.0 : 1.0), + grid[(iqs % 4) + 1] * (sign1 ? -1.0 : 1.0) + ); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint ib32 = iqs / 32; + const uint ib8 = iqs / 8; + + const uint scale = (data_a[a_offset + ib].scales[ib32] >> (4 * ((iqs / 16) & 1))) & 0xf; + const uint qs = data_a[a_offset + ib].qs[ib8]; + const uint qh = data_a[a_offset + ib].qh[ib32]; + const uint qhshift = 2 * (ib8 % 4); + const uint sign = data_a[a_offset + ib].qs[QUANT_K / 8 + ib8] >> (iqs % 8); + + const float db = 0.25 * (0.5 + scale); + const u8vec4 grid = unpack8(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(iqs % 8) / 4]); + bool sign0 = (sign & 1) != 0; + bool sign1 = (sign & 2) != 0; + bool sign2 = (sign & 4) != 0; + bool sign3 = (sign & 8) != 0; + return db * vec4( + grid.x * (sign0 ? -1.0 : 1.0), + grid.y * (sign1 ? -1.0 : 1.0), + grid.z * (sign2 ? -1.0 : 1.0), + grid.w * (sign3 ? -1.0 : 1.0) + ); +} +#endif + +#if defined(DATA_A_IQ3_XXS) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint ib4 = iqs / 4; + const uint ib32 = iqs / 32; + const uint is = QUANT_K / 4 + 4 * ib32; + const uint qs = data_a[a_offset + ib].qs[ib4]; + // Scales are stored as packed 7+7+7+7+4 bits (4 sign tuples and 1 int4 scale) + const uint signs = pack32(u16vec2(data_a_packed16[a_offset + ib].qs[is / 2], + data_a_packed16[a_offset + ib].qs[is / 2 + 1])); + const float db = 0.5 * (0.5 + (signs >> 28)); + const uint sign7 = bitfieldExtract(signs, 7 * (int(ib4 / 2) % 4), 7); + // Add parity bit + const uint sign8 = sign7 | (bitCount(sign7) << 7); + const uint sign = sign8 >> (iqs % 8); + const u8vec4 grid = unpack8(iq3xxs_grid[qs] >> (8 * (iqs % 4))); + bool sign0 = (sign & 1) != 0; + bool sign1 = (sign & 2) != 0; + return db * vec2( + grid.x * (sign0 ? -1.0 : 1.0), + grid.y * (sign1 ? -1.0 : 1.0) + ); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint ib4 = iqs / 4; + const uint ib32 = iqs / 32; + const uint is = QUANT_K / 4 + 4 * ib32; + const uint qs = data_a[a_offset + ib].qs[ib4]; + const uint signs = pack32(u16vec2(data_a_packed16[a_offset + ib].qs[is / 2], + data_a_packed16[a_offset + ib].qs[is / 2 + 1])); + const float db = 0.5 * (0.5 + (signs >> 28)); + const uint sign7 = bitfieldExtract(signs, 7 * (int(ib4 / 2) % 4), 7); + // Add parity bit + const uint sign8 = sign7 | (bitCount(sign7) << 7); + const uint sign = sign8 >> (iqs % 8); + const u8vec4 grid = unpack8(iq3xxs_grid[qs]); + bool sign0 = (sign & 1) != 0; + bool sign1 = (sign & 2) != 0; + bool sign2 = (sign & 4) != 0; + bool sign3 = (sign & 8) != 0; + return db * vec4( + grid.x * (sign0 ? -1.0 : 1.0), + grid.y * (sign1 ? -1.0 : 1.0), + grid.z * (sign2 ? -1.0 : 1.0), + grid.w * (sign3 ? -1.0 : 1.0) + ); +} +#endif + +#if defined(DATA_A_IQ3_S) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint qs = data_a[a_offset + ib].qs[iqs / 4]; + const uint qh = data_a[a_offset + ib].qh[iqs / 32]; + const uint sign = data_a[a_offset + ib].signs[iqs / 8] >> (iqs % 8); + const uint scale = data_a[a_offset + ib].scales[iqs / 64]; + bool sign0 = (sign & 1) != 0; + bool sign1 = (sign & 2) != 0; + const float db = 1 + 2 * ((scale >> (4 * ((iqs / 32) & 1))) & 0xf); + const uint32_t grid = iq3s_grid[qs | ((qh << (8 - ((iqs / 4) % 8))) & 256)] >> (8 * (iqs % 4)); + return db * vec2( + int(grid & 0xFF) * (sign0 ? -1.0 : 1.0), + int((grid >> 8) & 0xFF) * (sign1 ? -1.0 : 1.0) + ); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint ib4 = iqs / 4; + const uint ib32 = iqs / 32; + const uint qs = data_a[a_offset + ib].qs[ib4]; + const uint qh = data_a[a_offset + ib].qh[ib32]; + const uint sign = data_a[a_offset + ib].signs[iqs / 8] >> (iqs % 8); + const uint scale = data_a[a_offset + ib].scales[ib32 / 2]; + bool sign0 = (sign & 1) != 0; + bool sign1 = (sign & 2) != 0; + bool sign2 = (sign & 4) != 0; + bool sign3 = (sign & 8) != 0; + const float db = 1 + 2 * ((scale >> (4 * (ib32 & 1))) & 0xf); + const uint32_t grid = iq3s_grid[qs | ((qh << (8 - ib4 % 8)) & 256)] >> (8 * (iqs % 4)); + return db * vec4( + int(grid & 0xFF) * (sign0 ? -1.0 : 1.0), + int((grid >> 8) & 0xFF) * (sign1 ? -1.0 : 1.0), + int((grid >> 16) & 0xFF) * (sign2 ? -1.0 : 1.0), + int((grid >> 24) & 0xFF) * (sign3 ? -1.0 : 1.0) + ); +} +#endif + +#if defined(DATA_A_IQ4_XS) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint ib32 = iqs / 32; + const uint iq = 16 * ib32 + (iqs % 16); + + const uint sl = (data_a[a_offset + ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF; + const uint sh = (data_a[a_offset + ib].scales_h >> (2 * ib32)) & 3; + const uint qshift = (iqs & 16) >> 2; + u8vec2 qs = u8vec2(data_a[a_offset + ib].qs[iq], data_a[a_offset + ib].qs[iq + 1]); + qs = (qs >> qshift) & uint8_t(0xF); + + const float dl = float(int(sl | (sh << 4)) - 32); + return dl * vec2(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y]); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint ib32 = iqs / 32; + const uint iq = 16 * ib32 + (iqs % 16); + + const uint sl = (data_a[a_offset + ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF; + const uint sh = (data_a[a_offset + ib].scales_h >> (2 * ib32)) & 3; + const uint qshift = (iqs & 16) >> 2; + u8vec4 qs = u8vec4( + data_a[a_offset + ib].qs[iq + 0], + data_a[a_offset + ib].qs[iq + 1], + data_a[a_offset + ib].qs[iq + 2], + data_a[a_offset + ib].qs[iq + 3] + ); + qs = (qs >> qshift) & uint8_t(0xF); + + const float dl = float(int(sl | (sh << 4)) - 32); + return dl * vec4( + kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y], + kvalues_iq4nl[qs.z], kvalues_iq4nl[qs.w]); +} +#endif + #if defined(DATA_A_IQ4_NL) vec2 dequantize(uint ib, uint iqs, uint a_offset) { const uint vui = uint(data_a[a_offset + ib].qs[iqs]); @@ -105,7 +434,16 @@ vec2 get_dm(uint ib, uint a_offset) { } #endif -#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ4_NL) +#if defined(DATA_A_IQ1_M) +vec2 get_dm(uint ib, uint a_offset) { + const uint16_t[4] scales = data_a[a_offset + ib].scales; + const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12; + const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x); + return vec2(d, 0); +} +#endif + +#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL) vec2 get_dm(uint ib, uint a_offset) { return vec2(float(data_a[a_offset + ib].d), 0); } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp index 94b78598e..4770469ed 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp @@ -101,19 +101,25 @@ layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ2_ block_q2_K block; }; +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ2_K_packed16 { + block_q2_K_packed16 block; +}; + float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { + decodeBufQ2_K_packed16 bl16 = decodeBufQ2_K_packed16(bl); const f16vec2 d = bl.block.d; const uint idx = coordInBlock[1]; - const uint iqs = idx; - const uint qsi = (iqs / 128) * 32 + (iqs % 32); // 0..31 - const uint scalesi = iqs / 16; // 0..15 - const uint qsshift = ((iqs % 128) / 32) * 2; // 0,2,4,6 + const uint scalesi = (idx & 0xF0) >> 4; // 0..15 + const uint qsshift = (idx & 0x60) >> 4; // 0,2,4,6 + + uint qs = uint32_t(bl16.block.qs[((idx & 0x80) >> 3) + ((idx & 0x1E) >> 1)]); + qs = (qs >> qsshift) & 0x0303; + qs = unpack8(qs)[idx & 1]; - uint32_t qs = bl.block.qs[qsi]; const uint scales = bl.block.scales[scalesi]; - float16_t ret = d.x * float16_t(scales & 0xF) * float16_t((qs >> qsshift) & 3) - d.y * float16_t(scales >> 4); + float16_t ret = d.x * float16_t(scales & 0xF) * float16_t(qs) - d.y * float16_t(scales >> 4); return ret; } @@ -157,39 +163,47 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4 block_q4_K_packed16 block; }; +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K_packed128 { + block_q4_K_packed128 block; +}; + float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { decodeBufQ4_K_packed16 bl16 = decodeBufQ4_K_packed16(bl); + decodeBufQ4_K_packed128 bl128 = decodeBufQ4_K_packed128(bl); const uint idx = coordInBlock[1]; const uint b = (idx & 0x20) >> 5; // 0,1 const uint is = (idx & 0xE0) >> 5; // 0..7 - const f16vec2 loadd = bl.block.d; + uvec4 v = bl128.block.q4k[0]; + + const f16vec2 loadd = unpackFloat2x16(v.x); uint32_t sc; uint32_t mbyte; - uint32_t scidx0 = (is < 4) ? is : (is + 4); - uint32_t scidx1 = (is < 4) ? is : (is - 4); - uint32_t scidxmask1 = (is < 4) ? 0x30 : 0xC0; - uint32_t scidxshift1 = (is < 4) ? 0 : 2; - uint32_t mbidx0 = is + 4; - uint32_t mbidx1 = (is < 4) ? is + 4 : is; - uint32_t mbidxmask0 = (is < 4) ? 0xF : 0xF0; - uint32_t mbidxshift0 = (is < 4) ? 0 : 4; - uint32_t mbidxmask1 = (is < 4) ? 0x30 : 0xC0; - uint32_t mbidxshift1 = (is < 4) ? 0 : 2; + uint32_t scale0 = v.y; + uint32_t scale4 = v.z; + uint32_t scale8 = v.w; - sc = uint8_t((bl.block.scales[scidx0] & 0xF) | ((bl.block.scales[scidx1] & scidxmask1) >> scidxshift1)); - mbyte = uint8_t(((bl.block.scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((bl.block.scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + uint32_t sc_lo = scale0; + uint32_t mb_lo = scale4; + uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2); + uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2); + + sc = is < 4 ? sc_lo : sc_hi; + mbyte = is < 4 ? mb_lo : mb_hi; + sc = sc >> (8 * (is & 3)); + mbyte = mbyte >> (8 * (is & 3)); + sc &= 0x3F; + mbyte &= 0x3F; const float16_t d = loadd.x * float16_t(sc); const float16_t m = loadd.y * float16_t(mbyte); uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]); - qs = (qs >> (b * 4)) & 0x0F0F; - qs = unpack8(qs)[idx & 1]; + qs = (qs >> (b * 4 + 8 * (idx & 1))) & 0xF; float16_t ret = d * float16_t(qs) - m; @@ -204,47 +218,53 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5 block_q5_K_packed16 block; }; +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K_packed128 { + block_q5_K_packed128 block; +}; + float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { decodeBufQ5_K_packed16 bl16 = decodeBufQ5_K_packed16(bl); + decodeBufQ5_K_packed128 bl128 = decodeBufQ5_K_packed128(bl); const uint idx = coordInBlock[1]; const uint b = (idx & 0x20) >> 5; // 0,1 const uint is = (idx & 0xE0) >> 5; // 0..7 - const uint32_t hm = 0x0101 << is; + uvec4 v = bl128.block.q5k[0]; - const f16vec2 loadd = bl.block.d; + const f16vec2 loadd = unpackFloat2x16(v.x); uint32_t sc; uint32_t mbyte; - uint32_t scidx0 = (is < 4) ? is : (is + 4); - uint32_t scidx1 = (is < 4) ? is : (is - 4); - uint32_t scidxmask1 = (is < 4) ? 0x30 : 0xC0; - uint32_t scidxshift1 = (is < 4) ? 0 : 2; - uint32_t mbidx0 = is + 4; - uint32_t mbidx1 = (is < 4) ? is + 4 : is; - uint32_t mbidxmask0 = (is < 4) ? 0xF : 0xF0; - uint32_t mbidxshift0 = (is < 4) ? 0 : 4; - uint32_t mbidxmask1 = (is < 4) ? 0x30 : 0xC0; - uint32_t mbidxshift1 = (is < 4) ? 0 : 2; + uint32_t scale0 = v.y; + uint32_t scale4 = v.z; + uint32_t scale8 = v.w; - sc = uint8_t((bl.block.scales[scidx0] & 0xF) | ((bl.block.scales[scidx1] & scidxmask1) >> scidxshift1)); - mbyte = uint8_t(((bl.block.scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((bl.block.scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + uint32_t sc_lo = scale0; + uint32_t mb_lo = scale4; + uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2); + uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2); + + sc = is < 4 ? sc_lo : sc_hi; + mbyte = is < 4 ? mb_lo : mb_hi; + sc = sc >> (8 * (is & 3)); + mbyte = mbyte >> (8 * (is & 3)); + sc &= 0x3F; + mbyte &= 0x3F; const float16_t d = loadd.x * float16_t(sc); const float16_t m = loadd.y * float16_t(mbyte); uint qh = uint32_t(bl16.block.qh[(idx & 0x1E) >> 1]); - qh = qh & hm; - qh = unpack8(qh)[idx & 1]; + qh = ((qh >> is) & 0x101) << 4; uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]); qs = (qs >> (b * 4)) & 0x0F0F; - qs = unpack8(qs)[idx & 1]; + qs = unpack8(qs | qh)[idx & 1]; - float16_t ret = d * (float16_t(qs) + (qh != 0 ? float16_t(16) : float16_t(0))) - m; + float16_t ret = d * (float16_t(qs)) - m; return ret; } @@ -281,6 +301,228 @@ float16_t dequantFuncQ6_K(const in decodeBufQ6_K bl, const in uint blockCoords[2 return ret; } +#if defined(DATA_A_IQ1_S) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ1_S { + block_iq1_s block; +}; + +float16_t dequantFuncIQ1_S(const in decodeBufIQ1_S bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + + const uint ib32 = idx / 32; + const uint ib8 = idx / 8; + + const uint qh = bl.block.qh[ib32]; + const uint qs = bl.block.qs[ib8]; + const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1); + const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA; + const uint grid = iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]; + + float16_t ret = float16_t(dl) * (float16_t(bitfieldExtract(int(grid), 2 * int(idx % 8), 2)) + float16_t(delta)); + return ret; +} +#endif + +#if defined(DATA_A_IQ1_M) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ1_M { + block_iq1_m block; +}; + +float16_t dequantFuncIQ1_M(const in decodeBufIQ1_M bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const u16vec4 scales = u16vec4(bl.block.scales[0], bl.block.scales[1], bl.block.scales[2], bl.block.scales[3]) >> 12; + const float16_t d = uint16BitsToHalf(scales.x | (scales.y << 4) | (scales.z << 8) | (scales.w << 12)); + const uint idx = coordInBlock[1]; + + const uint ib8 = idx / 8; + const uint ib16 = idx / 16; + const int i8 = int(idx % 8); + const uint sc = bl.block.scales[ib8 / 8]; + const uint qs = bl.block.qs[ib8]; + const uint qh = bl.block.qh[ib16] >> (4 * (ib8 & 1)); + const float dl = 2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1; + const float delta = ((qh & 8) != 0) ? -IQ1S_DELTA : IQ1S_DELTA; + const uint grid = iq1s_grid[qs | ((qh & 7) << 8)]; + + float16_t ret = d * float16_t(dl) * (float16_t(bitfieldExtract(int(grid), 2 * i8, 2)) + float16_t(delta)); + return ret; +} +#endif + +#if defined(DATA_A_IQ2_XXS) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_XXS { + block_iq2_xxs block; +}; + +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_XXS_packed16 { + block_iq2_xxs_packed16 block; +}; + +float16_t dequantFuncIQ2_XXS(const in decodeBufIQ2_XXS bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufIQ2_XXS_packed16 bl16 = decodeBufIQ2_XXS_packed16(bl); + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + + const uint ib32 = (idx & 0xE0) >> 5; // 0..7 + const uint ib8 = (idx & 0x18) >> 3; // 0..3 + const uint iqs = 8 * ib32 + ib8; + + const uint8_t qs = bl.block.qs[iqs]; + const uint signscale = pack32(u16vec2(bl16.block.qs[4*ib32+2], bl16.block.qs[4*ib32+3])); + + const float dscale = float(bl.block.d) * 0.25 * (0.5 + float(signscale >> 28)); + uint sign = bitfieldExtract(signscale, 7 * int(ib8), 7); + sign |= bitCount(sign) << 7; + + uint g2 = iq2xxs_grid[qs][(idx & 4) >> 2]; + g2 >>= (idx & 2) * 8; + const vec2 g = vec2(unpack8(g2)); + + vec2 ret = dscale * g * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf); + return float16_t(ret[idx & 1]); +} +#endif + +#if defined(DATA_A_IQ2_XS) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_XS { + block_iq2_xs block; +}; + +float16_t dequantFuncIQ2_XS(const in decodeBufIQ2_XS bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + + const uint is = (idx & 0xE0) >> 5; // 0..8 + const uint sshift = (idx & 0x10) >> 2; // 0,4 + const uint iqs = (idx & 0xF8) >> 3; // 0..63 + + const uint16_t qs = bl.block.qs[iqs]; + const float dscale = float(bl.block.d) * 0.25 * (0.5 + float((bl.block.scales[is] >> sshift) & 0xF)); + + uint sign = uint(qs >> 9); + sign |= bitCount(sign) << 7; + uint g2 = iq2xs_grid[qs & 0x1FF][(idx & 4) >> 2]; + g2 >>= (idx & 2) * 8; + const vec2 g = vec2(unpack8(g2)); + + vec2 ret = dscale * g * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf); + return float16_t(ret[idx & 1]); +} +#endif + +#if defined(DATA_A_IQ2_S) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_S { + block_iq2_s block; +}; + +float16_t dequantFuncIQ2_S(const in decodeBufIQ2_S bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + uint idx = coordInBlock[1]; + + const uint ib32 = (idx & 0xE0) >> 5; // 0..7 + const uint ib8 = (idx & 0xF8) >> 3; // 0..31 + const uint qhshift = 2 * (ib8 % 4); + + const uint scale = (bl.block.scales[ib32] >> ((idx & 0x10) >> 2)) & 0xf; + const uint qs = bl.block.qs[ib8]; + const uint qh = bl.block.qh[ib32]; + const uint sign = bl.block.qs[QUANT_K / 8 + ib8] >> (idx & 0x6); + + const float d = float(bl.block.d); + const float db = d * 0.25 * (0.5 + scale); + const ivec2 sign01 = 1 - (2 & ivec2(sign << 1, sign)); + uint g2 = iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 4) >> 2]; + g2 >>= (idx & 2) * 8; + const vec2 v = db * vec2(sign01) * vec2(unpack8(g2)); + return float16_t(v[idx & 1]); +} +#endif + +#if defined(DATA_A_IQ3_XXS) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ3_XXS { + block_iq3_xxs block; +}; + +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ3_XXS_packed16 { + block_iq3_xxs_packed16 block; +}; + +float16_t dequantFuncIQ3_XXS(const in decodeBufIQ3_XXS bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufIQ3_XXS_packed16 bl16 = decodeBufIQ3_XXS_packed16(bl); + uint idx = coordInBlock[1]; + + const uint iqs = (idx & 0xFC) >> 2; // 0..63 + const uint is = QUANT_K / 4 + ((idx & 0xE0) >> 3);// 8 values + + const float d = float(bl.block.d); + const uint qs = bl.block.qs[iqs]; + const uint signs = pack32(u16vec2( + bl16.block.qs[is/2+0], + bl16.block.qs[is/2+1] + )); + const float db = d * 0.5 * (0.5 + (signs >> 28)); + const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7); + const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (idx & 0x6); + const ivec2 sign01 = ivec2(1 - (2 & ivec2(sign << 1, sign))); + const uint grid = iq3xxs_grid[qs] >> (16 * ((idx & 2) >> 1)); + const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); + return float16_t(v[idx & 1]); +} +#endif + +#if defined(DATA_A_IQ3_S) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ3_S { + block_iq3_s block; +}; + +float16_t dequantFuncIQ3_S(const in decodeBufIQ3_S bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + uint idx = coordInBlock[1]; + + const uint iqs = (idx & 0xFC) >> 2; // 0..63 + const uint iqh = (idx & 0xE0) >> 5; + + const float d = float(bl.block.d); + const uint qs = bl.block.qs[iqs]; + const uint qh = bl.block.qh[iqh]; + const int8_t sign = int8_t(bl.block.signs[iqs / 2] >> (idx & 0x6)); + const uint scale = bl.block.scales[iqs / 16]; + const ivec2 sign01 = ivec2(1 - (2 & ivec2(sign << 1, sign))); + const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf)); + const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> ((idx & 2) << 3); + const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); + + return float16_t(v[idx & 1]); +} +#endif + +#if defined(DATA_A_IQ4_XS) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4_XS { + block_iq4_xs block; +}; + +float16_t dequantFuncIQ4_XS(const in decodeBufIQ4_XS bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + + const uint ib32 = (idx & 0xE0) >> 5; // 0..7 + + const uint sl = (bl.block.scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF; + const uint sh = ((bl.block.scales_h) >> (2 * ib32)) & 3; + const uint qshift = (idx & 16) >> 2; + const uint q = (bl.block.qs[16 * ib32 + (idx % 16)] >> qshift) & 0xF; + + float16_t ret = d * float16_t(int(sl | (sh << 4)) - 32) * float16_t(kvalues_iq4nl[q]); + return ret; +} +#endif + #if defined(DATA_A_IQ4_NL) layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4_NL { block_iq4_nl block; @@ -320,6 +562,22 @@ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoor #define dequantFuncA dequantFuncQ5_K #elif defined(DATA_A_Q6_K) #define dequantFuncA dequantFuncQ6_K +#elif defined(DATA_A_IQ1_S) +#define dequantFuncA dequantFuncIQ1_S +#elif defined(DATA_A_IQ1_M) +#define dequantFuncA dequantFuncIQ1_M +#elif defined(DATA_A_IQ2_XXS) +#define dequantFuncA dequantFuncIQ2_XXS +#elif defined(DATA_A_IQ2_XS) +#define dequantFuncA dequantFuncIQ2_XS +#elif defined(DATA_A_IQ2_S) +#define dequantFuncA dequantFuncIQ2_S +#elif defined(DATA_A_IQ3_XXS) +#define dequantFuncA dequantFuncIQ3_XXS +#elif defined(DATA_A_IQ3_S) +#define dequantFuncA dequantFuncIQ3_S +#elif defined(DATA_A_IQ4_XS) +#define dequantFuncA dequantFuncIQ4_XS #elif defined(DATA_A_IQ4_NL) #define dequantFuncA dequantFuncIQ4_NL #endif diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp new file mode 100644 index 000000000..39184ef58 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp @@ -0,0 +1,42 @@ +#version 450 + +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require + +#include "dequant_head.comp" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_iq1_m data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + // Each thread handles 1 subblock (32 values with 2 scales) + const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8; + + init_iq_shmem(gl_WorkGroupSize); + + if (ib >= p.nel / 256) { + return; + } + + const uint ib32 = gl_LocalInvocationID.x % 8; + const uint ib64 = ib32 / 2; + const uint b_idx = 256 * ib + 32 * ib32; + + const uint16_t[4] scales = data_a[ib].scales; + const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12; + const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x); + + const uint sc = data_a[ib].scales[ib64]; + [[unroll]] for (int l = 0; l < 4; ++l) { + const uint ib16 = 2 * ib32 + l / 2; + const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1); + const uint qh = data_a[ib].qh[ib16] >> (4 * (l & 1)); + const uint qs = data_a[ib].qs[4 * ib32 + l]; + const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA; + const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]); + [[unroll]] for (int j = 0; j < 8; ++j) { + data_b[b_idx + 8 * l + j] = D_TYPE(dl * (bitfieldExtract(grid, 2*j, 2) + delta)); + } + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp new file mode 100644 index 000000000..fd1e4e30d --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp @@ -0,0 +1,35 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_iq1_s data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + // Each thread handles 1 subblock (32 values with 2 scales) + const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8; + + init_iq_shmem(gl_WorkGroupSize); + + if (ib >= p.nel / 256) { + return; + } + + const uint ib32 = gl_LocalInvocationID.x % 8; + const uint b_idx = 256 * ib + 32 * ib32; + + uint qh = data_a[ib].qh[ib32]; + const float d = float(data_a[ib].d); + const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1); + const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA; + [[unroll]] for (uint l = 0; l < 4; ++l) { + const uint qs = data_a[ib].qs[4 * ib32 + l]; + const uint hi = bitfieldExtract(qh, 3 * int(l), 3); + const int16_t grid = int16_t(iq1s_grid[qs | (hi << 8)]); + [[unroll]] for (int j = 0; j < 8; ++j) { + data_b[b_idx + 8 * l + j] = D_TYPE(dl * (bitfieldExtract(grid, 2*j, 2) + delta)); + } + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp new file mode 100644 index 000000000..48f6b65bc --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp @@ -0,0 +1,44 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_iq2_s data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + // Each thread handles 1 subblock (32 values with 2 scales) + const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8; + + init_iq_shmem(gl_WorkGroupSize); + + if (ib >= p.nel / 256) { + return; + } + + const uint ib32 = gl_LocalInvocationID.x % 8; + const uint b_idx = 256 * ib + 32 * ib32; + + const float d = float(data_a[ib].d); + const vec2 scale = vec2(data_a[ib].scales[ib32] & 0xf, data_a[ib].scales[ib32] >> 4); + const vec2 db = d * (0.5 + scale) * 0.25; + + uint qh = data_a[ib].qh[ib32]; + [[unroll]] for (uint l = 0; l < 4; ++l) { + uint qs = data_a[ib].qs[4 * ib32 + l]; + const uint8_t sign = data_a[ib].qs[QUANT_K / 8 + 4 * ib32 + l]; + qs |= (qh << (8 - 2 * l)) & 0x300; + const uvec2 grid = iq2s_grid[qs & 511]; + const u8vec4 grid0 = unpack8(grid.x); + const u8vec4 grid1 = unpack8(grid.y); + data_b[b_idx + 8 * l + 0] = D_TYPE(db[l/2] * grid0.x * ((sign & 1) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 1] = D_TYPE(db[l/2] * grid0.y * ((sign & 2) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 2] = D_TYPE(db[l/2] * grid0.z * ((sign & 4) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 3] = D_TYPE(db[l/2] * grid0.w * ((sign & 8) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 4] = D_TYPE(db[l/2] * grid1.x * ((sign & 16) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 5] = D_TYPE(db[l/2] * grid1.y * ((sign & 32) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 6] = D_TYPE(db[l/2] * grid1.z * ((sign & 64) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 7] = D_TYPE(db[l/2] * grid1.w * ((sign & 128) != 0 ? -1.0 : 1.0)); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp new file mode 100644 index 000000000..a08331c40 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp @@ -0,0 +1,43 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_iq2_xs data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + // Each thread handles 1 subblock (32 values with 2 scales) + const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8; + + init_iq_shmem(gl_WorkGroupSize); + + if (ib >= p.nel / 256) { + return; + } + + const uint ib32 = gl_LocalInvocationID.x % 8; + const uint b_idx = 256 * ib + 32 * ib32; + + const float d = float(data_a[ib].d); + const vec2 scale = vec2(data_a[ib].scales[ib32] & 0xf, data_a[ib].scales[ib32] >> 4); + const vec2 db = d * (0.5 + scale) * 0.25; + + [[unroll]] for (uint l = 0; l < 4; ++l) { + uint16_t qs = data_a[ib].qs[4 * ib32 + l]; + const uint sign7 = qs >> 9; + const uint sign8 = sign7 | (bitCount(sign7) << 7); // parity bit + const uvec2 grid = iq2xs_grid[qs & 511]; + const u8vec4 grid0 = unpack8(grid.x); + const u8vec4 grid1 = unpack8(grid.y); + data_b[b_idx + 8 * l + 0] = D_TYPE(db[l/2] * grid0.x * ((sign8 & 1) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 1] = D_TYPE(db[l/2] * grid0.y * ((sign8 & 2) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 2] = D_TYPE(db[l/2] * grid0.z * ((sign8 & 4) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 3] = D_TYPE(db[l/2] * grid0.w * ((sign8 & 8) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 4] = D_TYPE(db[l/2] * grid1.x * ((sign8 & 16) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 5] = D_TYPE(db[l/2] * grid1.y * ((sign8 & 32) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 6] = D_TYPE(db[l/2] * grid1.z * ((sign8 & 64) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 7] = D_TYPE(db[l/2] * grid1.w * ((sign8 & 128) != 0 ? -1.0 : 1.0)); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp new file mode 100644 index 000000000..e370690bc --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp @@ -0,0 +1,48 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_iq2_xxs data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + // Each thread handles 1 scale block (32 values) + // Each block is described by 4 lattice indices, 4x7 sign bits and 4 scale bits + const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8; + + init_iq_shmem(gl_WorkGroupSize); + + if (ib >= p.nel / 256) { + return; + } + + const uint is = gl_LocalInvocationID.x % 8; + const uint b_idx = 256 * ib + 32 * is; + + const float d = float(data_a[ib].d); + uint signscale = pack32(u8vec4( + data_a[ib].qs[8*is + 4], + data_a[ib].qs[8*is + 5], + data_a[ib].qs[8*is + 6], + data_a[ib].qs[8*is + 7] + )); + const float db = d * (0.5 + (signscale >> 28)) * 0.25; + + [[unroll]] for (uint l = 0; l < 4; ++l) { + const uint sign7 = bitfieldExtract(signscale, 7 * int(l), 7); + const uint sign8 = sign7 | (bitCount(sign7) << 7); // parity bit + const uvec2 grid = iq2xxs_grid[data_a[ib].qs[8 * is + l]]; + const u8vec4 grid0 = unpack8(grid.x); + const u8vec4 grid1 = unpack8(grid.y); + data_b[b_idx + 8 * l + 0] = D_TYPE(db * grid0.x * ((sign8 & 1) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 1] = D_TYPE(db * grid0.y * ((sign8 & 2) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 2] = D_TYPE(db * grid0.z * ((sign8 & 4) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 3] = D_TYPE(db * grid0.w * ((sign8 & 8) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 4] = D_TYPE(db * grid1.x * ((sign8 & 16) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 5] = D_TYPE(db * grid1.y * ((sign8 & 32) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 6] = D_TYPE(db * grid1.z * ((sign8 & 64) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 7] = D_TYPE(db * grid1.w * ((sign8 & 128) != 0 ? -1.0 : 1.0)); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp new file mode 100644 index 000000000..c3f4bca5d --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp @@ -0,0 +1,39 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_iq3_s data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + // Each thread handles 1 scale nibble. + // Each block contains 4 scale bytes (8 scales) for 256 output values. + const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8; + + init_iq_shmem(gl_WorkGroupSize); + + if (ib >= p.nel / 256) { + return; + } + + const uint is = gl_LocalInvocationID.x % 8; + const uint b_idx = 256 * ib + 32 * is; + + const float d = float(data_a[ib].d); + const float db = d * (1 + 2 * ((data_a[ib].scales[is] >> (4 * (is % 2))) & 0xf)); + + // We must produce 32 values using 4 sign bytes, 1 qh byte, 8 qs bytes. + uint qh = data_a[ib].qh[is]; + [[unroll]] for (uint l = 0; l < 8; ++l) { + uint qs = data_a[ib].qs[8 * is + l]; + uint gidx = qs | ((qh << (8 - l)) & 256); + uint8_t signs = data_a[ib].signs[8 * is + l / 2] >> (4 * (l & 1)); + u8vec4 grid = unpack8(iq3s_grid[gidx]); + data_b[b_idx + 4 * l + 0] = D_TYPE(db * grid.x * ((signs & 1) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 4 * l + 1] = D_TYPE(db * grid.y * ((signs & 2) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 4 * l + 2] = D_TYPE(db * grid.z * ((signs & 4) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 4 * l + 3] = D_TYPE(db * grid.w * ((signs & 8) != 0 ? -1.0 : 1.0)); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp new file mode 100644 index 000000000..a92b82961 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp @@ -0,0 +1,49 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_iq3_xxs data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + // Each thread handles 1 scale block (32 values) + // 8 threads handle 1 superblock + const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8; + + init_iq_shmem(gl_WorkGroupSize); + + if (ib >= p.nel / 256) { + return; + } + + const uint is = gl_LocalInvocationID.x % 8; + const uint b_idx = 256 * ib + 32 * is; + const uint s_idx = QUANT_K / 4 + 4 * is; + + const float d = float(data_a[ib].d); + uint signscale = pack32(u8vec4( + data_a[ib].qs[s_idx + 0], + data_a[ib].qs[s_idx + 1], + data_a[ib].qs[s_idx + 2], + data_a[ib].qs[s_idx + 3] + )); + const float db = d * (0.5 + (signscale >> 28)) * 0.5; + + [[unroll]] for (uint l = 0; l < 4; ++l) { + const uint sign7 = bitfieldExtract(signscale, 7 * int(l), 7); + // Restore parity bit. + const uint sign8 = sign7 | (bitCount(sign7) << 7); + const u8vec4 grid0 = unpack8(iq3xxs_grid[data_a[ib].qs[8 * is + 2 * l]]); + const u8vec4 grid1 = unpack8(iq3xxs_grid[data_a[ib].qs[8 * is + 2 * l + 1]]); + data_b[b_idx + 8 * l + 0] = D_TYPE(db * grid0.x * ((sign8 & 1) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 1] = D_TYPE(db * grid0.y * ((sign8 & 2) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 2] = D_TYPE(db * grid0.z * ((sign8 & 4) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 3] = D_TYPE(db * grid0.w * ((sign8 & 8) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 4] = D_TYPE(db * grid1.x * ((sign8 & 16) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 5] = D_TYPE(db * grid1.y * ((sign8 & 32) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 6] = D_TYPE(db * grid1.z * ((sign8 & 64) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 7] = D_TYPE(db * grid1.w * ((sign8 & 128) != 0 ? -1.0 : 1.0)); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp index 8de14fc03..46d9ad15e 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp @@ -10,7 +10,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; void main() { const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; - init_iq4nl_shmem(); + init_iq_shmem(gl_WorkGroupSize); const uint tid = gl_LocalInvocationID.x % 64; const uint il = tid/32; diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp new file mode 100644 index 000000000..f930852a4 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp @@ -0,0 +1,34 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_iq4_xs data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + // Each thread handles 1 subblock (1 scale and 32 quantized values) + const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8; + + init_iq_shmem(gl_WorkGroupSize); + + if (ib >= p.nel / 256) { + return; + } + + const uint ib32 = gl_LocalInvocationID.x % 8; + + const float d = float(data_a[ib].d); + // Scales are 6 bits + const uint scale = ((data_a[ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF) + | (((data_a[ib].scales_h >> (2 * ib32)) & 3) << 4); + const float dl = d * (int(scale) - 32); + + const uint b_idx = 256 * ib + 32 * ib32; + const uint q_idx = 16 * ib32; + [[unroll]] for (uint l = 0; l < 16; ++l) { + data_b[b_idx + l + 0] = D_TYPE(dl * kvalues_iq4nl[data_a[ib].qs[q_idx + l] & 0xF]); + data_b[b_idx + l + 16] = D_TYPE(dl * kvalues_iq4nl[data_a[ib].qs[q_idx + l] >> 4]); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp index 4e68742b5..26d8bc22a 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp @@ -12,7 +12,7 @@ layout (push_constant) uniform parameter #include "types.comp" -layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; +layout(local_size_x = 1, local_size_y = 512, local_size_z = 1) in; layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index c5be8131b..df30355f6 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -42,10 +42,13 @@ layout (push_constant) uniform parameter { uint32_t nev3; uint32_t nem1; + uint32_t nb01; uint32_t nb02; uint32_t nb03; + uint32_t nb11; uint32_t nb12; uint32_t nb13; + uint32_t nb21; uint32_t nb22; uint32_t nb23; uint32_t nb31; @@ -101,8 +104,8 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele #endif void main() { -#if defined(DATA_A_IQ4_NL) - init_iq4nl_shmem(); +#ifdef NEEDS_INIT_IQ_SHMEM + init_iq_shmem(gl_WorkGroupSize); #endif const uint32_t N = p.N; @@ -146,7 +149,24 @@ void main() { tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D); tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D); - coopmat Q; + // nb?1 are already divided by the type size and are in units of elements + uint32_t q_stride = p.nb01; + uint32_t k_stride = p.nb11; + uint32_t v_stride = p.nb21; + // hint to the compiler that strides are aligned for the aligned variant of the shader + if (Clamp != gl_CooperativeMatrixClampModeConstantNV) + { + q_stride &= ~7; +#if !defined(BLOCK_SIZE) + k_stride &= ~7; + v_stride &= ~7; +#endif + } + tensorLayoutQ = setTensorLayoutStrideNV(tensorLayoutQ, q_stride, 1); + tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1); + tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1); + + coopmat Q; coopmat Qf16; uint32_t q_offset = iq2*p.nb02+iq3*p.nb03; diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp index 68d1bc9f1..8dc9d360d 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp @@ -54,3 +54,23 @@ uint dst_idx(uint idx) { const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10; return i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + i10*p.nb10; } + +uint src0_idx_quant(uint idx, uint qk) { + const uint i03 = fastdiv(idx, p.ne0_012mp, p.ne0_012L); + const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00; + const uint i02 = fastdiv(idx - i03_offset, p.ne0_01mp, p.ne0_01L); + const uint i02_offset = i02*p.ne01*p.ne00; + const uint i01 = fastdiv(idx - i03_offset - i02_offset, p.ne0_0mp, p.ne0_0L); + const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00; + return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + (i00/qk)*p.nb00; +} + +uint dst_idx_quant(uint idx, uint qk) { + const uint i13 = fastdiv(idx, p.ne1_012mp, p.ne1_012L); + const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10; + const uint i12 = fastdiv(idx - i13_offset, p.ne1_01mp, p.ne1_01L); + const uint i12_offset = i12*p.ne11*p.ne10; + const uint i11 = fastdiv(idx - i13_offset - i12_offset, p.ne1_0mp, p.ne1_0L); + const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10; + return i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + (i10/qk)*p.nb10; +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp index 1426fde65..c9f855687 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp @@ -12,8 +12,8 @@ void main() { const uint i11 = (gl_GlobalInvocationID.z)/p.ne12; const uint i12 = (gl_GlobalInvocationID.z)%p.ne12; -#if defined(DATA_A_IQ4_NL) - init_iq4nl_shmem(); +#ifdef NEEDS_INIT_IQ_SHMEM + init_iq_shmem(gl_WorkGroupSize); #endif if (i00 >= p.ne00) { diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp index 24875cdcf..31ecd9f81 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp @@ -1,9 +1,6 @@ #version 450 -#ifdef FLOAT16 -#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require -#endif -#extension GL_EXT_shader_explicit_arithmetic_types : require +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require #include "mul_mat_vec_base.comp" @@ -27,8 +24,8 @@ void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const #if K_PER_ITER == 8 #if QUANT_R == 2 - const B_TYPE_VEC4 bv02 = data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]; - const B_TYPE_VEC4 bv13 = data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs + y_offset) / 4]; + const vec4 bv02 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]); + const vec4 bv13 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs + y_offset) / 4]); const vec4 bv0 = vec4(bv02.x, bv13.x, bv02.y, bv13.y); const vec4 bv1 = vec4(bv02.z, bv13.z, bv02.w, bv13.w); #else @@ -136,8 +133,8 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { void main() { const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); -#if defined(DATA_A_IQ4_NL) - init_iq4nl_shmem(); +#ifdef NEEDS_INIT_IQ_SHMEM + init_iq_shmem(gl_WorkGroupSize); #endif // do NUM_ROWS at a time, unless there aren't enough remaining rows diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp new file mode 100644 index 000000000..e4acbd4f9 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp @@ -0,0 +1,82 @@ +#version 450 +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { + const uint y_idx = i * QUANT_K + 32 * ib32; + + uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint16_t[4] scales = data_a[ibi].scales; + const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12; + const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x); + + const uint sc = data_a[ibi].scales[ib32 / 2] >> (6 * (ib32 & 1)); + [[unroll]] for (uint l = 0; l < 4; ++l) { + const uint qh = data_a[ibi].qh[2 * ib32 + l / 2] >> (4 * (l&1)); + const uint qs = data_a[ibi].qs[4 * ib32 + l]; + const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA; + const float dl = d * (2 * bitfieldExtract(sc, 3 * int(l / 2), 3) + 1); + + const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]); + vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]); + + FLOAT_TYPE sum = FLOAT_TYPE(0.0); + [[unroll]] for (int k = 0; k < 4; ++k) { + sum = fma(FLOAT_TYPE(b0[k]), bitfieldExtract(grid, 2 * k, 2) + delta, + fma(FLOAT_TYPE(b4[k]), bitfieldExtract(grid, 8 + 2 * k, 2) + delta, sum)); + } + temp[j][n] = fma(dl, sum, temp[j][n]); + } + } + ibi += num_blocks_per_row; + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 8 threads are used to process each block + const uint blocks_per_wg = gl_WorkGroupSize.x/8; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid % 8; // 0...7 + const uint ix = tid / 8; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg) + calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows); + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + init_iq_shmem(gl_WorkGroupSize); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp new file mode 100644 index 000000000..309da0991 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp @@ -0,0 +1,79 @@ +#version 450 +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { + const uint y_idx = i * QUANT_K + 32 * ib32; + + uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const float d = float(data_a[ibi].d); + const uint qh = data_a[ibi].qh[ib32]; + const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1); + const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA; + + [[unroll]] for (uint l = 0; l < 4; ++l) { + const uint qs = data_a[ibi].qs[4 * ib32 + l]; + const uint idxhi = bitfieldExtract(qh, 3 * int(l), 3); + const int16_t grid = int16_t(iq1s_grid[qs | (idxhi << 8)]); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]); + vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]); + + FLOAT_TYPE sum = FLOAT_TYPE(0.0); + [[unroll]] for (int k = 0; k < 4; ++k) { + sum = fma(FLOAT_TYPE(b0[k]), bitfieldExtract(grid, 2 * k, 2) + delta, + fma(FLOAT_TYPE(b4[k]), bitfieldExtract(grid, 8 + 2 * k, 2) + delta, sum)); + } + temp[j][n] = fma(dl, sum, temp[j][n]); + } + } + ibi += num_blocks_per_row; + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 8 threads are used to process each block + const uint blocks_per_wg = gl_WorkGroupSize.x/8; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid % 8; // 0...7 + const uint ix = tid / 8; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg) + calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows); + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + init_iq_shmem(gl_WorkGroupSize); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp index 934213446..8cdc640e8 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp @@ -1,10 +1,84 @@ #version 450 -#extension GL_EXT_shader_explicit_arithmetic_types : require +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require #include "mul_mat_vec_base.comp" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; +shared FLOAT_TYPE sccache1[BLOCK_SIZE/16][16]; +shared FLOAT_TYPE sccache2[BLOCK_SIZE/16][16]; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint v_im, const uint ix, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) { + const uint y_idx = i * QUANT_K + y_offset; + + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + + barrier(); + if (!all_threads) { // when we don't have enough blocks to use all threads + if (i < num_blocks_per_row) { + const uint32_t scale = uint32_t(data_a[ib0 + i].scales[itid]); + sccache1[ix][itid] = FLOAT_TYPE(scale & 0xF); + sccache2[ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF); + } + barrier(); + + if (i >= num_blocks_per_row) + continue; + } else { + const uint32_t scale = uint32_t(data_a[ib0 + i].scales[itid]); + sccache1[ix][itid] = FLOAT_TYPE(scale & 0xF); + sccache2[ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF); + barrier(); + } + + const uint32_t qs_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16); + const vec4 qs_u32_0 = vec4(unpack8(qs_u32 & 0x03030303)); + const vec4 qs_u32_2 = vec4(unpack8((qs_u32 >> 2) & 0x03030303)); + const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303)); + const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303)); + + vec2 d = vec2(data_a[ib0 + i].d); + const FLOAT_TYPE dall = FLOAT_TYPE(d.x); + const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]); + vec2 b16 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]); + vec2 b32 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]); + vec2 b48 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]); + vec2 b64 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]); + vec2 b80 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]); + vec2 b96 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]); + vec2 b112 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]); + + FLOAT_TYPE sum1 = FLOAT_TYPE(0.0); + FLOAT_TYPE sum2 = FLOAT_TYPE(0.0); + [[unroll]] for (int l = 0; l < 2; ++l) { + sum1 = fma(FLOAT_TYPE(b0[l]), sccache1[ix][ 8*v_im] * qs_u32_0[l ], + fma(FLOAT_TYPE(b16[l]), sccache1[ix][1 + 8*v_im] * qs_u32_0[l+2], + fma(FLOAT_TYPE(b32[l]), sccache1[ix][2 + 8*v_im] * qs_u32_2[l ], + fma(FLOAT_TYPE(b48[l]), sccache1[ix][3 + 8*v_im] * qs_u32_2[l+2], + fma(FLOAT_TYPE(b64[l]), sccache1[ix][4 + 8*v_im] * qs_u32_4[l ], + fma(FLOAT_TYPE(b80[l]), sccache1[ix][5 + 8*v_im] * qs_u32_4[l+2], + fma(FLOAT_TYPE(b96[l]), sccache1[ix][6 + 8*v_im] * qs_u32_6[l ], + fma(FLOAT_TYPE(b112[l]), sccache1[ix][7 + 8*v_im] * qs_u32_6[l+2], sum1)))))))); + sum2 = fma(FLOAT_TYPE(b0[l]), sccache2[ix][ 8*v_im], + fma(FLOAT_TYPE(b16[l]), sccache2[ix][1 + 8*v_im], + fma(FLOAT_TYPE(b32[l]), sccache2[ix][2 + 8*v_im], + fma(FLOAT_TYPE(b48[l]), sccache2[ix][3 + 8*v_im], + fma(FLOAT_TYPE(b64[l]), sccache2[ix][4 + 8*v_im], + fma(FLOAT_TYPE(b80[l]), sccache2[ix][5 + 8*v_im], + fma(FLOAT_TYPE(b96[l]), sccache2[ix][6 + 8*v_im], + fma(FLOAT_TYPE(b112[l]), sccache2[ix][7 + 8*v_im], sum2)))))))); + } + temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n])); + } + } +} + void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { uint a_offset, b_offset, d_offset; get_offsets(a_offset, b_offset, d_offset); @@ -14,88 +88,28 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { // 16 threads are used to process each block const uint it_size = gl_WorkGroupSize.x/16; const uint tid = gl_LocalInvocationID.x; - const uint itid = tid%16; // 0...16 - const uint ix = tid/16; + const uint itid = tid%16; // 0...15 + const uint ix = tid/16; - const uint step = 8; - - const uint v_im = itid/step; // 0 or 1. 0 computes 0..., 1 computes 128... - const uint v_in = itid - step*v_im; // 0...15 or 0...7 + const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128... + const uint v_in = itid - 8*v_im; // 0...7 const uint l0 = 2*v_in; // 0...15 const uint q_offset = 32*v_im + l0; - const uint s_offset = 8*v_im; const uint y_offset = 128*v_im + l0; - FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; - [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { temp[j][i] = FLOAT_TYPE(0); } } - [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { - const uint y_idx = i * QUANT_K + y_offset; - - [[unroll]] for (uint n = 0; n < num_rows; ++n) { - const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; - f16vec2 d = data_a[ib0 + i].d; - const FLOAT_TYPE dall = d.x; - const FLOAT_TYPE dmin = d.y; - - uint32_t s0_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 0]; - uint32_t s4_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 1]; - - uint32_t s0_lo4_u32 = s0_u32 & 0x0F0F0F0F; - uint32_t s0_hi4_u32 = (s0_u32 >> 4) & 0x0F0F0F0F; - uint32_t s4_lo4_u32 = s4_u32 & 0x0F0F0F0F; - uint32_t s4_hi4_u32 = (s4_u32 >> 4) & 0x0F0F0F0F; - - uvec4 s0_lo4 = uvec4(unpack8(s0_lo4_u32)); - uvec4 s4_lo4 = uvec4(unpack8(s4_lo4_u32)); - uvec4 s0_hi4 = uvec4(unpack8(s0_hi4_u32)); - uvec4 s4_hi4 = uvec4(unpack8(s4_hi4_u32)); - - uint16_t qs0_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2 + 0]; - uint16_t qs16_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]; - uvec2 qs0 = uvec2(unpack8(qs0_u16)); - uvec2 qs16 = uvec2(unpack8(qs16_u16)); - - [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { - B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]; - B_TYPE_VEC2 b16 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]; - B_TYPE_VEC2 b32 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]; - B_TYPE_VEC2 b48 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]; - B_TYPE_VEC2 b64 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]; - B_TYPE_VEC2 b80 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]; - B_TYPE_VEC2 b96 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]; - B_TYPE_VEC2 b112 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]; - - FLOAT_TYPE sum1 = FLOAT_TYPE(0.0); - FLOAT_TYPE sum2 = FLOAT_TYPE(0.0); - [[unroll]] for (int l = 0; l < 2; ++l) { - sum1 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_lo4[0]) * FLOAT_TYPE((qs0[l] >> 0) & 3), - fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_lo4[1]) * FLOAT_TYPE((qs16[l] >> 0) & 3), - fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_lo4[2]) * FLOAT_TYPE((qs0[l] >> 2) & 3), - fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_lo4[3]) * FLOAT_TYPE((qs16[l] >> 2) & 3), - fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_lo4[0]) * FLOAT_TYPE((qs0[l] >> 4) & 3), - fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_lo4[1]) * FLOAT_TYPE((qs16[l] >> 4) & 3), - fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_lo4[2]) * FLOAT_TYPE((qs0[l] >> 6) & 3), - fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_lo4[3]) * FLOAT_TYPE((qs16[l] >> 6) & 3), sum1)))))))); - sum2 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_hi4[0]), - fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_hi4[1]), - fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_hi4[2]), - fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_hi4[3]), - fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_hi4[0]), - fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_hi4[1]), - fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_hi4[2]), - fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_hi4[3]), sum2)))))))); - } - temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n])); - } - } - } + const uint nbr_par_th = num_blocks_per_row%it_size; + const uint nbr_all_th = num_blocks_per_row - nbr_par_th; + uint i0 = 0; + [[unroll]] for (; i0 < nbr_all_th; i0 += it_size) + calc_superblock(a_offset, b_offset, itid, v_im, ix, q_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, true); + calc_superblock(a_offset, b_offset, itid, v_im, ix, q_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, false); reduce_result(temp, d_offset, first_row, num_rows, tid); } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp index 86b0159d9..3116fad16 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp @@ -1,10 +1,78 @@ #version 450 -#extension GL_EXT_shader_explicit_arithmetic_types : require +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require #include "mul_mat_vec_base.comp" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; +shared FLOAT_TYPE sccache[BLOCK_SIZE/16][2][8]; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint ix, const uint itid8, const uint v_im, const uint v_im4, const uint v_in, const uint32_t hm_m[4], const uint q_offset, const uint y_offset, const uint s_shift, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) { + const uint y_idx = i * QUANT_K + y_offset; + + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + + if (!all_threads) { // when we don't have enough blocks to use all threads + barrier(); + if (i < num_blocks_per_row) + sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32); + barrier(); + + if (i >= num_blocks_per_row) + continue; + } + + const uint32_t hmk = ~(uint32_t(data_a_packed16[ib0 + i].hmask[v_in]) | (uint32_t(data_a_packed16[ib0 + i].hmask[v_in + 8]) << 16)); + const vec4 hmk_0 = vec4(unpack8(((hmk & hm_m[0]) >> ( v_im4)) << 2)); + const vec4 hmk_1 = vec4(unpack8(((hmk & hm_m[1]) >> (1 + v_im4)) << 2)); + const vec4 hmk_2 = vec4(unpack8(((hmk & hm_m[2]) >> (2 + v_im4)) << 2)); + const vec4 hmk_3 = vec4(unpack8(((hmk & hm_m[3]) >> (3 + v_im4)) << 2)); + + // 0, 1, 16, 17 + uint32_t qs_u32 = uint32_t(data_a[ib0 + i].qs[q_offset]) | (uint32_t(data_a[ib0 + i].qs[q_offset + 1]) << 8); + qs_u32 |= (uint32_t(data_a[ib0 + i].qs[q_offset + 16]) | (uint32_t(data_a[ib0 + i].qs[q_offset + 17]) << 8)) << 16; + const vec4 qs_u32_0 = vec4(unpack8(qs_u32 & 0x03030303)); + const vec4 qs_u32_2 = vec4(unpack8((qs_u32 >> 2) & 0x03030303)); + const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303)); + const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303)); + + if (all_threads) { + barrier(); + sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32); + barrier(); + } + + const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]); + vec2 b16 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]); + vec2 b32 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]); + vec2 b48 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]); + vec2 b64 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]); + vec2 b80 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]); + vec2 b96 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]); + vec2 b112 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]); + + FLOAT_TYPE sum = FLOAT_TYPE(0.0); + [[unroll]] for (int l = 0; l < 2; ++l) { + sum = fma(FLOAT_TYPE( b0[l]) * sccache[ix][v_im][0], qs_u32_0[l ] - hmk_0[l ], + fma(FLOAT_TYPE( b16[l]) * sccache[ix][v_im][1], qs_u32_0[l+2] - hmk_0[l+2], + fma(FLOAT_TYPE( b32[l]) * sccache[ix][v_im][2], qs_u32_2[l ] - hmk_1[l ], + fma(FLOAT_TYPE( b48[l]) * sccache[ix][v_im][3], qs_u32_2[l+2] - hmk_1[l+2], + fma(FLOAT_TYPE( b64[l]) * sccache[ix][v_im][4], qs_u32_4[l ] - hmk_2[l ], + fma(FLOAT_TYPE( b80[l]) * sccache[ix][v_im][5], qs_u32_4[l+2] - hmk_2[l+2], + fma(FLOAT_TYPE( b96[l]) * sccache[ix][v_im][6], qs_u32_6[l ] - hmk_3[l ], + fma(FLOAT_TYPE(b112[l]) * sccache[ix][v_im][7], qs_u32_6[l+2] - hmk_3[l+2], sum)))))))); + } + temp[j][n] = fma(d, sum, temp[j][n]); + } + } +} + void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { uint a_offset, b_offset, d_offset; get_offsets(a_offset, b_offset, d_offset); @@ -14,76 +82,37 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { // 16 threads are used to process each block const uint it_size = gl_WorkGroupSize.x/16; const uint tid = gl_LocalInvocationID.x; - const uint itid = tid%16; // 0...16 - const uint ix = tid/16; + const uint itid = tid%16; // 0...15 + const uint ix = tid/16; + const uint itid8 = itid%8; - const uint step = 8; + const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128... + const uint v_im4 = v_im*4; + const uint v_in = itid - 8*v_im; // 0...7 - const uint v_im = itid/step; // 0 or 1. 0 computes 0..., 1 computes 128... - const uint v_in = itid - step*v_im; // 0...15 or 0...7 - - const uint8_t m = uint8_t(1 << (4 * v_im)); + const uint32_t m = 0x01010101 << (4 * v_im); + uint32_t hm_m[4]; + [[unroll]] for (uint j = 0; j < 4; ++j) + hm_m[j] = m << j; const uint l0 = 2*v_in; // 0...15 const uint q_offset = 32*v_im + l0; const uint y_offset = 128*v_im + l0; - FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; - [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { temp[j][i] = FLOAT_TYPE(0); } } - const uint s_shift = 4 * v_im; + const uint s_shift = v_im4 + 2*(itid8/4); - [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { - const uint y_idx = i * QUANT_K + y_offset; - - [[unroll]] for (uint n = 0; n < num_rows; ++n) { - const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; - const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); - - uint16_t s0_16 = data_a_packed16[ib0 + i].scales[0]; - uint16_t s2_16 = data_a_packed16[ib0 + i].scales[1]; - uint16_t s4_16 = data_a_packed16[ib0 + i].scales[2]; - uint16_t s6_16 = data_a_packed16[ib0 + i].scales[3]; - uint16_t s8_16 = data_a_packed16[ib0 + i].scales[4]; - uint16_t s10_16 = data_a_packed16[ib0 + i].scales[5]; - u8vec2 s0 = unpack8(s0_16); - u8vec2 s2 = unpack8(s2_16); - u8vec2 s4 = unpack8(s4_16); - u8vec2 s6 = unpack8(s6_16); - u8vec2 s8 = unpack8(s8_16); - u8vec2 s10 = unpack8(s10_16); - - [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { - - B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]; - B_TYPE_VEC2 b16 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]; - B_TYPE_VEC2 b32 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]; - B_TYPE_VEC2 b48 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]; - B_TYPE_VEC2 b64 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]; - B_TYPE_VEC2 b80 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]; - B_TYPE_VEC2 b96 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]; - B_TYPE_VEC2 b112 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]; - - FLOAT_TYPE sum = FLOAT_TYPE(0.0); - [[unroll]] for (int l = 0; l < 2; ++l) { - sum = fma(FLOAT_TYPE(b0[l]) * FLOAT_TYPE(int8_t(((s0[0] >> s_shift) & 0xF) | ((s8[0] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 0)) != 0) ? 0 : 4)), - fma(FLOAT_TYPE(b32[l]) * FLOAT_TYPE(int8_t(((s2[0] >> s_shift) & 0xF) | ((s10[0] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 1)) != 0) ? 0 : 4)), - fma(FLOAT_TYPE(b64[l]) * FLOAT_TYPE(int8_t(((s4[0] >> s_shift) & 0xF) | ((s8[0] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 2)) != 0) ? 0 : 4)), - fma(FLOAT_TYPE(b96[l]) * FLOAT_TYPE(int8_t(((s6[0] >> s_shift) & 0xF) | ((s10[0] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 3)) != 0) ? 0 : 4)), - fma(FLOAT_TYPE(b16[l]) * FLOAT_TYPE(int8_t(((s0[1] >> s_shift) & 0xF) | ((s8[1] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 0)) != 0) ? 0 : 4)), - fma(FLOAT_TYPE(b48[l]) * FLOAT_TYPE(int8_t(((s2[1] >> s_shift) & 0xF) | ((s10[1] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 1)) != 0) ? 0 : 4)), - fma(FLOAT_TYPE(b80[l]) * FLOAT_TYPE(int8_t(((s4[1] >> s_shift) & 0xF) | ((s8[1] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 2)) != 0) ? 0 : 4)), - fma(FLOAT_TYPE(b112[l]) * FLOAT_TYPE(int8_t(((s6[1] >> s_shift) & 0xF) | ((s10[1] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 3)) != 0) ? 0 : 4)), sum)))))))); - } - temp[j][n] = fma(d, sum, temp[j][n]); - } - } - } + const uint nbr_par_th = num_blocks_per_row%it_size; + const uint nbr_all_th = num_blocks_per_row - nbr_par_th; + uint i0 = 0; + [[unroll]] for (; i0 < nbr_all_th; i0 += it_size) + calc_superblock(a_offset, b_offset, ix, itid8, v_im, v_im4, v_in, hm_m, q_offset, y_offset, s_shift, i0 + ix, num_blocks_per_row, first_row, num_rows, true); + calc_superblock(a_offset, b_offset, ix, itid8, v_im, v_im4, v_in, hm_m, q_offset, y_offset, s_shift, i0 + ix, num_blocks_per_row, first_row, num_rows, false); reduce_result(temp, d_offset, first_row, num_rows, tid); } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp index cd1dd8e89..f9cde0648 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp @@ -1,11 +1,91 @@ #version 450 -#extension GL_EXT_shader_explicit_arithmetic_types : require +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require #include "mul_mat_vec_base.comp" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { + const uint y1_idx = i * QUANT_K + y_offset; + const uint y2_idx = y1_idx + 128; + + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + vec2 d = vec2(data_a[ib0 + i].d); + const FLOAT_TYPE dall = FLOAT_TYPE(d.x); + const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); + + const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; + const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; + const uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4]; + + const uint32_t scale_0_4_l = (scale4_u32 << 16) | scale0_u32; + const uint32_t scale_0_4_h = (scale_0_4_l & 0xC0C0C0C0) >> 2; + const vec4 scale_0_4_l_f = vec4(unpack8(scale_0_4_l & 0x3F3F3F3F)); + const vec4 scale8_f = vec4(unpack8((((scale8_u32 << 12) | scale8_u32) & 0x0F0F0F0F) | scale_0_4_h)); + + const FLOAT_TYPE sc0 = scale_0_4_l_f.x; + const FLOAT_TYPE sc1 = scale_0_4_l_f.y; + const FLOAT_TYPE sc2 = scale_0_4_l_f.z; + const FLOAT_TYPE sc3 = scale_0_4_l_f.w; + const FLOAT_TYPE sc4 = scale8_f.x; + const FLOAT_TYPE sc5 = scale8_f.y; + const FLOAT_TYPE sc6 = scale8_f.z; + const FLOAT_TYPE sc7 = scale8_f.w; + + const uint32_t qs0_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4]; + const uint32_t qs64_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4 + 16]; + + const uint32_t qs0_u32_lo4 = qs0_u32 & 0x0F0F0F0F; + const uint32_t qs0_u32_hi4 = (qs0_u32 >> 4) & 0x0F0F0F0F; + const uint32_t qs64_u32_lo4 = qs64_u32 & 0x0F0F0F0F; + const uint32_t qs64_u32_hi4 = (qs64_u32 >> 4) & 0x0F0F0F0F; + + const vec4 qs0_lo4 = vec4(unpack8(qs0_u32_lo4)); + const vec4 qs64_lo4 = vec4(unpack8(qs64_u32_lo4)); + const vec4 qs0_hi4 = vec4(unpack8(qs0_u32_hi4)); + const vec4 qs64_hi4 = vec4(unpack8(qs64_u32_hi4)); + + const FLOAT_TYPE q4_0 = qs0_lo4.x; + const FLOAT_TYPE q4_1 = qs0_lo4.y; + const FLOAT_TYPE q4_2 = qs0_lo4.z; + const FLOAT_TYPE q4_3 = qs0_lo4.w; + const FLOAT_TYPE q4_4 = qs0_hi4.x; + const FLOAT_TYPE q4_5 = qs0_hi4.y; + const FLOAT_TYPE q4_6 = qs0_hi4.z; + const FLOAT_TYPE q4_7 = qs0_hi4.w; + const FLOAT_TYPE q4_8 = qs64_lo4.x; + const FLOAT_TYPE q4_9 = qs64_lo4.y; + const FLOAT_TYPE q4_10 = qs64_lo4.z; + const FLOAT_TYPE q4_11 = qs64_lo4.w; + const FLOAT_TYPE q4_12 = qs64_hi4.x; + const FLOAT_TYPE q4_13 = qs64_hi4.y; + const FLOAT_TYPE q4_14 = qs64_hi4.z; + const FLOAT_TYPE q4_15 = qs64_hi4.w; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + vec4 by10 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 ]); + vec4 by132 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 + 8]); + vec4 by20 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 ]); + vec4 by232 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 + 8]); + + const FLOAT_TYPE sx = fma(FLOAT_TYPE(by10.x), q4_0, fma(FLOAT_TYPE(by10.y), q4_1, fma(FLOAT_TYPE(by10.z), q4_2, FLOAT_TYPE(by10.w) * q4_3))); + const FLOAT_TYPE sy = fma(FLOAT_TYPE(by132.x), q4_4, fma(FLOAT_TYPE(by132.y), q4_5, fma(FLOAT_TYPE(by132.z), q4_6, FLOAT_TYPE(by132.w) * q4_7))); + const FLOAT_TYPE sz = fma(FLOAT_TYPE(by20.x), q4_8, fma(FLOAT_TYPE(by20.y), q4_9, fma(FLOAT_TYPE(by20.z), q4_10, FLOAT_TYPE(by20.w) * q4_11))); + const FLOAT_TYPE sw = fma(FLOAT_TYPE(by232.x), q4_12, fma(FLOAT_TYPE(by232.y), q4_13, fma(FLOAT_TYPE(by232.z), q4_14, FLOAT_TYPE(by232.w) * q4_15))); + const FLOAT_TYPE smin = + fma(FLOAT_TYPE(by10.x), sc2, fma(FLOAT_TYPE(by132.x), sc3, fma(FLOAT_TYPE(by20.x), sc6, fma(FLOAT_TYPE(by232.x), sc7, + fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7, + fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7, + fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6, FLOAT_TYPE(by232.w) * sc7))))))))))))))); + temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n])); + } + } +} + void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { uint a_offset, b_offset, d_offset; get_offsets(a_offset, b_offset, d_offset); @@ -15,13 +95,11 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { // 16 threads are used to process each block const uint it_size = gl_WorkGroupSize.x/16; const uint tid = gl_LocalInvocationID.x; - const uint itid = tid%16; // 0...16 - const uint ix = tid/16; + const uint itid = tid%16; // 0...15 + const uint ix = tid/16; - const uint step = 4; - - const uint il = itid/step; // 0...3 - const uint ir = itid - step*il; // 0...7 or 0...3 + const uint il = itid/4; // 0...3 + const uint ir = itid - 4*il; // 0...3 const uint n = 4; const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 @@ -31,89 +109,14 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const uint q_offset = 32*v_im + l0; const uint y_offset = 64*v_im + l0; - FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; - [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { temp[j][i] = FLOAT_TYPE(0); } } - [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { - const uint y1_idx = i * QUANT_K + y_offset; - const uint y2_idx = y1_idx + 128; - - [[unroll]] for (uint n = 0; n < num_rows; ++n) { - const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; - f16vec2 d = data_a[ib0 + i].d; - const FLOAT_TYPE dall = FLOAT_TYPE(d.x); - const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); - - uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; - uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; - uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4]; - uvec4 scale0 = uvec4(unpack8(scale0_u32)); - uvec4 scale4 = uvec4(unpack8(scale4_u32)); - uvec4 scale8 = uvec4(unpack8(scale8_u32)); - - const uint32_t sc0 = ( scale0.x & 0x3f); - const uint32_t sc1 = ( scale0.y & 0x3f); - const uint32_t sc2 = ( scale4.x & 0x3f); - const uint32_t sc3 = ( scale4.y & 0x3f); - const uint32_t sc4 = (( scale8.x & 0x0f) | ((scale0.x & 0xc0) >> 2)); - const uint32_t sc5 = (( scale8.y & 0x0f) | ((scale0.y & 0xc0) >> 2)); - const uint32_t sc6 = (((scale8.x >> 4) & 0x0f) | ((scale4.x & 0xc0) >> 2)); - const uint32_t sc7 = (((scale8.y >> 4) & 0x0f) | ((scale4.y & 0xc0) >> 2)); - - uint32_t qs0_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4]; - uint32_t qs64_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4 + 16]; - - uint32_t qs0_u32_lo4 = qs0_u32 & 0x0F0F0F0F; - uint32_t qs0_u32_hi4 = (qs0_u32 >> 4) & 0x0F0F0F0F; - uint32_t qs64_u32_lo4 = qs64_u32 & 0x0F0F0F0F; - uint32_t qs64_u32_hi4 = (qs64_u32 >> 4) & 0x0F0F0F0F; - - uvec4 qs0_lo4 = uvec4(unpack8(qs0_u32_lo4)); - uvec4 qs64_lo4 = uvec4(unpack8(qs64_u32_lo4)); - uvec4 qs0_hi4 = uvec4(unpack8(qs0_u32_hi4)); - uvec4 qs64_hi4 = uvec4(unpack8(qs64_u32_hi4)); - - const uint32_t q4_0 = qs0_lo4.x; - const uint32_t q4_1 = qs0_lo4.y; - const uint32_t q4_2 = qs0_lo4.z; - const uint32_t q4_3 = qs0_lo4.w; - const uint32_t q4_4 = qs0_hi4.x; - const uint32_t q4_5 = qs0_hi4.y; - const uint32_t q4_6 = qs0_hi4.z; - const uint32_t q4_7 = qs0_hi4.w; - const uint32_t q4_8 = qs64_lo4.x; - const uint32_t q4_9 = qs64_lo4.y; - const uint32_t q4_10 = qs64_lo4.z; - const uint32_t q4_11 = qs64_lo4.w; - const uint32_t q4_12 = qs64_hi4.x; - const uint32_t q4_13 = qs64_hi4.y; - const uint32_t q4_14 = qs64_hi4.z; - const uint32_t q4_15 = qs64_hi4.w; - - [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { - B_TYPE_VEC4 by10 = data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4]; - B_TYPE_VEC4 by132 = data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 + 8]; - B_TYPE_VEC4 by20 = data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4]; - B_TYPE_VEC4 by232 = data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 + 8]; - - const FLOAT_TYPE sx = fma(FLOAT_TYPE(by10.x), q4_0, fma(FLOAT_TYPE(by10.y), q4_1, fma(FLOAT_TYPE(by10.z), q4_2, FLOAT_TYPE(by10.w) * q4_3))); - const FLOAT_TYPE sy = fma(FLOAT_TYPE(by132.x), q4_4, fma(FLOAT_TYPE(by132.y), q4_5, fma(FLOAT_TYPE(by132.z), q4_6, FLOAT_TYPE(by132.w) * q4_7))); - const FLOAT_TYPE sz = fma(FLOAT_TYPE(by20.x), q4_8, fma(FLOAT_TYPE(by20.y), q4_9, fma(FLOAT_TYPE(by20.z), q4_10, FLOAT_TYPE(by20.w) * q4_11))); - const FLOAT_TYPE sw = fma(FLOAT_TYPE(by232.x), q4_12, fma(FLOAT_TYPE(by232.y), q4_13, fma(FLOAT_TYPE(by232.z), q4_14, FLOAT_TYPE(by232.w) * q4_15))); - const FLOAT_TYPE smin = - fma(FLOAT_TYPE(by10.x), sc2, fma(FLOAT_TYPE(by132.x), sc3, fma(FLOAT_TYPE(by20.x), sc6, fma(FLOAT_TYPE(by232.x), sc7, - fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7, - fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7, - fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6, FLOAT_TYPE(by232.w) * sc7))))))))))))))); - temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n])); - } - } - } + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) + calc_superblock(a_offset, b_offset, v_im, q_offset, y_offset, i, num_blocks_per_row, first_row, num_rows); reduce_result(temp, d_offset, first_row, num_rows, tid); } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp index 0a68891c3..6c84ef3cd 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp @@ -1,11 +1,123 @@ #version 450 -#extension GL_EXT_shader_explicit_arithmetic_types : require +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require #include "mul_mat_vec_base.comp" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im, const uint l0, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { + const uint y1_idx = i * QUANT_K + y_offset; + const uint y2_idx = y1_idx + 128; + + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + vec2 d = vec2(data_a[ib0 + i].d); + const FLOAT_TYPE dall = FLOAT_TYPE(d.x); + const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); + + const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; + const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; + const uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4]; + + const uint32_t scale_0_4_l = (scale4_u32 << 16) | scale0_u32; + const uint32_t scale_0_4_h = (scale_0_4_l & 0xC0C0C0C0) >> 2; + const vec4 scale_0_4_l_f = vec4(unpack8(scale_0_4_l & 0x3F3F3F3F)); + const vec4 scale8_f = vec4(unpack8((((scale8_u32 << 12) | scale8_u32) & 0x0F0F0F0F) | scale_0_4_h)); + + const FLOAT_TYPE sc0 = scale_0_4_l_f.x; + const FLOAT_TYPE sc1 = scale_0_4_l_f.y; + const FLOAT_TYPE sc2 = scale_0_4_l_f.z; + const FLOAT_TYPE sc3 = scale_0_4_l_f.w; + const FLOAT_TYPE sc4 = scale8_f.x; + const FLOAT_TYPE sc5 = scale8_f.y; + const FLOAT_TYPE sc6 = scale8_f.z; + const FLOAT_TYPE sc7 = scale8_f.w; + + const uint32_t qs0_16_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16); + const uint32_t qs64_80_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 32]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 40]) << 16); + + uint32_t qs0_16_u32_lo4 = qs0_16_u32 & 0x0F0F0F0F; + uint32_t qs0_16_u32_hi4 = (qs0_16_u32 >> 4) & 0x0F0F0F0F; + uint32_t qs64_80_u32_lo4 = qs64_80_u32 & 0x0F0F0F0F; + uint32_t qs64_80_u32_hi4 = (qs64_80_u32 >> 4) & 0x0F0F0F0F; + + const uint32_t qh = pack32(u16vec2(data_a_packed16[ib0 + i].qh[l0 / 2], data_a_packed16[ib0 + i].qh[l0 / 2 + 8])); + + const uint32_t qs0_16_lo4_offset16 = ((qh >> (2*v_im)) & 0x01010101) << 4; + const uint32_t qs0_16_hi4_offset16 = ((qh >> (2*v_im)) & 0x02020202) << 3; + const uint32_t qs64_80_lo4_offset16 = ((qh >> (2*v_im)) & 0x10101010); + const uint32_t qs64_80_hi4_offset16 = ((qh >> (2*v_im)) & 0x20202020) >> 1; + + qs0_16_u32_lo4 += qs0_16_lo4_offset16; + qs0_16_u32_hi4 += qs0_16_hi4_offset16; + qs64_80_u32_lo4 += qs64_80_lo4_offset16; + qs64_80_u32_hi4 += qs64_80_hi4_offset16; + + const vec4 qs0_16_lo4 = vec4(unpack8(qs0_16_u32_lo4)); + const vec4 qs64_80_lo4 = vec4(unpack8(qs64_80_u32_lo4)); + const vec4 qs0_16_hi4 = vec4(unpack8(qs0_16_u32_hi4)); + const vec4 qs64_80_hi4 = vec4(unpack8(qs64_80_u32_hi4)); + + const FLOAT_TYPE q4_0 = qs0_16_lo4.x; + const FLOAT_TYPE q4_1 = qs0_16_lo4.y; + const FLOAT_TYPE q4_2 = qs0_16_lo4.z; + const FLOAT_TYPE q4_3 = qs0_16_lo4.w; + const FLOAT_TYPE q4_4 = qs0_16_hi4.x; + const FLOAT_TYPE q4_5 = qs0_16_hi4.y; + const FLOAT_TYPE q4_6 = qs0_16_hi4.z; + const FLOAT_TYPE q4_7 = qs0_16_hi4.w; + const FLOAT_TYPE q4_8 = qs64_80_lo4.x; + const FLOAT_TYPE q4_9 = qs64_80_lo4.y; + const FLOAT_TYPE q4_10 = qs64_80_lo4.z; + const FLOAT_TYPE q4_11 = qs64_80_lo4.w; + const FLOAT_TYPE q4_12 = qs64_80_hi4.x; + const FLOAT_TYPE q4_13 = qs64_80_hi4.y; + const FLOAT_TYPE q4_14 = qs64_80_hi4.z; + const FLOAT_TYPE q4_15 = qs64_80_hi4.w; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + vec2 by10 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 ]); + vec2 by116 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 8]); + vec2 by132 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 16]); + vec2 by148 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 24]); + vec2 by20 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 ]); + vec2 by216 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 8]); + vec2 by232 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 16]); + vec2 by248 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 24]); + + const FLOAT_TYPE sx = + fma(FLOAT_TYPE(by10.x), q4_0, + fma(FLOAT_TYPE(by10.y), q4_1, + fma(FLOAT_TYPE(by116.x), q4_2, + FLOAT_TYPE(by116.y) * q4_3))); + const FLOAT_TYPE sy = + fma(FLOAT_TYPE(by132.x), q4_4, + fma(FLOAT_TYPE(by132.y), q4_5, + fma(FLOAT_TYPE(by148.x), q4_6, + FLOAT_TYPE(by148.y) * q4_7))); + const FLOAT_TYPE sz = + fma(FLOAT_TYPE(by20.x), q4_8, + fma(FLOAT_TYPE(by20.y), q4_9, + fma(FLOAT_TYPE(by216.x), q4_10, + FLOAT_TYPE(by216.y) * q4_11))); + const FLOAT_TYPE sw = + fma(FLOAT_TYPE(by232.x), q4_12, + fma(FLOAT_TYPE(by232.y), q4_13, + fma(FLOAT_TYPE(by248.x), q4_14, + FLOAT_TYPE(by248.y) * q4_15))); + const FLOAT_TYPE smin = + fma(FLOAT_TYPE(by10.x) + FLOAT_TYPE(by10.y) + FLOAT_TYPE(by116.x) + FLOAT_TYPE(by116.y), sc2, + fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3, + fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6, + (FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7))); + temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n])); + } + } +} + void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { uint a_offset, b_offset, d_offset; get_offsets(a_offset, b_offset, d_offset); @@ -15,11 +127,11 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { // 16 threads are used to process each block const uint it_size = gl_WorkGroupSize.x/16; const uint tid = gl_LocalInvocationID.x; - const uint itid = tid%16; // 0...16 - const uint ix = tid/16; + const uint itid = tid%16; // 0...15 + const uint ix = tid/16; const uint il = itid/4; // 0...3 - const uint ir = itid - 4*il; // 0...7 or 0...3 + const uint ir = itid - 4*il; // 0...3 const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 const uint v_in = il % 2; @@ -28,121 +140,14 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const uint q_offset = 32*v_im + l0; const uint y_offset = 64*v_im + l0; - FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; - [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { temp[j][i] = FLOAT_TYPE(0); } } - [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { - const uint y1_idx = i * QUANT_K + y_offset; - const uint y2_idx = y1_idx + 128; - - [[unroll]] for (uint n = 0; n < num_rows; ++n) { - const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; - f16vec2 d = data_a[ib0 + i].d; - const FLOAT_TYPE dall = FLOAT_TYPE(d.x); - const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); - - uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; - uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; - uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4]; - uvec4 scale0 = uvec4(unpack8(scale0_u32)); - uvec4 scale4 = uvec4(unpack8(scale4_u32)); - uvec4 scale8 = uvec4(unpack8(scale8_u32)); - - const uint32_t sc0 = ( scale0.x & 0x3f); - const uint32_t sc1 = ( scale0.y & 0x3f); - const uint32_t sc2 = ( scale4.x & 0x3f); - const uint32_t sc3 = ( scale4.y & 0x3f); - const uint32_t sc4 = (( scale8.x & 0x0f) | ((scale0.x & 0xc0) >> 2)); - const uint32_t sc5 = (( scale8.y & 0x0f) | ((scale0.y & 0xc0) >> 2)); - const uint32_t sc6 = (((scale8.x >> 4) & 0x0f) | ((scale4.x & 0xc0) >> 2)); - const uint32_t sc7 = (((scale8.y >> 4) & 0x0f) | ((scale4.y & 0xc0) >> 2)); - - uint32_t qs0_16_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16); - uint32_t qs64_80_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 32]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 40]) << 16); - - uint32_t qs0_16_u32_lo4 = qs0_16_u32 & 0x0F0F0F0F; - uint32_t qs0_16_u32_hi4 = (qs0_16_u32 >> 4) & 0x0F0F0F0F; - uint32_t qs64_80_u32_lo4 = qs64_80_u32 & 0x0F0F0F0F; - uint32_t qs64_80_u32_hi4 = (qs64_80_u32 >> 4) & 0x0F0F0F0F; - - uint32_t qh = pack32(u16vec2(data_a_packed16[ib0 + i].qh[l0 / 2], data_a_packed16[ib0 + i].qh[l0 / 2 + 8])); - - uint32_t qs0_16_lo4_offset16 = ((qh >> (2*v_im)) & 0x01010101) << 4; - uint32_t qs0_16_hi4_offset16 = ((qh >> (2*v_im)) & 0x02020202) << 3; - uint32_t qs64_80_lo4_offset16 = ((qh >> (2*v_im)) & 0x10101010) << 0; - uint32_t qs64_80_hi4_offset16 = ((qh >> (2*v_im)) & 0x20202020) >> 1; - - qs0_16_u32_lo4 += qs0_16_lo4_offset16; - qs0_16_u32_hi4 += qs0_16_hi4_offset16; - qs64_80_u32_lo4 += qs64_80_lo4_offset16; - qs64_80_u32_hi4 += qs64_80_hi4_offset16; - - uvec4 qs0_16_lo4 = uvec4(unpack8(qs0_16_u32_lo4)); - uvec4 qs64_80_lo4 = uvec4(unpack8(qs64_80_u32_lo4)); - uvec4 qs0_16_hi4 = uvec4(unpack8(qs0_16_u32_hi4)); - uvec4 qs64_80_hi4 = uvec4(unpack8(qs64_80_u32_hi4)); - - const uint32_t q4_0 = qs0_16_lo4.x; - const uint32_t q4_1 = qs0_16_lo4.y; - const uint32_t q4_2 = qs0_16_lo4.z; - const uint32_t q4_3 = qs0_16_lo4.w; - const uint32_t q4_4 = qs0_16_hi4.x; - const uint32_t q4_5 = qs0_16_hi4.y; - const uint32_t q4_6 = qs0_16_hi4.z; - const uint32_t q4_7 = qs0_16_hi4.w; - const uint32_t q4_8 = qs64_80_lo4.x; - const uint32_t q4_9 = qs64_80_lo4.y; - const uint32_t q4_10 = qs64_80_lo4.z; - const uint32_t q4_11 = qs64_80_lo4.w; - const uint32_t q4_12 = qs64_80_hi4.x; - const uint32_t q4_13 = qs64_80_hi4.y; - const uint32_t q4_14 = qs64_80_hi4.z; - const uint32_t q4_15 = qs64_80_hi4.w; - - [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { - B_TYPE_VEC2 by10 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2]; - B_TYPE_VEC2 by116 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 8]; - B_TYPE_VEC2 by132 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 16]; - B_TYPE_VEC2 by148 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 24]; - B_TYPE_VEC2 by20 = data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2]; - B_TYPE_VEC2 by216 = data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 8]; - B_TYPE_VEC2 by232 = data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 16]; - B_TYPE_VEC2 by248 = data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 24]; - - const FLOAT_TYPE sx = - fma(FLOAT_TYPE(by10.x), q4_0, - fma(FLOAT_TYPE(by10.y), q4_1, - fma(FLOAT_TYPE(by116.x), q4_2, - FLOAT_TYPE(by116.y) * q4_3))); - const FLOAT_TYPE sy = - fma(FLOAT_TYPE(by132.x), q4_4, - fma(FLOAT_TYPE(by132.y), q4_5, - fma(FLOAT_TYPE(by148.x), q4_6, - FLOAT_TYPE(by148.y) * q4_7))); - const FLOAT_TYPE sz = - fma(FLOAT_TYPE(by20.x), q4_8, - fma(FLOAT_TYPE(by20.y), q4_9, - fma(FLOAT_TYPE(by216.x), q4_10, - FLOAT_TYPE(by216.y) * q4_11))); - const FLOAT_TYPE sw = - fma(FLOAT_TYPE(by232.x), q4_12, - fma(FLOAT_TYPE(by232.y), q4_13, - fma(FLOAT_TYPE(by248.x), q4_14, - FLOAT_TYPE(by248.y) * q4_15))); - const FLOAT_TYPE smin = - fma(FLOAT_TYPE(by10.x) + FLOAT_TYPE(by10.y) + FLOAT_TYPE(by116.x) + FLOAT_TYPE(by116.y), sc2, - fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3, - fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6, - (FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7))); - temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n])); - } - } - } + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) + calc_superblock(a_offset, b_offset, v_im, l0, q_offset, y_offset, i, num_blocks_per_row, first_row, num_rows); reduce_result(temp, d_offset, first_row, num_rows, tid); } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp index 70e13a56b..f05f96b5e 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp @@ -1,12 +1,82 @@ #version 450 -#extension GL_EXT_shader_explicit_arithmetic_types : require +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require #include "mul_mat_vec_base.comp" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; -void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { +shared FLOAT_TYPE sccache[BLOCK_SIZE/16][16]; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint ix, const uint ql_offset, const uint qh_offset, const uint s_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) { + const uint y_idx = i * QUANT_K + y_offset; + + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + + if (!all_threads) { // when we don't have enough blocks to use all threads + barrier(); + if (i < num_blocks_per_row) + sccache[ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]); + barrier(); + + if (i >= num_blocks_per_row) + continue; + } + + const uint32_t ql0_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 1]) << 16); + const uint32_t ql32_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 16]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 17]) << 16); + + const uint32_t ql0_u32_lo4 = ql0_u32 & 0x0F0F0F0F; + const uint32_t ql0_u32_hi4 = (ql0_u32 >> 4) & 0x0F0F0F0F; + const uint32_t ql32_u32_lo4 = ql32_u32 & 0x0F0F0F0F; + const uint32_t ql32_u32_hi4 = (ql32_u32 >> 4) & 0x0F0F0F0F; + + const uint32_t qh_u32 = uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2 + 1]) << 16); + const uint32_t qh0_u32 = (qh_u32 & 0x03030303) << 4; + const uint32_t qh2_u32 = (qh_u32 & 0x0C0C0C0C) << 2; + const uint32_t qh4_u32 = (qh_u32 & 0x30303030); + const uint32_t qh6_u32 = (qh_u32 & 0xC0C0C0C0) >> 2; + + const uint32_t q0_u32 = ql0_u32_lo4 | qh0_u32; + const uint32_t q1_u32 = ql32_u32_lo4 | qh2_u32; + const uint32_t q2_u32 = ql0_u32_hi4 | qh4_u32; + const uint32_t q3_u32 = ql32_u32_hi4 | qh6_u32; + + const vec4 q0 = vec4(unpack8(q0_u32)) - 32; + const vec4 q1 = vec4(unpack8(q1_u32)) - 32; + const vec4 q2 = vec4(unpack8(q2_u32)) - 32; + const vec4 q3 = vec4(unpack8(q3_u32)) - 32; + + if (all_threads) { + barrier(); + sccache[ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]); + barrier(); + } + + const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + vec4 by0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 ]); + vec4 by32 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 8]); + vec4 by64 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 16]); + vec4 by96 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 24]); + + FLOAT_TYPE sum[4] = {0, 0, 0, 0}; + [[unroll]] for (uint l = 0; l < 4; ++l) { + sum[0] = fma(FLOAT_TYPE(by0[l]), q0[l], sum[0]); + sum[1] = fma(FLOAT_TYPE(by32[l]), q1[l], sum[1]); + sum[2] = fma(FLOAT_TYPE(by64[l]), q2[l], sum[2]); + sum[3] = fma(FLOAT_TYPE(by96[l]), q3[l], sum[3]); + } + temp[j][n] = fma(fma(sum[0], sccache[ix][s_offset], fma(sum[1], sccache[ix][s_offset + 2], fma(sum[2], sccache[ix][s_offset + 4], sum[3] * sccache[ix][s_offset + 6]))), d, temp[j][n]); + } + } +} + +void compute_outputs(const uint first_row, const uint num_rows) { uint a_offset, b_offset, d_offset; get_offsets(a_offset, b_offset, d_offset); @@ -15,13 +85,11 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { // 16 threads are used to process each block const uint it_size = gl_WorkGroupSize.x/16; const uint tid = gl_LocalInvocationID.x; - const uint itid = tid%16; // 0...16 - const uint ix = tid/16; + const uint itid = tid%16; // 0...15 + const uint ix = tid/16; - const uint step = 8; - - const uint v_im = itid/step; // 0 or 1. 0 computes 0..., 1 computes 128... - const uint v_in = itid - step*v_im; // 0...15 or 0...7 + const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128... + const uint v_in = itid - 8*v_im; // 0...7 const uint l0 = 4 * v_in; // 0, 4, 8, ..., 28 const uint is = v_in / 4; @@ -31,68 +99,18 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const uint s_offset = 8*v_im + is; const uint y_offset = 128*v_im + l0; - FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; - [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { temp[j][i] = FLOAT_TYPE(0); } } - [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { - const uint y_idx = i * QUANT_K + y_offset; - - [[unroll]] for (uint n = 0; n < num_rows; ++n) { - const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; - const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); - - FLOAT_TYPE scales[4]; - scales[0] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]); - scales[1] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]); - scales[2] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]); - scales[3] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]); - - uint32_t ql0_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 1]) << 16); - uint32_t ql32_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 16]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 17]) << 16); - - uint32_t ql0_u32_lo4 = ql0_u32 & 0x0F0F0F0F; - uint32_t ql0_u32_hi4 = (ql0_u32 >> 4) & 0x0F0F0F0F; - uint32_t ql32_u32_lo4 = ql32_u32 & 0x0F0F0F0F; - uint32_t ql32_u32_hi4 = (ql32_u32 >> 4) & 0x0F0F0F0F; - - uint32_t qh_u32 = uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2 + 1]) << 16); - uint32_t qh0_u32 = (qh_u32 & 0x03030303) << 4; - uint32_t qh2_u32 = (qh_u32 & 0x0C0C0C0C) << 2; - uint32_t qh4_u32 = (qh_u32 & 0x30303030) << 0; - uint32_t qh6_u32 = (qh_u32 & 0xC0C0C0C0) >> 2; - - uint32_t q0_u32 = ql0_u32_lo4 | qh0_u32; - uint32_t q1_u32 = ql32_u32_lo4 | qh2_u32; - uint32_t q2_u32 = ql0_u32_hi4 | qh4_u32; - uint32_t q3_u32 = ql32_u32_hi4 | qh6_u32; - - uvec4 q0 = uvec4(unpack8(q0_u32)); - uvec4 q1 = uvec4(unpack8(q1_u32)); - uvec4 q2 = uvec4(unpack8(q2_u32)); - uvec4 q3 = uvec4(unpack8(q3_u32)); - - [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { - B_TYPE_VEC4 by0 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4]; - B_TYPE_VEC4 by32 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 8]; - B_TYPE_VEC4 by64 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 16]; - B_TYPE_VEC4 by96 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 24]; - - FLOAT_TYPE sum = FLOAT_TYPE(0.0); - [[unroll]] for (int l = 0; l < 4; ++l) { - sum = fma(FLOAT_TYPE(by0[l]) * scales[0], FLOAT_TYPE(int8_t(q0[l]) - 32), - fma(FLOAT_TYPE(by32[l]) * scales[1], FLOAT_TYPE(int8_t(q1[l]) - 32), - fma(FLOAT_TYPE(by64[l]) * scales[2], FLOAT_TYPE(int8_t(q2[l]) - 32), - fma(FLOAT_TYPE(by96[l]) * scales[3], FLOAT_TYPE(int8_t(q3[l]) - 32), sum)))); - } - temp[j][n] += sum * d; - } - } - } + const uint nbr_par_th = num_blocks_per_row%it_size; + const uint nbr_all_th = num_blocks_per_row - nbr_par_th; + uint i0 = 0; + [[unroll]] for (; i0 < nbr_all_th; i0 += it_size) + calc_superblock(a_offset, b_offset, itid, ix, ql_offset, qh_offset, s_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, true); + calc_superblock(a_offset, b_offset, itid, ix, ql_offset, qh_offset, s_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, false); reduce_result(temp, d_offset, first_row, num_rows, tid); } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index 48122cbef..39657195c 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -6,6 +6,9 @@ #ifdef FLOAT16 #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require #endif +#if defined(DATA_A_IQ1_M) +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#endif #ifdef COOPMAT #extension GL_KHR_cooperative_matrix : enable @@ -95,8 +98,8 @@ shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS]; #endif void main() { -#if defined(DATA_A_IQ4_NL) - init_iq4nl_shmem(); +#ifdef NEEDS_INIT_IQ_SHMEM + init_iq_shmem(gl_WorkGroupSize); #endif #ifdef MUL_MAT_ID @@ -343,10 +346,8 @@ void main() { const uint qsshift = halfsplit * 2; // 0,2,4,6 const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128 - const int8_t us = int8_t(is < 4 ? (data_a[ib].scales[is-0] & 0xF) | (((data_a[ib].scales[is+8] >> 0) & 3) << 4) : - is < 8 ? (data_a[ib].scales[is-0] & 0xF) | (((data_a[ib].scales[is+4] >> 2) & 3) << 4) : - is < 12 ? (data_a[ib].scales[is-8] >> 4) | (((data_a[ib].scales[is+0] >> 4) & 3) << 4) : - (data_a[ib].scales[is-8] >> 4) | (((data_a[ib].scales[is-4] >> 6) & 3) << 4)); + const int8_t us = int8_t(((data_a[ib].scales[is % 8] >> (4 * int(is / 8))) & 0xF) + | (((data_a[ib].scales[8 + (is % 4)] >> (2 * int(is / 4))) & 3) << 4)); const float dl = float(data_a[ib].d) * float(us - 32); buf_a[buf_idx ] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi ] >> qsshift) & 3) - (((data_a[ib].hmask[hmi ] & m) != 0) ? 0 : 4))); @@ -439,6 +440,187 @@ void main() { buf_a[buf_idx ] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi ] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi ] >> qhshift) & 3) << 4)) - 32)); buf_a[buf_idx + 1] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32)); +#elif defined(DATA_A_IQ1_S) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; + + const uint ib = idx / 128; // 2 values per idx + const uint ib32 = (idx % 128) / 16; // 0..7 + const uint ib8 = (idx % 128) / 4; + const int i8 = 2 * int(idx % 4); + + const float d = float(data_a[ib].d); + const uint qh = data_a[ib].qh[ib32]; + const uint qs = data_a[ib].qs[ib8]; + const float dl = d * (2 * bitfieldExtract(qh, 12, 3) + 1); + const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA; + const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]); + + const ivec2 gvec = ivec2( + bitfieldExtract(grid, 2 * (i8), 2), + bitfieldExtract(grid, 2 * (i8 + 1), 2) + ); + const vec2 v = dl * (vec2(gvec) + delta); + + buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); +#elif defined(DATA_A_IQ1_M) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; + + const uint ib = idx / 128; // 2 values per idx + const uint ib8 = (idx % 128) / 4; + const uint ib16 = ib8 / 2; + const int i8 = 2 * int(idx % 4); + + const uint16_t[4] scales = data_a[ib].scales; + const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12; + const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x); + const uint sc = scales[ib8 / 8]; + const uint qs = data_a[ib].qs[ib8]; + const uint qh = data_a[ib].qh[ib16] >> (4 * (ib8 & 1)); + const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1); + const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA; + const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]); + const ivec2 gvec = ivec2( + bitfieldExtract(grid, 2 * (i8), 2), + bitfieldExtract(grid, 2 * (i8 + 1), 2) + ); + const vec2 v = dl * (vec2(gvec) + delta); + + buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); +#elif defined(DATA_A_IQ2_XXS) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; + + const uint ib = idx / 128; // 2 values per idx + const uint ib32 = (idx % 128) / 16; // 0..7 + const uint ib8 = (idx / 4) % 4; + + const float d = float(data_a[ib].d); + const uint qs = data_a[ib].qs[8 * ib32 + ib8]; + const uint signs = pack32(u8vec4( + data_a[ib].qs[8*ib32 + 4], + data_a[ib].qs[8*ib32 + 5], + data_a[ib].qs[8*ib32 + 6], + data_a[ib].qs[8*ib32 + 7] + )); + const float db = d * 0.25 * (0.5 + (signs >> 28)); + const uint32_t sign7 = bitfieldExtract(signs, 7 * int(ib8), 7); + const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4)); + const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); + const uint grid = iq2xxs_grid[qs][(idx % 4) / 2] >> (16 * (idx & 1)); + const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); + + buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); +#elif defined(DATA_A_IQ2_XS) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; + + const uint ib = idx / 128; // 2 values per idx + const uint ib32 = (idx % 128) / 16; // 0..7 + const uint ib8 = (idx / 4) % 4; // 0..3 + + const float d = float(data_a[ib].d); + const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf; + const float db = d * 0.25 * (0.5 + scale); + const uint qs = data_a[ib].qs[4 * ib32 + ib8]; + const uint sign7 = qs >> 9; + const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4)); + const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); + const uint grid = iq2xs_grid[qs & 511][(idx % 4) / 2] >> (16 * (idx & 1)); + const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); + + buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); +#elif defined(DATA_A_IQ2_S) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; + + const uint ib = idx / 128; // 2 values per idx + const uint ib8 = (idx % 128) / 4; // 0..31 + const uint ib32 = ib8 / 4; // 0..7 + + const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf; + const uint qs = data_a[ib].qs[ib8]; + const uint qh = data_a[ib].qh[ib32]; + const uint qhshift = 2 * (ib8 % 4); + const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8] >> (2 * (idx % 4)); + + const float d = float(data_a[ib].d); + const float db = d * 0.25 * (0.5 + scale); + const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); + const uint16_t grid = unpack16(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 2) >> 1])[idx & 1]; + const vec2 v = db * vec2(sign01) * vec2(unpack8(grid)); + + buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); +#elif defined(DATA_A_IQ3_XXS) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; + + const uint ib = idx / 128; // 2 values per idx + const uint iqs = (idx % 128) / 2; // 0..63 + const uint is = QUANT_K / 4 + 4 * (iqs / 8); // 8 values + + const float d = float(data_a[ib].d); + const uint qs = data_a[ib].qs[iqs]; + const uint signs = pack32(u8vec4( + data_a[ib].qs[is+0], + data_a[ib].qs[is+1], + data_a[ib].qs[is+2], + data_a[ib].qs[is+3] + )); + const float db = d * 0.5 * (0.5 + (signs >> 28)); + const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7); + const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4)); + const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); + const uint grid = iq3xxs_grid[qs] >> (16 * (idx & 1)); + const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); + + buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); +#elif defined(DATA_A_IQ3_S) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; + + const uint ib = idx / 128; // 2 values per idx + const uint iqs = (idx % 128) / 2; // 0..63 + const uint iqh = iqs / 8; + + const float d = float(data_a[ib].d); + const uint qs = data_a[ib].qs[iqs]; + const uint qh = data_a[ib].qh[iqh]; + const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (2 * (idx % 4))); + const uint scale = data_a[ib].scales[iqs / 16]; + const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign))); + const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf)); + const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> (16 * (idx % 2)); + const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); + + buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); +#elif defined(DATA_A_IQ4_XS) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; + + const uint ib = idx / 128; // 2 values per idx + const uint ib32 = (idx % 128) / 16; // 0..7 + const uint iq = 16 * ib32 + 2 * (idx % 8); + + const uint sl = (data_a[ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF; + const uint sh = ((data_a[ib].scales_h) >> (2 * ib32)) & 3; + const uint qshift = (idx & 8) >> 1; + u8vec2 qs = u8vec2(data_a[ib].qs[iq], data_a[ib].qs[iq + 1]); + qs = (qs >> qshift) & uint8_t(0xF); + + const float d = float(data_a[ib].d); + const vec2 v = d * float(int(sl | (sh << 4)) - 32) * vec2(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y]); + + buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); #elif defined(DATA_A_IQ4_NL) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a; diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp index cbfa5dce1..66dd2c860 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp @@ -57,17 +57,13 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; #if QUANT_K > 1 #define DECODEFUNCA , dequantFuncA -#define MAT_A_TYPE float16_t #include "dequant_funcs_cm2.comp" #else #define DECODEFUNCA -#define MAT_A_TYPE A_TYPE #endif -#define MAT_B_TYPE B_TYPE - #ifdef MUL_MAT_ID layout (binding = 3) readonly buffer IDS {int data_ids[];}; @@ -110,8 +106,8 @@ D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem #endif void main() { -#if defined(DATA_A_IQ4_NL) - init_iq4nl_shmem(); +#ifdef NEEDS_INIT_IQ_SHMEM + init_iq_shmem(gl_WorkGroupSize); #endif #ifdef MUL_MAT_ID @@ -236,16 +232,13 @@ void main() { for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) { - coopmat mat_a; - coopmat mat_b; + coopmat mat_a; + coopmat mat_b; coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); - coopmat mat_a_ft = coopmat(mat_a); - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose); - coopmat mat_b_ft = coopmat(mat_b); - sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum); + sum = coopMatMulAdd(mat_a, mat_b, sum); } } else #endif // !defined(MUL_MAT_ID) @@ -261,10 +254,8 @@ void main() { [[dont_unroll]] for (uint block_k = start_k; block_k < end_k; block_k += BK) { - coopmat mat_a; - coopmat mat_b; - coopmat mat_a_ft; - coopmat mat_b_ft; + coopmat mat_a; + coopmat mat_b; // Clamping is expensive, so detect different code paths for each combination // of A and B needing clamping. @@ -281,16 +272,12 @@ void main() { #else coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, (block_k & ~7), BK), tensorViewTranspose); #endif - mat_a_ft = coopmat(mat_a); - mat_b_ft = coopmat(mat_b); - sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum); + sum = coopMatMulAdd(mat_a, mat_b, sum); } else if (unclampedA && !unclampedB) { coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, (block_k & ~7), BK) DECODEFUNCA); coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose); - mat_a_ft = coopmat(mat_a); - mat_b_ft = coopmat(mat_b); - sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum); + sum = coopMatMulAdd(mat_a, mat_b, sum); } else if (!unclampedA && unclampedB) { coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); #ifdef MUL_MAT_ID @@ -298,16 +285,12 @@ void main() { #else coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, (block_k & ~7), BK), tensorViewTranspose); #endif - mat_a_ft = coopmat(mat_a); - mat_b_ft = coopmat(mat_b); - sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum); + sum = coopMatMulAdd(mat_a, mat_b, sum); } else if (!unclampedA && !unclampedB) { coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose); - mat_a_ft = coopmat(mat_a); - mat_b_ft = coopmat(mat_b); - sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum); + sum = coopMatMulAdd(mat_a, mat_b, sum); } } } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp new file mode 100644 index 000000000..e0214fe76 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp @@ -0,0 +1,42 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) buffer X {A_TYPE x[];}; +layout (binding = 1) readonly buffer G {A_TYPE grad[];}; +layout (binding = 2) buffer GM {A_TYPE gradm[];}; +layout (binding = 3) buffer GV {A_TYPE gradv[];}; +layout (binding = 4) readonly buffer P {float params[7];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float alpha = params[0]; + const float beta1 = params[1]; + const float beta2 = params[2]; + const float eps = params[3]; + const float wd = params[4]; + const float beta1h = params[5]; + const float beta2h = params[6]; + + const float gi = grad[i]; + const float gmi = gradm[i]*beta1 + gi*(1.0f - beta1); + const float gvi = gradv[i]*beta2 + gi*gi*(1.0f - beta2); + + gradm[i] = gmi; + gradv[i] = gvi; + + const float mh = gmi*beta1h; + const float vh = sqrt(gvi*beta2h) + eps; + + x[i] = x[i]*(1.0f - alpha*wd) - alpha*mh/vh; +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp new file mode 100644 index 000000000..d86279934 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp @@ -0,0 +1,37 @@ +#version 450 + +#include "types.comp" +#include "generic_unary_head.comp" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint idx = get_idx(); + + if (idx >= p.ne) { + return; + } + + // Destination multi-index (inlined dst_idx) + const uint i13 = fastdiv(idx, p.ne1_012mp, p.ne1_012L); + const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10; + const uint i12 = fastdiv(idx - i13_offset, p.ne1_01mp, p.ne1_01L); + const uint i12_offset = i12*p.ne11*p.ne10; + const uint i11 = fastdiv(idx - i13_offset - i12_offset, p.ne1_0mp, p.ne1_0L); + const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10; + const uint d_idx = i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + i10*p.nb10; + + // Accumulate from sources + A_TYPE acc = A_TYPE(0); + for (uint i3 = i13; i3 < p.ne03; i3 += p.ne13) { + for (uint i2 = i12; i2 < p.ne02; i2 += p.ne12) { + for (uint i1 = i11; i1 < p.ne01; i1 += p.ne11) { + for (uint i0 = i10; i0 < p.ne00; i0 += p.ne10) { + acc += data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0*p.nb00]; + } + } + } + } + + data_d[get_doffset() + d_idx] = D_TYPE(acc); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp new file mode 100644 index 000000000..76009f3df --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp @@ -0,0 +1,55 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable +#define BLOCK_SIZE 512 + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer G {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer X {B_TYPE data_b[];}; +layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; + +shared FLOAT_TYPE sum_xx[BLOCK_SIZE]; +shared FLOAT_TYPE sum_xg[BLOCK_SIZE]; + +void main() { + const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; + const uint tid = gl_LocalInvocationID.x; + + // Compute derivative of x[i]/norm(x) = g[i]/norm(x) - x[i] dot(x,g)/KX / norm(x)^1.5 + + // partial sums for thread in warp + sum_xx[tid] = FLOAT_TYPE(0.0f); + sum_xg[tid] = FLOAT_TYPE(0.0f); + + [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { + const FLOAT_TYPE gi = FLOAT_TYPE(data_a[row*p.KX + col]); + const FLOAT_TYPE xi = FLOAT_TYPE(data_b[row*p.KX + col]); + sum_xx[tid] += xi * xi; + sum_xg[tid] += xi * gi; + } + + // sum up partial sums and write back result + barrier(); + [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + sum_xx[tid] += sum_xx[tid + s]; + sum_xg[tid] += sum_xg[tid + s]; + } + barrier(); + } + + const FLOAT_TYPE eps = FLOAT_TYPE(p.param1); + const FLOAT_TYPE mean = sum_xx[0] / FLOAT_TYPE(p.KX); + const FLOAT_TYPE scale_g = inversesqrt(mean + eps); + const FLOAT_TYPE scale_x = -scale_g * sum_xg[0] / (sum_xx[0] + FLOAT_TYPE(p.KX) * eps); + + [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { + data_d[row*p.KX + col] = D_TYPE( + scale_g * FLOAT_TYPE(data_a[row*p.KX + col]) + + scale_x * FLOAT_TYPE(data_b[row*p.KX + col])); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp index 574b51ca5..96c9c4cbd 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp @@ -25,6 +25,11 @@ layout (push_constant) uniform parameter { float corr_dims[2]; float theta_scale; uint has_ff; + uint ne02; + uint s1; + uint s2; + int sections[4]; + uint is_back; } p; float rope_yarn_ramp(const float low, const float high, const uint i0) { @@ -44,6 +49,10 @@ void rope_yarn(const float theta_extrap, const uint i0, out float cos_theta, out // Get n-d magnitude scaling corrected for interpolation mscale *= 1.0f + 0.1f * log(1.0f / p.freq_scale); } + // Backprogagation uses inverted rotation + if (p.is_back != 0) { + theta = -theta; + } cos_theta = cos(theta) * mscale; sin_theta = sin(theta) * mscale; } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp new file mode 100644 index 000000000..4f5b1a0ec --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp @@ -0,0 +1,60 @@ +#version 450 + +#include "rope_head.comp" + +void main() { + const uint i0 = 2*gl_GlobalInvocationID.y; + uint ne0 = p.ncols; + uint ne1 = p.p_delta_rows; + uint ne2 = p.ne02; + + if (i0 >= ne0) { + return; + } + + const uint row_dst = gl_GlobalInvocationID.x; + + if (i0 >= p.n_dims) { + const uint i = row_dst*ne0 + i0; + + data_d[i + 0] = data_a[i + 0]; + data_d[i + 1] = data_a[i + 1]; + + return; + } + + const uint row_x = row_dst % ne1; + const uint channel_x = row_dst / ne1; + + const uint idst = row_dst*ne0 + i0/2; + const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2; + + const int sect_dims = p.sections[0] + p.sections[1] + p.sections[2] + p.sections[3]; + const int sec_w = p.sections[1] + p.sections[0]; + const uint sector = (i0 / 2) % sect_dims; + + float theta_base = 0.0; + if (sector < p.sections[0]) { + theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f); + } + else if (sector >= p.sections[0] && sector < sec_w) { + theta_base = data_pos[channel_x + ne2 * 1]*pow(p.theta_scale, i0/2.0f); + } + else if (sector >= sec_w && sector < sec_w + p.sections[2]) { + theta_base = data_pos[channel_x + ne2 * 2]*pow(p.theta_scale, i0/2.0f); + } + else if (sector >= sec_w + p.sections[2]) { + theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f); + } + + const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f; + + float cos_theta, sin_theta; + rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta); + + const float x0 = float(data_a[ix + 0]); + const float x1 = float(data_a[ix + p.n_dims/2]); + + data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta); + data_d[idst + p.n_dims/2] = D_TYPE(x0*sin_theta + x1*cos_theta); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp index 83b46b69b..db775c456 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp @@ -3,15 +3,18 @@ #include "rope_head.comp" void main() { - const uint col = gl_GlobalInvocationID.y * 2; - const uint row = gl_GlobalInvocationID.x; + const uint i0 = 2*gl_GlobalInvocationID.y; + uint ne0 = p.ncols; + uint ne1 = p.p_delta_rows; - if (col >= p.ncols) { + if (i0 >= ne0) { return; } - if (col >= p.n_dims) { - const uint i = row*p.ncols + col; + const uint row_dst = gl_GlobalInvocationID.x; + + if (i0 >= p.n_dims) { + const uint i = row_dst*ne0 + i0; data_d[i + 0] = data_a[i + 0]; data_d[i + 1] = data_a[i + 1]; @@ -19,19 +22,22 @@ void main() { return; } - const uint i = row*p.ncols + col/2; - const uint i2 = row/p.p_delta_rows; + const uint row_x = row_dst % ne1; + const uint channel_x = row_dst / ne1; - const float theta_base = data_pos[i2] * pow(p.theta_scale, col/2.0f); + const uint idst = row_dst*ne0 + i0/2; + const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2; - const float freq_factor = p.has_ff != 0 ? data_ff[col/2] : 1.0f; + const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f); + + const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f; float cos_theta, sin_theta; - rope_yarn(theta_base / freq_factor, col, cos_theta, sin_theta); + rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta); - const float x0 = float(data_a[i + 0]); - const float x1 = float(data_a[i + p.n_dims/2]); + const float x0 = float(data_a[ix + 0]); + const float x1 = float(data_a[ix + p.n_dims/2]); - data_d[i + 0] = D_TYPE(x0*cos_theta - x1*sin_theta); - data_d[i + p.n_dims/2] = D_TYPE(x0*sin_theta + x1*cos_theta); + data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta); + data_d[idst + p.n_dims/2] = D_TYPE(x0*sin_theta + x1*cos_theta); } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp index e416ad938..4ad35e549 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp @@ -3,15 +3,18 @@ #include "rope_head.comp" void main() { - const uint col = gl_GlobalInvocationID.y * 2; - const uint row = gl_GlobalInvocationID.x; + const uint i0 = 2*gl_GlobalInvocationID.y; + uint ne0 = p.ncols; + uint ne1 = p.p_delta_rows; - if (col >= p.ncols) { + if (i0 >= ne0) { return; } - if (col >= p.n_dims) { - const uint i = row*p.ncols + col; + const uint row_dst = gl_GlobalInvocationID.x; + + if (i0 >= p.n_dims) { + const uint i = row_dst*ne0 + i0; data_d[i + 0] = data_a[i + 0]; data_d[i + 1] = data_a[i + 1]; @@ -19,19 +22,22 @@ void main() { return; } - const uint i = row*p.ncols + col; - const uint i2 = row/p.p_delta_rows; + const uint row_x = row_dst % ne1; + const uint channel_x = row_dst / ne1; - const float theta_base = data_pos[i2] * pow(p.theta_scale, col/2.0f); + const uint idst = row_dst*ne0 + i0; + const uint ix = channel_x*p.s2 + row_x*p.s1 + i0; - const float freq_factor = p.has_ff != 0 ? data_ff[col/2] : 1.0f; + const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f); + + const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f; float cos_theta, sin_theta; - rope_yarn(theta_base / freq_factor, col, cos_theta, sin_theta); + rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta); - const float x0 = float(data_a[i + 0]); - const float x1 = float(data_a[i + 1]); + const float x0 = float(data_a[ix + 0]); + const float x1 = float(data_a[ix + 1]); - data_d[i + 0] = D_TYPE(x0*cos_theta - x1*sin_theta); - data_d[i + 1] = D_TYPE(x0*sin_theta + x1*cos_theta); + data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta); + data_d[idst + 1] = D_TYPE(x0*sin_theta + x1*cos_theta); } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp new file mode 100644 index 000000000..cedacc4d1 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp @@ -0,0 +1,47 @@ +#version 450 + +#include "rope_head.comp" + +void main() { + const uint i0 = 2*gl_GlobalInvocationID.y; + uint ne0 = p.ncols; + uint ne1 = p.p_delta_rows; + uint ne2 = p.ne02; + + if (i0 >= ne0) { + return; + } + + const uint row_dst = gl_GlobalInvocationID.x; + + const uint row_x = row_dst % ne1; + const uint channel_x = row_dst / ne1; + + const uint idst = row_dst*ne0 + i0/2; + const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2; + + const int sect_dims = p.sections[0] + p.sections[1]; + const int sec_w = p.sections[1] + p.sections[0]; + const uint sector = (i0 / 2) % sect_dims; + + float theta_base = 0.0; + if (sector < p.sections[0]) { + const uint p0 = sector; + theta_base = data_pos[channel_x]*pow(p.theta_scale, p0); + } + else if (sector >= p.sections[0] && sector < sec_w) { + const uint p0 = sector - p.sections[0]; + theta_base = data_pos[channel_x + ne2]*pow(p.theta_scale, p0); + } + + const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f; + + float cos_theta, sin_theta; + rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta); + + const float x0 = float(data_a[ix + 0]); + const float x1 = float(data_a[ix + p.n_dims]); + + data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta); + data_d[idst + p.n_dims] = D_TYPE(x0*sin_theta + x1*cos_theta); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp new file mode 100644 index 000000000..776581e2c --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp @@ -0,0 +1,20 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + data_d[i] = D_TYPE(1. / (1 + exp(-1. *data_a[i]))); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp new file mode 100644 index 000000000..f9afa9b13 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp @@ -0,0 +1,26 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer G {A_TYPE data_g[];}; +layout (binding = 1) readonly buffer X {B_TYPE data_x[];}; +layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + // Compute derivative of SiLU(x): 1/(1+exp(-x)) - x*exp(-x)/(1+exp(-x))^2 + + const float xi = float(data_x[i]); + const float s = 1.0f / (1.0f + exp(-xi)); + data_d[i] = D_TYPE(data_g[i] * (s + xi * s * (1 - s))); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp index a25808e16..51fc2dc7e 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp @@ -1,6 +1,5 @@ #version 450 -#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require #extension GL_EXT_control_flow_attributes : enable layout (push_constant) uniform parameter diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp new file mode 100644 index 000000000..29bd77d7e --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp @@ -0,0 +1,50 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable + +#include "generic_head.comp" +#include "types.comp" + +layout(constant_id = 0) const uint BLOCK_SIZE = 32; +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +// In this shader Y = softmax(X) and X is not provided as input. + +layout (binding = 0) readonly buffer G {A_TYPE data_g[];}; +layout (binding = 1) readonly buffer Y {B_TYPE data_y[];}; +layout (binding = 2) buffer D {D_TYPE data_d[];}; + +shared FLOAT_TYPE sum_yg[BLOCK_SIZE]; + +void main() { + const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; + const uint tid = gl_LocalInvocationID.x; + + FLOAT_TYPE scale = p.param1; + + // partial sums for thread in warp + sum_yg[tid] = FLOAT_TYPE(0.0f); + + [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { + const FLOAT_TYPE gi = FLOAT_TYPE(data_g[row*p.KX + col]); + const FLOAT_TYPE yi = FLOAT_TYPE(data_y[row*p.KX + col]); + sum_yg[tid] += yi * gi; + } + + // sum up partial sums and write back result + barrier(); + [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + sum_yg[tid] += sum_yg[tid + s]; + } + barrier(); + } + + const FLOAT_TYPE dot_yg = sum_yg[0]; + + [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { + data_d[row*p.KX + col] = D_TYPE(scale + * (FLOAT_TYPE(data_g[row*p.KX + col]) - dot_yg) + * FLOAT_TYPE(data_y[row*p.KX + col])); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp new file mode 100644 index 000000000..72353cc32 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp @@ -0,0 +1,29 @@ +#version 450 + +#extension GL_EXT_shader_16bit_storage : require + +#include "types.comp" +#include "generic_binary_head.comp" + +const uint num_threads = 256; + +layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; + +void main() { + uint idx = get_idx(); + + // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation + const uint num_iter = 2; + + [[unroll]] for (uint i = 0; i < num_iter; ++i) { + if (idx >= p.ne) { + continue; + } + uint i00, i01, i02, i03; + get_indices(idx, i00, i01, i02, i03); + + data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) - FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)])); + + idx += num_threads; + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat_support.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat_support.comp new file mode 100644 index 000000000..8c5dd1bd1 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat_support.comp @@ -0,0 +1,7 @@ +#version 460 + +#extension GL_KHR_cooperative_matrix : require + +void main() +{ +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/types.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/types.comp index eecc47f3a..dfa16cda5 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/types.comp @@ -2,7 +2,10 @@ #if !defined(GGML_TYPES_COMP) #define GGML_TYPES_COMP -#extension GL_EXT_shader_explicit_arithmetic_types : require +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require +#extension GL_EXT_shader_16bit_storage : require #if defined(DATA_A_F32) #define QUANT_K 1 @@ -224,6 +227,11 @@ struct block_q4_K_packed32 uint32_t qs[QUANT_K_Q4_K/2/4]; }; +struct block_q4_K_packed128 +{ + uvec4 q4k[9]; +}; + #if defined(DATA_A_Q4_K) #define QUANT_K QUANT_K_Q4_K #define A_TYPE block_q4_K @@ -249,6 +257,11 @@ struct block_q5_K_packed16 uint16_t qs[QUANT_K_Q5_K/2/2]; }; +struct block_q5_K_packed128 +{ + uvec4 q5k[11]; +}; + #if defined(DATA_A_Q5_K) #define QUANT_K QUANT_K_Q5_K #define A_TYPE block_q5_K @@ -281,6 +294,941 @@ struct block_q6_K_packed16 // IQuants +#define QUANT_K_IQ1_S 256 +#define QUANT_R_IQ1_S 1 + +struct block_iq1_s { + float16_t d; + uint8_t qs[QUANT_K_IQ1_S/8]; + uint16_t qh[QUANT_K_IQ1_S/32]; +}; + +#define QUANT_K_IQ1_M 256 +#define QUANT_R_IQ1_M 1 + +struct block_iq1_m { + uint8_t qs[QUANT_K_IQ1_M/8]; + uint8_t qh[QUANT_K_IQ1_M/16]; + uint16_t scales[QUANT_K_IQ1_M/64]; +}; + +#if defined(DATA_A_IQ1_S) +#define QUANT_K QUANT_K_IQ1_S +#define QUANT_R QUANT_R_IQ1_S +#define A_TYPE block_iq1_s +#endif + +#if defined(DATA_A_IQ1_M) +#define QUANT_K QUANT_K_IQ1_M +#define QUANT_R QUANT_R_IQ1_M +#define A_TYPE block_iq1_m +#endif + +#if defined(DATA_A_IQ1_S) || defined(DATA_A_IQ1_M) +#define IQ1S_DELTA 0.125f +#define IQ1M_DELTA 0.125f + +// Packed IQ1S grid where every 2 vec8 are encoded on 32 bits (2 bits per coordinate). +const uint[1024] iq1s_grid_const = { + 0xfffdffff, 0xfff7fff0, 0xffccfff5, 0xffdfffc0, 0xffd7ffdd, 0xff30ffd5, 0xff03ff0c, 0xff10ff01, + 0xff7dff7f, 0xff75ff77, 0xff5fff40, 0xff57ff5d, 0xfcf3ff55, 0xfcccfcf0, 0xfcc1fcc3, 0xfcc5fcc4, + 0xfc3cfcd0, 0xfc34fc31, 0xfc00fc0d, 0xfc1cfc05, 0xfc11fc13, 0xfc70fc17, 0xfc43fc4c, 0xfc50fc41, + 0xfdfdfdff, 0xfdf5fdf7, 0xfddffdc0, 0xfdd7fddd, 0xfd30fdd5, 0xfd04fd0c, 0xfd14fd13, 0xfd7dfd7f, + 0xfd75fd77, 0xfd40fd4c, 0xfd5ffd44, 0xfd57fd5d, 0xf3ccfd55, 0xf3c1f3c3, 0xf33cf3d0, 0xf300f334, + 0xf313f305, 0xf34cf310, 0xf350f344, 0xf0f3f0fc, 0xf0f1f0f0, 0xf0c7f0c0, 0xf0d4f0c5, 0xf030f03f, + 0xf00ff035, 0xf003f00c, 0xf001f000, 0xf01ff004, 0xf010f01d, 0xf015f017, 0xf04cf07c, 0xf047f040, + 0xf05cf045, 0xf050f053, 0xf054f051, 0xf1c4f1c3, 0xf133f13c, 0xf10df10f, 0xf107f100, 0xf11cf11f, + 0xf114f111, 0xf14cf170, 0xf144f143, 0xf7fdf7ff, 0xf7f5f7f7, 0xf7dff7c0, 0xf7d7f7dd, 0xf730f7d5, + 0xf701f70c, 0xf77ff710, 0xf777f77d, 0xf740f775, 0xf75df75f, 0xf755f757, 0xf4ccf4f0, 0xf4c4f4c3, + 0xf4d0f4d3, 0xf40ff43c, 0xf400f40c, 0xf413f41c, 0xf44cf414, 0xf441f443, 0xf450f444, 0xf5fdf5ff, + 0xf5f5f5f7, 0xf5dff5c0, 0xf5d7f5dd, 0xf530f5d5, 0xf504f50c, 0xf510f51c, 0xf57df57f, 0xf577f570, + 0xf540f575, 0xf55df55f, 0xf555f557, 0xcfcccfcf, 0xcfc4cfc3, 0xcfd0cfd3, 0xcf33cf3c, 0xcf00cf0f, + 0xcf1ccf07, 0xcf10cf13, 0xcf4ccf14, 0xcf41cf43, 0xcf50cf5c, 0xccf3ccfc, 0xccf4ccf1, 0xcccdcccf, + 0xccc7ccc0, 0xccd3ccdc, 0xcc30ccd4, 0xcc0fcc35, 0xcc0dcc0c, 0xcc00cc03, 0xcc04cc01, 0xcc10cc1f, + 0xcc4dcc73, 0xcc5ccc40, 0xcdcccc53, 0xcdc1cdc3, 0xcd3fcdd0, 0xcd34cd31, 0xcd00cd0d, 0xcd05cd07, + 0xcd11cd13, 0xcd4ccd70, 0xcd41cd43, 0xc3fccd50, 0xc3f4c3f1, 0xc3c0c3c3, 0xc3c4c3c7, 0xc3d1c3dc, + 0xc330c33c, 0xc337c331, 0xc30cc335, 0xc300c303, 0xc304c301, 0xc310c31d, 0xc373c317, 0xc34fc374, + 0xc340c343, 0xc344c347, 0xc35cc345, 0xc350c353, 0xc0fdc354, 0xc0f5c0f0, 0xc0c3c0cc, 0xc0c1c0c0, + 0xc0dfc0c4, 0xc0d0c0dd, 0xc0d5c0d7, 0xc033c03c, 0xc031c030, 0xc00dc00c, 0xc000c003, 0xc004c001, + 0xc01cc005, 0xc010c013, 0xc014c011, 0xc07dc07f, 0xc070c073, 0xc075c077, 0xc04cc04f, 0xc040c043, + 0xc044c041, 0xc05fc045, 0xc050c05d, 0xc1f3c1fc, 0xc1f1c1f0, 0xc1c1c1c0, 0xc1c5c1c7, 0xc1d1c1dc, + 0xc13dc13f, 0xc130c133, 0xc135c137, 0xc100c10c, 0xc107c101, 0xc11cc104, 0xc110c113, 0xc114c117, + 0xc171c115, 0xc14dc175, 0xc153c140, 0xc7ccc154, 0xc7d0c7c1, 0xc733c73c, 0xc734c731, 0xc700c70f, + 0xc705c707, 0xc71cc71f, 0xc711c713, 0xc770c714, 0xc743c74c, 0xc4cfc750, 0xc4c0c4cd, 0xc4dcc4c5, + 0xc43dc4d0, 0xc430c433, 0xc40cc437, 0xc400c403, 0xc404c401, 0xc41fc405, 0xc415c410, 0xc44cc474, + 0xc440c44d, 0xc45cc447, 0xc454c451, 0xc5c1c5f4, 0xc5d1c5d3, 0xc531c533, 0xc50fc534, 0xc500c50d, + 0xc51cc507, 0xc514c511, 0xc54cc570, 0xc545c541, 0xdffddfff, 0xdff5dff7, 0xdfdfdfc0, 0xdfd0dfdd, + 0xdfd5dfd7, 0xdf0cdf30, 0xdf1cdf04, 0xdf7fdf10, 0xdf77df7d, 0xdf40df75, 0xdf5ddf5f, 0xdf57df50, + 0xdcf0df55, 0xdcc3dccc, 0xdcd0dcc4, 0xdc33dc3d, 0xdc00dc34, 0xdc05dc07, 0xdc13dc1c, 0xdc11dc10, + 0xdc4fdc70, 0xdc44dc41, 0xddfcdc50, 0xddf5ddf7, 0xddc0ddcc, 0xdddddddf, 0xddd5ddd7, 0xdd0cdd30, + 0xdd04dd01, 0xdd7cdd10, 0xdd75dd77, 0xdd40dd4c, 0xdd5ddd5f, 0xdd55dd57, 0xd3c3d3f0, 0xd3c4d3c1, + 0xd333d3d0, 0xd331d330, 0xd30dd334, 0xd307d300, 0xd311d305, 0xd34cd370, 0xd344d343, 0xd350d35c, + 0xd0c0d0f4, 0xd0d4d0dc, 0xd030d03f, 0xd00cd037, 0xd000d003, 0xd01dd004, 0xd017d010, 0xd04fd074, + 0xd040d043, 0xd045d047, 0xd053d05c, 0xd054d051, 0xd1cfd1f0, 0xd1c4d1cd, 0xd13cd1d0, 0xd100d134, + 0xd11cd11f, 0xd173d114, 0xd14fd171, 0xd7ffd145, 0xd7f7d7fd, 0xd7c0d7f5, 0xd7ddd7df, 0xd7d5d7d7, + 0xd70cd730, 0xd710d703, 0xd77dd77f, 0xd775d777, 0xd75dd75f, 0xd755d757, 0xd4ccd4f4, 0xd4c4d4c3, + 0xd431d4d0, 0xd40dd434, 0xd41cd400, 0xd411d413, 0xd470d414, 0xd441d44f, 0xd453d444, 0xd5ffd450, + 0xd5f7d5fd, 0xd5dfd5f5, 0xd5d7d5dd, 0xd530d5d5, 0xd501d50c, 0xd510d504, 0xd57dd57f, 0xd575d577, + 0xd55fd540, 0xd557d55d, 0x3ff0d555, 0x3fc13fcc, 0x3f343fd0, 0x3f003f0d, 0x3f053f07, 0x3f133f1c, + 0x3f433f11, 0x3f5c3f44, 0x3cff3f51, 0x3cf33cfc, 0x3cf43cf1, 0x3cc03ccd, 0x3cc73cc1, 0x3cdc3cc5, + 0x3cd43cd1, 0x3c373c30, 0x3c0c3c35, 0x3c003c03, 0x3c043c01, 0x3c103c05, 0x3c153c17, 0x3c733c7c, + 0x3c4f3c71, 0x3c403c4d, 0x3c5c3c5f, 0x3df03c5d, 0x3dc33dcc, 0x3dd03dc1, 0x3d0d3d3c, 0x3d053d00, + 0x3d143d13, 0x3d433d74, 0x33fc3d50, 0x33c433c0, 0x333033d4, 0x33353337, 0x3303330c, 0x33013300, + 0x331d331c, 0x33173310, 0x337c3315, 0x33743371, 0x334d334f, 0x335f3340, 0x3354335c, 0x30fd30fc, + 0x30f530f0, 0x30c330cc, 0x30c130c0, 0x30df30c4, 0x30d530d0, 0x3033303c, 0x30313030, 0x300f3034, + 0x3003300c, 0x30013000, 0x30043007, 0x3013301c, 0x30113010, 0x307d3014, 0x30703073, 0x304c3077, + 0x30403043, 0x30443041, 0x30503045, 0x30553057, 0x31f031fc, 0x31c331f4, 0x31c731c0, 0x31dc31c5, + 0x31d431d3, 0x313d313f, 0x31373130, 0x310c310f, 0x3100310d, 0x31043101, 0x3110311d, 0x317c3117, + 0x31753170, 0x31403143, 0x3153315c, 0x37f03151, 0x37c037cc, 0x37d037c5, 0x3734373d, 0x3700370f, + 0x371c3707, 0x37113713, 0x37703714, 0x3743374c, 0x37443741, 0x34fc3750, 0x34f134f0, 0x34cf34f5, + 0x34c034c3, 0x34dc34c7, 0x34d134d3, 0x3430343f, 0x340c3435, 0x3403340d, 0x34013400, 0x341f3404, + 0x3410341d, 0x34153411, 0x34743471, 0x3440344d, 0x34473441, 0x3453345c, 0x34543451, 0x353335c1, + 0x35343531, 0x35073500, 0x35133505, 0x35433514, 0x0ffc3550, 0x0ff00ff3, 0x0ff40ff1, 0x0fc00fcd, + 0x0fdc0fc5, 0x0fd40fd3, 0x0f300f3f, 0x0f0c0f37, 0x0f000f03, 0x0f040f01, 0x0f170f10, 0x0f740f71, + 0x0f470f40, 0x0f5c0f5f, 0x0f540f51, 0x0cf70cf0, 0x0cf50cf4, 0x0cc30ccc, 0x0cc10cc0, 0x0cc40cc7, + 0x0cd00cdf, 0x0cd70cd1, 0x0c3c0cd5, 0x0c300c33, 0x0c340c31, 0x0c0c0c0f, 0x0c030c0d, 0x0c010c00, + 0x0c040c07, 0x0c1c0c05, 0x0c100c13, 0x0c140c11, 0x0c700c7d, 0x0c430c4c, 0x0c410c40, 0x0c5f0c44, + 0x0c550c50, 0x0df10dfc, 0x0dc00dcd, 0x0ddc0dc5, 0x0d3d0dd3, 0x0d350d30, 0x0d030d0c, 0x0d010d00, + 0x0d1d0d04, 0x0d700d10, 0x0d4d0d4f, 0x0d440d40, 0x0d530d45, 0x03f003f3, 0x03c303cc, 0x03c103c0, + 0x03c403c7, 0x03d003dc, 0x03d503d7, 0x0333033c, 0x03310330, 0x03350334, 0x030c030f, 0x03000303, + 0x03070301, 0x03050304, 0x031d031c, 0x03100313, 0x03140311, 0x0377037f, 0x034c0375, 0x03400343, + 0x03440341, 0x0353035c, 0x03550350, 0x00fd00fc, 0x00f000f3, 0x00f400f1, 0x00cc00cf, 0x00c300cd, + 0x00c100c0, 0x00c500c4, 0x00d300dc, 0x00d100d0, 0x003f00d4, 0x003d003c, 0x00300033, 0x00370031, + 0x000f0034, 0x000d000c, 0x00000003, 0x00070001, 0x00050004, 0x001c001f, 0x00100013, 0x00170011, + 0x00150014, 0x0073007c, 0x00740070, 0x004f0075, 0x0043004c, 0x00410040, 0x00440047, 0x0053005c, + 0x00510050, 0x01ff0054, 0x01fd01fc, 0x01f101f3, 0x01f401f7, 0x01c301cc, 0x01c701c0, 0x01df01c4, + 0x01dd01dc, 0x01d001d3, 0x01d701d1, 0x013c01d4, 0x01310130, 0x01340137, 0x010f0135, 0x010d010c, + 0x01000103, 0x01070101, 0x01050104, 0x0113011c, 0x01140110, 0x0170017d, 0x01770171, 0x01750174, + 0x0140014c, 0x015d0145, 0x01510150, 0x01540157, 0x07f007f3, 0x07f407f1, 0x07c007cf, 0x07dc07c7, + 0x073007d5, 0x07350737, 0x0703070c, 0x07010700, 0x07040707, 0x071d071f, 0x07100713, 0x0774077d, + 0x074d074f, 0x07470740, 0x0754075c, 0x04fd04fc, 0x04f504f0, 0x04c304cc, 0x04c104c0, 0x04d004c4, + 0x0433043c, 0x04310430, 0x040f0434, 0x040d040c, 0x04000403, 0x04070401, 0x04050404, 0x0413041c, + 0x04110410, 0x047c0414, 0x04740470, 0x0443044c, 0x04410440, 0x04440447, 0x05f30450, 0x05c005f7, + 0x05df05c5, 0x05d105d0, 0x053005d4, 0x05340537, 0x0500050c, 0x05070501, 0x051d0504, 0x05170510, + 0x057c0515, 0x054d0575, 0x05410540, 0x05450547, 0x1ff0055c, 0x1fc11fc3, 0x1fd01fc4, 0x1f0f1f33, + 0x1f011f00, 0x1f051f07, 0x1f131f1c, 0x1f141f11, 0x1f411f7c, 0x1cfc1f50, 0x1cf11cf3, 0x1ccd1cf4, + 0x1cdc1cc0, 0x1cd11cdd, 0x1c301cd4, 0x1c0c1c34, 0x1c011c00, 0x1c101c04, 0x1c151c11, 0x1c751c73, + 0x1c401c4d, 0x1c511c5c, 0x1dcc1c54, 0x1dc41dc1, 0x1d3c1d3f, 0x1d001d31, 0x1d071d01, 0x1d701d1f, + 0x1d411d4c, 0x13cc1d50, 0x13c013cd, 0x13c513c1, 0x13d113dc, 0x133f13d4, 0x1330133d, 0x13351337, + 0x1303130c, 0x13011300, 0x13051304, 0x131d131f, 0x13731310, 0x13741370, 0x134d134f, 0x13401343, + 0x13471341, 0x135c1345, 0x13541353, 0x10f710f0, 0x10cc10f5, 0x10c110c0, 0x103310c4, 0x10311030, + 0x100f1034, 0x1003100c, 0x10011000, 0x101c1004, 0x10101013, 0x10141011, 0x10741071, 0x104c1075, + 0x10411040, 0x10451044, 0x1050105d, 0x10571051, 0x11f411fd, 0x11df11c0, 0x11d711d1, 0x113f11d4, + 0x11371130, 0x110c1135, 0x11001103, 0x11071101, 0x111f1105, 0x11171110, 0x117d117f, 0x11751170, + 0x11411143, 0x11441147, 0x1153115f, 0x11551151, 0x17c417c1, 0x173c17d0, 0x1700170d, 0x171c1705, + 0x17701714, 0x1747174c, 0x14fc1751, 0x14cf14f3, 0x14dc14c0, 0x14d114d3, 0x143f14d4, 0x1430143c, + 0x14371431, 0x1403140c, 0x14011400, 0x141f1404, 0x14151410, 0x1473147d, 0x14401475, 0x1453145c, + 0x14541450, 0x15c115cc, 0x153c15c7, 0x15341533, 0x1500150f, 0x15051507, 0x15101513, 0x15711514, + 0x15471543, 0x15511545, 0x7ffd7fff, 0x7ff57ff7, 0x7fdd7fdf, 0x7fd57fd7, 0x7f0f7f30, 0x7f037f0c, + 0x7f047f01, 0x7f7f7f10, 0x7f777f7d, 0x7f407f75, 0x7f5d7f5f, 0x7f557f57, 0x7ccc7cf0, 0x7cc17cc3, + 0x7cd07cc4, 0x7c337c3c, 0x7c0f7c34, 0x7c007c0d, 0x7c077c01, 0x7c137c04, 0x7c147c11, 0x7c747c70, + 0x7c417c43, 0x7c507c44, 0x7dfd7dff, 0x7df57df7, 0x7ddf7dc0, 0x7dd77ddd, 0x7d0c7dd5, 0x7d047d03, + 0x7d7f7d10, 0x7d777d7d, 0x7d407d75, 0x7d5d7d5f, 0x7d557d57, 0x73c473c3, 0x7333733c, 0x7300730c, + 0x731c7305, 0x73147313, 0x73447343, 0x70f470fc, 0x70c070cd, 0x70d170c5, 0x703f70d4, 0x7030703c, + 0x700c7037, 0x70007003, 0x70047001, 0x70107005, 0x70177011, 0x707c7015, 0x70717073, 0x704f7074, + 0x7040704d, 0x70517047, 0x71c171cc, 0x71d071c4, 0x7133713c, 0x71357134, 0x7100710f, 0x71057104, + 0x7111711c, 0x71707115, 0x7145714c, 0x77ff7153, 0x77f777fd, 0x77c077f5, 0x77dd77df, 0x77d577d7, + 0x7730773c, 0x7703770c, 0x77107704, 0x777f7714, 0x7777777d, 0x77407775, 0x775d775f, 0x77557757, + 0x74f174f0, 0x74c374cc, 0x74d074c1, 0x7433743c, 0x74347431, 0x740d740f, 0x74057400, 0x7413741c, + 0x74417470, 0x74507444, 0x75fd75ff, 0x75f575f7, 0x75df75c0, 0x75d775dd, 0x753075d5, 0x7503750c, + 0x757f7501, 0x7577757d, 0x75407575, 0x755d755f, 0x75557557, 0x4fcc4ff0, 0x4fc74fc1, 0x4fd04fc4, + 0x4f314f3c, 0x4f004f34, 0x4f054f07, 0x4f154f14, 0x4f4c4f70, 0x4f414f43, 0x4f504f44, 0x4cf34cfc, + 0x4cf44cf1, 0x4cc04ccf, 0x4cc54cc7, 0x4cd34cdc, 0x4cd44cd1, 0x4c304c3f, 0x4c0c4c0f, 0x4c004c03, + 0x4c044c01, 0x4c104c1d, 0x4c714c73, 0x4c404c4d, 0x4c5c4c47, 0x4c514c53, 0x4df04c54, 0x4dc34dcc, + 0x4dd04dc4, 0x4d314d33, 0x4d0f4d34, 0x4d004d0d, 0x4d114d07, 0x4d704d14, 0x4d414d43, 0x43fc4d54, + 0x43f143f3, 0x43c043cf, 0x43d143c7, 0x4335433f, 0x4303430c, 0x43014300, 0x43044307, 0x431c431f, + 0x4310431d, 0x43714373, 0x4343434d, 0x43474340, 0x4354435c, 0x40f040ff, 0x40f540f7, 0x40cc40cf, + 0x40c040c3, 0x40c440c1, 0x40d040dc, 0x40d540d4, 0x4033403c, 0x40314030, 0x400f4034, 0x400d400c, + 0x40004003, 0x40074001, 0x40054004, 0x4013401c, 0x40114010, 0x407c4014, 0x40774070, 0x404d404c, + 0x40404043, 0x40444041, 0x405f4045, 0x4050405d, 0x40554057, 0x41f341fc, 0x41c041cf, 0x41df41c4, + 0x41d441d1, 0x41374130, 0x410c4134, 0x4100410d, 0x41044101, 0x41174110, 0x4173417d, 0x41754174, + 0x4143414d, 0x41534140, 0x41544151, 0x47c147f0, 0x47d047c4, 0x4731473c, 0x470d470f, 0x47014700, + 0x47134705, 0x47704710, 0x4741474c, 0x47504744, 0x44f144f3, 0x44cf44f4, 0x44c044cd, 0x44c544c7, + 0x44dc44df, 0x44d144d3, 0x443d443f, 0x44374430, 0x440c4435, 0x44004403, 0x44044401, 0x4410441d, + 0x44154411, 0x4473447c, 0x444d444f, 0x44454440, 0x4451445c, 0x45c045f0, 0x453345d0, 0x45344531, + 0x4500450f, 0x451c4507, 0x454c4570, 0x45404543, 0x5fff4541, 0x5ff75ffd, 0x5fc05ff5, 0x5fdd5fdf, + 0x5fd55fd7, 0x5f0c5f30, 0x5f015f03, 0x5f7f5f04, 0x5f775f7d, 0x5f405f75, 0x5f5d5f5f, 0x5f555f57, + 0x5cf45cf0, 0x5cc35ccc, 0x5cc45cc1, 0x5c315cc5, 0x5c0c5c34, 0x5c075c00, 0x5c1c5c05, 0x5c705c13, + 0x5c4d5c4f, 0x5c445c41, 0x5df75dfd, 0x5dcf5df5, 0x5ddd5dc4, 0x5dd55dd7, 0x5d0c5d30, 0x5d045d01, + 0x5d7f5d10, 0x5d775d7d, 0x5d405d75, 0x5d5d5d5f, 0x5d555d57, 0x53d053c4, 0x5333533c, 0x5303530f, + 0x53075300, 0x531c5305, 0x53115310, 0x53145317, 0x50f15370, 0x50cf50f4, 0x50c050cd, 0x50d150c7, + 0x503d50d4, 0x500c5030, 0x50005003, 0x50045001, 0x50155010, 0x5073507c, 0x50715070, 0x504d5074, + 0x50475040, 0x51cc51f0, 0x51c551c1, 0x51d051dc, 0x51315133, 0x510d5135, 0x51015100, 0x511f5107, + 0x5171511d, 0x5140514f, 0x51445141, 0x5153515c, 0x57ff5151, 0x57f757fd, 0x57df57f5, 0x57d757dd, + 0x570c57d5, 0x57015703, 0x577f5704, 0x5777577d, 0x57405775, 0x575d575f, 0x57555757, 0x54c354f0, + 0x54dc54c4, 0x543c54d0, 0x5400540f, 0x541c5405, 0x54145411, 0x5441544f, 0x55fd55ff, 0x55f555f7, + 0x55dd55df, 0x55d555d7, 0x5503550c, 0x557f5501, 0x5577557d, 0x55405575, 0x555d555f, 0x55555557 +}; + +shared uint16_t iq1s_grid[2048]; + +#define NEEDS_INIT_IQ_SHMEM +void init_iq_shmem(uvec3 wgsize) +{ + // copy the table into shared memory and sync + for (uint i = gl_LocalInvocationIndex.x; i < iq1s_grid_const.length(); i += wgsize.x) { + u16vec2 g = unpack16(iq1s_grid_const[i]); + iq1s_grid[2*i+0] = g.x; + iq1s_grid[2*i+1] = g.y; + } + barrier(); +} +#endif + +#define QUANT_K_IQ2_XXS 256 +#define QUANT_R_IQ2_XXS 1 + +struct block_iq2_xxs +{ + float16_t d; + uint8_t qs[QUANT_K_IQ2_XXS/4]; +}; + +struct block_iq2_xxs_packed16 +{ + float16_t d; + uint16_t qs[QUANT_K_IQ2_XXS/8]; +}; + +#if defined(DATA_A_IQ2_XXS) + +const uvec2[256] iq2xxs_grid_const = { + uvec2(0x08080808, 0x08080808), uvec2(0x0808082b, 0x08080808), uvec2(0x08081919, 0x08080808), uvec2(0x08082b08, 0x08080808), + uvec2(0x08082b2b, 0x08080808), uvec2(0x08190819, 0x08080808), uvec2(0x08191908, 0x08080808), uvec2(0x082b0808, 0x08080808), + uvec2(0x082b082b, 0x08080808), uvec2(0x082b2b08, 0x08080808), uvec2(0x082b2b2b, 0x08080808), uvec2(0x19080819, 0x08080808), + uvec2(0x19081908, 0x08080808), uvec2(0x19190808, 0x08080808), uvec2(0x19192b08, 0x08080808), uvec2(0x192b0819, 0x08080808), + uvec2(0x192b1908, 0x08080808), uvec2(0x2b080808, 0x08080808), uvec2(0x2b08082b, 0x08080808), uvec2(0x2b082b2b, 0x08080808), + uvec2(0x2b2b082b, 0x08080808), uvec2(0x08080819, 0x08080819), uvec2(0x08081908, 0x08080819), uvec2(0x08190808, 0x08080819), + uvec2(0x08191919, 0x08080819), uvec2(0x19080808, 0x08080819), uvec2(0x2b081908, 0x08080819), uvec2(0x2b192b08, 0x08080819), + uvec2(0x08080808, 0x0808082b), uvec2(0x0808082b, 0x0808082b), uvec2(0x082b082b, 0x0808082b), uvec2(0x2b08082b, 0x0808082b), + uvec2(0x08080819, 0x08081908), uvec2(0x08081908, 0x08081908), uvec2(0x08190808, 0x08081908), uvec2(0x082b0819, 0x08081908), + uvec2(0x082b1908, 0x08081908), uvec2(0x19080808, 0x08081908), uvec2(0x1908082b, 0x08081908), uvec2(0x19082b08, 0x08081908), + uvec2(0x192b0808, 0x08081908), uvec2(0x2b080819, 0x08081908), uvec2(0x2b081908, 0x08081908), uvec2(0x2b190808, 0x08081908), + uvec2(0x2b2b1908, 0x08081908), uvec2(0x08080808, 0x08081919), uvec2(0x0808082b, 0x08081919), uvec2(0x08082b08, 0x08081919), + uvec2(0x082b0808, 0x08081919), uvec2(0x1908192b, 0x08081919), uvec2(0x192b2b19, 0x08081919), uvec2(0x2b080808, 0x08081919), + uvec2(0x2b190819, 0x08081919), uvec2(0x08082b19, 0x0808192b), uvec2(0x08190808, 0x0808192b), uvec2(0x19080808, 0x0808192b), + uvec2(0x2b081908, 0x0808192b), uvec2(0x2b2b1908, 0x0808192b), uvec2(0x08080808, 0x08082b08), uvec2(0x08081919, 0x08082b08), + uvec2(0x08082b08, 0x08082b08), uvec2(0x08191908, 0x08082b08), uvec2(0x082b2b08, 0x08082b08), uvec2(0x19080819, 0x08082b08), + uvec2(0x19081908, 0x08082b08), uvec2(0x19190808, 0x08082b08), uvec2(0x1919082b, 0x08082b08), uvec2(0x2b082b08, 0x08082b08), + uvec2(0x08081908, 0x08082b19), uvec2(0x19080808, 0x08082b19), uvec2(0x0808082b, 0x08082b2b), uvec2(0x08191908, 0x08082b2b), + uvec2(0x08080819, 0x08190808), uvec2(0x08081908, 0x08190808), uvec2(0x08190808, 0x08190808), uvec2(0x082b0819, 0x08190808), + uvec2(0x19080808, 0x08190808), uvec2(0x192b0808, 0x08190808), uvec2(0x2b081908, 0x08190808), uvec2(0x2b190808, 0x08190808), + uvec2(0x2b191919, 0x08190808), uvec2(0x08080808, 0x08190819), uvec2(0x08082b08, 0x08190819), uvec2(0x082b0808, 0x08190819), + uvec2(0x19190808, 0x08190819), uvec2(0x19192b2b, 0x08190819), uvec2(0x2b080808, 0x08190819), uvec2(0x082b1908, 0x0819082b), + uvec2(0x19081919, 0x0819082b), uvec2(0x08080808, 0x08191908), uvec2(0x08082b08, 0x08191908), uvec2(0x082b0808, 0x08191908), + uvec2(0x082b1919, 0x08191908), uvec2(0x19082b19, 0x08191908), uvec2(0x2b080808, 0x08191908), uvec2(0x08192b08, 0x08191919), + uvec2(0x192b082b, 0x08191919), uvec2(0x08080808, 0x0819192b), uvec2(0x0819192b, 0x0819192b), uvec2(0x08080819, 0x08192b08), + uvec2(0x08081908, 0x08192b08), uvec2(0x08190808, 0x08192b08), uvec2(0x19080808, 0x08192b08), uvec2(0x2b080819, 0x08192b08), + uvec2(0x08080808, 0x08192b19), uvec2(0x08081919, 0x08192b19), uvec2(0x2b2b0808, 0x08192b19), uvec2(0x19190819, 0x08192b2b), + uvec2(0x08080808, 0x082b0808), uvec2(0x0808082b, 0x082b0808), uvec2(0x08082b2b, 0x082b0808), uvec2(0x19081908, 0x082b0808), + uvec2(0x192b0819, 0x082b0808), uvec2(0x2b080808, 0x082b0808), uvec2(0x2b08082b, 0x082b0808), uvec2(0x082b2b19, 0x082b0819), + uvec2(0x19082b08, 0x082b0819), uvec2(0x08080808, 0x082b082b), uvec2(0x0808082b, 0x082b082b), uvec2(0x08080819, 0x082b1908), + uvec2(0x08081908, 0x082b1908), uvec2(0x08190808, 0x082b1908), uvec2(0x19080808, 0x082b1908), uvec2(0x1919192b, 0x082b1908), + uvec2(0x08080808, 0x082b1919), uvec2(0x19080819, 0x082b1919), uvec2(0x192b1908, 0x082b1919), uvec2(0x2b190808, 0x082b192b), + uvec2(0x08082b08, 0x082b2b08), uvec2(0x082b0808, 0x082b2b08), uvec2(0x2b191908, 0x082b2b08), uvec2(0x19081908, 0x082b2b2b), + uvec2(0x08080819, 0x19080808), uvec2(0x08081908, 0x19080808), uvec2(0x08190808, 0x19080808), uvec2(0x08192b08, 0x19080808), + uvec2(0x082b0819, 0x19080808), uvec2(0x082b1908, 0x19080808), uvec2(0x19080808, 0x19080808), uvec2(0x19082b08, 0x19080808), + uvec2(0x1919192b, 0x19080808), uvec2(0x192b0808, 0x19080808), uvec2(0x2b080819, 0x19080808), uvec2(0x2b081908, 0x19080808), + uvec2(0x2b190808, 0x19080808), uvec2(0x08080808, 0x19080819), uvec2(0x082b0808, 0x19080819), uvec2(0x192b0819, 0x19080819), + uvec2(0x2b080808, 0x19080819), uvec2(0x2b081919, 0x19080819), uvec2(0x08080819, 0x1908082b), uvec2(0x08190808, 0x1908082b), + uvec2(0x19082b08, 0x1908082b), uvec2(0x1919192b, 0x1908082b), uvec2(0x192b2b08, 0x1908082b), uvec2(0x08080808, 0x19081908), + uvec2(0x08082b08, 0x19081908), uvec2(0x082b0808, 0x19081908), uvec2(0x2b080808, 0x19081908), uvec2(0x2b192b19, 0x19081908), + uvec2(0x0819082b, 0x19081919), uvec2(0x082b1908, 0x19081919), uvec2(0x08080808, 0x1908192b), uvec2(0x08080819, 0x19082b08), + uvec2(0x08081908, 0x19082b08), uvec2(0x08190808, 0x19082b08), uvec2(0x19080808, 0x19082b08), uvec2(0x19081919, 0x19082b08), + uvec2(0x08080808, 0x19082b19), uvec2(0x19192b08, 0x19082b19), uvec2(0x192b0819, 0x19082b19), uvec2(0x2b08082b, 0x19082b19), + uvec2(0x19081919, 0x19082b2b), uvec2(0x2b190808, 0x19082b2b), uvec2(0x08080808, 0x19190808), uvec2(0x08082b08, 0x19190808), + uvec2(0x08190819, 0x19190808), uvec2(0x08192b19, 0x19190808), uvec2(0x082b0808, 0x19190808), uvec2(0x2b080808, 0x19190808), + uvec2(0x2b082b08, 0x19190808), uvec2(0x08081908, 0x19190819), uvec2(0x1908082b, 0x19190819), uvec2(0x2b2b1908, 0x19190819), + uvec2(0x2b190819, 0x1919082b), uvec2(0x2b190808, 0x19191908), uvec2(0x2b19082b, 0x19191908), uvec2(0x08082b2b, 0x19191919), + uvec2(0x08080819, 0x1919192b), uvec2(0x19191908, 0x1919192b), uvec2(0x08080808, 0x19192b08), uvec2(0x08190819, 0x19192b08), + uvec2(0x08192b19, 0x19192b08), uvec2(0x192b1908, 0x19192b08), uvec2(0x19080808, 0x19192b19), uvec2(0x08082b08, 0x19192b2b), + uvec2(0x08081908, 0x192b0808), uvec2(0x08190808, 0x192b0808), uvec2(0x19080808, 0x192b0808), uvec2(0x192b2b08, 0x192b0808), + uvec2(0x08080808, 0x192b0819), uvec2(0x19191919, 0x192b0819), uvec2(0x08192b08, 0x192b082b), uvec2(0x192b0808, 0x192b082b), + uvec2(0x08080808, 0x192b1908), uvec2(0x08081919, 0x192b1908), uvec2(0x08190808, 0x192b1919), uvec2(0x0819082b, 0x192b1919), + uvec2(0x2b081908, 0x192b1919), uvec2(0x1908082b, 0x192b2b08), uvec2(0x08080808, 0x2b080808), uvec2(0x0808082b, 0x2b080808), + uvec2(0x08082b2b, 0x2b080808), uvec2(0x19080819, 0x2b080808), uvec2(0x2b08082b, 0x2b080808), uvec2(0x08081908, 0x2b080819), + uvec2(0x08192b08, 0x2b080819), uvec2(0x19080808, 0x2b080819), uvec2(0x08190819, 0x2b08082b), uvec2(0x08080819, 0x2b081908), + uvec2(0x08081908, 0x2b081908), uvec2(0x08190808, 0x2b081908), uvec2(0x08191919, 0x2b081908), uvec2(0x19080808, 0x2b081908), + uvec2(0x192b0808, 0x2b081908), uvec2(0x08080808, 0x2b081919), uvec2(0x1908192b, 0x2b081919), uvec2(0x2b191908, 0x2b081919), + uvec2(0x08082b19, 0x2b08192b), uvec2(0x19080808, 0x2b08192b), uvec2(0x192b0808, 0x2b08192b), uvec2(0x0808082b, 0x2b082b08), + uvec2(0x08081908, 0x2b082b19), uvec2(0x08190819, 0x2b082b2b), uvec2(0x08081908, 0x2b190808), uvec2(0x08190808, 0x2b190808), + uvec2(0x082b1908, 0x2b190808), uvec2(0x19080808, 0x2b190808), uvec2(0x2b2b0819, 0x2b190808), uvec2(0x0819192b, 0x2b190819), + uvec2(0x2b080808, 0x2b190819), uvec2(0x19081919, 0x2b19082b), uvec2(0x08080808, 0x2b191908), uvec2(0x082b082b, 0x2b191908), + uvec2(0x19081908, 0x2b191908), uvec2(0x19190819, 0x2b191919), uvec2(0x2b080819, 0x2b192b08), uvec2(0x082b0808, 0x2b192b19), + uvec2(0x0808082b, 0x2b2b0808), uvec2(0x19190808, 0x2b2b0808), uvec2(0x2b081919, 0x2b2b0808), uvec2(0x08082b19, 0x2b2b0819), + uvec2(0x08080808, 0x2b2b082b), uvec2(0x08192b08, 0x2b2b1908), uvec2(0x19190808, 0x2b2b2b08), uvec2(0x08081908, 0x2b2b2b19) +}; + +shared uvec2 iq2xxs_grid[256]; + +#define NEEDS_INIT_IQ_SHMEM +void init_iq_shmem(uvec3 wgsize) +{ + // copy the table into shared memory and sync + for (uint i = gl_LocalInvocationIndex.x; i < iq2xxs_grid.length(); i += wgsize.x) { + iq2xxs_grid[i] = iq2xxs_grid_const[i]; + } + barrier(); +} + +#define QUANT_K QUANT_K_IQ2_XXS +#define QUANT_R QUANT_R_IQ2_XXS +#define A_TYPE block_iq2_xxs +#define A_TYPE_PACKED16 block_iq2_xxs_packed16 +#endif + +#define QUANT_K_IQ2_XS 256 +#define QUANT_R_IQ2_XS 1 + +struct block_iq2_xs +{ + float16_t d; + uint16_t qs[QUANT_K_IQ2_XS/8]; + uint8_t scales[QUANT_K_IQ2_XS/32]; +}; + +struct block_iq2_xs_packed16 +{ + float16_t d; + uint16_t qs[QUANT_K_IQ2_XS/8]; + uint16_t scales[QUANT_K_IQ2_XS/64]; +}; + +#if defined(DATA_A_IQ2_XS) + +const uvec2 iq2xs_grid_const[512] = { + uvec2(0x08080808, 0x08080808), uvec2(0x0808082b, 0x08080808), uvec2(0x08081919, 0x08080808), uvec2(0x08082b08, 0x08080808), + uvec2(0x08082b2b, 0x08080808), uvec2(0x08190819, 0x08080808), uvec2(0x08191908, 0x08080808), uvec2(0x0819192b, 0x08080808), + uvec2(0x08192b19, 0x08080808), uvec2(0x082b0808, 0x08080808), uvec2(0x082b082b, 0x08080808), uvec2(0x082b1919, 0x08080808), + uvec2(0x082b2b08, 0x08080808), uvec2(0x19080819, 0x08080808), uvec2(0x19081908, 0x08080808), uvec2(0x1908192b, 0x08080808), + uvec2(0x19082b19, 0x08080808), uvec2(0x19190808, 0x08080808), uvec2(0x1919082b, 0x08080808), uvec2(0x19191919, 0x08080808), + uvec2(0x19192b08, 0x08080808), uvec2(0x192b0819, 0x08080808), uvec2(0x192b1908, 0x08080808), uvec2(0x2b080808, 0x08080808), + uvec2(0x2b08082b, 0x08080808), uvec2(0x2b081919, 0x08080808), uvec2(0x2b082b08, 0x08080808), uvec2(0x2b190819, 0x08080808), + uvec2(0x2b191908, 0x08080808), uvec2(0x2b192b19, 0x08080808), uvec2(0x2b2b0808, 0x08080808), uvec2(0x08080819, 0x08080819), + uvec2(0x08081908, 0x08080819), uvec2(0x0808192b, 0x08080819), uvec2(0x08082b19, 0x08080819), uvec2(0x08190808, 0x08080819), + uvec2(0x0819082b, 0x08080819), uvec2(0x08191919, 0x08080819), uvec2(0x08192b08, 0x08080819), uvec2(0x08192b2b, 0x08080819), + uvec2(0x082b0819, 0x08080819), uvec2(0x082b1908, 0x08080819), uvec2(0x19080808, 0x08080819), uvec2(0x1908082b, 0x08080819), + uvec2(0x19081919, 0x08080819), uvec2(0x19082b08, 0x08080819), uvec2(0x19190819, 0x08080819), uvec2(0x19191908, 0x08080819), + uvec2(0x192b0808, 0x08080819), uvec2(0x192b2b08, 0x08080819), uvec2(0x2b080819, 0x08080819), uvec2(0x2b081908, 0x08080819), + uvec2(0x2b190808, 0x08080819), uvec2(0x08080808, 0x0808082b), uvec2(0x0808082b, 0x0808082b), uvec2(0x08081919, 0x0808082b), + uvec2(0x08082b08, 0x0808082b), uvec2(0x08190819, 0x0808082b), uvec2(0x08191908, 0x0808082b), uvec2(0x082b0808, 0x0808082b), + uvec2(0x19080819, 0x0808082b), uvec2(0x19081908, 0x0808082b), uvec2(0x19190808, 0x0808082b), uvec2(0x19191919, 0x0808082b), + uvec2(0x2b080808, 0x0808082b), uvec2(0x2b082b2b, 0x0808082b), uvec2(0x08080819, 0x08081908), uvec2(0x08081908, 0x08081908), + uvec2(0x0808192b, 0x08081908), uvec2(0x08082b19, 0x08081908), uvec2(0x08190808, 0x08081908), uvec2(0x0819082b, 0x08081908), + uvec2(0x08191919, 0x08081908), uvec2(0x08192b08, 0x08081908), uvec2(0x082b0819, 0x08081908), uvec2(0x082b1908, 0x08081908), + uvec2(0x19080808, 0x08081908), uvec2(0x1908082b, 0x08081908), uvec2(0x19081919, 0x08081908), uvec2(0x19082b08, 0x08081908), + uvec2(0x19190819, 0x08081908), uvec2(0x19191908, 0x08081908), uvec2(0x1919192b, 0x08081908), uvec2(0x192b0808, 0x08081908), + uvec2(0x2b080819, 0x08081908), uvec2(0x2b081908, 0x08081908), uvec2(0x2b190808, 0x08081908), uvec2(0x08080808, 0x08081919), + uvec2(0x0808082b, 0x08081919), uvec2(0x08081919, 0x08081919), uvec2(0x08082b08, 0x08081919), uvec2(0x08190819, 0x08081919), + uvec2(0x08191908, 0x08081919), uvec2(0x082b0808, 0x08081919), uvec2(0x19080819, 0x08081919), uvec2(0x19081908, 0x08081919), + uvec2(0x19190808, 0x08081919), uvec2(0x192b0819, 0x08081919), uvec2(0x2b080808, 0x08081919), uvec2(0x08080819, 0x0808192b), + uvec2(0x08081908, 0x0808192b), uvec2(0x08190808, 0x0808192b), uvec2(0x082b192b, 0x0808192b), uvec2(0x19080808, 0x0808192b), + uvec2(0x1908082b, 0x0808192b), uvec2(0x2b081908, 0x0808192b), uvec2(0x08080808, 0x08082b08), uvec2(0x0808082b, 0x08082b08), + uvec2(0x08081919, 0x08082b08), uvec2(0x08082b08, 0x08082b08), uvec2(0x08082b2b, 0x08082b08), uvec2(0x08190819, 0x08082b08), + uvec2(0x08191908, 0x08082b08), uvec2(0x082b0808, 0x08082b08), uvec2(0x082b1919, 0x08082b08), uvec2(0x19080819, 0x08082b08), + uvec2(0x19081908, 0x08082b08), uvec2(0x19190808, 0x08082b08), uvec2(0x19192b08, 0x08082b08), uvec2(0x2b080808, 0x08082b08), + uvec2(0x2b2b0808, 0x08082b08), uvec2(0x2b2b2b2b, 0x08082b08), uvec2(0x08080819, 0x08082b19), uvec2(0x08081908, 0x08082b19), + uvec2(0x08190808, 0x08082b19), uvec2(0x19080808, 0x08082b19), uvec2(0x2b080819, 0x08082b19), uvec2(0x2b082b19, 0x08082b19), + uvec2(0x08080808, 0x08082b2b), uvec2(0x082b0808, 0x08082b2b), uvec2(0x082b2b08, 0x08082b2b), uvec2(0x2b19192b, 0x08082b2b), + uvec2(0x2b2b0808, 0x08082b2b), uvec2(0x08080819, 0x08190808), uvec2(0x08081908, 0x08190808), uvec2(0x0808192b, 0x08190808), + uvec2(0x08082b19, 0x08190808), uvec2(0x08190808, 0x08190808), uvec2(0x0819082b, 0x08190808), uvec2(0x08191919, 0x08190808), + uvec2(0x08192b08, 0x08190808), uvec2(0x082b0819, 0x08190808), uvec2(0x082b1908, 0x08190808), uvec2(0x19080808, 0x08190808), + uvec2(0x1908082b, 0x08190808), uvec2(0x19081919, 0x08190808), uvec2(0x19082b08, 0x08190808), uvec2(0x19190819, 0x08190808), + uvec2(0x19191908, 0x08190808), uvec2(0x192b0808, 0x08190808), uvec2(0x192b2b2b, 0x08190808), uvec2(0x2b080819, 0x08190808), + uvec2(0x2b081908, 0x08190808), uvec2(0x2b190808, 0x08190808), uvec2(0x08080808, 0x08190819), uvec2(0x0808082b, 0x08190819), + uvec2(0x08081919, 0x08190819), uvec2(0x08082b08, 0x08190819), uvec2(0x08190819, 0x08190819), uvec2(0x08191908, 0x08190819), + uvec2(0x082b0808, 0x08190819), uvec2(0x19080819, 0x08190819), uvec2(0x19081908, 0x08190819), uvec2(0x19190808, 0x08190819), + uvec2(0x2b080808, 0x08190819), uvec2(0x2b191908, 0x08190819), uvec2(0x2b19192b, 0x08190819), uvec2(0x08080819, 0x0819082b), + uvec2(0x08081908, 0x0819082b), uvec2(0x0808192b, 0x0819082b), uvec2(0x08190808, 0x0819082b), uvec2(0x19080808, 0x0819082b), + uvec2(0x192b0808, 0x0819082b), uvec2(0x08080808, 0x08191908), uvec2(0x0808082b, 0x08191908), uvec2(0x08081919, 0x08191908), + uvec2(0x08082b08, 0x08191908), uvec2(0x08190819, 0x08191908), uvec2(0x08191908, 0x08191908), uvec2(0x082b0808, 0x08191908), + uvec2(0x19080819, 0x08191908), uvec2(0x19081908, 0x08191908), uvec2(0x19082b19, 0x08191908), uvec2(0x19190808, 0x08191908), + uvec2(0x192b1908, 0x08191908), uvec2(0x2b080808, 0x08191908), uvec2(0x08080819, 0x08191919), uvec2(0x08081908, 0x08191919), + uvec2(0x08190808, 0x08191919), uvec2(0x19080808, 0x08191919), uvec2(0x08080808, 0x0819192b), uvec2(0x08191908, 0x0819192b), + uvec2(0x19082b19, 0x0819192b), uvec2(0x08080819, 0x08192b08), uvec2(0x08081908, 0x08192b08), uvec2(0x08190808, 0x08192b08), + uvec2(0x0819082b, 0x08192b08), uvec2(0x19080808, 0x08192b08), uvec2(0x19191908, 0x08192b08), uvec2(0x2b08192b, 0x08192b08), + uvec2(0x08080808, 0x08192b19), uvec2(0x08081919, 0x08192b19), uvec2(0x192b192b, 0x08192b19), uvec2(0x19190819, 0x08192b2b), + uvec2(0x2b2b2b19, 0x08192b2b), uvec2(0x08080808, 0x082b0808), uvec2(0x0808082b, 0x082b0808), uvec2(0x08081919, 0x082b0808), + uvec2(0x08082b08, 0x082b0808), uvec2(0x08082b2b, 0x082b0808), uvec2(0x08190819, 0x082b0808), uvec2(0x08191908, 0x082b0808), + uvec2(0x082b0808, 0x082b0808), uvec2(0x19080819, 0x082b0808), uvec2(0x19081908, 0x082b0808), uvec2(0x19190808, 0x082b0808), + uvec2(0x2b080808, 0x082b0808), uvec2(0x2b2b0808, 0x082b0808), uvec2(0x08080819, 0x082b0819), uvec2(0x08081908, 0x082b0819), + uvec2(0x08190808, 0x082b0819), uvec2(0x19080808, 0x082b0819), uvec2(0x19082b08, 0x082b0819), uvec2(0x192b1919, 0x082b0819), + uvec2(0x08080808, 0x082b082b), uvec2(0x082b082b, 0x082b082b), uvec2(0x2b080808, 0x082b082b), uvec2(0x2b2b2b08, 0x082b082b), + uvec2(0x08080819, 0x082b1908), uvec2(0x08081908, 0x082b1908), uvec2(0x08190808, 0x082b1908), uvec2(0x082b2b19, 0x082b1908), + uvec2(0x19080808, 0x082b1908), uvec2(0x08080808, 0x082b1919), uvec2(0x19080819, 0x082b1919), uvec2(0x1919082b, 0x082b1919), + uvec2(0x2b192b19, 0x082b1919), uvec2(0x08080819, 0x082b192b), uvec2(0x08192b2b, 0x082b192b), uvec2(0x2b2b192b, 0x082b192b), + uvec2(0x08080808, 0x082b2b08), uvec2(0x08082b08, 0x082b2b08), uvec2(0x08082b2b, 0x082b2b08), uvec2(0x082b0808, 0x082b2b08), + uvec2(0x19191919, 0x082b2b08), uvec2(0x2b082b08, 0x082b2b08), uvec2(0x2b2b082b, 0x082b2b08), uvec2(0x192b2b08, 0x082b2b19), + uvec2(0x2b190808, 0x082b2b19), uvec2(0x08082b08, 0x082b2b2b), uvec2(0x082b0808, 0x082b2b2b), uvec2(0x2b08082b, 0x082b2b2b), + uvec2(0x2b082b08, 0x082b2b2b), uvec2(0x2b082b2b, 0x082b2b2b), uvec2(0x08080819, 0x19080808), uvec2(0x08081908, 0x19080808), + uvec2(0x0808192b, 0x19080808), uvec2(0x08082b19, 0x19080808), uvec2(0x08190808, 0x19080808), uvec2(0x0819082b, 0x19080808), + uvec2(0x08191919, 0x19080808), uvec2(0x08192b08, 0x19080808), uvec2(0x082b0819, 0x19080808), uvec2(0x082b1908, 0x19080808), + uvec2(0x19080808, 0x19080808), uvec2(0x1908082b, 0x19080808), uvec2(0x19081919, 0x19080808), uvec2(0x19082b08, 0x19080808), + uvec2(0x19082b2b, 0x19080808), uvec2(0x19190819, 0x19080808), uvec2(0x19191908, 0x19080808), uvec2(0x192b0808, 0x19080808), + uvec2(0x192b1919, 0x19080808), uvec2(0x2b080819, 0x19080808), uvec2(0x2b081908, 0x19080808), uvec2(0x2b190808, 0x19080808), + uvec2(0x08080808, 0x19080819), uvec2(0x0808082b, 0x19080819), uvec2(0x08081919, 0x19080819), uvec2(0x08082b08, 0x19080819), + uvec2(0x08190819, 0x19080819), uvec2(0x08191908, 0x19080819), uvec2(0x082b0808, 0x19080819), uvec2(0x19080819, 0x19080819), + uvec2(0x19081908, 0x19080819), uvec2(0x19190808, 0x19080819), uvec2(0x2b080808, 0x19080819), uvec2(0x2b081919, 0x19080819), + uvec2(0x2b2b082b, 0x19080819), uvec2(0x08080819, 0x1908082b), uvec2(0x08081908, 0x1908082b), uvec2(0x08190808, 0x1908082b), + uvec2(0x0819082b, 0x1908082b), uvec2(0x082b2b19, 0x1908082b), uvec2(0x19080808, 0x1908082b), uvec2(0x08080808, 0x19081908), + uvec2(0x0808082b, 0x19081908), uvec2(0x08081919, 0x19081908), uvec2(0x08082b08, 0x19081908), uvec2(0x08190819, 0x19081908), + uvec2(0x08191908, 0x19081908), uvec2(0x08192b19, 0x19081908), uvec2(0x082b0808, 0x19081908), uvec2(0x19080819, 0x19081908), + uvec2(0x19081908, 0x19081908), uvec2(0x19190808, 0x19081908), uvec2(0x2b080808, 0x19081908), uvec2(0x2b191908, 0x19081908), + uvec2(0x08080819, 0x19081919), uvec2(0x08081908, 0x19081919), uvec2(0x08190808, 0x19081919), uvec2(0x082b1908, 0x19081919), + uvec2(0x19080808, 0x19081919), uvec2(0x2b192b2b, 0x19081919), uvec2(0x08080808, 0x1908192b), uvec2(0x08082b2b, 0x1908192b), + uvec2(0x19081908, 0x1908192b), uvec2(0x19190808, 0x1908192b), uvec2(0x08080819, 0x19082b08), uvec2(0x08081908, 0x19082b08), + uvec2(0x08190808, 0x19082b08), uvec2(0x19080808, 0x19082b08), uvec2(0x19081919, 0x19082b08), uvec2(0x19191908, 0x19082b08), + uvec2(0x192b082b, 0x19082b08), uvec2(0x08080808, 0x19082b19), uvec2(0x08190819, 0x19082b19), uvec2(0x19081908, 0x19082b19), + uvec2(0x19190808, 0x19082b19), uvec2(0x192b2b19, 0x19082b19), uvec2(0x08081908, 0x19082b2b), uvec2(0x08080808, 0x19190808), + uvec2(0x0808082b, 0x19190808), uvec2(0x08081919, 0x19190808), uvec2(0x08082b08, 0x19190808), uvec2(0x08190819, 0x19190808), + uvec2(0x08191908, 0x19190808), uvec2(0x082b0808, 0x19190808), uvec2(0x082b2b08, 0x19190808), uvec2(0x19080819, 0x19190808), + uvec2(0x19081908, 0x19190808), uvec2(0x19190808, 0x19190808), uvec2(0x2b080808, 0x19190808), uvec2(0x08080819, 0x19190819), + uvec2(0x08081908, 0x19190819), uvec2(0x08190808, 0x19190819), uvec2(0x08191919, 0x19190819), uvec2(0x19080808, 0x19190819), + uvec2(0x1908082b, 0x19190819), uvec2(0x08080808, 0x1919082b), uvec2(0x19081908, 0x1919082b), uvec2(0x2b2b2b2b, 0x1919082b), + uvec2(0x08080819, 0x19191908), uvec2(0x08081908, 0x19191908), uvec2(0x08190808, 0x19191908), uvec2(0x082b0819, 0x19191908), + uvec2(0x19080808, 0x19191908), uvec2(0x192b0808, 0x19191908), uvec2(0x2b080819, 0x19191908), uvec2(0x2b2b0819, 0x19191908), + uvec2(0x08080808, 0x19191919), uvec2(0x08082b08, 0x19191919), uvec2(0x2b080808, 0x19191919), uvec2(0x2b082b08, 0x19191919), + uvec2(0x082b0819, 0x1919192b), uvec2(0x192b2b08, 0x1919192b), uvec2(0x2b2b0819, 0x1919192b), uvec2(0x08080808, 0x19192b08), + uvec2(0x08191908, 0x19192b08), uvec2(0x19080819, 0x19192b08), uvec2(0x19190808, 0x19192b08), uvec2(0x2b192b19, 0x19192b08), + uvec2(0x08192b2b, 0x19192b19), uvec2(0x19080808, 0x19192b19), uvec2(0x1908082b, 0x19192b19), uvec2(0x2b081919, 0x19192b2b), + uvec2(0x08080819, 0x192b0808), uvec2(0x08081908, 0x192b0808), uvec2(0x08190808, 0x192b0808), uvec2(0x19080808, 0x192b0808), + uvec2(0x19191908, 0x192b0808), uvec2(0x192b082b, 0x192b0808), uvec2(0x2b08192b, 0x192b0808), uvec2(0x2b2b2b19, 0x192b0808), + uvec2(0x08080808, 0x192b0819), uvec2(0x082b1908, 0x192b082b), uvec2(0x19082b2b, 0x192b082b), uvec2(0x2b19082b, 0x192b082b), + uvec2(0x08080808, 0x192b1908), uvec2(0x0819192b, 0x192b1908), uvec2(0x08190808, 0x192b1919), uvec2(0x19080808, 0x192b1919), + uvec2(0x19081919, 0x192b1919), uvec2(0x2b2b1908, 0x192b1919), uvec2(0x08080819, 0x192b2b08), uvec2(0x192b2b2b, 0x192b2b08), + uvec2(0x082b1919, 0x192b2b19), uvec2(0x0808192b, 0x192b2b2b), uvec2(0x19191908, 0x192b2b2b), uvec2(0x192b082b, 0x192b2b2b), + uvec2(0x08080808, 0x2b080808), uvec2(0x0808082b, 0x2b080808), uvec2(0x08081919, 0x2b080808), uvec2(0x08082b08, 0x2b080808), + uvec2(0x08190819, 0x2b080808), uvec2(0x08191908, 0x2b080808), uvec2(0x082b0808, 0x2b080808), uvec2(0x082b2b2b, 0x2b080808), + uvec2(0x19080819, 0x2b080808), uvec2(0x19081908, 0x2b080808), uvec2(0x19190808, 0x2b080808), uvec2(0x2b080808, 0x2b080808), + uvec2(0x2b08082b, 0x2b080808), uvec2(0x2b2b2b08, 0x2b080808), uvec2(0x2b2b2b2b, 0x2b080808), uvec2(0x08080819, 0x2b080819), + uvec2(0x08081908, 0x2b080819), uvec2(0x0808192b, 0x2b080819), uvec2(0x08190808, 0x2b080819), uvec2(0x19080808, 0x2b080819), + uvec2(0x19190819, 0x2b080819), uvec2(0x19192b19, 0x2b080819), uvec2(0x08080808, 0x2b08082b), uvec2(0x082b0808, 0x2b08082b), + uvec2(0x2b080808, 0x2b08082b), uvec2(0x2b08082b, 0x2b08082b), uvec2(0x2b2b0808, 0x2b08082b), uvec2(0x2b2b2b08, 0x2b08082b), + uvec2(0x08080819, 0x2b081908), uvec2(0x08081908, 0x2b081908), uvec2(0x08190808, 0x2b081908), uvec2(0x0819082b, 0x2b081908), + uvec2(0x08191919, 0x2b081908), uvec2(0x19080808, 0x2b081908), uvec2(0x192b0808, 0x2b081908), uvec2(0x2b082b19, 0x2b081908), + uvec2(0x08080808, 0x2b081919), uvec2(0x19081908, 0x2b081919), uvec2(0x2b2b1919, 0x2b081919), uvec2(0x08192b08, 0x2b08192b), + uvec2(0x192b2b2b, 0x2b08192b), uvec2(0x08080808, 0x2b082b08), uvec2(0x08082b08, 0x2b082b08), uvec2(0x082b1919, 0x2b082b08), + uvec2(0x19192b2b, 0x2b082b08), uvec2(0x2b080808, 0x2b082b08), uvec2(0x2b08082b, 0x2b082b08), uvec2(0x2b2b2b08, 0x2b082b08), + uvec2(0x0808192b, 0x2b082b19), uvec2(0x082b082b, 0x2b082b2b), uvec2(0x2b080808, 0x2b082b2b), uvec2(0x2b082b08, 0x2b082b2b), + uvec2(0x2b19192b, 0x2b082b2b), uvec2(0x2b2b2b08, 0x2b082b2b), uvec2(0x08080819, 0x2b190808), uvec2(0x08081908, 0x2b190808), + uvec2(0x08190808, 0x2b190808), uvec2(0x19080808, 0x2b190808), uvec2(0x1919192b, 0x2b190808), uvec2(0x2b081908, 0x2b190808), + uvec2(0x08080808, 0x2b190819), uvec2(0x082b082b, 0x2b190819), uvec2(0x192b1908, 0x2b190819), uvec2(0x1919192b, 0x2b19082b), + uvec2(0x2b082b19, 0x2b19082b), uvec2(0x08080808, 0x2b191908), uvec2(0x08081919, 0x2b191908), uvec2(0x19081908, 0x2b191908), + uvec2(0x19190808, 0x2b191908), uvec2(0x19192b08, 0x2b191908), uvec2(0x082b2b19, 0x2b191919), uvec2(0x2b190808, 0x2b191919), + uvec2(0x2b19082b, 0x2b191919), uvec2(0x19080819, 0x2b19192b), uvec2(0x19190819, 0x2b192b08), uvec2(0x2b2b192b, 0x2b192b08), + uvec2(0x19082b19, 0x2b192b19), uvec2(0x08191919, 0x2b192b2b), uvec2(0x192b0808, 0x2b192b2b), uvec2(0x08080808, 0x2b2b0808), + uvec2(0x0808082b, 0x2b2b0808), uvec2(0x08082b08, 0x2b2b0808), uvec2(0x08082b2b, 0x2b2b0808), uvec2(0x082b0808, 0x2b2b0808), + uvec2(0x082b2b2b, 0x2b2b0808), uvec2(0x2b2b0808, 0x2b2b0808), uvec2(0x19190819, 0x2b2b0819), uvec2(0x19192b19, 0x2b2b0819), + uvec2(0x2b2b192b, 0x2b2b0819), uvec2(0x08080808, 0x2b2b082b), uvec2(0x0808082b, 0x2b2b082b), uvec2(0x08082b08, 0x2b2b082b), + uvec2(0x082b2b2b, 0x2b2b082b), uvec2(0x2b080808, 0x2b2b082b), uvec2(0x2b2b0808, 0x2b2b082b), uvec2(0x19080808, 0x2b2b1908), + uvec2(0x2b191919, 0x2b2b1908), uvec2(0x192b1919, 0x2b2b192b), uvec2(0x2b192b08, 0x2b2b192b), uvec2(0x08082b2b, 0x2b2b2b08), + uvec2(0x082b0808, 0x2b2b2b08), uvec2(0x082b082b, 0x2b2b2b08), uvec2(0x082b2b08, 0x2b2b2b08), uvec2(0x2b2b0808, 0x2b2b2b08), + uvec2(0x2b2b2b08, 0x2b2b2b08), uvec2(0x08081908, 0x2b2b2b19), uvec2(0x2b081908, 0x2b2b2b19), uvec2(0x2b08192b, 0x2b2b2b19), + uvec2(0x082b2b08, 0x2b2b2b2b), uvec2(0x082b2b2b, 0x2b2b2b2b), uvec2(0x2b190819, 0x2b2b2b2b), uvec2(0x2b2b2b2b, 0x2b2b2b2b), +}; + +shared uvec2 iq2xs_grid[512]; + +#define NEEDS_INIT_IQ_SHMEM +void init_iq_shmem(uvec3 wgsize) +{ + // copy the table into shared memory and sync + for (uint i = gl_LocalInvocationIndex.x; i < iq2xs_grid.length(); i += wgsize.x) { + iq2xs_grid[i] = iq2xs_grid_const[i]; + } + barrier(); +} + +#define QUANT_K QUANT_K_IQ2_XS +#define QUANT_R QUANT_R_IQ2_XS +#define A_TYPE block_iq2_xs +#define A_TYPE_PACKED16 block_iq2_xs_packed16 +#endif + +#define QUANT_K_IQ2_S 256 +#define QUANT_R_IQ2_S 1 + +struct block_iq2_s +{ + float16_t d; + uint8_t qs[QUANT_K_IQ2_S/4]; + uint8_t qh[QUANT_K_IQ2_S/32]; + uint8_t scales[QUANT_K_IQ2_S/32]; +}; + +#if defined(DATA_A_IQ2_S) + +const uvec2 iq2s_grid_const[1024] = { + uvec2(0x08080808, 0x08080808), uvec2(0x0808082b, 0x08080808), uvec2(0x08081919, 0x08080808), uvec2(0x08082b08, 0x08080808), + uvec2(0x08082b2b, 0x08080808), uvec2(0x08190819, 0x08080808), uvec2(0x08191908, 0x08080808), uvec2(0x0819192b, 0x08080808), + uvec2(0x08192b19, 0x08080808), uvec2(0x082b0808, 0x08080808), uvec2(0x082b082b, 0x08080808), uvec2(0x082b1919, 0x08080808), + uvec2(0x082b2b08, 0x08080808), uvec2(0x19080819, 0x08080808), uvec2(0x19081908, 0x08080808), uvec2(0x1908192b, 0x08080808), + uvec2(0x19082b19, 0x08080808), uvec2(0x19190808, 0x08080808), uvec2(0x1919082b, 0x08080808), uvec2(0x19191919, 0x08080808), + uvec2(0x19192b08, 0x08080808), uvec2(0x192b0819, 0x08080808), uvec2(0x192b1908, 0x08080808), uvec2(0x192b192b, 0x08080808), + uvec2(0x192b2b19, 0x08080808), uvec2(0x2b080808, 0x08080808), uvec2(0x2b08082b, 0x08080808), uvec2(0x2b081919, 0x08080808), + uvec2(0x2b082b08, 0x08080808), uvec2(0x2b190819, 0x08080808), uvec2(0x2b191908, 0x08080808), uvec2(0x2b2b0808, 0x08080808), + uvec2(0x2b2b1919, 0x08080808), uvec2(0x2b2b2b2b, 0x08080808), uvec2(0x08080819, 0x08080819), uvec2(0x08081908, 0x08080819), + uvec2(0x0808192b, 0x08080819), uvec2(0x08082b19, 0x08080819), uvec2(0x08190808, 0x08080819), uvec2(0x0819082b, 0x08080819), + uvec2(0x08191919, 0x08080819), uvec2(0x08192b08, 0x08080819), uvec2(0x082b0819, 0x08080819), uvec2(0x082b1908, 0x08080819), + uvec2(0x19080808, 0x08080819), uvec2(0x1908082b, 0x08080819), uvec2(0x19081919, 0x08080819), uvec2(0x19082b08, 0x08080819), + uvec2(0x19190819, 0x08080819), uvec2(0x19191908, 0x08080819), uvec2(0x1919192b, 0x08080819), uvec2(0x19192b19, 0x08080819), + uvec2(0x192b0808, 0x08080819), uvec2(0x192b1919, 0x08080819), uvec2(0x192b2b08, 0x08080819), uvec2(0x2b080819, 0x08080819), + uvec2(0x2b081908, 0x08080819), uvec2(0x2b190808, 0x08080819), uvec2(0x2b19082b, 0x08080819), uvec2(0x2b191919, 0x08080819), + uvec2(0x2b2b0819, 0x08080819), uvec2(0x2b2b1908, 0x08080819), uvec2(0x08080808, 0x0808082b), uvec2(0x0808082b, 0x0808082b), + uvec2(0x08081919, 0x0808082b), uvec2(0x08082b08, 0x0808082b), uvec2(0x08190819, 0x0808082b), uvec2(0x08191908, 0x0808082b), + uvec2(0x082b0808, 0x0808082b), uvec2(0x082b2b2b, 0x0808082b), uvec2(0x19080819, 0x0808082b), uvec2(0x19081908, 0x0808082b), + uvec2(0x1908192b, 0x0808082b), uvec2(0x19082b19, 0x0808082b), uvec2(0x19190808, 0x0808082b), uvec2(0x19191919, 0x0808082b), + uvec2(0x2b080808, 0x0808082b), uvec2(0x2b081919, 0x0808082b), uvec2(0x2b082b2b, 0x0808082b), uvec2(0x2b191908, 0x0808082b), + uvec2(0x2b2b082b, 0x0808082b), uvec2(0x08080819, 0x08081908), uvec2(0x08081908, 0x08081908), uvec2(0x0808192b, 0x08081908), + uvec2(0x08082b19, 0x08081908), uvec2(0x08190808, 0x08081908), uvec2(0x0819082b, 0x08081908), uvec2(0x08191919, 0x08081908), + uvec2(0x08192b08, 0x08081908), uvec2(0x082b0819, 0x08081908), uvec2(0x082b1908, 0x08081908), uvec2(0x082b192b, 0x08081908), + uvec2(0x082b2b19, 0x08081908), uvec2(0x19080808, 0x08081908), uvec2(0x1908082b, 0x08081908), uvec2(0x19081919, 0x08081908), + uvec2(0x19082b08, 0x08081908), uvec2(0x19082b2b, 0x08081908), uvec2(0x19190819, 0x08081908), uvec2(0x19191908, 0x08081908), + uvec2(0x1919192b, 0x08081908), uvec2(0x19192b19, 0x08081908), uvec2(0x192b0808, 0x08081908), uvec2(0x192b082b, 0x08081908), + uvec2(0x192b1919, 0x08081908), uvec2(0x2b080819, 0x08081908), uvec2(0x2b081908, 0x08081908), uvec2(0x2b08192b, 0x08081908), + uvec2(0x2b082b19, 0x08081908), uvec2(0x2b190808, 0x08081908), uvec2(0x2b191919, 0x08081908), uvec2(0x2b192b08, 0x08081908), + uvec2(0x2b2b0819, 0x08081908), uvec2(0x2b2b1908, 0x08081908), uvec2(0x08080808, 0x08081919), uvec2(0x0808082b, 0x08081919), + uvec2(0x08081919, 0x08081919), uvec2(0x08082b08, 0x08081919), uvec2(0x08082b2b, 0x08081919), uvec2(0x08190819, 0x08081919), + uvec2(0x08191908, 0x08081919), uvec2(0x0819192b, 0x08081919), uvec2(0x08192b19, 0x08081919), uvec2(0x082b0808, 0x08081919), + uvec2(0x082b1919, 0x08081919), uvec2(0x082b2b08, 0x08081919), uvec2(0x19080819, 0x08081919), uvec2(0x19081908, 0x08081919), + uvec2(0x1908192b, 0x08081919), uvec2(0x19082b19, 0x08081919), uvec2(0x19190808, 0x08081919), uvec2(0x1919082b, 0x08081919), + uvec2(0x19191919, 0x08081919), uvec2(0x19192b08, 0x08081919), uvec2(0x192b0819, 0x08081919), uvec2(0x192b1908, 0x08081919), + uvec2(0x2b080808, 0x08081919), uvec2(0x2b08082b, 0x08081919), uvec2(0x2b081919, 0x08081919), uvec2(0x2b082b08, 0x08081919), + uvec2(0x2b190819, 0x08081919), uvec2(0x2b191908, 0x08081919), uvec2(0x2b2b0808, 0x08081919), uvec2(0x08080819, 0x0808192b), + uvec2(0x08081908, 0x0808192b), uvec2(0x0808192b, 0x0808192b), uvec2(0x08082b19, 0x0808192b), uvec2(0x08190808, 0x0808192b), + uvec2(0x08191919, 0x0808192b), uvec2(0x19080808, 0x0808192b), uvec2(0x19081919, 0x0808192b), uvec2(0x19082b08, 0x0808192b), + uvec2(0x19190819, 0x0808192b), uvec2(0x19191908, 0x0808192b), uvec2(0x192b0808, 0x0808192b), uvec2(0x2b080819, 0x0808192b), + uvec2(0x2b081908, 0x0808192b), uvec2(0x2b190808, 0x0808192b), uvec2(0x08080808, 0x08082b08), uvec2(0x0808082b, 0x08082b08), + uvec2(0x08081919, 0x08082b08), uvec2(0x08082b08, 0x08082b08), uvec2(0x08190819, 0x08082b08), uvec2(0x08191908, 0x08082b08), + uvec2(0x0819192b, 0x08082b08), uvec2(0x08192b19, 0x08082b08), uvec2(0x082b0808, 0x08082b08), uvec2(0x082b1919, 0x08082b08), + uvec2(0x082b2b2b, 0x08082b08), uvec2(0x19080819, 0x08082b08), uvec2(0x19081908, 0x08082b08), uvec2(0x1908192b, 0x08082b08), + uvec2(0x19082b19, 0x08082b08), uvec2(0x19190808, 0x08082b08), uvec2(0x1919082b, 0x08082b08), uvec2(0x19191919, 0x08082b08), + uvec2(0x19192b08, 0x08082b08), uvec2(0x192b0819, 0x08082b08), uvec2(0x192b1908, 0x08082b08), uvec2(0x2b080808, 0x08082b08), + uvec2(0x2b081919, 0x08082b08), uvec2(0x2b191908, 0x08082b08), uvec2(0x2b2b2b2b, 0x08082b08), uvec2(0x08080819, 0x08082b19), + uvec2(0x08081908, 0x08082b19), uvec2(0x08190808, 0x08082b19), uvec2(0x0819082b, 0x08082b19), uvec2(0x08191919, 0x08082b19), + uvec2(0x08192b08, 0x08082b19), uvec2(0x082b0819, 0x08082b19), uvec2(0x19080808, 0x08082b19), uvec2(0x19081919, 0x08082b19), + uvec2(0x19082b08, 0x08082b19), uvec2(0x19190819, 0x08082b19), uvec2(0x19191908, 0x08082b19), uvec2(0x192b0808, 0x08082b19), + uvec2(0x2b080819, 0x08082b19), uvec2(0x2b190808, 0x08082b19), uvec2(0x08080808, 0x08082b2b), uvec2(0x08190819, 0x08082b2b), + uvec2(0x08191908, 0x08082b2b), uvec2(0x082b082b, 0x08082b2b), uvec2(0x082b2b08, 0x08082b2b), uvec2(0x082b2b2b, 0x08082b2b), + uvec2(0x19190808, 0x08082b2b), uvec2(0x2b192b19, 0x08082b2b), uvec2(0x08080819, 0x08190808), uvec2(0x08081908, 0x08190808), + uvec2(0x0808192b, 0x08190808), uvec2(0x08082b19, 0x08190808), uvec2(0x08190808, 0x08190808), uvec2(0x0819082b, 0x08190808), + uvec2(0x08191919, 0x08190808), uvec2(0x08192b08, 0x08190808), uvec2(0x082b0819, 0x08190808), uvec2(0x082b1908, 0x08190808), + uvec2(0x082b192b, 0x08190808), uvec2(0x19080808, 0x08190808), uvec2(0x1908082b, 0x08190808), uvec2(0x19081919, 0x08190808), + uvec2(0x19082b08, 0x08190808), uvec2(0x19190819, 0x08190808), uvec2(0x19191908, 0x08190808), uvec2(0x1919192b, 0x08190808), + uvec2(0x19192b19, 0x08190808), uvec2(0x192b0808, 0x08190808), uvec2(0x192b082b, 0x08190808), uvec2(0x192b1919, 0x08190808), + uvec2(0x192b2b08, 0x08190808), uvec2(0x2b080819, 0x08190808), uvec2(0x2b081908, 0x08190808), uvec2(0x2b08192b, 0x08190808), + uvec2(0x2b190808, 0x08190808), uvec2(0x2b191919, 0x08190808), uvec2(0x2b192b08, 0x08190808), uvec2(0x2b2b0819, 0x08190808), + uvec2(0x2b2b1908, 0x08190808), uvec2(0x08080808, 0x08190819), uvec2(0x0808082b, 0x08190819), uvec2(0x08081919, 0x08190819), + uvec2(0x08082b08, 0x08190819), uvec2(0x08082b2b, 0x08190819), uvec2(0x08190819, 0x08190819), uvec2(0x08191908, 0x08190819), + uvec2(0x0819192b, 0x08190819), uvec2(0x08192b19, 0x08190819), uvec2(0x082b0808, 0x08190819), uvec2(0x082b082b, 0x08190819), + uvec2(0x082b1919, 0x08190819), uvec2(0x082b2b08, 0x08190819), uvec2(0x19080819, 0x08190819), uvec2(0x19081908, 0x08190819), + uvec2(0x1908192b, 0x08190819), uvec2(0x19082b19, 0x08190819), uvec2(0x19190808, 0x08190819), uvec2(0x1919082b, 0x08190819), + uvec2(0x19191919, 0x08190819), uvec2(0x19192b08, 0x08190819), uvec2(0x192b0819, 0x08190819), uvec2(0x192b1908, 0x08190819), + uvec2(0x2b080808, 0x08190819), uvec2(0x2b08082b, 0x08190819), uvec2(0x2b081919, 0x08190819), uvec2(0x2b082b08, 0x08190819), + uvec2(0x2b190819, 0x08190819), uvec2(0x2b191908, 0x08190819), uvec2(0x08080819, 0x0819082b), uvec2(0x08081908, 0x0819082b), + uvec2(0x08082b19, 0x0819082b), uvec2(0x08190808, 0x0819082b), uvec2(0x08191919, 0x0819082b), uvec2(0x082b0819, 0x0819082b), + uvec2(0x082b1908, 0x0819082b), uvec2(0x19080808, 0x0819082b), uvec2(0x19081919, 0x0819082b), uvec2(0x19190819, 0x0819082b), + uvec2(0x19191908, 0x0819082b), uvec2(0x2b080819, 0x0819082b), uvec2(0x2b081908, 0x0819082b), uvec2(0x2b190808, 0x0819082b), + uvec2(0x08080808, 0x08191908), uvec2(0x0808082b, 0x08191908), uvec2(0x08081919, 0x08191908), uvec2(0x08082b08, 0x08191908), + uvec2(0x08190819, 0x08191908), uvec2(0x08191908, 0x08191908), uvec2(0x0819192b, 0x08191908), uvec2(0x08192b19, 0x08191908), + uvec2(0x082b0808, 0x08191908), uvec2(0x082b1919, 0x08191908), uvec2(0x082b2b08, 0x08191908), uvec2(0x19080819, 0x08191908), + uvec2(0x19081908, 0x08191908), uvec2(0x1908192b, 0x08191908), uvec2(0x19082b19, 0x08191908), uvec2(0x19190808, 0x08191908), + uvec2(0x1919082b, 0x08191908), uvec2(0x19191919, 0x08191908), uvec2(0x19192b08, 0x08191908), uvec2(0x192b0819, 0x08191908), + uvec2(0x192b1908, 0x08191908), uvec2(0x2b080808, 0x08191908), uvec2(0x2b08082b, 0x08191908), uvec2(0x2b081919, 0x08191908), + uvec2(0x2b082b08, 0x08191908), uvec2(0x2b190819, 0x08191908), uvec2(0x2b191908, 0x08191908), uvec2(0x2b2b0808, 0x08191908), + uvec2(0x08080819, 0x08191919), uvec2(0x08081908, 0x08191919), uvec2(0x0808192b, 0x08191919), uvec2(0x08082b19, 0x08191919), + uvec2(0x08190808, 0x08191919), uvec2(0x0819082b, 0x08191919), uvec2(0x08191919, 0x08191919), uvec2(0x08192b08, 0x08191919), + uvec2(0x082b0819, 0x08191919), uvec2(0x082b1908, 0x08191919), uvec2(0x19080808, 0x08191919), uvec2(0x1908082b, 0x08191919), + uvec2(0x19081919, 0x08191919), uvec2(0x19082b08, 0x08191919), uvec2(0x19190819, 0x08191919), uvec2(0x19191908, 0x08191919), + uvec2(0x192b0808, 0x08191919), uvec2(0x2b080819, 0x08191919), uvec2(0x2b081908, 0x08191919), uvec2(0x2b190808, 0x08191919), + uvec2(0x08080808, 0x0819192b), uvec2(0x08081919, 0x0819192b), uvec2(0x08082b08, 0x0819192b), uvec2(0x08190819, 0x0819192b), + uvec2(0x08191908, 0x0819192b), uvec2(0x082b0808, 0x0819192b), uvec2(0x19080819, 0x0819192b), uvec2(0x19081908, 0x0819192b), + uvec2(0x19190808, 0x0819192b), uvec2(0x2b080808, 0x0819192b), uvec2(0x2b2b2b2b, 0x0819192b), uvec2(0x08080819, 0x08192b08), + uvec2(0x08081908, 0x08192b08), uvec2(0x0808192b, 0x08192b08), uvec2(0x08082b19, 0x08192b08), uvec2(0x08190808, 0x08192b08), + uvec2(0x08191919, 0x08192b08), uvec2(0x08192b08, 0x08192b08), uvec2(0x082b0819, 0x08192b08), uvec2(0x19080808, 0x08192b08), + uvec2(0x1908082b, 0x08192b08), uvec2(0x19081919, 0x08192b08), uvec2(0x19082b08, 0x08192b08), uvec2(0x19190819, 0x08192b08), + uvec2(0x19191908, 0x08192b08), uvec2(0x192b0808, 0x08192b08), uvec2(0x2b080819, 0x08192b08), uvec2(0x2b081908, 0x08192b08), + uvec2(0x08080808, 0x08192b19), uvec2(0x0808082b, 0x08192b19), uvec2(0x08081919, 0x08192b19), uvec2(0x08082b08, 0x08192b19), + uvec2(0x08190819, 0x08192b19), uvec2(0x08191908, 0x08192b19), uvec2(0x082b0808, 0x08192b19), uvec2(0x19080819, 0x08192b19), + uvec2(0x19081908, 0x08192b19), uvec2(0x19190808, 0x08192b19), uvec2(0x192b2b19, 0x08192b19), uvec2(0x2b2b082b, 0x08192b19), + uvec2(0x08081908, 0x08192b2b), uvec2(0x08190808, 0x08192b2b), uvec2(0x19080808, 0x08192b2b), uvec2(0x1919192b, 0x08192b2b), + uvec2(0x08080808, 0x082b0808), uvec2(0x0808082b, 0x082b0808), uvec2(0x08081919, 0x082b0808), uvec2(0x08082b08, 0x082b0808), + uvec2(0x08190819, 0x082b0808), uvec2(0x08191908, 0x082b0808), uvec2(0x0819192b, 0x082b0808), uvec2(0x08192b19, 0x082b0808), + uvec2(0x082b0808, 0x082b0808), uvec2(0x082b1919, 0x082b0808), uvec2(0x082b2b2b, 0x082b0808), uvec2(0x19080819, 0x082b0808), + uvec2(0x19081908, 0x082b0808), uvec2(0x19190808, 0x082b0808), uvec2(0x1919082b, 0x082b0808), uvec2(0x19191919, 0x082b0808), + uvec2(0x192b1908, 0x082b0808), uvec2(0x2b080808, 0x082b0808), uvec2(0x2b082b2b, 0x082b0808), uvec2(0x2b191908, 0x082b0808), + uvec2(0x2b2b2b2b, 0x082b0808), uvec2(0x08080819, 0x082b0819), uvec2(0x08081908, 0x082b0819), uvec2(0x08190808, 0x082b0819), + uvec2(0x0819082b, 0x082b0819), uvec2(0x08191919, 0x082b0819), uvec2(0x082b0819, 0x082b0819), uvec2(0x19080808, 0x082b0819), + uvec2(0x1908082b, 0x082b0819), uvec2(0x19081919, 0x082b0819), uvec2(0x19190819, 0x082b0819), uvec2(0x19191908, 0x082b0819), + uvec2(0x192b0808, 0x082b0819), uvec2(0x2b080819, 0x082b0819), uvec2(0x2b081908, 0x082b0819), uvec2(0x2b190808, 0x082b0819), + uvec2(0x08080808, 0x082b082b), uvec2(0x08082b2b, 0x082b082b), uvec2(0x082b082b, 0x082b082b), uvec2(0x082b2b08, 0x082b082b), + uvec2(0x082b2b2b, 0x082b082b), uvec2(0x19081908, 0x082b082b), uvec2(0x19190808, 0x082b082b), uvec2(0x2b082b08, 0x082b082b), + uvec2(0x2b082b2b, 0x082b082b), uvec2(0x2b2b2b08, 0x082b082b), uvec2(0x08080819, 0x082b1908), uvec2(0x08081908, 0x082b1908), + uvec2(0x0808192b, 0x082b1908), uvec2(0x08082b19, 0x082b1908), uvec2(0x08190808, 0x082b1908), uvec2(0x08191919, 0x082b1908), + uvec2(0x08192b08, 0x082b1908), uvec2(0x082b0819, 0x082b1908), uvec2(0x082b1908, 0x082b1908), uvec2(0x19080808, 0x082b1908), + uvec2(0x1908082b, 0x082b1908), uvec2(0x19081919, 0x082b1908), uvec2(0x19082b08, 0x082b1908), uvec2(0x19190819, 0x082b1908), + uvec2(0x19191908, 0x082b1908), uvec2(0x192b0808, 0x082b1908), uvec2(0x2b080819, 0x082b1908), uvec2(0x2b081908, 0x082b1908), + uvec2(0x2b190808, 0x082b1908), uvec2(0x08080808, 0x082b1919), uvec2(0x08081919, 0x082b1919), uvec2(0x08082b08, 0x082b1919), + uvec2(0x08190819, 0x082b1919), uvec2(0x08191908, 0x082b1919), uvec2(0x082b0808, 0x082b1919), uvec2(0x19080819, 0x082b1919), + uvec2(0x19081908, 0x082b1919), uvec2(0x19190808, 0x082b1919), uvec2(0x192b192b, 0x082b1919), uvec2(0x2b080808, 0x082b1919), + uvec2(0x08080819, 0x082b192b), uvec2(0x08081908, 0x082b192b), uvec2(0x08190808, 0x082b192b), uvec2(0x19080808, 0x082b192b), + uvec2(0x19192b19, 0x082b192b), uvec2(0x08080808, 0x082b2b08), uvec2(0x08081919, 0x082b2b08), uvec2(0x08190819, 0x082b2b08), + uvec2(0x08191908, 0x082b2b08), uvec2(0x19080819, 0x082b2b08), uvec2(0x19081908, 0x082b2b08), uvec2(0x19190808, 0x082b2b08), + uvec2(0x2b082b2b, 0x082b2b08), uvec2(0x2b2b2b2b, 0x082b2b08), uvec2(0x08080819, 0x082b2b19), uvec2(0x08081908, 0x082b2b19), + uvec2(0x08190808, 0x082b2b19), uvec2(0x2b191919, 0x082b2b19), uvec2(0x08082b2b, 0x082b2b2b), uvec2(0x082b082b, 0x082b2b2b), + uvec2(0x192b1908, 0x082b2b2b), uvec2(0x2b082b08, 0x082b2b2b), uvec2(0x2b082b2b, 0x082b2b2b), uvec2(0x08080819, 0x19080808), + uvec2(0x08081908, 0x19080808), uvec2(0x0808192b, 0x19080808), uvec2(0x08082b19, 0x19080808), uvec2(0x08190808, 0x19080808), + uvec2(0x0819082b, 0x19080808), uvec2(0x08191919, 0x19080808), uvec2(0x08192b08, 0x19080808), uvec2(0x08192b2b, 0x19080808), + uvec2(0x082b0819, 0x19080808), uvec2(0x082b1908, 0x19080808), uvec2(0x082b192b, 0x19080808), uvec2(0x19080808, 0x19080808), + uvec2(0x1908082b, 0x19080808), uvec2(0x19081919, 0x19080808), uvec2(0x19082b08, 0x19080808), uvec2(0x19082b2b, 0x19080808), + uvec2(0x19190819, 0x19080808), uvec2(0x19191908, 0x19080808), uvec2(0x1919192b, 0x19080808), uvec2(0x19192b19, 0x19080808), + uvec2(0x192b0808, 0x19080808), uvec2(0x192b082b, 0x19080808), uvec2(0x192b1919, 0x19080808), uvec2(0x2b080819, 0x19080808), + uvec2(0x2b081908, 0x19080808), uvec2(0x2b190808, 0x19080808), uvec2(0x2b191919, 0x19080808), uvec2(0x2b192b08, 0x19080808), + uvec2(0x2b2b0819, 0x19080808), uvec2(0x2b2b1908, 0x19080808), uvec2(0x08080808, 0x19080819), uvec2(0x0808082b, 0x19080819), + uvec2(0x08081919, 0x19080819), uvec2(0x08082b08, 0x19080819), uvec2(0x08190819, 0x19080819), uvec2(0x08191908, 0x19080819), + uvec2(0x0819192b, 0x19080819), uvec2(0x08192b19, 0x19080819), uvec2(0x082b0808, 0x19080819), uvec2(0x082b082b, 0x19080819), + uvec2(0x082b1919, 0x19080819), uvec2(0x19080819, 0x19080819), uvec2(0x19081908, 0x19080819), uvec2(0x1908192b, 0x19080819), + uvec2(0x19082b19, 0x19080819), uvec2(0x19190808, 0x19080819), uvec2(0x1919082b, 0x19080819), uvec2(0x19191919, 0x19080819), + uvec2(0x19192b08, 0x19080819), uvec2(0x192b0819, 0x19080819), uvec2(0x192b1908, 0x19080819), uvec2(0x2b080808, 0x19080819), + uvec2(0x2b08082b, 0x19080819), uvec2(0x2b081919, 0x19080819), uvec2(0x2b082b08, 0x19080819), uvec2(0x2b190819, 0x19080819), + uvec2(0x2b191908, 0x19080819), uvec2(0x2b2b0808, 0x19080819), uvec2(0x08080819, 0x1908082b), uvec2(0x08081908, 0x1908082b), + uvec2(0x08190808, 0x1908082b), uvec2(0x0819082b, 0x1908082b), uvec2(0x08191919, 0x1908082b), uvec2(0x08192b08, 0x1908082b), + uvec2(0x082b1908, 0x1908082b), uvec2(0x19080808, 0x1908082b), uvec2(0x19081919, 0x1908082b), uvec2(0x19082b08, 0x1908082b), + uvec2(0x19190819, 0x1908082b), uvec2(0x19191908, 0x1908082b), uvec2(0x192b0808, 0x1908082b), uvec2(0x2b080819, 0x1908082b), + uvec2(0x2b081908, 0x1908082b), uvec2(0x08080808, 0x19081908), uvec2(0x0808082b, 0x19081908), uvec2(0x08081919, 0x19081908), + uvec2(0x08082b08, 0x19081908), uvec2(0x08082b2b, 0x19081908), uvec2(0x08190819, 0x19081908), uvec2(0x08191908, 0x19081908), + uvec2(0x0819192b, 0x19081908), uvec2(0x08192b19, 0x19081908), uvec2(0x082b0808, 0x19081908), uvec2(0x082b082b, 0x19081908), + uvec2(0x082b1919, 0x19081908), uvec2(0x082b2b08, 0x19081908), uvec2(0x19080819, 0x19081908), uvec2(0x19081908, 0x19081908), + uvec2(0x1908192b, 0x19081908), uvec2(0x19082b19, 0x19081908), uvec2(0x19190808, 0x19081908), uvec2(0x1919082b, 0x19081908), + uvec2(0x19191919, 0x19081908), uvec2(0x19192b08, 0x19081908), uvec2(0x192b0819, 0x19081908), uvec2(0x192b1908, 0x19081908), + uvec2(0x2b080808, 0x19081908), uvec2(0x2b08082b, 0x19081908), uvec2(0x2b081919, 0x19081908), uvec2(0x2b082b08, 0x19081908), + uvec2(0x2b190819, 0x19081908), uvec2(0x2b191908, 0x19081908), uvec2(0x2b2b0808, 0x19081908), uvec2(0x08080819, 0x19081919), + uvec2(0x08081908, 0x19081919), uvec2(0x0808192b, 0x19081919), uvec2(0x08082b19, 0x19081919), uvec2(0x08190808, 0x19081919), + uvec2(0x0819082b, 0x19081919), uvec2(0x08191919, 0x19081919), uvec2(0x08192b08, 0x19081919), uvec2(0x082b0819, 0x19081919), + uvec2(0x082b1908, 0x19081919), uvec2(0x19080808, 0x19081919), uvec2(0x1908082b, 0x19081919), uvec2(0x19081919, 0x19081919), + uvec2(0x19082b08, 0x19081919), uvec2(0x19190819, 0x19081919), uvec2(0x19191908, 0x19081919), uvec2(0x192b0808, 0x19081919), + uvec2(0x192b2b2b, 0x19081919), uvec2(0x2b080819, 0x19081919), uvec2(0x2b081908, 0x19081919), uvec2(0x2b190808, 0x19081919), + uvec2(0x08080808, 0x1908192b), uvec2(0x0808082b, 0x1908192b), uvec2(0x08081919, 0x1908192b), uvec2(0x08082b08, 0x1908192b), + uvec2(0x08190819, 0x1908192b), uvec2(0x08191908, 0x1908192b), uvec2(0x082b0808, 0x1908192b), uvec2(0x19080819, 0x1908192b), + uvec2(0x19081908, 0x1908192b), uvec2(0x19190808, 0x1908192b), uvec2(0x2b080808, 0x1908192b), uvec2(0x2b2b1919, 0x1908192b), + uvec2(0x08080819, 0x19082b08), uvec2(0x08081908, 0x19082b08), uvec2(0x08082b19, 0x19082b08), uvec2(0x08190808, 0x19082b08), + uvec2(0x0819082b, 0x19082b08), uvec2(0x08191919, 0x19082b08), uvec2(0x08192b08, 0x19082b08), uvec2(0x082b0819, 0x19082b08), + uvec2(0x082b1908, 0x19082b08), uvec2(0x19080808, 0x19082b08), uvec2(0x1908082b, 0x19082b08), uvec2(0x19081919, 0x19082b08), + uvec2(0x19082b08, 0x19082b08), uvec2(0x19190819, 0x19082b08), uvec2(0x19191908, 0x19082b08), uvec2(0x192b0808, 0x19082b08), + uvec2(0x2b081908, 0x19082b08), uvec2(0x2b190808, 0x19082b08), uvec2(0x08080808, 0x19082b19), uvec2(0x0808082b, 0x19082b19), + uvec2(0x08081919, 0x19082b19), uvec2(0x08082b08, 0x19082b19), uvec2(0x08190819, 0x19082b19), uvec2(0x08191908, 0x19082b19), + uvec2(0x082b0808, 0x19082b19), uvec2(0x19080819, 0x19082b19), uvec2(0x19081908, 0x19082b19), uvec2(0x19190808, 0x19082b19), + uvec2(0x2b080808, 0x19082b19), uvec2(0x2b19192b, 0x19082b19), uvec2(0x08080819, 0x19082b2b), uvec2(0x08081908, 0x19082b2b), + uvec2(0x08190808, 0x19082b2b), uvec2(0x19080808, 0x19082b2b), uvec2(0x08080808, 0x19190808), uvec2(0x0808082b, 0x19190808), + uvec2(0x08081919, 0x19190808), uvec2(0x08082b08, 0x19190808), uvec2(0x08190819, 0x19190808), uvec2(0x08191908, 0x19190808), + uvec2(0x0819192b, 0x19190808), uvec2(0x08192b19, 0x19190808), uvec2(0x082b0808, 0x19190808), uvec2(0x082b082b, 0x19190808), + uvec2(0x082b1919, 0x19190808), uvec2(0x082b2b08, 0x19190808), uvec2(0x19080819, 0x19190808), uvec2(0x19081908, 0x19190808), + uvec2(0x1908192b, 0x19190808), uvec2(0x19082b19, 0x19190808), uvec2(0x19190808, 0x19190808), uvec2(0x1919082b, 0x19190808), + uvec2(0x19191919, 0x19190808), uvec2(0x19192b08, 0x19190808), uvec2(0x192b0819, 0x19190808), uvec2(0x192b1908, 0x19190808), + uvec2(0x2b080808, 0x19190808), uvec2(0x2b08082b, 0x19190808), uvec2(0x2b081919, 0x19190808), uvec2(0x2b082b08, 0x19190808), + uvec2(0x2b190819, 0x19190808), uvec2(0x2b191908, 0x19190808), uvec2(0x08080819, 0x19190819), uvec2(0x08081908, 0x19190819), + uvec2(0x0808192b, 0x19190819), uvec2(0x08082b19, 0x19190819), uvec2(0x08190808, 0x19190819), uvec2(0x0819082b, 0x19190819), + uvec2(0x08191919, 0x19190819), uvec2(0x08192b08, 0x19190819), uvec2(0x082b0819, 0x19190819), uvec2(0x082b1908, 0x19190819), + uvec2(0x19080808, 0x19190819), uvec2(0x1908082b, 0x19190819), uvec2(0x19081919, 0x19190819), uvec2(0x19082b08, 0x19190819), + uvec2(0x19190819, 0x19190819), uvec2(0x19191908, 0x19190819), uvec2(0x192b0808, 0x19190819), uvec2(0x2b080819, 0x19190819), + uvec2(0x2b081908, 0x19190819), uvec2(0x2b190808, 0x19190819), uvec2(0x08080808, 0x1919082b), uvec2(0x08081919, 0x1919082b), + uvec2(0x08082b08, 0x1919082b), uvec2(0x08190819, 0x1919082b), uvec2(0x08191908, 0x1919082b), uvec2(0x082b0808, 0x1919082b), + uvec2(0x19080819, 0x1919082b), uvec2(0x19081908, 0x1919082b), uvec2(0x19190808, 0x1919082b), uvec2(0x192b2b19, 0x1919082b), + uvec2(0x2b080808, 0x1919082b), uvec2(0x08080819, 0x19191908), uvec2(0x08081908, 0x19191908), uvec2(0x0808192b, 0x19191908), + uvec2(0x08082b19, 0x19191908), uvec2(0x08190808, 0x19191908), uvec2(0x0819082b, 0x19191908), uvec2(0x08191919, 0x19191908), + uvec2(0x08192b08, 0x19191908), uvec2(0x082b0819, 0x19191908), uvec2(0x082b1908, 0x19191908), uvec2(0x19080808, 0x19191908), + uvec2(0x1908082b, 0x19191908), uvec2(0x19081919, 0x19191908), uvec2(0x19082b08, 0x19191908), uvec2(0x19190819, 0x19191908), + uvec2(0x19191908, 0x19191908), uvec2(0x192b0808, 0x19191908), uvec2(0x2b080819, 0x19191908), uvec2(0x2b081908, 0x19191908), + uvec2(0x2b190808, 0x19191908), uvec2(0x08080808, 0x19191919), uvec2(0x0808082b, 0x19191919), uvec2(0x08081919, 0x19191919), + uvec2(0x08082b08, 0x19191919), uvec2(0x08190819, 0x19191919), uvec2(0x08191908, 0x19191919), uvec2(0x082b0808, 0x19191919), + uvec2(0x19080819, 0x19191919), uvec2(0x19081908, 0x19191919), uvec2(0x19190808, 0x19191919), uvec2(0x2b080808, 0x19191919), + uvec2(0x08080819, 0x1919192b), uvec2(0x08081908, 0x1919192b), uvec2(0x08190808, 0x1919192b), uvec2(0x082b192b, 0x1919192b), + uvec2(0x19080808, 0x1919192b), uvec2(0x08080808, 0x19192b08), uvec2(0x0808082b, 0x19192b08), uvec2(0x08081919, 0x19192b08), + uvec2(0x08082b08, 0x19192b08), uvec2(0x08190819, 0x19192b08), uvec2(0x08191908, 0x19192b08), uvec2(0x082b0808, 0x19192b08), + uvec2(0x19080819, 0x19192b08), uvec2(0x19081908, 0x19192b08), uvec2(0x19190808, 0x19192b08), uvec2(0x19192b2b, 0x19192b08), + uvec2(0x2b080808, 0x19192b08), uvec2(0x08080819, 0x19192b19), uvec2(0x08081908, 0x19192b19), uvec2(0x08190808, 0x19192b19), + uvec2(0x19080808, 0x19192b19), uvec2(0x08080808, 0x19192b2b), uvec2(0x08192b19, 0x19192b2b), uvec2(0x2b081919, 0x19192b2b), + uvec2(0x2b2b2b08, 0x19192b2b), uvec2(0x08080819, 0x192b0808), uvec2(0x08081908, 0x192b0808), uvec2(0x0808192b, 0x192b0808), + uvec2(0x08190808, 0x192b0808), uvec2(0x0819082b, 0x192b0808), uvec2(0x08191919, 0x192b0808), uvec2(0x08192b08, 0x192b0808), + uvec2(0x082b0819, 0x192b0808), uvec2(0x082b1908, 0x192b0808), uvec2(0x19080808, 0x192b0808), uvec2(0x19081919, 0x192b0808), + uvec2(0x19082b08, 0x192b0808), uvec2(0x19190819, 0x192b0808), uvec2(0x19191908, 0x192b0808), uvec2(0x192b0808, 0x192b0808), + uvec2(0x2b081908, 0x192b0808), uvec2(0x2b190808, 0x192b0808), uvec2(0x08080808, 0x192b0819), uvec2(0x0808082b, 0x192b0819), + uvec2(0x08081919, 0x192b0819), uvec2(0x08082b08, 0x192b0819), uvec2(0x08190819, 0x192b0819), uvec2(0x08191908, 0x192b0819), + uvec2(0x082b0808, 0x192b0819), uvec2(0x19080819, 0x192b0819), uvec2(0x19081908, 0x192b0819), uvec2(0x19190808, 0x192b0819), + uvec2(0x2b080808, 0x192b0819), uvec2(0x2b192b19, 0x192b0819), uvec2(0x08081908, 0x192b082b), uvec2(0x08190808, 0x192b082b), + uvec2(0x19080808, 0x192b082b), uvec2(0x1919192b, 0x192b082b), uvec2(0x2b2b0819, 0x192b082b), uvec2(0x08080808, 0x192b1908), + uvec2(0x08081919, 0x192b1908), uvec2(0x08082b08, 0x192b1908), uvec2(0x08190819, 0x192b1908), uvec2(0x08191908, 0x192b1908), + uvec2(0x082b0808, 0x192b1908), uvec2(0x19080819, 0x192b1908), uvec2(0x19081908, 0x192b1908), uvec2(0x19190808, 0x192b1908), + uvec2(0x2b080808, 0x192b1908), uvec2(0x08080819, 0x192b1919), uvec2(0x08081908, 0x192b1919), uvec2(0x08190808, 0x192b1919), + uvec2(0x19080808, 0x192b1919), uvec2(0x19082b2b, 0x192b1919), uvec2(0x192b2b08, 0x192b1919), uvec2(0x2b19082b, 0x192b1919), + uvec2(0x08080808, 0x192b192b), uvec2(0x2b191908, 0x192b192b), uvec2(0x08080819, 0x192b2b08), uvec2(0x08081908, 0x192b2b08), + uvec2(0x08190808, 0x192b2b08), uvec2(0x192b1919, 0x192b2b08), uvec2(0x2b192b08, 0x192b2b08), uvec2(0x08080808, 0x192b2b19), + uvec2(0x082b2b2b, 0x192b2b19), uvec2(0x1908082b, 0x192b2b2b), uvec2(0x2b2b0819, 0x192b2b2b), uvec2(0x08080808, 0x2b080808), + uvec2(0x0808082b, 0x2b080808), uvec2(0x08081919, 0x2b080808), uvec2(0x08082b08, 0x2b080808), uvec2(0x08190819, 0x2b080808), + uvec2(0x08191908, 0x2b080808), uvec2(0x08192b19, 0x2b080808), uvec2(0x082b0808, 0x2b080808), uvec2(0x082b1919, 0x2b080808), + uvec2(0x19080819, 0x2b080808), uvec2(0x19081908, 0x2b080808), uvec2(0x19190808, 0x2b080808), uvec2(0x1919082b, 0x2b080808), + uvec2(0x19191919, 0x2b080808), uvec2(0x19192b08, 0x2b080808), uvec2(0x192b0819, 0x2b080808), uvec2(0x2b080808, 0x2b080808), + uvec2(0x2b081919, 0x2b080808), uvec2(0x2b190819, 0x2b080808), uvec2(0x2b191908, 0x2b080808), uvec2(0x08080819, 0x2b080819), + uvec2(0x08081908, 0x2b080819), uvec2(0x08082b19, 0x2b080819), uvec2(0x08190808, 0x2b080819), uvec2(0x0819082b, 0x2b080819), + uvec2(0x08191919, 0x2b080819), uvec2(0x08192b08, 0x2b080819), uvec2(0x082b0819, 0x2b080819), uvec2(0x082b1908, 0x2b080819), + uvec2(0x19080808, 0x2b080819), uvec2(0x1908082b, 0x2b080819), uvec2(0x19081919, 0x2b080819), uvec2(0x19082b08, 0x2b080819), + uvec2(0x19190819, 0x2b080819), uvec2(0x19191908, 0x2b080819), uvec2(0x2b080819, 0x2b080819), uvec2(0x2b081908, 0x2b080819), + uvec2(0x2b190808, 0x2b080819), uvec2(0x2b2b2b19, 0x2b080819), uvec2(0x08080808, 0x2b08082b), uvec2(0x08081919, 0x2b08082b), + uvec2(0x08082b2b, 0x2b08082b), uvec2(0x08190819, 0x2b08082b), uvec2(0x08191908, 0x2b08082b), uvec2(0x19080819, 0x2b08082b), + uvec2(0x19081908, 0x2b08082b), uvec2(0x19190808, 0x2b08082b), uvec2(0x08080819, 0x2b081908), uvec2(0x08081908, 0x2b081908), + uvec2(0x0808192b, 0x2b081908), uvec2(0x08082b19, 0x2b081908), uvec2(0x08190808, 0x2b081908), uvec2(0x0819082b, 0x2b081908), + uvec2(0x08191919, 0x2b081908), uvec2(0x08192b08, 0x2b081908), uvec2(0x082b0819, 0x2b081908), uvec2(0x19080808, 0x2b081908), + uvec2(0x1908082b, 0x2b081908), uvec2(0x19081919, 0x2b081908), uvec2(0x19082b08, 0x2b081908), uvec2(0x19190819, 0x2b081908), + uvec2(0x19191908, 0x2b081908), uvec2(0x192b0808, 0x2b081908), uvec2(0x2b080819, 0x2b081908), uvec2(0x2b081908, 0x2b081908), + uvec2(0x2b190808, 0x2b081908), uvec2(0x08080808, 0x2b081919), uvec2(0x0808082b, 0x2b081919), uvec2(0x08081919, 0x2b081919), + uvec2(0x08082b08, 0x2b081919), uvec2(0x08190819, 0x2b081919), uvec2(0x08191908, 0x2b081919), uvec2(0x082b0808, 0x2b081919), + uvec2(0x19080819, 0x2b081919), uvec2(0x19081908, 0x2b081919), uvec2(0x19190808, 0x2b081919), uvec2(0x2b080808, 0x2b081919), + uvec2(0x2b082b2b, 0x2b081919), uvec2(0x08080819, 0x2b08192b), uvec2(0x08081908, 0x2b08192b), uvec2(0x08190808, 0x2b08192b), + uvec2(0x082b2b19, 0x2b08192b), uvec2(0x19080808, 0x2b08192b), uvec2(0x08080808, 0x2b082b08), uvec2(0x08081919, 0x2b082b08), + uvec2(0x08190819, 0x2b082b08), uvec2(0x08191908, 0x2b082b08), uvec2(0x19080819, 0x2b082b08), uvec2(0x19081908, 0x2b082b08), + uvec2(0x19190808, 0x2b082b08), uvec2(0x2b2b082b, 0x2b082b08), uvec2(0x08080819, 0x2b082b19), uvec2(0x08081908, 0x2b082b19), + uvec2(0x19080808, 0x2b082b19), uvec2(0x192b1919, 0x2b082b19), uvec2(0x082b082b, 0x2b082b2b), uvec2(0x19192b08, 0x2b082b2b), + uvec2(0x19192b2b, 0x2b082b2b), uvec2(0x2b08082b, 0x2b082b2b), uvec2(0x2b2b082b, 0x2b082b2b), uvec2(0x08080819, 0x2b190808), + uvec2(0x08081908, 0x2b190808), uvec2(0x08082b19, 0x2b190808), uvec2(0x08190808, 0x2b190808), uvec2(0x0819082b, 0x2b190808), + uvec2(0x08191919, 0x2b190808), uvec2(0x08192b08, 0x2b190808), uvec2(0x082b1908, 0x2b190808), uvec2(0x19080808, 0x2b190808), + uvec2(0x1908082b, 0x2b190808), uvec2(0x19081919, 0x2b190808), uvec2(0x19082b08, 0x2b190808), uvec2(0x19190819, 0x2b190808), + uvec2(0x19191908, 0x2b190808), uvec2(0x192b0808, 0x2b190808), uvec2(0x2b080819, 0x2b190808), uvec2(0x2b081908, 0x2b190808), + uvec2(0x2b190808, 0x2b190808), uvec2(0x08080808, 0x2b190819), uvec2(0x08081919, 0x2b190819), uvec2(0x08190819, 0x2b190819), + uvec2(0x08191908, 0x2b190819), uvec2(0x19080819, 0x2b190819), uvec2(0x19081908, 0x2b190819), uvec2(0x19190808, 0x2b190819), + uvec2(0x19192b2b, 0x2b190819), uvec2(0x08080819, 0x2b19082b), uvec2(0x08081908, 0x2b19082b), uvec2(0x08190808, 0x2b19082b), + uvec2(0x19080808, 0x2b19082b), uvec2(0x2b2b192b, 0x2b19082b), uvec2(0x08080808, 0x2b191908), uvec2(0x0808082b, 0x2b191908), + uvec2(0x08081919, 0x2b191908), uvec2(0x08082b08, 0x2b191908), uvec2(0x08190819, 0x2b191908), uvec2(0x08191908, 0x2b191908), + uvec2(0x082b0808, 0x2b191908), uvec2(0x19080819, 0x2b191908), uvec2(0x19081908, 0x2b191908), uvec2(0x19190808, 0x2b191908), + uvec2(0x2b080808, 0x2b191908), uvec2(0x2b19192b, 0x2b191908), uvec2(0x08080819, 0x2b191919), uvec2(0x08081908, 0x2b191919), + uvec2(0x08190808, 0x2b191919), uvec2(0x19080808, 0x2b191919), uvec2(0x2b192b08, 0x2b191919), uvec2(0x2b2b0819, 0x2b191919), + uvec2(0x08080808, 0x2b19192b), uvec2(0x1908192b, 0x2b19192b), uvec2(0x192b1908, 0x2b19192b), uvec2(0x08080819, 0x2b192b08), + uvec2(0x08081908, 0x2b192b08), uvec2(0x08190808, 0x2b192b08), uvec2(0x082b192b, 0x2b192b08), uvec2(0x19080808, 0x2b192b08), + uvec2(0x2b2b2b19, 0x2b192b08), uvec2(0x08080808, 0x2b192b19), uvec2(0x19082b19, 0x2b192b19), uvec2(0x1919082b, 0x2b192b19), + uvec2(0x2b190808, 0x2b192b2b), uvec2(0x08080808, 0x2b2b0808), uvec2(0x08081919, 0x2b2b0808), uvec2(0x08082b2b, 0x2b2b0808), + uvec2(0x08191908, 0x2b2b0808), uvec2(0x082b082b, 0x2b2b0808), uvec2(0x082b2b2b, 0x2b2b0808), uvec2(0x19080819, 0x2b2b0808), + uvec2(0x19081908, 0x2b2b0808), uvec2(0x19190808, 0x2b2b0808), uvec2(0x2b2b082b, 0x2b2b0808), uvec2(0x2b2b2b2b, 0x2b2b0808), + uvec2(0x19080808, 0x2b2b0819), uvec2(0x192b1919, 0x2b2b0819), uvec2(0x0808082b, 0x2b2b082b), uvec2(0x08082b2b, 0x2b2b082b), + uvec2(0x082b082b, 0x2b2b082b), uvec2(0x082b2b08, 0x2b2b082b), uvec2(0x082b2b2b, 0x2b2b082b), uvec2(0x2b08082b, 0x2b2b082b), + uvec2(0x2b082b08, 0x2b2b082b), uvec2(0x2b082b2b, 0x2b2b082b), uvec2(0x2b2b2b08, 0x2b2b082b), uvec2(0x08080819, 0x2b2b1908), + uvec2(0x08081908, 0x2b2b1908), uvec2(0x08190808, 0x2b2b1908), uvec2(0x19080808, 0x2b2b1908), uvec2(0x2b082b19, 0x2b2b1908), + uvec2(0x2b2b1908, 0x2b2b1908), uvec2(0x08080808, 0x2b2b1919), uvec2(0x08192b19, 0x2b2b1919), uvec2(0x19190819, 0x2b2b192b), + uvec2(0x08082b2b, 0x2b2b2b08), uvec2(0x082b2b08, 0x2b2b2b08), uvec2(0x2b2b082b, 0x2b2b2b08), uvec2(0x19191908, 0x2b2b2b19), + uvec2(0x2b08192b, 0x2b2b2b19), uvec2(0x08082b08, 0x2b2b2b2b), uvec2(0x08082b2b, 0x2b2b2b2b), uvec2(0x082b0808, 0x2b2b2b2b), + uvec2(0x082b082b, 0x2b2b2b2b), uvec2(0x082b2b08, 0x2b2b2b2b), uvec2(0x2b082b08, 0x2b2b2b2b), uvec2(0x2b2b2b2b, 0x2b2b2b2b) +}; + +shared uvec2 iq2s_grid[1024]; + +#define NEEDS_INIT_IQ_SHMEM +void init_iq_shmem(uvec3 wgsize) +{ + // copy the table into shared memory and sync + for (uint i = gl_LocalInvocationIndex.x; i < iq2s_grid.length(); i += wgsize.x) { + iq2s_grid[i] = iq2s_grid_const[i]; + } + barrier(); +} + +#define QUANT_K QUANT_K_IQ2_S +#define QUANT_R QUANT_R_IQ2_S +#define A_TYPE block_iq2_s +#endif + +#define QUANT_K_IQ3_XXS 256 +#define QUANT_R_IQ3_XXS 1 + +struct block_iq3_xxs +{ + float16_t d; + uint8_t qs[QUANT_K_IQ3_XXS/4 + QUANT_K_IQ3_XXS/8]; +}; + +struct block_iq3_xxs_packed16 +{ + float16_t d; + uint16_t qs[QUANT_K_IQ3_XXS/8 + QUANT_K_IQ3_XXS/16]; +}; + +#if defined(DATA_A_IQ3_XXS) + +const uint32_t iq3xxs_grid_const[256] = { + 0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414, + 0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14, + 0x040c140c, 0x040c142c, 0x040c1c04, 0x040c1c14, 0x040c240c, 0x040c2c24, 0x040c3e04, 0x04140404, + 0x04140414, 0x04140424, 0x04140c0c, 0x04141404, 0x04141414, 0x04141c0c, 0x04141c1c, 0x04141c3e, + 0x04142c0c, 0x04142c3e, 0x04143e2c, 0x041c040c, 0x041c043e, 0x041c0c04, 0x041c0c14, 0x041c142c, + 0x041c3e04, 0x04240c1c, 0x04241c3e, 0x04242424, 0x04242c3e, 0x04243e1c, 0x04243e2c, 0x042c040c, + 0x042c043e, 0x042c1c14, 0x042c2c14, 0x04341c2c, 0x04343424, 0x043e0c04, 0x043e0c24, 0x043e0c34, + 0x043e241c, 0x043e340c, 0x0c04040c, 0x0c04041c, 0x0c040c04, 0x0c040c14, 0x0c04140c, 0x0c04141c, + 0x0c041c04, 0x0c041c14, 0x0c041c24, 0x0c04243e, 0x0c042c04, 0x0c0c0404, 0x0c0c0414, 0x0c0c0c0c, + 0x0c0c1404, 0x0c0c1414, 0x0c14040c, 0x0c14041c, 0x0c140c04, 0x0c140c14, 0x0c14140c, 0x0c141c04, + 0x0c143e14, 0x0c1c0404, 0x0c1c0414, 0x0c1c1404, 0x0c1c1c0c, 0x0c1c2434, 0x0c1c3434, 0x0c24040c, + 0x0c24042c, 0x0c242c04, 0x0c2c1404, 0x0c2c1424, 0x0c2c2434, 0x0c2c3e0c, 0x0c34042c, 0x0c3e1414, + 0x0c3e2404, 0x14040404, 0x14040414, 0x14040c0c, 0x14040c1c, 0x14041404, 0x14041414, 0x14041434, + 0x14041c0c, 0x14042414, 0x140c040c, 0x140c041c, 0x140c042c, 0x140c0c04, 0x140c0c14, 0x140c140c, + 0x140c1c04, 0x140c341c, 0x140c343e, 0x140c3e04, 0x14140404, 0x14140414, 0x14140c0c, 0x14140c3e, + 0x14141404, 0x14141414, 0x14141c3e, 0x14142404, 0x14142c2c, 0x141c040c, 0x141c0c04, 0x141c0c24, + 0x141c3e04, 0x141c3e24, 0x14241c2c, 0x14242c1c, 0x142c041c, 0x142c143e, 0x142c240c, 0x142c3e24, + 0x143e040c, 0x143e041c, 0x143e0c34, 0x143e242c, 0x1c04040c, 0x1c040c04, 0x1c040c14, 0x1c04140c, + 0x1c04141c, 0x1c042c04, 0x1c04342c, 0x1c043e14, 0x1c0c0404, 0x1c0c0414, 0x1c0c1404, 0x1c0c1c0c, + 0x1c0c2424, 0x1c0c2434, 0x1c14040c, 0x1c14041c, 0x1c140c04, 0x1c14142c, 0x1c142c14, 0x1c143e14, + 0x1c1c0c0c, 0x1c1c1c1c, 0x1c241c04, 0x1c24243e, 0x1c243e14, 0x1c2c0404, 0x1c2c0434, 0x1c2c1414, + 0x1c2c2c2c, 0x1c340c24, 0x1c341c34, 0x1c34341c, 0x1c3e1c1c, 0x1c3e3404, 0x24040424, 0x24040c3e, + 0x24041c2c, 0x24041c3e, 0x24042c1c, 0x24042c3e, 0x240c3e24, 0x24141404, 0x24141c3e, 0x24142404, + 0x24143404, 0x24143434, 0x241c043e, 0x241c242c, 0x24240424, 0x24242c0c, 0x24243424, 0x242c142c, + 0x242c241c, 0x242c3e04, 0x243e042c, 0x243e0c04, 0x243e0c14, 0x243e1c04, 0x2c040c14, 0x2c04240c, + 0x2c043e04, 0x2c0c0404, 0x2c0c0434, 0x2c0c1434, 0x2c0c2c2c, 0x2c140c24, 0x2c141c14, 0x2c143e14, + 0x2c1c0414, 0x2c1c2c1c, 0x2c240c04, 0x2c24141c, 0x2c24143e, 0x2c243e14, 0x2c2c0414, 0x2c2c1c0c, + 0x2c342c04, 0x2c3e1424, 0x2c3e2414, 0x34041424, 0x34042424, 0x34042434, 0x34043424, 0x340c140c, + 0x340c340c, 0x34140c3e, 0x34143424, 0x341c1c04, 0x341c1c34, 0x34242424, 0x342c042c, 0x342c2c14, + 0x34341c1c, 0x343e041c, 0x343e140c, 0x3e04041c, 0x3e04042c, 0x3e04043e, 0x3e040c04, 0x3e041c14, + 0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c, + 0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04, +}; + +shared uint32_t iq3xxs_grid[256]; + +#define NEEDS_INIT_IQ_SHMEM +void init_iq_shmem(uvec3 wgsize) +{ + // copy the table into shared memory and sync + for (uint i = gl_LocalInvocationIndex.x; i < iq3xxs_grid.length(); i += wgsize.x) { + iq3xxs_grid[i] = iq3xxs_grid_const[i]; + } + barrier(); +} + +#define QUANT_K QUANT_K_IQ3_XXS +#define QUANT_R QUANT_R_IQ3_XXS +#define A_TYPE block_iq3_xxs +#define A_TYPE_PACKED16 block_iq3_xxs_packed16 +#endif + +#define QUANT_K_IQ3_S 256 +#define QUANT_R_IQ3_S 1 + +struct block_iq3_s +{ + float16_t d; + uint8_t qs[QUANT_K_IQ3_S/4]; + uint8_t qh[QUANT_K_IQ3_S/32]; + uint8_t signs[QUANT_K_IQ3_S/8]; + uint8_t scales[QUANT_K_IQ3_S/64]; +}; + +struct block_iq3_s_packed16 +{ + float16_t d; + uint16_t qs[QUANT_K_IQ3_S/4/2]; + uint16_t qh[QUANT_K_IQ3_S/32/2]; + uint16_t signs[QUANT_K_IQ3_S/8/2]; + uint16_t scales[QUANT_K_IQ3_S/64/2]; +}; + +#if defined(DATA_A_IQ3_S) + +const uint32_t iq3s_grid_const[512] = { + 0x01010101, 0x01010103, 0x01010105, 0x0101010b, 0x0101010f, 0x01010301, 0x01010303, 0x01010305, + 0x01010309, 0x0101030d, 0x01010501, 0x01010503, 0x0101050b, 0x01010707, 0x01010901, 0x01010905, + 0x0101090b, 0x0101090f, 0x01010b03, 0x01010b07, 0x01010d01, 0x01010d05, 0x01010f03, 0x01010f09, + 0x01010f0f, 0x01030101, 0x01030103, 0x01030105, 0x01030109, 0x01030301, 0x01030303, 0x0103030b, + 0x01030501, 0x01030507, 0x0103050f, 0x01030703, 0x0103070b, 0x01030909, 0x01030d03, 0x01030d0b, + 0x01030f05, 0x01050101, 0x01050103, 0x0105010b, 0x0105010f, 0x01050301, 0x01050307, 0x0105030d, + 0x01050503, 0x0105050b, 0x01050701, 0x01050709, 0x01050905, 0x0105090b, 0x0105090f, 0x01050b03, + 0x01050b07, 0x01050f01, 0x01050f07, 0x01070107, 0x01070303, 0x0107030b, 0x01070501, 0x01070505, + 0x01070703, 0x01070707, 0x0107070d, 0x01070909, 0x01070b01, 0x01070b05, 0x01070d0f, 0x01070f03, + 0x01070f0b, 0x01090101, 0x01090307, 0x0109030f, 0x01090503, 0x01090509, 0x01090705, 0x01090901, + 0x01090907, 0x01090b03, 0x01090f01, 0x010b0105, 0x010b0109, 0x010b0501, 0x010b0505, 0x010b050d, + 0x010b0707, 0x010b0903, 0x010b090b, 0x010b090f, 0x010b0d0d, 0x010b0f07, 0x010d010d, 0x010d0303, + 0x010d0307, 0x010d0703, 0x010d0b05, 0x010d0f03, 0x010f0101, 0x010f0105, 0x010f0109, 0x010f0501, + 0x010f0505, 0x010f050d, 0x010f0707, 0x010f0b01, 0x010f0b09, 0x03010101, 0x03010103, 0x03010105, + 0x03010109, 0x03010301, 0x03010303, 0x03010307, 0x0301030b, 0x0301030f, 0x03010501, 0x03010505, + 0x03010703, 0x03010709, 0x0301070d, 0x03010b09, 0x03010b0d, 0x03010d03, 0x03010f05, 0x03030101, + 0x03030103, 0x03030107, 0x0303010d, 0x03030301, 0x03030309, 0x03030503, 0x03030701, 0x03030707, + 0x03030903, 0x03030b01, 0x03030b05, 0x03030f01, 0x03030f0d, 0x03050101, 0x03050305, 0x0305030b, + 0x0305030f, 0x03050501, 0x03050509, 0x03050705, 0x03050901, 0x03050907, 0x03050b0b, 0x03050d01, + 0x03050f05, 0x03070103, 0x03070109, 0x0307010f, 0x03070301, 0x03070307, 0x03070503, 0x0307050f, + 0x03070701, 0x03070709, 0x03070903, 0x03070d05, 0x03070f01, 0x03090107, 0x0309010b, 0x03090305, + 0x03090309, 0x03090703, 0x03090707, 0x03090905, 0x0309090d, 0x03090b01, 0x03090b09, 0x030b0103, + 0x030b0301, 0x030b0307, 0x030b0503, 0x030b0701, 0x030b0705, 0x030b0b03, 0x030d0501, 0x030d0509, + 0x030d050f, 0x030d0909, 0x030d090d, 0x030f0103, 0x030f0107, 0x030f0301, 0x030f0305, 0x030f0503, + 0x030f070b, 0x030f0903, 0x030f0d05, 0x030f0f01, 0x05010101, 0x05010103, 0x05010107, 0x0501010b, + 0x0501010f, 0x05010301, 0x05010305, 0x05010309, 0x0501030d, 0x05010503, 0x05010507, 0x0501050f, + 0x05010701, 0x05010705, 0x05010903, 0x05010907, 0x0501090b, 0x05010b01, 0x05010b05, 0x05010d0f, + 0x05010f01, 0x05010f07, 0x05010f0b, 0x05030101, 0x05030105, 0x05030301, 0x05030307, 0x0503030f, + 0x05030505, 0x0503050b, 0x05030703, 0x05030709, 0x05030905, 0x05030b03, 0x05050103, 0x05050109, + 0x0505010f, 0x05050503, 0x05050507, 0x05050701, 0x0505070f, 0x05050903, 0x05050b07, 0x05050b0f, + 0x05050f03, 0x05050f09, 0x05070101, 0x05070105, 0x0507010b, 0x05070303, 0x05070505, 0x05070509, + 0x05070703, 0x05070707, 0x05070905, 0x05070b01, 0x05070d0d, 0x05090103, 0x0509010f, 0x05090501, + 0x05090507, 0x05090705, 0x0509070b, 0x05090903, 0x05090f05, 0x05090f0b, 0x050b0109, 0x050b0303, + 0x050b0505, 0x050b070f, 0x050b0901, 0x050b0b07, 0x050b0f01, 0x050d0101, 0x050d0105, 0x050d010f, + 0x050d0503, 0x050d0b0b, 0x050d0d03, 0x050f010b, 0x050f0303, 0x050f050d, 0x050f0701, 0x050f0907, + 0x050f0b01, 0x07010105, 0x07010303, 0x07010307, 0x0701030b, 0x0701030f, 0x07010505, 0x07010703, + 0x07010707, 0x0701070b, 0x07010905, 0x07010909, 0x0701090f, 0x07010b03, 0x07010d07, 0x07010f03, + 0x07030103, 0x07030107, 0x0703010b, 0x07030309, 0x07030503, 0x07030507, 0x07030901, 0x07030d01, + 0x07030f05, 0x07030f0d, 0x07050101, 0x07050305, 0x07050501, 0x07050705, 0x07050709, 0x07050b01, + 0x07070103, 0x07070301, 0x07070309, 0x07070503, 0x07070507, 0x0707050f, 0x07070701, 0x07070903, + 0x07070907, 0x0707090f, 0x07070b0b, 0x07070f07, 0x07090107, 0x07090303, 0x0709030d, 0x07090505, + 0x07090703, 0x07090b05, 0x07090d01, 0x07090d09, 0x070b0103, 0x070b0301, 0x070b0305, 0x070b050b, + 0x070b0705, 0x070b0909, 0x070b0b0d, 0x070b0f07, 0x070d030d, 0x070d0903, 0x070f0103, 0x070f0107, + 0x070f0501, 0x070f0505, 0x070f070b, 0x09010101, 0x09010109, 0x09010305, 0x09010501, 0x09010509, + 0x0901050f, 0x09010705, 0x09010903, 0x09010b01, 0x09010f01, 0x09030105, 0x0903010f, 0x09030303, + 0x09030307, 0x09030505, 0x09030701, 0x0903070b, 0x09030907, 0x09030b03, 0x09030b0b, 0x09050103, + 0x09050107, 0x09050301, 0x0905030b, 0x09050503, 0x09050707, 0x09050901, 0x09050b0f, 0x09050d05, + 0x09050f01, 0x09070109, 0x09070303, 0x09070307, 0x09070501, 0x09070505, 0x09070703, 0x0907070b, + 0x09090101, 0x09090105, 0x09090509, 0x0909070f, 0x09090901, 0x09090f03, 0x090b010b, 0x090b010f, + 0x090b0503, 0x090b0d05, 0x090d0307, 0x090d0709, 0x090d0d01, 0x090f0301, 0x090f030b, 0x090f0701, + 0x090f0907, 0x090f0b03, 0x0b010105, 0x0b010301, 0x0b010309, 0x0b010505, 0x0b010901, 0x0b010909, + 0x0b01090f, 0x0b010b05, 0x0b010d0d, 0x0b010f09, 0x0b030103, 0x0b030107, 0x0b03010b, 0x0b030305, + 0x0b030503, 0x0b030705, 0x0b030f05, 0x0b050101, 0x0b050303, 0x0b050507, 0x0b050701, 0x0b05070d, + 0x0b050b07, 0x0b070105, 0x0b07010f, 0x0b070301, 0x0b07050f, 0x0b070909, 0x0b070b03, 0x0b070d0b, + 0x0b070f07, 0x0b090103, 0x0b090109, 0x0b090501, 0x0b090705, 0x0b09090d, 0x0b0b0305, 0x0b0b050d, + 0x0b0b0b03, 0x0b0b0b07, 0x0b0d0905, 0x0b0f0105, 0x0b0f0109, 0x0b0f0505, 0x0d010303, 0x0d010307, + 0x0d01030b, 0x0d010703, 0x0d010707, 0x0d010d01, 0x0d030101, 0x0d030501, 0x0d03050f, 0x0d030d09, + 0x0d050305, 0x0d050709, 0x0d050905, 0x0d050b0b, 0x0d050d05, 0x0d050f01, 0x0d070101, 0x0d070309, + 0x0d070503, 0x0d070901, 0x0d09050b, 0x0d090907, 0x0d090d05, 0x0d0b0101, 0x0d0b0107, 0x0d0b0709, + 0x0d0b0d01, 0x0d0d010b, 0x0d0d0901, 0x0d0f0303, 0x0d0f0307, 0x0f010101, 0x0f010109, 0x0f01010f, + 0x0f010501, 0x0f010505, 0x0f01070d, 0x0f010901, 0x0f010b09, 0x0f010d05, 0x0f030105, 0x0f030303, + 0x0f030509, 0x0f030907, 0x0f03090b, 0x0f050103, 0x0f050109, 0x0f050301, 0x0f05030d, 0x0f050503, + 0x0f050701, 0x0f050b03, 0x0f070105, 0x0f070705, 0x0f07070b, 0x0f070b07, 0x0f090103, 0x0f09010b, + 0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101, +}; + +shared uint32_t iq3s_grid[512]; + +#define NEEDS_INIT_IQ_SHMEM +void init_iq_shmem(uvec3 wgsize) +{ + // copy the table into shared memory and sync + for (uint i = gl_LocalInvocationIndex.x; i < iq3s_grid.length(); i += wgsize.x) { + iq3s_grid[i] = iq3s_grid_const[i]; + } + barrier(); +} + +#define QUANT_K QUANT_K_IQ3_S +#define QUANT_R QUANT_R_IQ3_S +#define A_TYPE block_iq3_s +#define A_TYPE_PACKED16 block_iq3_s_packed16 +#endif + +#define QUANT_K_IQ4_XS 256 +#define QUANT_R_IQ4_XS 1 + +struct block_iq4_xs +{ + float16_t d; + uint16_t scales_h; + uint8_t scales_l[QUANT_K_IQ4_XS/64]; + uint8_t qs[QUANT_K_IQ4_XS/2]; +}; + +#if defined(DATA_A_IQ4_XS) +#define QUANT_K QUANT_K_IQ4_XS +#define QUANT_R QUANT_R_IQ4_XS +#define A_TYPE block_iq4_xs +#endif + #define QUANT_K_IQ4_NL 32 #define QUANT_R_IQ4_NL 2 @@ -297,7 +1245,13 @@ struct block_iq4_nl_packed16 }; #if defined(DATA_A_IQ4_NL) +#define QUANT_K QUANT_K_IQ4_NL +#define QUANT_R QUANT_R_IQ4_NL +#define A_TYPE block_iq4_nl +#define A_TYPE_PACKED16 block_iq4_nl_packed16 +#endif +#if defined(DATA_A_IQ4_NL) || defined(DATA_A_IQ4_XS) const int8_t kvalues_iq4nl_const[16] = { int8_t(-127), int8_t(-104), int8_t(-83), int8_t(-65), int8_t(-49), int8_t(-35), int8_t(-22), int8_t(-10), int8_t(1), int8_t(13), int8_t(25), int8_t(38), int8_t(53), int8_t(69), int8_t(89), int8_t(113) @@ -305,19 +1259,15 @@ const int8_t kvalues_iq4nl_const[16] = { shared FLOAT_TYPE kvalues_iq4nl[16]; -void init_iq4nl_shmem() +#define NEEDS_INIT_IQ_SHMEM +void init_iq_shmem(uvec3 wgsize) { // copy the table into shared memory and sync - if (gl_LocalInvocationIndex.x < 16) { - kvalues_iq4nl[gl_LocalInvocationIndex.x] = FLOAT_TYPE(kvalues_iq4nl_const[gl_LocalInvocationIndex.x]); + for (uint i = gl_LocalInvocationIndex.x; i < kvalues_iq4nl.length(); i += wgsize.x) { + kvalues_iq4nl[i] = FLOAT_TYPE(kvalues_iq4nl_const[i]); } barrier(); } - -#define QUANT_K QUANT_K_IQ4_NL -#define QUANT_R QUANT_R_IQ4_NL -#define A_TYPE block_iq4_nl -#define A_TYPE_PACKED16 block_iq4_nl_packed16 #endif #endif // !defined(GGML_TYPES_COMP) diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 8111c0638..c5e0bba82 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -17,21 +17,19 @@ #include #include #include +#include #include #include #ifdef _WIN32 #include #include // For _mkdir on Windows - #include // For std::replace on w64devkit #else #include #include #include #endif -#include - #define ASYNCIO_CONCURRENCY 64 std::mutex lock; @@ -57,6 +55,14 @@ const std::vector type_names = { "q4_k", "q5_k", "q6_k", + "iq1_s", + "iq1_m", + "iq2_xxs", + "iq2_xs", + "iq2_s", + "iq3_xxs", + "iq3_s", + "iq4_xs", "iq4_nl" }; @@ -178,6 +184,13 @@ std::string to_uppercase(const std::string& input) { return result; } +bool string_starts_with(const std::string& str, const std::string& prefix) { + if (prefix.size() > str.size()) { + return false; + } + return std::equal(prefix.begin(), prefix.end(), str.begin()); +} + bool string_ends_with(const std::string& str, const std::string& suffix) { if (suffix.size() > str.size()) { return false; @@ -318,8 +331,11 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool // For aligned matmul loads std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : "2"; - string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + // don't generate f32 variants for coopmat2 + if (!coopmat2) { + string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + } if (tname != "f16" && tname != "f32") { string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc); @@ -342,9 +358,11 @@ void process_shaders() { matmul_shaders(true, matmul_id, false, false, false); matmul_shaders(true, matmul_id, false, false, true); +#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) // Coopmat, fp32acc and fp16acc matmul_shaders(true, matmul_id, true, false, false); matmul_shaders(true, matmul_id, true, false, true); +#endif #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) // Coopmat2, fp32acc and fp16acc @@ -378,7 +396,7 @@ void process_shaders() { for (const auto& tname : type_names) { // mul mat vec std::string data_a_key = "DATA_A_" + to_uppercase(tname); - std::string shader = (string_ends_with(tname, "_k")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp"; + std::string shader = (string_ends_with(tname, "_k") || string_starts_with(tname, "iq1_")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp"; string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}})); string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}})); @@ -409,6 +427,7 @@ void process_shaders() { string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}); @@ -417,9 +436,16 @@ void process_shaders() { string_to_spv("contig_cpy_f32_f16", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}); string_to_spv("contig_cpy_f16_f16", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); + for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) { + string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + string_to_spv("cpy_" + t + "_f32", "copy_from_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + } + string_to_spv("add_f32", "add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); string_to_spv("add_f16_f32_f16", "add.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}}); + string_to_spv("sub_f32", "sub.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {}); @@ -429,6 +455,7 @@ void process_shaders() { string_to_spv("div_f32", "div.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); string_to_spv("repeat_f32", "repeat.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("repeat_back_f32", "repeat_back.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); @@ -451,14 +478,17 @@ void process_shaders() { string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("gelu_quick_f32", "gelu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("sigmoid_f32", "sigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("soft_max_f32", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("soft_max_f32_f16", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}})); + string_to_spv("soft_max_back_f32", "soft_max_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); @@ -468,9 +498,19 @@ void process_shaders() { string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("rope_neox_f16_rte", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}); + string_to_spv("rope_multi_f32", "rope_multi.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("rope_multi_f16", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("rope_multi_f16_rte", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}); + + string_to_spv("rope_vision_f32", "rope_vision.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("rope_vision_f16", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("rope_vision_f16_rte", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}); + string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}}); + string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}})); string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}})); string_to_spv("im2col_f32", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("im2col_f32_f16", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}})); @@ -482,6 +522,8 @@ void process_shaders() { string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); + string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); + for (auto &c : compiles) { c.wait(); } @@ -494,6 +536,7 @@ void write_output_files() { fprintf(hdr, "#include \n\n"); fprintf(src, "#include \"%s\"\n\n", basename(target_hpp).c_str()); + std::sort(shader_fnames.begin(), shader_fnames.end()); for (const auto& pair : shader_fnames) { const std::string& name = pair.first; #ifdef _WIN32 From 45dbd1464542a197f3b0af54f0ccd68b39fb2c1e Mon Sep 17 00:00:00 2001 From: Vadim Grinco Date: Sun, 23 Mar 2025 12:27:37 +0100 Subject: [PATCH 033/172] Merged latest ollama 0.6.2 and nasrally's Flash Attention patches (#5) * readme: add Ellama to list of community integrations (#9800) * readme: add screenpipe to community integrations (#9786) * Add support for ROCm gfx1151 (#9773) * conditionally enable parallel pipelines * sample: make mutations in transforms explicit (#9743) * updated minP to use early exit making use of sorted tokens * ml/backend/ggml: allocate memory with malloc when loading model (#9822) * runner: remove cache prompt flag from ollama runner (#9826) We do not need to bypass the prompt caching in the ollama runner yet, as only embedding models needed to bypass the prompt caching. When embedding models are implemented they can skip initializing this cache completely. * ollamarunner: Check for minBatch of context space when shifting Models can specify that a group of inputs need to be handled a single batch. However, context shifting didn't respect this and could trigger a break anyways. In this case, we should instead trigger a context shift earlier so that it occurs before the grouped batch. Note that there still some corner cases: - A long prompt that exceeds the context window can get truncated in the middle of an image. With the current models, this will result in the model not recognizing the image at all, which is pretty much the expected result with truncation. - The context window is set less than the minimum batch size. The only solution to this is to refuse to load the model with these settings. However, this can never occur with current models and default settings. Since users are unlikely to run into these scenarios, fixing them is left as a follow up. * Applied latest patches from McBane87 See this for details: https://github.com/whyvl/ollama-vulkan/issues/7#issuecomment-2708820861 Signed-off-by: Vadim Grinco * Add ability to enable flash attention on vulkan (#4) * discover: add flash attention handling for vulkan * envconfig: fix typo in config.go As part of the process some code was refactored and I added a new field FlashAttention to GpuInfo since the previous solution didn't allow for a granular check via vulkan extensions. As a side effect, this now allows for granular per-device FA support checking in other places --------- Signed-off-by: Vadim Grinco Co-authored-by: zeo <108888572+zeozeozeo@users.noreply.github.com> Co-authored-by: Louis Beaumont Co-authored-by: Daniel Hiltgen Co-authored-by: Michael Yang Co-authored-by: Parth Sareen Co-authored-by: Jeffrey Morgan Co-authored-by: Bruce MacDonald Co-authored-by: Jesse Gross Co-authored-by: Nikita <50599445+nasrally@users.noreply.github.com> --- CMakePresets.json | 2 +- README.md | 2 + discover/amd_linux.go | 13 +-- discover/gpu.go | 3 + discover/gpu_info_vulkan.c | 64 +++++++++++++- discover/gpu_info_vulkan.h | 1 + discover/types.go | 13 ++- envconfig/config.go | 2 +- ml/backend/ggml/ggml.go | 18 ++-- runner/ollamarunner/cache.go | 7 +- runner/ollamarunner/cache_test.go | 128 +++++++++++++++++++++++++++ runner/ollamarunner/runner.go | 30 ++++--- sample/samplers.go | 5 +- sample/transforms.go | 39 +++------ sample/transforms_test.go | 138 ++++++++++++++++++++---------- 15 files changed, 348 insertions(+), 117 deletions(-) diff --git a/CMakePresets.json b/CMakePresets.json index 6181eb732..fd0fb9b3a 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -56,7 +56,7 @@ "name": "ROCm 6", "inherits": [ "ROCm" ], "cacheVariables": { - "AMDGPU_TARGETS": "gfx900;gfx940;gfx941;gfx942;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx906:xnack-;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-" + "AMDGPU_TARGETS": "gfx900;gfx940;gfx941;gfx942;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx1151;gfx906:xnack-;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-" } }, { diff --git a/README.md b/README.md index 60b23cded..47d0aebd9 100644 --- a/README.md +++ b/README.md @@ -392,6 +392,8 @@ See the [API documentation](./docs/api.md) for all endpoints. - [1Panel](https://github.com/1Panel-dev/1Panel/) (Web-based Linux Server Management Tool) - [AstrBot](https://github.com/Soulter/AstrBot/) (User-friendly LLM-based multi-platform chatbot with a WebUI, supporting RAG, LLM agents, and plugins integration) - [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.) +- [Ellama](https://github.com/zeozeozeo/ellama) (Friendly native app to chat with an Ollama instance) +- [screenpipe](https://github.com/mediar-ai/screenpipe) Build agents powered by your screen history ### Cloud diff --git a/discover/amd_linux.go b/discover/amd_linux.go index 830fa1df6..06e907391 100644 --- a/discover/amd_linux.go +++ b/discover/amd_linux.go @@ -279,12 +279,13 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) { TotalMemory: totalMemory, FreeMemory: (totalMemory - usedMemory), }, - ID: ID, - Name: name, - Compute: fmt.Sprintf("gfx%d%x%x", major, minor, patch), - MinimumMemory: rocmMinimumMemory, - DriverMajor: driverMajor, - DriverMinor: driverMinor, + ID: ID, + Name: name, + Compute: fmt.Sprintf("gfx%d%x%x", major, minor, patch), + MinimumMemory: rocmMinimumMemory, + FlashAttention: true, // Supposedly ROCm supports it everywhere + DriverMajor: driverMajor, + DriverMinor: driverMinor, }, usedFilepath: usedFile, index: gpuID, diff --git a/discover/gpu.go b/discover/gpu.go index 2494469a7..c889c4833 100644 --- a/discover/gpu.go +++ b/discover/gpu.go @@ -310,6 +310,7 @@ func GetGPUInfo() GpuInfoList { C.free(unsafe.Pointer(memInfo.err)) continue } + gpuInfo.FlashAttention = driverMajor >= 7 gpuInfo.TotalMemory = uint64(memInfo.total) gpuInfo.FreeMemory = uint64(memInfo.free) gpuInfo.ID = C.GoString(&memInfo.gpu_id[0]) @@ -394,6 +395,7 @@ func GetGPUInfo() GpuInfoList { // TODO - convert this to MinimumMemory based on testing... var totalFreeMem float64 = float64(memInfo.free) * 0.95 // work-around: leave some reserve vram for mkl lib used in ggml-sycl backend. memInfo.free = C.uint64_t(totalFreeMem) + gpuInfo.FlashAttention = false gpuInfo.TotalMemory = uint64(memInfo.total) gpuInfo.FreeMemory = uint64(memInfo.free) gpuInfo.ID = C.GoString(&memInfo.gpu_id[0]) @@ -423,6 +425,7 @@ func GetGPUInfo() GpuInfoList { continue } + gpuInfo.FlashAttention = (C.vk_check_flash_attention(*vHandles.vulkan, C.int(i)) == 0) // 0 means supported gpuInfo.TotalMemory = uint64(memInfo.total) gpuInfo.FreeMemory = uint64(memInfo.free) gpuInfo.ID = C.GoString(&memInfo.gpu_id[0]) diff --git a/discover/gpu_info_vulkan.c b/discover/gpu_info_vulkan.c index e868dcc1b..29eaaeb7f 100644 --- a/discover/gpu_info_vulkan.c +++ b/discover/gpu_info_vulkan.c @@ -24,18 +24,32 @@ int check_perfmon(vk_handle_t* rh) { return 0; } -int support_memory_budget(vk_handle_t* rh, VkPhysicalDevice device) { +int is_extension_supported(vk_handle_t* rh, VkPhysicalDevice device, char* extension) { VkPhysicalDeviceProperties properties; (*rh->vkGetPhysicalDeviceProperties)(device, &properties); + uint32_t extensionCount; (*rh->vkEnumerateDeviceExtensionProperties)(device, NULL, &extensionCount, NULL); + + if (extensionCount == 0) { + return 0; + } + VkExtensionProperties* extensions = malloc(extensionCount * sizeof(VkExtensionProperties)); + if (extensions == NULL) { + return 0; + } + (*rh->vkEnumerateDeviceExtensionProperties)(device, NULL, &extensionCount, extensions); + for (int j = 0; j < extensionCount; j++) { - if (strcmp(extensions[j].extensionName, VK_EXT_MEMORY_BUDGET_EXTENSION_NAME) == 0) { + if (strcmp(extensions[j].extensionName, extension) == 0) { + free(extensions); return 1; } } + + free(extensions); return 0; } @@ -125,6 +139,7 @@ void vk_init(char* vk_lib_path, char* cap_lib_path, vk_init_resp_t *resp) { } VkInstance instance; + VkApplicationInfo appInfo = {}; appInfo.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO; appInfo.pNext = NULL; @@ -133,6 +148,7 @@ void vk_init(char* vk_lib_path, char* cap_lib_path, vk_init_resp_t *resp) { appInfo.pEngineName = "No Engine"; appInfo.engineVersion = VK_MAKE_VERSION(1, 0, 0); appInfo.apiVersion = VK_API_VERSION_1_2; + VkInstanceCreateInfo createInfo = {}; createInfo.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO; createInfo.pNext = NULL; @@ -141,6 +157,7 @@ void vk_init(char* vk_lib_path, char* cap_lib_path, vk_init_resp_t *resp) { const char* extensions[] = { VK_KHR_GET_PHYSICAL_DEVICE_PROPERTIES_2_EXTENSION_NAME }; createInfo.ppEnabledExtensionNames = extensions; createInfo.pApplicationInfo = &appInfo; + VkResult result = (*resp->ch.vkCreateInstance)(&createInfo, NULL, &instance); if (result != VK_SUCCESS) { resp->err = strdup("failed to create instance"); @@ -160,25 +177,63 @@ void vk_init(char* vk_lib_path, char* cap_lib_path, vk_init_resp_t *resp) { resp->num_devices = deviceCount; } +int vk_check_flash_attention(vk_handle_t rh, int i) { + VkInstance instance = rh.vk; + uint32_t deviceCount = rh.num_devices; + + VkPhysicalDevice* devices = malloc(deviceCount * sizeof(VkPhysicalDevice)); + if (devices == NULL) { + return 0; + } + + VkResult result = (*rh.vkEnumeratePhysicalDevices)(instance, &deviceCount, devices); + if (result != VK_SUCCESS) { + free(devices); + return 0; + } + + VkPhysicalDeviceProperties properties; + (*rh.vkGetPhysicalDeviceProperties)(devices[i], &properties); + + int supports_nv_coopmat2 = is_extension_supported(&rh, devices[i], VK_NV_COOPERATIVE_MATRIX_2_EXTENSION_NAME); + if (!supports_nv_coopmat2) { + free(devices); + return 1; + } + + free(devices); + return 0; +} + void vk_check_vram(vk_handle_t rh, int i, mem_info_t *resp) { VkInstance instance = rh.vk; uint32_t deviceCount = rh.num_devices; VkPhysicalDevice* devices = malloc(deviceCount * sizeof(VkPhysicalDevice)); + if (devices == NULL) { + resp->err = strdup("memory allocation failed for devices array"); + return; + } + VkResult result = (*rh.vkEnumeratePhysicalDevices)(instance, &deviceCount, devices); if (result != VK_SUCCESS) { + free(devices); resp->err = strdup("failed to enumerate physical devices"); return; } VkPhysicalDeviceProperties properties; (*rh.vkGetPhysicalDeviceProperties)(devices[i], &properties); - int supports_budget = support_memory_budget(&rh, devices[i]); + + int supports_budget = is_extension_supported(&rh, devices[i], VK_EXT_MEMORY_BUDGET_EXTENSION_NAME); if (!supports_budget) { + free(devices); resp->err = strdup("device does not support memory budget"); return; } + if (properties.deviceType == VK_PHYSICAL_DEVICE_TYPE_CPU) { + free(devices); resp->err = strdup("device is a CPU"); return; } @@ -204,6 +259,8 @@ void vk_check_vram(vk_handle_t rh, int i, mem_info_t *resp) { } } + free(devices); + resp->err = NULL; snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", i); resp->gpu_name[GPU_NAME_LEN - 1] = '\0'; @@ -220,6 +277,7 @@ void vk_release(vk_handle_t rh) { (*rh.vkDestroyInstance)(rh.vk, NULL); UNLOAD_LIBRARY(rh.vk_handle); rh.vk_handle = NULL; + #ifdef __linux__ LOG(rh.verbose, "releasing libcap library\n"); UNLOAD_LIBRARY(rh.cap_handle); diff --git a/discover/gpu_info_vulkan.h b/discover/gpu_info_vulkan.h index 6025f3e09..1f19be58e 100644 --- a/discover/gpu_info_vulkan.h +++ b/discover/gpu_info_vulkan.h @@ -60,6 +60,7 @@ typedef struct vk_init_resp void vk_init(char* vk_lib_path, char* cap_lib_path, vk_init_resp_t *resp); void vk_check_vram(vk_handle_t rh, int i, mem_info_t *resp); +int vk_check_flash_attention(vk_handle_t rh, int i); void vk_release(vk_handle_t rh); #endif diff --git a/discover/types.go b/discover/types.go index 11a3acec3..b096b9e2e 100644 --- a/discover/types.go +++ b/discover/types.go @@ -36,9 +36,10 @@ type GpuInfo struct { // TODO better name maybe "InferenceProcessor"? UnreliableFreeMemory bool // GPU information - ID string `json:"gpu_id"` // string to use for selection of this specific GPU - Name string `json:"name"` // user friendly name if available - Compute string `json:"compute"` // Compute Capability or gfx + ID string `json:"gpu_id"` // string to use for selection of this specific GPU + Name string `json:"name"` // user friendly name if available + Compute string `json:"compute"` // Compute Capability or gfx + FlashAttention bool `json:"flash_attention"` // is flash attention supported // Driver Information - TODO no need to put this on each GPU DriverMajor int `json:"driver_major,omitempty"` @@ -178,11 +179,7 @@ func (si SystemInfo) GetOptimalThreadCount() int { // For each GPU, check if it does NOT support flash attention func (l GpuInfoList) FlashAttentionSupported() bool { for _, gpu := range l { - supportsFA := gpu.Library == "metal" || - (gpu.Library == "cuda" && gpu.DriverMajor >= 7) || - gpu.Library == "rocm" - - if !supportsFA { + if !gpu.FlashAttention { return false } } diff --git a/envconfig/config.go b/envconfig/config.go index cee40f6a8..53e358155 100644 --- a/envconfig/config.go +++ b/envconfig/config.go @@ -276,7 +276,7 @@ func AsMap() map[string]EnvVar { ret["CUDA_VISIBLE_DEVICES"] = EnvVar{"CUDA_VISIBLE_DEVICES", CudaVisibleDevices(), "Set which NVIDIA devices are visible"} ret["HIP_VISIBLE_DEVICES"] = EnvVar{"HIP_VISIBLE_DEVICES", HipVisibleDevices(), "Set which AMD devices are visible by numeric ID"} ret["ROCR_VISIBLE_DEVICES"] = EnvVar{"ROCR_VISIBLE_DEVICES", RocrVisibleDevices(), "Set which AMD devices are visible by UUID or numeric ID"} - ret["GGML_VK_VISIBLE_DEVICES"] = EnvVar{"GGML_VK_VISIBLE_DEVICES", VkVisibleDevices(), "Set which VK AMD devices are visible by numeric ID"} + ret["GGML_VK_VISIBLE_DEVICES"] = EnvVar{"GGML_VK_VISIBLE_DEVICES", VkVisibleDevices(), "Set which Vulkan devices are visible by numeric ID"} ret["GPU_DEVICE_ORDINAL"] = EnvVar{"GPU_DEVICE_ORDINAL", GpuDeviceOrdinal(), "Set which AMD devices are visible by numeric ID"} ret["HSA_OVERRIDE_GFX_VERSION"] = EnvVar{"HSA_OVERRIDE_GFX_VERSION", HsaOverrideGfxVersion(), "Override the gfx used for all detected AMD GPUs"} ret["OLLAMA_INTEL_GPU"] = EnvVar{"OLLAMA_INTEL_GPU", IntelGPU(), "Enable experimental Intel GPU detection"} diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 2237e7f51..7f2d61f09 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -314,18 +314,20 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) { return fmt.Errorf("unassigned tensor: %s", t.Name) } - bts := make([]byte, t.Size()) - n, err := io.ReadFull(io.NewSectionReader(sr, int64(t.Offset), int64(t.Size())), bts) - if err != nil { - return err + bts := C.malloc(C.size_t(t.Size())) + if bts == nil { + return errors.New("failed to allocate tensor buffer") } + defer C.free(bts) - if n != len(bts) { - return errors.New("short read") + buf := unsafe.Slice((*byte)(bts), t.Size()) + n, err := io.ReadFull(io.NewSectionReader(sr, int64(t.Offset), int64(t.Size())), buf) + if err != nil || n != len(buf) { + return errors.New("read failed") } tensorSetMutex.Lock() - C.ggml_backend_tensor_set(tt, unsafe.Pointer(&bts[0]), 0, C.size_t(t.Size())) + C.ggml_backend_tensor_set(tt, bts, 0, C.size_t(t.Size())) tensorSetMutex.Unlock() return nil }) @@ -375,7 +377,7 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) { (*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&schedBufts[0])), C.int(len(schedBackends)), C.size_t(maxGraphNodes), - true, + C._Bool(len(gpus) > 1 && slices.Contains(gpus, output.d)), ), input: deviceBufferTypes[input.d], output: deviceBufferTypes[output.d], diff --git a/runner/ollamarunner/cache.go b/runner/ollamarunner/cache.go index adcb3f738..cf5e6b911 100644 --- a/runner/ollamarunner/cache.go +++ b/runner/ollamarunner/cache.go @@ -89,7 +89,7 @@ type InputCacheSlot struct { lastUsed time.Time } -func (c *InputCache) LoadCacheSlot(prompt []input.Input, cachePrompt bool) (*InputCacheSlot, []input.Input, error) { +func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []input.Input, error) { var slot *InputCacheSlot var numPast int32 var err error @@ -107,11 +107,6 @@ func (c *InputCache) LoadCacheSlot(prompt []input.Input, cachePrompt bool) (*Inp return nil, nil, err } - // TODO (brucemacd): cachePrompt is always true for completion, but false for embedding, can this be improved? - if !cachePrompt { - numPast = 0 - } - slot.InUse = true slot.lastUsed = time.Now() diff --git a/runner/ollamarunner/cache_test.go b/runner/ollamarunner/cache_test.go index 0a1b73f5a..f8925d119 100644 --- a/runner/ollamarunner/cache_test.go +++ b/runner/ollamarunner/cache_test.go @@ -297,3 +297,131 @@ func TestShiftDiscard(t *testing.T) { }) } } + +func TestLoadCacheSlot(t *testing.T) { + tests := []struct { + name string + cache InputCache + prompt []input.Input + wantErr bool + expectedSlotId int + expectedPrompt int // expected length of remaining prompt + }{ + { + name: "Basic cache hit - single user", + cache: InputCache{ + multiUserCache: false, + slots: []InputCacheSlot{ + { + Id: 0, + Inputs: []input.Input{{Token: 1}, {Token: 2}}, + InUse: false, + lastUsed: time.Now().Add(-time.Second), + }, + { + Id: 1, + Inputs: []input.Input{}, + InUse: false, + lastUsed: time.Now().Add(-2 * time.Second), + }, + }, + }, + prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, + wantErr: false, + expectedSlotId: 0, + expectedPrompt: 1, // Only token 3 remains + }, + { + name: "Basic cache hit - multi user", + cache: InputCache{ + multiUserCache: true, + slots: []InputCacheSlot{ + { + Id: 0, + Inputs: []input.Input{{Token: 1}, {Token: 2}}, + InUse: false, + lastUsed: time.Now().Add(-time.Second), + }, + { + Id: 1, + Inputs: []input.Input{}, + InUse: false, + lastUsed: time.Now().Add(-2 * time.Second), + }, + }, + }, + prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, + wantErr: false, + expectedSlotId: 0, + expectedPrompt: 1, // Only token 3 remains + }, + { + name: "Exact match - leave one input", + cache: InputCache{ + multiUserCache: false, + slots: []InputCacheSlot{ + { + Id: 0, + Inputs: []input.Input{{Token: 1}, {Token: 2}}, + InUse: false, + lastUsed: time.Now().Add(-time.Second), + }, + }, + }, + prompt: []input.Input{{Token: 1}, {Token: 2}}, + wantErr: false, + expectedSlotId: 0, + expectedPrompt: 1, // Should leave 1 token for sampling + }, + { + name: "No available slots", + cache: InputCache{ + multiUserCache: false, + slots: []InputCacheSlot{ + { + Id: 0, + Inputs: []input.Input{{Token: 1}, {Token: 2}}, + InUse: true, + lastUsed: time.Now().Add(-time.Second), + }, + }, + }, + prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, + wantErr: true, + expectedSlotId: -1, + expectedPrompt: -1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + slot, remainingPrompt, err := tt.cache.LoadCacheSlot(tt.prompt) + + // Check error state + if (err != nil) != tt.wantErr { + t.Errorf("LoadCacheSlot() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if tt.wantErr { + return // Skip further checks if we expected an error + } + + // Verify slot ID + if slot.Id != tt.expectedSlotId { + t.Errorf("LoadCacheSlot() slot ID = %v, expected %v", slot.Id, tt.expectedSlotId) + } + + // Verify slot is now marked in use + if !slot.InUse { + t.Errorf("LoadCacheSlot() slot not marked InUse") + } + + // Verify remaining prompt length + if len(remainingPrompt) != tt.expectedPrompt { + t.Errorf("LoadCacheSlot() remaining prompt length = %v, expected %v", + len(remainingPrompt), tt.expectedPrompt) + } + }) + } +} diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index d4c24556c..9a1a549cd 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -115,6 +115,9 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe params.numKeep = int32(len(inputs)) } + // TODO(jessegross): We should ensure that we always leave minBatch of context space to shift, + // otherwise we might truncate or split the batch against the model's wishes + // Ensure that at least 1 input can be discarded during shift params.numKeep = min(params.numKeep, s.cache.numCtx-1) @@ -366,17 +369,6 @@ func (s *Server) processBatch() error { batchSize := s.batchSize for j, inp := range seq.inputs { - if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+1) > s.cache.numCtx { - if len(seq.pendingInputs) == 0 { - err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep) - if err != nil { - return err - } - } else { - break - } - } - // If we are required to put following inputs into a single batch then extend the // batch size. Since we are only extending the size the minimum amount possible, this // will cause a break if we have pending inputs. @@ -389,6 +381,20 @@ func (s *Server) processBatch() error { break } + // If the sum of our working set (already processed tokens, tokens we added to this + // batch, required following tokens) exceeds the context size, then trigger a shift + // now so we don't have to do one later when we can't break the batch. + if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+minBatch) > s.cache.numCtx { + if len(seq.pendingInputs) != 0 { + break + } + + err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep) + if err != nil { + return err + } + } + options.Inputs = append(options.Inputs, inp.Token) if inp.Multimodal != nil { options.Multimodal = append(options.Multimodal, input.MultimodalIndex{Index: len(options.Inputs) - 1, Multimodal: inp.Multimodal}) @@ -590,7 +596,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { found := false for i, sq := range s.seqs { if sq == nil { - seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true) + seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs) if err != nil { s.mu.Unlock() http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError) diff --git a/sample/samplers.go b/sample/samplers.go index e302f9147..7c12da08b 100644 --- a/sample/samplers.go +++ b/sample/samplers.go @@ -87,8 +87,9 @@ func (s *Sampler) sample(tokens []token) (token, error) { // topK also sorts the tokens in descending order of logits tokens = topK(tokens, s.topK) - tokens = temperature(tokens, s.temperature) - tokens = softmax(tokens) + // scale and normalize the tokens in place + temperature(tokens, s.temperature) + softmax(tokens) tokens = topP(tokens, s.topP) tokens = minP(tokens, s.minP) diff --git a/sample/transforms.go b/sample/transforms.go index a5efa704e..3f677553f 100644 --- a/sample/transforms.go +++ b/sample/transforms.go @@ -26,17 +26,16 @@ func (h *tokenHeap) Pop() any { } // temperature applies scaling to the logits -func temperature(ts []token, temp float32) []token { +func temperature(ts []token, temp float32) { // Ensure temperature clipping near 0 to avoid numerical instability temp = max(temp, 1e-7) for i := range ts { ts[i].value = ts[i].value / temp } - return ts } // softmax applies normalization to the logits -func softmax(ts []token) []token { +func softmax(ts []token) { // Find max logit for numerical stability maxLogit := float32(math.Inf(-1)) for _, t := range ts { @@ -56,8 +55,6 @@ func softmax(ts []token) []token { for i := range ts { ts[i].value /= sum } - - return ts } // topK limits the number of tokens considered to the k highest logits @@ -99,6 +96,7 @@ func topK(ts []token, k int) []token { } // topP limits tokens to those with cumulative probability p +// requires ts to be sorted in descending order of probabilities func topP(ts []token, p float32) []token { if p == 1.0 { return ts @@ -109,37 +107,24 @@ func topP(ts []token, p float32) []token { for i, t := range ts { sum += t.value if sum > float32(p) { - ts = ts[:i+1] - return ts + return ts[:i+1] } } return ts } -// minP limits tokens to those with cumulative probability p +// minP filters tokens with probabilities >= p * max_prob +// requires ts to be sorted in descending order of probabilities func minP(ts []token, p float32) []token { - if p == 1.0 { - return ts - } + maxProb := ts[0].value - maxProb := float32(math.Inf(-1)) - for _, token := range ts { - if token.value > maxProb { - maxProb = token.value + threshold := maxProb * p + + for i, t := range ts { + if t.value < threshold { + return ts[:i] } } - - threshold := maxProb * float32(p) - - // Filter tokens in-place - validTokens := ts[:0] - for i, token := range ts { - if token.value >= threshold { - validTokens = append(validTokens, ts[i]) - } - } - - ts = validTokens return ts } diff --git a/sample/transforms_test.go b/sample/transforms_test.go index 4880dd8f4..7faf30a55 100644 --- a/sample/transforms_test.go +++ b/sample/transforms_test.go @@ -34,17 +34,22 @@ func compareLogits(t *testing.T, name string, want []float32, got []token) { func TestTemperature(t *testing.T) { input := []float32{1.0, 4.0, -2.0, 0.0} - got := temperature(toTokens(input), 0.5) + tokens := toTokens(input) + temperature(tokens, 0.5) want := []float32{2.0, 8.0, -4.0, 0.0} - compareLogits(t, "temperature(0.5)", want, got) + compareLogits(t, "temperature(0.5)", want, tokens) - got = temperature(toTokens(input), 1.0) + input = []float32{1.0, 4.0, -2.0, 0.0} + tokens = toTokens(input) + temperature(tokens, 1.0) want = []float32{1.0, 4.0, -2.0, 0.0} - compareLogits(t, "temperature(1)", want, got) + compareLogits(t, "temperature(1)", want, tokens) - got = temperature(toTokens(input), 0.0) + input = []float32{1.0, 4.0, -2.0, 0.0} + tokens = toTokens(input) + temperature(tokens, 0.0) want = []float32{1e7, 4e7, -2e7, 0.0} - compareLogits(t, "temperature(0)", want, got) + compareLogits(t, "temperature(0)", want, tokens) } func TestSoftmax(t *testing.T) { @@ -90,16 +95,17 @@ func TestSoftmax(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := softmax(toTokens(tt.input)) + tokens := toTokens(tt.input) + softmax(tokens) if tt.expected != nil { - compareLogits(t, tt.name, tt.expected, got) + compareLogits(t, tt.name, tt.expected, tokens) return } // Check probabilities sum to 1 var sum float32 - for _, token := range got { + for _, token := range tokens { sum += token.value if token.value < 0 || token.value > 1 { t.Errorf("probability out of range [0,1]: got %f", token.value) @@ -114,38 +120,44 @@ func TestSoftmax(t *testing.T) { func TestTopK(t *testing.T) { input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367} - - // Test k=5 - got := topK(toTokens(input), 5) - if len(got) != 5 { - t.Errorf("topK(5): wrong length: want 5, got %d", len(got)) + tokens := toTokens(input) + tokens = topK(tokens, 5) + if len(tokens) != 5 { + t.Errorf("topK(5): wrong length: want 5, got %d", len(tokens)) } - // Should keep highest 3 values in descending order want := []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154} - compareLogits(t, "topK(3)", want, got) + compareLogits(t, "topK(3)", want, tokens) - got = topK(toTokens(input), 20) - if len(got) != len(input) { - t.Errorf("topK(20): wrong length: want %d, got %d", len(input), len(got)) + tokens = toTokens(input) + tokens = topK(tokens, 20) + if len(tokens) != len(input) { + t.Errorf("topK(20): wrong length: want %d, got %d", len(input), len(tokens)) } - // Test k=-1 input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367} want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839} - got = topK(toTokens(input), -1) - if len(got) != len(input) { - t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(got)) + tokens = toTokens(input) + tokens = topK(tokens, -1) + if len(tokens) != len(input) { + t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(tokens)) } - compareLogits(t, "topK(-1)", want, got) + compareLogits(t, "topK(-1)", want, tokens) - // Test k=0 input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367} want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839} - got = topK(toTokens(input), 0) - if len(got) != len(input) { - t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(got)) + tokens = toTokens(input) + tokens = topK(tokens, 0) + if len(tokens) != len(input) { + t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(tokens)) + } + compareLogits(t, "topK(-1)", want, tokens) + + input = []float32{-1e7, -2e7, -3e7, -4e7} + tokens = toTokens(input) + tokens = topK(tokens, 1) + if len(tokens) < 1 { + t.Error("topK should keep at least one token") } - compareLogits(t, "topK(-1)", want, got) } func TestTopP(t *testing.T) { @@ -153,16 +165,25 @@ func TestTopP(t *testing.T) { tokens := toTokens(input) // First apply temperature and softmax to get probabilities - tokens = softmax(tokens) + softmax(tokens) tokens = topK(tokens, 20) // Then apply topP - got := topP(tokens, 0.95) + tokens = topP(tokens, 0.95) // Should keep tokens until cumsum > 0.95 - if len(got) > 3 { - t.Errorf("topP(0.95): kept too many tokens: got %d", len(got)) - t.Logf("got: %v", got) + if len(tokens) > 3 { + t.Errorf("topP(0.95): kept too many tokens: got %d", len(tokens)) + t.Logf("got: %v", tokens) + } + + // Test edge case - ensure at least one token remains + input = []float32{-1e6, -1e6, -1e6} // One dominant token + tokens = toTokens(input) + softmax(tokens) + tokens = topP(tokens, 0.0) // Very small p + if len(tokens) < 1 { + t.Error("topP should keep at least one token") } } @@ -171,14 +192,45 @@ func TestMinP(t *testing.T) { tokens := toTokens(input) // First apply temperature and softmax - tokens = softmax(tokens) + tokens = topK(tokens, 20) + softmax(tokens) - // Then apply minP - got := minP(tokens, 0.2) + tokens = minP(tokens, 1.0) + + if len(tokens) != 1 { + t.Errorf("minP(1.0): should keep all tokens, got %d, want %d", len(tokens), len(tokens)) + } + + // Test with normal p value + tokens = toTokens(input) // Reset tokens + tokens = topK(tokens, 20) + softmax(tokens) + tokens = minP(tokens, 0.2) // Should keep tokens with prob >= 0.2 * max_prob - if len(got) > 3 { - t.Errorf("minP(0.2): kept too many tokens: got %d", len(got)) + if len(tokens) > 3 { + t.Errorf("minP(0.2): kept too many tokens: got %d", len(tokens)) + t.Logf("got: %v", tokens) + } + + // Test with zero p value + tokens = toTokens(input) // Reset tokens + tokens = topK(tokens, 20) + softmax(tokens) + tokens = minP(tokens, 0.0) + + // Should keep only the highest probability token + if len(tokens) != len(input) { + t.Errorf("minP(0.0): should keep only one token, got %d", len(tokens)) + t.Logf("got: %v", tokens) + } + + input = []float32{1e-10, 1e-10, 1e-10} + tokens = toTokens(input) + softmax(tokens) + tokens = minP(tokens, 1.0) + if len(tokens) < 1 { + t.Error("minP should keep at least one token even with extreme probabilities") } } @@ -231,7 +283,7 @@ func BenchmarkTransforms(b *testing.B) { b.ResetTimer() for b.Loop() { copy(tokensCopy, tokens) - topK(tokensCopy, 10) + tokens = topK(tokensCopy, 10) } }) @@ -239,7 +291,7 @@ func BenchmarkTransforms(b *testing.B) { b.ResetTimer() for b.Loop() { copy(tokensCopy, tokens) - topP(tokensCopy, 0.9) + tokens = topP(tokensCopy, 0.9) } }) @@ -247,7 +299,7 @@ func BenchmarkTransforms(b *testing.B) { b.ResetTimer() for b.Loop() { copy(tokensCopy, tokens) - minP(tokensCopy, 0.2) + tokens = minP(tokensCopy, 0.2) } }) @@ -255,7 +307,7 @@ func BenchmarkTransforms(b *testing.B) { b.ResetTimer() for b.Loop() { copy(tokensCopy, tokens) - topK(tokensCopy, 200000) + tokens = topK(tokensCopy, 200000) } }) } From 643b1c505ef9600b801d1f32d35d8cb46d1550b1 Mon Sep 17 00:00:00 2001 From: Thomas Stocker Date: Sat, 9 Aug 2025 22:14:54 +0200 Subject: [PATCH 034/172] Revert Readme changes --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 6d6da080b..e821efcde 100644 --- a/README.md +++ b/README.md @@ -396,11 +396,11 @@ See the [API documentation](./docs/api.md) for all endpoints. - [yla](https://github.com/danielekp/yla) (Web interface to freely interact with your customized models) - [LangBot](https://github.com/RockChinQ/LangBot) (LLM-based instant messaging bots platform, with Agents, RAG features, supports multiple platforms) - [1Panel](https://github.com/1Panel-dev/1Panel/) (Web-based Linux Server Management Tool) -- [AstrBot](https://github.com/Soulter/AstrBot/) (User-friendly LLM-based multi-platform chatbot with a WebUI, supporting RAG, LLM agents, and plugins integration) +- [AstrBot](https://github.com/Soulter/AstrBot/) (User-friendly LLM-based multi-platform chatbot with a WebUI, supporting RAG, LLM agents, and plugins integration)- - [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.) +- [Flufy](https://github.com/Aharon-Bensadoun/Flufy) (A beautiful chat interface for interacting with Ollama's API. Built with React, TypeScript, and Material-UI.) - [Ellama](https://github.com/zeozeozeo/ellama) (Friendly native app to chat with an Ollama instance) - [screenpipe](https://github.com/mediar-ai/screenpipe) Build agents powered by your screen history -- [Flufy](https://github.com/Aharon-Bensadoun/Flufy) (A beautiful chat interface for interacting with Ollama's API. Built with React, TypeScript, and Material-UI.) - [Ollamb](https://github.com/hengkysteen/ollamb) (Simple yet rich in features, cross-platform built with Flutter and designed for Ollama. Try the [web demo](https://hengkysteen.github.io/demo/ollamb/).) - [Writeopia](https://github.com/Writeopia/Writeopia) (Text editor with integration with Ollama) - [AppFlowy](https://github.com/AppFlowy-IO/AppFlowy) (AI collaborative workspace with Ollama, cross-platform and self-hostable) From 47bff3e532419228d5b621f7fbca2ccc6536f242 Mon Sep 17 00:00:00 2001 From: Thomas Stocker Date: Sat, 9 Aug 2025 22:15:54 +0200 Subject: [PATCH 035/172] Revert --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index e821efcde..d5049d3eb 100644 --- a/README.md +++ b/README.md @@ -396,7 +396,7 @@ See the [API documentation](./docs/api.md) for all endpoints. - [yla](https://github.com/danielekp/yla) (Web interface to freely interact with your customized models) - [LangBot](https://github.com/RockChinQ/LangBot) (LLM-based instant messaging bots platform, with Agents, RAG features, supports multiple platforms) - [1Panel](https://github.com/1Panel-dev/1Panel/) (Web-based Linux Server Management Tool) -- [AstrBot](https://github.com/Soulter/AstrBot/) (User-friendly LLM-based multi-platform chatbot with a WebUI, supporting RAG, LLM agents, and plugins integration)- +- [AstrBot](https://github.com/Soulter/AstrBot/) (User-friendly LLM-based multi-platform chatbot with a WebUI, supporting RAG, LLM agents, and plugins integration) - [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.) - [Flufy](https://github.com/Aharon-Bensadoun/Flufy) (A beautiful chat interface for interacting with Ollama's API. Built with React, TypeScript, and Material-UI.) - [Ellama](https://github.com/zeozeozeo/ellama) (Friendly native app to chat with an Ollama instance) From 89ac91099de2a9150aff926cc0e35a44d9b07f50 Mon Sep 17 00:00:00 2001 From: Thomas Stocker Date: Sat, 9 Aug 2025 22:23:00 +0200 Subject: [PATCH 036/172] Revert changes in amd_linux.go --- discover/amd_linux.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/discover/amd_linux.go b/discover/amd_linux.go index 187184b81..105acbd79 100644 --- a/discover/amd_linux.go +++ b/discover/amd_linux.go @@ -282,8 +282,7 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) { ID: ID, Name: name, Compute: fmt.Sprintf("gfx%d%x%x", major, minor, patch), - MinimumMemory: rocmMinimumMemory, - FlashAttention: true, // Supposedly ROCm supports it everywhere + MinimumMemory: rocmMinimumMemory, DriverMajor: driverMajor, DriverMinor: driverMinor, }, From 42463fbb7f2cbff92873a176b22b42d682337d6c Mon Sep 17 00:00:00 2001 From: Thomas Stocker Date: Sat, 9 Aug 2025 22:24:33 +0200 Subject: [PATCH 037/172] Revert changes in amd_linux.go --- discover/amd_linux.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/discover/amd_linux.go b/discover/amd_linux.go index 105acbd79..dc9a4e185 100644 --- a/discover/amd_linux.go +++ b/discover/amd_linux.go @@ -279,12 +279,12 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) { TotalMemory: totalMemory, FreeMemory: (totalMemory - usedMemory), }, - ID: ID, - Name: name, - Compute: fmt.Sprintf("gfx%d%x%x", major, minor, patch), - MinimumMemory: rocmMinimumMemory, - DriverMajor: driverMajor, - DriverMinor: driverMinor, + ID: ID, + Name: name, + Compute: fmt.Sprintf("gfx%d%x%x", major, minor, patch), + MinimumMemory: rocmMinimumMemory, + DriverMajor: driverMajor, + DriverMinor: driverMinor, }, usedFilepath: usedFile, index: gpuID, From 57270767ac5d4e6686a34c01e7045c397dc95d1c Mon Sep 17 00:00:00 2001 From: Thomas Stocker Date: Sat, 9 Aug 2025 22:26:54 +0200 Subject: [PATCH 038/172] Remove flashattention setting gpu.go --- discover/gpu.go | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/discover/gpu.go b/discover/gpu.go index 842e817c6..f76e2abf1 100644 --- a/discover/gpu.go +++ b/discover/gpu.go @@ -309,8 +309,7 @@ func GetGPUInfo() GpuInfoList { slog.Info("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err)) C.free(unsafe.Pointer(memInfo.err)) continue - } - gpuInfo.FlashAttention = driverMajor >= 7 + } gpuInfo.TotalMemory = uint64(memInfo.total) gpuInfo.FreeMemory = uint64(memInfo.free) gpuInfo.ID = C.GoString(&memInfo.gpu_id[0]) @@ -394,8 +393,7 @@ func GetGPUInfo() GpuInfoList { C.oneapi_check_vram(*oHandles.oneapi, C.int(d), i, &memInfo) // TODO - convert this to MinimumMemory based on testing... var totalFreeMem float64 = float64(memInfo.free) * 0.95 // work-around: leave some reserve vram for mkl lib used in ggml-sycl backend. - memInfo.free = C.uint64_t(totalFreeMem) - gpuInfo.FlashAttention = false + memInfo.free = C.uint64_t(totalFreeMem) gpuInfo.TotalMemory = uint64(memInfo.total) gpuInfo.FreeMemory = uint64(memInfo.free) gpuInfo.ID = C.GoString(&memInfo.gpu_id[0]) @@ -424,8 +422,7 @@ func GetGPUInfo() GpuInfoList { C.free(unsafe.Pointer(memInfo.err)) continue } - - gpuInfo.FlashAttention = (C.vk_check_flash_attention(*vHandles.vulkan, C.int(i)) == 0) // 0 means supported + gpuInfo.TotalMemory = uint64(memInfo.total) gpuInfo.FreeMemory = uint64(memInfo.free) gpuInfo.ID = C.GoString(&memInfo.gpu_id[0]) From 29b1ed00774dac5b26004a6bc057c4d270fe1943 Mon Sep 17 00:00:00 2001 From: Thomas Stocker Date: Sat, 9 Aug 2025 22:30:13 +0200 Subject: [PATCH 039/172] Revert whitespace changes in gpu.go --- discover/gpu.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/discover/gpu.go b/discover/gpu.go index f76e2abf1..9048631a8 100644 --- a/discover/gpu.go +++ b/discover/gpu.go @@ -309,7 +309,7 @@ func GetGPUInfo() GpuInfoList { slog.Info("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err)) C.free(unsafe.Pointer(memInfo.err)) continue - } + } gpuInfo.TotalMemory = uint64(memInfo.total) gpuInfo.FreeMemory = uint64(memInfo.free) gpuInfo.ID = C.GoString(&memInfo.gpu_id[0]) @@ -393,7 +393,7 @@ func GetGPUInfo() GpuInfoList { C.oneapi_check_vram(*oHandles.oneapi, C.int(d), i, &memInfo) // TODO - convert this to MinimumMemory based on testing... var totalFreeMem float64 = float64(memInfo.free) * 0.95 // work-around: leave some reserve vram for mkl lib used in ggml-sycl backend. - memInfo.free = C.uint64_t(totalFreeMem) + memInfo.free = C.uint64_t(totalFreeMem) gpuInfo.TotalMemory = uint64(memInfo.total) gpuInfo.FreeMemory = uint64(memInfo.free) gpuInfo.ID = C.GoString(&memInfo.gpu_id[0]) From 0ddb64db1f3b461969eee4911e7332abdae63d67 Mon Sep 17 00:00:00 2001 From: Thomas Stocker Date: Sat, 9 Aug 2025 22:33:42 +0200 Subject: [PATCH 040/172] Revert changes in transforms_test.go --- sample/transforms_test.go | 78 +++++++-------------------------------- 1 file changed, 13 insertions(+), 65 deletions(-) diff --git a/sample/transforms_test.go b/sample/transforms_test.go index 92e57a987..5307c5f8a 100644 --- a/sample/transforms_test.go +++ b/sample/transforms_test.go @@ -178,16 +178,8 @@ func TestTopP(t *testing.T) { // Test with normal p value got = topP(tokens, 0.95) - // Should keep tokens until cumsum > 0.95 + if len(got) > 3 { - t.Errorf("topP(0.95): kept too many tokens: got %d", len(got)) - t.Logf("got: %v", got) - } - - // Test with normal p value - got = topP(tokens, 0.95) - - if len(tokens) > 3 { t.Errorf("topP(0.95): kept too many tokens: got %d", len(tokens)) t.Logf("got: %v", got) } @@ -216,17 +208,8 @@ func TestTopP(t *testing.T) { softmax(tokens) got = topP(tokens, 1e-10) if len(got) == 0 { - t.Errorf("topP(1e-10): should keep at least one token, got %d", len(tokens)) - t.Logf("got: %v", tokens) - } - - // Test edge case - ensure at least one token remains - input = []float32{-1e6, -1e6, -1e6} // One dominant token - tokens = toTokens(input) - softmax(tokens) - tokens = topP(tokens, 0.0) // Very small p - if len(tokens) < 1 { - t.Error("topP should keep at least one token") + t.Errorf("topP(1e-10): should keep at least one token, got %d", len(got)) + t.Logf("got: %v", got) } } @@ -268,27 +251,6 @@ func TestMinP(t *testing.T) { t.Logf("got: %v", tokens) } - tokens = topK(tokens, 20) - softmax(tokens) - - tokens = minP(tokens, 1.0) - - if len(tokens) != 1 { - t.Errorf("minP(1.0): should keep all tokens, got %d, want %d", len(tokens), len(tokens)) - } - - // Test with normal p value - tokens = toTokens(input) // Reset tokens - tokens = topK(tokens, 20) - softmax(tokens) - tokens = minP(tokens, 0.2) - - // Should keep tokens with prob >= 0.2 * max_prob - if len(tokens) > 3 { - t.Errorf("minP(0.2): kept too many tokens: got %d", len(tokens)) - t.Logf("got: %v", tokens) - } - // Test with single token tokens = toTokens(input[:1]) tokens = topK(tokens, 20) @@ -307,32 +269,18 @@ func TestMinP(t *testing.T) { tokens = minP(tokens, 1.0) if len(tokens) < 1 { t.Error("minP should keep at least one token even with extreme probabilities") - } + got := minP(tokens, 1.0) + + if len(got) != 1 { + t.Errorf("minP(1.0): should keep all tokens, got %d, want %d", len(got), len(tokens)) + } + + // Test with normal p value + got = minP(tokens, 0.2) // Should keep tokens with prob >= 0.2 * max_prob - if len(tokens) > 3 { - t.Errorf("minP(0.2): kept too many tokens: got %d", len(tokens)) - t.Logf("got: %v", tokens) - } - - // Test with zero p value - tokens = toTokens(input) // Reset tokens - tokens = topK(tokens, 20) - softmax(tokens) - tokens = minP(tokens, 0.0) - - // Should keep only the highest probability token - if len(tokens) != len(input) { - t.Errorf("minP(0.0): should keep only one token, got %d", len(tokens)) - t.Logf("got: %v", tokens) - } - - input = []float32{1e-10, 1e-10, 1e-10} - tokens = toTokens(input) - softmax(tokens) - tokens = minP(tokens, 1.0) - if len(tokens) < 1 { - t.Error("minP should keep at least one token even with extreme probabilities") + if len(got) > 3 { + t.Errorf("minP(0.2): kept too many tokens: got %d", len(got)) t.Logf("got: %v", got) } From a6d0d6c6ff57fdb145424a29eed4fbcd6127a55e Mon Sep 17 00:00:00 2001 From: Thomas Stocker Date: Sat, 9 Aug 2025 22:35:20 +0200 Subject: [PATCH 041/172] Revert changes in runner.go --- runner/ollamarunner/runner.go | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 2b80dbd3f..cebe30def 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -118,9 +118,6 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe params.numKeep = int32(len(inputs)) } - // TODO(jessegross): We should ensure that we always leave minBatch of context space to shift, - // otherwise we might truncate or split the batch against the model's wishes - // Ensure that at least 1 input can be discarded during shift params.numKeep = min(params.numKeep, s.cache.numCtx-1) @@ -426,14 +423,6 @@ func (s *Server) processBatch() error { break } - // If the sum of our working set (already processed tokens, tokens we added to this - // batch, required following tokens) exceeds the context size, then trigger a shift - // now so we don't have to do one later when we can't break the batch. - if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+minBatch) > s.cache.numCtx { - if len(seq.pendingInputs) != 0 { - break - } - // If the sum of our working set (already processed tokens, tokens we added to this // batch, required following tokens) exceeds the context size, then trigger a shift // now so we don't have to do one later when we can't break the batch. @@ -456,7 +445,7 @@ func (s *Server) processBatch() error { } } - options.Inputs = append(options.Inputs, inp.Token) + batchInputs = append(batchInputs, inp.Token) if inp.Multimodal != nil { mm, err := seq.mmStore.getMultimodal(s.model.Backend(), ctx, inp.Multimodal, false) if err != nil { From d03fc13d3624c054fdd392f17f3dfad8c27d5e02 Mon Sep 17 00:00:00 2001 From: Thomas Stocker Date: Sat, 9 Aug 2025 22:38:37 +0200 Subject: [PATCH 042/172] Revert changes in Makefile.sync --- Makefile.sync | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile.sync b/Makefile.sync index 167f6bba7..711667c98 100644 --- a/Makefile.sync +++ b/Makefile.sync @@ -39,7 +39,7 @@ PATCHED=$(join $(dir $(PATCHES)), $(addsuffix ed, $(addprefix ., $(notdir $(PATC apply-patches: $(PATCHED) llama/patches/.%.patched: llama/patches/%.patch - @if git -c commit.gpgSign=false -c user.name=nobody -c 'user.email=<>' -C $(WORKDIR) am -3 $(realpath $<); then touch $@; else git -C $(WORKDIR) am --abort; exit 1; fi + @if git -c user.name=nobody -c 'user.email=<>' -C $(WORKDIR) am -3 $(realpath $<); then touch $@; else git -C $(WORKDIR) am --abort; exit 1; fi .PHONY: checkout checkout: $(WORKDIR) From fa13b8de450fe2b119fdfc424c5787360b8eddaf Mon Sep 17 00:00:00 2001 From: Thomas Stocker Date: Sat, 9 Aug 2025 22:43:12 +0200 Subject: [PATCH 043/172] Revert some unintented changes in Dockerfile --- Dockerfile | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/Dockerfile b/Dockerfile index 68fa5fa04..4c9e1bf72 100644 --- a/Dockerfile +++ b/Dockerfile @@ -15,7 +15,7 @@ RUN yum install -y yum-utils \ && rpm --import https://dl.rockylinux.org/pub/rocky/RPM-GPG-KEY-Rocky-8 \ && dnf install -y yum-utils ccache gcc-toolset-10-gcc-10.2.1-8.2.el8 gcc-toolset-10-gcc-c++-10.2.1-8.2.el8 gcc-toolset-10-binutils-2.35-11.el8 \ && dnf install -y ccache \ - && yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo + && yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo ENV PATH=/opt/rh/gcc-toolset-10/root/usr/bin:$PATH ARG VULKANVERSION RUN wget https://sdk.lunarg.com/sdk/download/${VULKANVERSION}/linux/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz -O /tmp/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz \ @@ -108,7 +108,6 @@ RUN --mount=type=cache,target=/root/.cache/go-build \ go build -trimpath -buildmode=pie -o /bin/ollama . FROM --platform=linux/amd64 scratch AS amd64 -COPY --from=cuda-11 dist/lib/ollama/cuda_v11 /lib/ollama/cuda_v11 COPY --from=cuda-12 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_v12 COPY --from=vulkan dist/lib/ollama/vulkan /lib/ollama/vulkan @@ -131,12 +130,12 @@ RUN apt-get update \ && apt-get clean \ && rm -rf /var/lib/apt/lists/* COPY --from=archive /bin /usr/bin -COPY --from=archive /lib/ollama /usr/lib/ollama ENV PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin +COPY --from=archive /lib/ollama /usr/lib/ollama ENV LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64 ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility ENV NVIDIA_VISIBLE_DEVICES=all ENV OLLAMA_HOST=0.0.0.0:11434 EXPOSE 11434 ENTRYPOINT ["/bin/ollama"] -CMD ["serve"] \ No newline at end of file +CMD ["serve"] From bc5c3fb213e1ebb8dfa4734316c85123a9f80f66 Mon Sep 17 00:00:00 2001 From: Thomas Stocker Date: Sat, 9 Aug 2025 22:45:52 +0200 Subject: [PATCH 044/172] Revert vulkan copy changes in Dockerfile --- Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index 4c9e1bf72..416e1bb0a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -108,7 +108,7 @@ RUN --mount=type=cache,target=/root/.cache/go-build \ go build -trimpath -buildmode=pie -o /bin/ollama . FROM --platform=linux/amd64 scratch AS amd64 -COPY --from=cuda-12 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_v12 +COPY --from=cuda-12 dist/lib/ollama /lib/ollama COPY --from=vulkan dist/lib/ollama/vulkan /lib/ollama/vulkan FROM --platform=linux/arm64 scratch AS arm64 From 2e7452be718999c738adadafc9d8e05ff1bd7576 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sun, 10 Aug 2025 16:01:07 +0200 Subject: [PATCH 045/172] Update Vulkan Code to de4c07f93783a1a96456a44dc16b9db538ee1618 --- .../ggml/ggml/src/ggml-vulkan/CMakeLists.txt | 40 + .../ggml-vulkan/cmake/host-toolchain.cmake.in | 15 + .../ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp | 2007 +++++++++++++---- .../ggml-vulkan/vulkan-shaders/CMakeLists.txt | 21 +- .../vulkan-shaders/contig_copy.comp | 11 +- .../ggml-vulkan/vulkan-shaders/conv2d_dw.comp | 105 + .../src/ggml-vulkan/vulkan-shaders/copy.comp | 5 +- .../vulkan-shaders/copy_to_quant.comp | 5 + .../vulkan-shaders/dequant_funcs.comp | 14 +- .../vulkan-shaders/dequant_funcs_cm2.comp | 148 +- .../vulkan-shaders/flash_attn.comp | 483 ++++ .../vulkan-shaders/flash_attn_cm2.comp | 126 +- .../flash_attn_split_k_reduce.comp | 59 + .../ggml-vulkan/vulkan-shaders/get_rows.comp | 11 +- .../vulkan-shaders/get_rows_quant.comp | 2 + .../ggml-vulkan/vulkan-shaders/im2col.comp | 53 +- .../ggml-vulkan/vulkan-shaders/l2_norm.comp | 41 + .../vulkan-shaders/mul_mat_vec.comp | 22 +- .../vulkan-shaders/mul_mat_vec_iq2_s.comp | 90 + .../vulkan-shaders/mul_mat_vec_iq2_xs.comp | 87 + .../vulkan-shaders/mul_mat_vec_iq2_xxs.comp | 87 + .../vulkan-shaders/mul_mat_vec_iq3_s.comp | 90 + .../vulkan-shaders/mul_mat_vec_iq3_xxs.comp | 88 + .../vulkan-shaders/mul_mat_vec_nc.comp | 75 +- .../vulkan-shaders/mul_mat_vec_p021.comp | 137 +- .../vulkan-shaders/mul_mat_vec_q2_k.comp | 47 +- .../vulkan-shaders/mul_mat_vec_q3_k.comp | 26 +- .../vulkan-shaders/mul_mat_vec_q6_k.comp | 12 +- .../ggml-vulkan/vulkan-shaders/mul_mm.comp | 187 +- .../vulkan-shaders/mul_mm_cm2.comp | 262 ++- .../ggml-vulkan/vulkan-shaders/mul_mmq.comp | 442 ++++ .../vulkan-shaders/mul_mmq_funcs.comp | 99 + .../vulkan-shaders/quantize_q8_1.comp | 77 + .../src/ggml-vulkan/vulkan-shaders/relu.comp | 2 +- .../ggml-vulkan/vulkan-shaders/rms_norm.comp | 32 +- .../ggml-vulkan/vulkan-shaders/sigmoid.comp | 2 +- .../src/ggml-vulkan/vulkan-shaders/tanh.comp | 2 +- .../vulkan-shaders/test_bfloat16_support.comp | 7 + .../test_integer_dot_support.comp | 7 + .../src/ggml-vulkan/vulkan-shaders/types.comp | 132 +- .../vulkan-shaders/vulkan-shaders-gen.cpp | 167 +- .../src/ggml-vulkan/vulkan-shaders/wkv7.comp | 91 + 42 files changed, 4607 insertions(+), 809 deletions(-) create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/cmake/host-toolchain.cmake.in create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/test_bfloat16_support.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/test_integer_dot_support.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/CMakeLists.txt b/ml/backend/ggml/ggml/src/ggml-vulkan/CMakeLists.txt index d970f7e20..31816219c 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/CMakeLists.txt +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/CMakeLists.txt @@ -32,8 +32,10 @@ if (Vulkan_FOUND) if (${glslc_error} MATCHES ".*extension not supported: GL_KHR_cooperative_matrix.*") message(STATUS "GL_KHR_cooperative_matrix not supported by glslc") + set(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT OFF) else() message(STATUS "GL_KHR_cooperative_matrix supported by glslc") + set(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT ON) add_compile_definitions(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) endif() @@ -46,11 +48,45 @@ if (Vulkan_FOUND) if (${glslc_error} MATCHES ".*extension not supported: GL_NV_cooperative_matrix2.*") message(STATUS "GL_NV_cooperative_matrix2 not supported by glslc") + set(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT OFF) else() message(STATUS "GL_NV_cooperative_matrix2 supported by glslc") + set(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT ON) add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) endif() + # Compile a test shader to determine whether GL_EXT_integer_dot_product is supported. + # If it's not, there will be an error to stderr. + # If it's supported, set a define to indicate that we should compile those shaders + execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_integer_dot_support.comp" + OUTPUT_VARIABLE glslc_output + ERROR_VARIABLE glslc_error) + + if (${glslc_error} MATCHES ".*extension not supported: GL_EXT_integer_dot_product.*") + message(STATUS "GL_EXT_integer_dot_product not supported by glslc") + set(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT OFF) + else() + message(STATUS "GL_EXT_integer_dot_product supported by glslc") + set(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT ON) + add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + endif() + + # Compile a test shader to determine whether GL_EXT_bfloat16 is supported. + # If it's not, there will be an error to stderr. + # If it's supported, set a define to indicate that we should compile those shaders + execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_bfloat16_support.comp" + OUTPUT_VARIABLE glslc_output + ERROR_VARIABLE glslc_error) + + if (${glslc_error} MATCHES ".*extension not supported: GL_EXT_bfloat16.*") + message(STATUS "GL_EXT_bfloat16 not supported by glslc") + set(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT OFF) + else() + message(STATUS "GL_EXT_bfloat16 supported by glslc") + set(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT ON) + add_compile_definitions(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + endif() + target_link_libraries(ggml-vulkan PRIVATE Vulkan::Vulkan) target_include_directories(ggml-vulkan PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) @@ -119,6 +155,10 @@ if (Vulkan_FOUND) SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders CMAKE_ARGS -DCMAKE_TOOLCHAIN_FILE=${HOST_CMAKE_TOOLCHAIN_FILE} -DCMAKE_INSTALL_PREFIX=${CMAKE_BINARY_DIR} + -DGGML_VULKAN_COOPMAT_GLSLC_SUPPORT=${GGML_VULKAN_COOPMAT_GLSLC_SUPPORT} + -DGGML_VULKAN_COOPMAT2_GLSLC_SUPPORT=${GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT} + -DGGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT=${GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT} + -DGGML_VULKAN_BFLOAT16_GLSLC_SUPPORT=${GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT} BUILD_COMMAND ${CMAKE_COMMAND} --build . INSTALL_COMMAND ${CMAKE_COMMAND} --install . INSTALL_DIR ${CMAKE_BINARY_DIR} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/cmake/host-toolchain.cmake.in b/ml/backend/ggml/ggml/src/ggml-vulkan/cmake/host-toolchain.cmake.in new file mode 100644 index 000000000..2d8a85696 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/cmake/host-toolchain.cmake.in @@ -0,0 +1,15 @@ +set(CMAKE_BUILD_TYPE Release) +set(CMAKE_C_FLAGS -O2) +set(CMAKE_CXX_FLAGS -O2) +set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER) +set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY NEVER) +set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE NEVER) +set(CMAKE_C_COMPILER "@HOST_C_COMPILER@") +set(CMAKE_CXX_COMPILER "@HOST_CXX_COMPILER@") +set(CMAKE_RUNTIME_OUTPUT_DIRECTORY @CMAKE_RUNTIME_OUTPUT_DIRECTORY@) + +if("@CMAKE_C_COMPILER_ID@" STREQUAL "MSVC") + foreach(CONFIG IN ITEMS DEBUG RELEASE MINSIZEREL RELWITHDEBINFO) + set(CMAKE_RUNTIME_OUTPUT_DIRECTORY_${CONFIG} ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}) + endforeach() +endif() diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp index abe3e7908..e2b357fdc 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -24,12 +24,54 @@ #include #include +#if defined(_MSC_VER) +# define NOMINMAX 1 +# include +# define YIELD() YieldProcessor() +#elif defined(__clang__) || defined(__GNUC__) +# if defined(__x86_64__) ||defined(__i386__) +# include +# define YIELD() _mm_pause() +# elif defined(__arm__) || defined(__aarch64__) +# if defined(__clang__) +# include +# define YIELD() __yield() +# else +# define YIELD() asm volatile("yield") +# endif +# endif +#endif + +#if !defined(YIELD) +#define YIELD() +#endif + #include "ggml-impl.h" #include "ggml-backend-impl.h" #include "ggml-vulkan-shaders.hpp" +// remove this once it's more widely available in the SDK +#if !defined(VK_KHR_shader_bfloat16) + +#define VK_KHR_shader_bfloat16 1 +#define VK_KHR_SHADER_BFLOAT16_SPEC_VERSION 1 +#define VK_KHR_SHADER_BFLOAT16_EXTENSION_NAME "VK_KHR_shader_bfloat16" +#define VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR ((VkStructureType)1000141000) +#define VK_COMPONENT_TYPE_BFLOAT16_KHR ((VkComponentTypeKHR)1000141000) + +typedef struct VkPhysicalDeviceShaderBfloat16FeaturesKHR { + VkStructureType sType; + void* pNext; + VkBool32 shaderBFloat16Type; + VkBool32 shaderBFloat16DotProduct; + VkBool32 shaderBFloat16CooperativeMatrix; +} VkPhysicalDeviceShaderBfloat16FeaturesKHR; +#endif + +#define ROUNDUP_POW2(M, N) (((M) + (N) - 1) & ~((N) - 1)) #define CEIL_DIV(M, N) (((M) + (N)-1) / (N)) +static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; } #define VK_VENDOR_ID_AMD 0x1002 #define VK_VENDOR_ID_APPLE 0x106b @@ -148,6 +190,67 @@ class vk_perf_logger; static void ggml_vk_destroy_buffer(vk_buffer& buf); static constexpr uint32_t mul_mat_vec_max_cols = 8; +static constexpr uint32_t p021_max_gqa_ratio = 8; + +enum vk_device_architecture { + OTHER, + AMD_GCN, + AMD_RDNA1, + AMD_RDNA2, + AMD_RDNA3, +}; + +static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) { + vk::PhysicalDeviceProperties props = device.getProperties(); + + if (props.vendorID == VK_VENDOR_ID_AMD) { + const std::vector ext_props = device.enumerateDeviceExtensionProperties(); + + bool amd_shader_core_properties = false; + bool integer_dot_product = false; + bool subgroup_size_control = false; + + for (const auto& properties : ext_props) { + if (strcmp("VK_AMD_shader_core_properties", properties.extensionName) == 0) { + amd_shader_core_properties = true; + } else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0) { + integer_dot_product = true; + } else if (strcmp("VK_EXT_subgroup_size_control", properties.extensionName) == 0) { + subgroup_size_control = true; + } + } + + if (!amd_shader_core_properties || !integer_dot_product || !subgroup_size_control) { + return vk_device_architecture::OTHER; + } + + vk::PhysicalDeviceProperties2 props2; + vk::PhysicalDeviceShaderCorePropertiesAMD shader_core_props_amd; + vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR integer_dot_props; + vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props; + + props2.pNext = &shader_core_props_amd; + shader_core_props_amd.pNext = &integer_dot_props; + integer_dot_props.pNext = &subgroup_size_control_props; + + device.getProperties2(&props2); + + if (subgroup_size_control_props.maxSubgroupSize == 64 && subgroup_size_control_props.minSubgroupSize == 64) { + return vk_device_architecture::AMD_GCN; + } + if (subgroup_size_control_props.maxSubgroupSize == 64 && subgroup_size_control_props.minSubgroupSize == 32) { + // RDNA + if (shader_core_props_amd.wavefrontsPerSimd == 20) { + return vk_device_architecture::AMD_RDNA1; + } + if (integer_dot_props.integerDotProduct4x8BitPackedMixedSignednessAccelerated) { + return vk_device_architecture::AMD_RDNA3; + } + return vk_device_architecture::AMD_RDNA2; + } + } + return vk_device_architecture::OTHER; +} struct vk_device_struct { std::mutex mutex; @@ -161,6 +264,8 @@ struct vk_device_struct { bool pipeline_robustness; vk::Device device; uint32_t vendor_id; + vk::DriverId driver_id; + vk_device_architecture architecture; vk_queue compute_queue; vk_queue transfer_queue; bool single_queue; @@ -169,6 +274,10 @@ struct vk_device_struct { bool uma; bool prefer_host_memory; bool float_controls_rte_fp16; + bool subgroup_add; + bool subgroup_shuffle; + + bool integer_dot_product; bool subgroup_size_control; uint32_t subgroup_min_size; @@ -176,11 +285,18 @@ struct vk_device_struct { bool subgroup_require_full_support; bool coopmat_support; - bool coopmat_acc_f32_support; - bool coopmat_acc_f16_support; + bool coopmat_acc_f32_support {}; + bool coopmat_acc_f16_support {}; + bool coopmat_bf16_support {}; uint32_t coopmat_m; uint32_t coopmat_n; uint32_t coopmat_k; + + bool coopmat_int_support; + uint32_t coopmat_int_m; + uint32_t coopmat_int_n; + uint32_t coopmat_int_k; + bool coopmat2; size_t idx; @@ -197,34 +313,45 @@ struct vk_device_struct { vk_matmul_pipeline pipeline_matmul_f32 {}; vk_matmul_pipeline pipeline_matmul_f32_f16 {}; + vk_matmul_pipeline pipeline_matmul_bf16 {}; vk_matmul_pipeline2 pipeline_matmul_f16; vk_matmul_pipeline2 pipeline_matmul_f16_f32; - vk_pipeline pipeline_matmul_split_k_reduce; - vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_COUNT]; vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat[GGML_TYPE_COUNT]; + vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_COUNT]; + vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_COUNT]; vk_matmul_pipeline pipeline_matmul_id_f32 {}; + vk_matmul_pipeline pipeline_matmul_id_bf16 {}; vk_matmul_pipeline2 pipeline_matmul_id_f16; vk_matmul_pipeline2 pipeline_matmul_id_f16_f32; vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT]; + vk_pipeline pipeline_matmul_split_k_reduce; + vk_pipeline pipeline_quantize_q8_1; + vk_pipeline pipeline_dequant[GGML_TYPE_COUNT]; vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols]; vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols]; vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_COUNT]; - vk_pipeline pipeline_mul_mat_vec_p021_f16_f32; + vk_pipeline pipeline_mul_mat_vec_p021_f16_f32[p021_max_gqa_ratio]; vk_pipeline pipeline_mul_mat_vec_nc_f16_f32; vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT]; vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT]; vk_pipeline pipeline_acc_f32; - vk_pipeline pipeline_add_f32, pipeline_add_f32_norepeat; - vk_pipeline pipeline_add_f16_f32_f16, pipeline_add_f16_f32_f16_norepeat; - vk_pipeline pipeline_sub_f32, pipeline_sub_f32_norepeat; - vk_pipeline pipeline_mul_f32, pipeline_mul_f32_norepeat; - vk_pipeline pipeline_div_f32, pipeline_div_f32_norepeat; + + // [src0 0=fp32,1=fp16][src1 0=fp32,1=fp16][dst 0=fp32,1=fp16] + vk_pipeline pipeline_add[2][2][2]; + vk_pipeline pipeline_add_norepeat[2][2][2]; + vk_pipeline pipeline_sub[2][2][2]; + vk_pipeline pipeline_sub_norepeat[2][2][2]; + vk_pipeline pipeline_mul[2][2][2]; + vk_pipeline pipeline_mul_norepeat[2][2][2]; + vk_pipeline pipeline_div[2][2][2]; + vk_pipeline pipeline_div_norepeat[2][2][2]; + vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32; vk_pipeline pipeline_upscale_f32; vk_pipeline pipeline_scale_f32; @@ -234,22 +361,26 @@ struct vk_device_struct { vk_pipeline pipeline_clamp_f32; vk_pipeline pipeline_pad_f32; vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32; - vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16; - vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16; + vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16; + vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16; vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT]; vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT]; vk_pipeline pipeline_norm_f32; vk_pipeline pipeline_group_norm_f32; vk_pipeline pipeline_rms_norm_f32; vk_pipeline pipeline_rms_norm_back_f32; - vk_pipeline pipeline_gelu_f32; - vk_pipeline pipeline_gelu_quick_f32; - vk_pipeline pipeline_silu_f32; - vk_pipeline pipeline_silu_back_f32; - vk_pipeline pipeline_relu_f32; + vk_pipeline pipeline_l2_norm_f32; + + // [src/dst 0=fp32,1=fp16] + vk_pipeline pipeline_gelu[2]; + vk_pipeline pipeline_gelu_quick[2]; + vk_pipeline pipeline_silu[2]; + vk_pipeline pipeline_relu[2]; + vk_pipeline pipeline_tanh[2]; + vk_pipeline pipeline_sigmoid[2]; + vk_pipeline pipeline_leaky_relu_f32; - vk_pipeline pipeline_tanh_f32; - vk_pipeline pipeline_sigmoid_f32; + vk_pipeline pipeline_silu_back_f32; vk_pipeline pipeline_diag_mask_inf_f32; vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16; vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512; @@ -266,9 +397,19 @@ struct vk_device_struct { vk_pipeline pipeline_timestep_embedding_f32; vk_pipeline pipeline_pool2d_f32; vk_pipeline pipeline_rwkv_wkv6_f32; + vk_pipeline pipeline_rwkv_wkv7_f32; vk_pipeline pipeline_opt_step_adamw_f32; + vk_pipeline pipeline_conv2d_dw_whcn_f32; + vk_pipeline pipeline_conv2d_dw_cwhn_f32; // [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned} + vk_pipeline pipeline_flash_attn_f32_f16_D64_cm2[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D80_cm2[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D96_cm2[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D112_cm2[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D128_cm2[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D256_cm2[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2]; vk_pipeline pipeline_flash_attn_f32_f16_D80[GGML_TYPE_COUNT][2][2][2]; vk_pipeline pipeline_flash_attn_f32_f16_D96[GGML_TYPE_COUNT][2][2][2]; @@ -276,6 +417,8 @@ struct vk_device_struct { vk_pipeline pipeline_flash_attn_f32_f16_D128[GGML_TYPE_COUNT][2][2][2]; vk_pipeline pipeline_flash_attn_f32_f16_D256[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_split_k_reduce; + std::unordered_map pipelines; std::unordered_map pipeline_descriptor_set_requirements; @@ -368,6 +511,7 @@ struct vk_mat_mat_push_constants { uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d; uint32_t k_split; uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3; + uint32_t padded_N; }; struct vk_mat_vec_push_constants { uint32_t ncols; uint32_t stride_a; uint32_t stride_b; uint32_t stride_d; @@ -380,6 +524,7 @@ struct vk_mat_mat_id_push_constants { uint32_t stride_a; uint32_t stride_b; uint32_t stride_d; uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d; uint32_t nei0; uint32_t nei1; uint32_t nbi1; uint32_t ne11; + uint32_t padded_N; }; struct vk_mat_vec_id_push_constants { uint32_t ncols; uint32_t stride_a; uint32_t stride_b; uint32_t stride_d; @@ -422,6 +567,10 @@ struct vk_flash_attn_push_constants { uint32_t n_head_log2; float m0; float m1; + + uint32_t gqa_ratio; + uint32_t split_kv; + uint32_t k_num; }; struct vk_op_push_constants { @@ -565,13 +714,29 @@ struct vk_op_rwkv_wkv6_push_constants { uint32_t H; }; -// Allow pre-recording command buffers -struct vk_staging_memcpy { - vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {} +struct vk_op_rwkv_wkv7_push_constants { + uint32_t B; + uint32_t T; + uint32_t C; + uint32_t H; +}; - void * dst; - const void * src; - size_t n; +struct vk_op_conv2d_dw_push_constants { + uint32_t ne; + uint32_t batches; + uint32_t channels; + uint32_t dst_w; + uint32_t dst_h; + uint32_t src_w; + uint32_t src_h; + uint32_t knl_w; + uint32_t knl_h; + int32_t stride_x; + int32_t stride_y; + int32_t pad_x; + int32_t pad_y; + int32_t dilation_x; + int32_t dilation_y; }; struct vk_op_upscale_push_constants { @@ -581,6 +746,15 @@ struct vk_op_upscale_push_constants { float sf0; float sf1; float sf2; float sf3; }; +// Allow pre-recording command buffers +struct vk_staging_memcpy { + vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {} + + void * dst; + const void * src; + size_t n; +}; + struct vk_context_struct { vk_submission * s; std::vector seqs; @@ -695,7 +869,8 @@ struct ggml_backend_vk_context { ggml_vk_garbage_collector gc; size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k; vk_buffer prealloc_x, prealloc_y, prealloc_split_k; - vk::Fence fence; + vk::Fence fence, almost_ready_fence; + bool almost_ready_fence_pending {}; vk_buffer buffer_pool[MAX_VK_BUFFERS]; @@ -786,6 +961,39 @@ typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context& subctx static void ggml_backend_vk_free(ggml_backend_t backend); +// Wait for ctx->fence to be signaled. +static void ggml_vk_wait_for_fence(ggml_backend_vk_context * ctx) { + // Use waitForFences while most of the graph executes. Hopefully the CPU can sleep + // during this wait. + if (ctx->almost_ready_fence_pending) { + VK_CHECK(ctx->device->device.waitForFences({ ctx->almost_ready_fence }, true, UINT64_MAX), "almost_ready_fence"); + ctx->device->device.resetFences({ ctx->almost_ready_fence }); + ctx->almost_ready_fence_pending = false; + } + + // Spin (w/pause) waiting for the graph to finish executing. + vk::Result result; + while ((result = ctx->device->device.getFenceStatus(ctx->fence)) != vk::Result::eSuccess) { + if (result != vk::Result::eNotReady) { + fprintf(stderr, "ggml_vulkan: error %s at %s:%d\n", to_string(result).c_str(), __FILE__, __LINE__); + exit(1); + } + for (uint32_t i = 0; i < 100; ++i) { + YIELD(); + YIELD(); + YIELD(); + YIELD(); + YIELD(); + YIELD(); + YIELD(); + YIELD(); + YIELD(); + YIELD(); + } + } + ctx->device->device.resetFences({ ctx->fence }); +} + // variables to track number of compiles in progress static uint32_t compile_count = 0; static std::mutex compile_count_mutex; @@ -1382,13 +1590,29 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector&& events // number of rows/cols for flash attention shader static constexpr uint32_t flash_attention_num_small_rows = 32; -static std::array fa_rows_cols(uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) { +static constexpr uint32_t scalar_flash_attention_num_small_rows = 1; +static constexpr uint32_t scalar_flash_attention_num_large_rows = 8; + +static uint32_t get_fa_num_small_rows(bool scalar) { + return scalar ? scalar_flash_attention_num_small_rows : flash_attention_num_small_rows; +} + +static std::array fa_rows_cols(bool scalar, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) { GGML_UNUSED(clamp); + if (scalar) { + if (small_rows) { + return {scalar_flash_attention_num_small_rows, 64}; + } else { + return {scalar_flash_attention_num_large_rows, 32}; + } + } + // small rows, large cols if (small_rows) { - return {flash_attention_num_small_rows, 128}; + return {get_fa_num_small_rows(scalar), 32}; } + // small cols to reduce register count if (ggml_is_quantized(type) || D == 256) { return {64, 32}; @@ -1433,7 +1657,7 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec const uint32_t warps = warptile[0] / warptile[10]; const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size; - const uint32_t mmid_row_ids = mul_mat_id ? 3072 * sizeof(uint32_t) : 0; + const uint32_t mmid_row_ids = mul_mat_id ? 4096 * sizeof(uint32_t) : 0; const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0; const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size; @@ -1445,6 +1669,73 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec return supported; } +struct GpuPipelineConfig { + // GPU architecture identifier. + // Example: vk_device_architecture::AMD_GCN + vk_device_architecture arch; + + // Mapping of pipeline names to their specific subgroup sizes. + // Example: {"soft_max_f32", 64} + std::unordered_map pipelines; + + // Default subgroup size for this GPU. + // Defaults to 0 if not explicitly provided. + uint32_t default_subgroup_size = 0; +}; + +// Pipeline configuration for RDNA1 GPUs. +static const std::unordered_map rdna1_pipelines = { + {"soft_max", 64}, {"im2col", 64}, + {"argmax", 64}, {"mul_mat_vec", 64}, + {"mul_mat_vec_f16", 32}, {"mul_mat_vec_f32_f16", 32} +}; + +// Pipeline configuration for RDNA2 GPUs. +static const std::unordered_map rdna2_pipelines = { + {"soft_max", 64}, {"im2col", 64}, +}; + +static constexpr uint32_t RDNA_DEFAULT_SUBGROUP_SIZE = 32; + +// Define configurations for different GPUs. +static std::vector gpu_pipeline_configs = { + { + vk_device_architecture::AMD_RDNA1, + { + rdna1_pipelines, + }, + RDNA_DEFAULT_SUBGROUP_SIZE + }, + { + vk_device_architecture::AMD_RDNA2, + { + rdna2_pipelines, + }, + RDNA_DEFAULT_SUBGROUP_SIZE + }, +}; + +static uint32_t get_subgroup_size(const std::string &pipeline_name, const vk_device_architecture &arch) { + for (const auto &config : gpu_pipeline_configs) { + if (config.arch == arch) { + auto pipIt = config.pipelines.find(pipeline_name); + if (pipIt != config.pipelines.end()) { + return pipIt->second; + } + std::vector> sorted_pipelines(config.pipelines.begin(), config.pipelines.end()); + std::sort(sorted_pipelines.begin(), sorted_pipelines.end(), + [](const auto &a, const auto &b) { return a.first.size() > b.first.size(); }); + for (const auto &entry : sorted_pipelines) { + if (pipeline_name.find(entry.first) != std::string::npos) { + return entry.second; + } + } + return config.default_subgroup_size; + } + } + return 0; // If no matching configuration is found +} + static void ggml_vk_load_shaders(vk_device& device) { VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")"); @@ -1456,6 +1747,7 @@ static void ggml_vk_load_shaders(vk_device& device) { // mulmat std::vector l_warptile, m_warptile, s_warptile, l_warptile_mmq, m_warptile_mmq, s_warptile_mmq, + l_warptile_mmq_int, m_warptile_mmq_int, s_warptile_mmq_int, l_warptile_mmq_k, m_warptile_mmq_k, s_warptile_mmq_k, l_warptile_mmqid, m_warptile_mmqid, s_warptile_mmqid; std::array l_wg_denoms, m_wg_denoms, s_wg_denoms, @@ -1466,36 +1758,36 @@ static void ggml_vk_load_shaders(vk_device& device) { uint32_t l_align, m_align, s_align; if (device->coopmat2) { // spec constants and tile sizes for non-quant matmul/matmul_id - l_warptile = { 256, 128, 256, 64 }; - m_warptile = { 256, 128, 128, 64 }; - s_warptile = { 128, 64, 64, 64 }; + l_warptile = { 256, 128, 256, 64, 1 }; + m_warptile = { 256, 128, 128, 64, 0 }; + s_warptile = { 128, 64, 64, 64, 0 }; l_wg_denoms = {128, 256, 1 }; m_wg_denoms = {128, 128, 1 }; s_wg_denoms = { 64, 64, 1 }; // spec constants and tile sizes for quant matmul (non-Qi_K) - l_warptile_mmq = { 256, 128, 256, 64 }; - m_warptile_mmq = { 256, 128, 128, 64 }; - s_warptile_mmq = { 256, 128, 128, 64 }; + l_warptile_mmq = { 256, 128, 256, 64, 1 }; + m_warptile_mmq = { 256, 128, 128, 64, 1 }; + s_warptile_mmq = { 256, 32, 64, 128, 0 }; l_mmq_wg_denoms = { 128, 256, 1 }; m_mmq_wg_denoms = { 128, 128, 1 }; - s_mmq_wg_denoms = { 128, 128, 1 }; + s_mmq_wg_denoms = { 32, 64, 1 }; // spec constants and tile sizes for quant matmul (Qi_K) - l_warptile_mmq_k = { 256, 128, 512, 16 }; - m_warptile_mmq_k = { 256, 128, 256, 16 }; - s_warptile_mmq_k = { 256, 32, 128, 64 }; - l_mmq_wg_denoms_k = { 128, 512, 1 }; - m_mmq_wg_denoms_k = { 128, 256, 1 }; - s_mmq_wg_denoms_k = { 32, 128, 1 }; + l_warptile_mmq_k = { 256, 64, 128, 64, 1 }; + m_warptile_mmq_k = { 256, 32, 64, 64, 0 }; + s_warptile_mmq_k = { 256, 32, 32, 128, 0 }; + l_mmq_wg_denoms_k = { 64, 128, 1 }; + m_mmq_wg_denoms_k = { 32, 64, 1 }; + s_mmq_wg_denoms_k = { 32, 32, 1 }; // spec constants and tile sizes for quant matmul_id - l_warptile_mmqid = { 256, 128, 128, 16 }; - m_warptile_mmqid = { 256, 128, 64, 16 }; - s_warptile_mmqid = { 256, 64, 64, 16 }; - l_mmqid_wg_denoms = { 128, 128, 1 }; + l_warptile_mmqid = { 256, 128, 64, 16, 0 }; + m_warptile_mmqid = { 256, 128, 64, 16, 0 }; + s_warptile_mmqid = { 256, 128, 64, 16, 0 }; + l_mmqid_wg_denoms = { 128, 64, 1 }; m_mmqid_wg_denoms = { 128, 64, 1 }; - s_mmqid_wg_denoms = { 64, 64, 1 }; + s_mmqid_wg_denoms = { 128, 64, 1 }; l_align = 128; m_align = 64; @@ -1520,6 +1812,15 @@ static void ggml_vk_load_shaders(vk_device& device) { m_warptile_mmq = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 }; s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 }; + l_warptile_mmq_int = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 }; + m_warptile_mmq_int = { 128, 64, 64, 32, subgroup_size_8, 32, 2, 2, 2, 1, subgroup_size_8 }; + s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, subgroup_size_8 }; + + // chip specific tuning + if ((device->architecture == AMD_GCN) && (device->driver_id != vk::DriverId::eAmdProprietary)) { + m_warptile_mmq = m_warptile_mmq_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 }; + } + l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 }; m_mmq_wg_denoms = m_wg_denoms = { 64, 64, 1 }; s_mmq_wg_denoms = s_wg_denoms = { 32, 32, 1 }; @@ -1565,12 +1866,22 @@ static void ggml_vk_load_shaders(vk_device& device) { if (!device->pipeline_matmul_id_f32) { device->pipeline_matmul_id_f32 = std::make_shared(); } + if (!device->pipeline_matmul_bf16) { + device->pipeline_matmul_bf16 = std::make_shared(); + } + if (!device->pipeline_matmul_id_bf16) { + device->pipeline_matmul_id_bf16 = std::make_shared(); + } std::vector> compiles; auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array wg_denoms, const std::vector& specialization_constants, uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) { + if (!require_full_subgroups && required_subgroup_size == 0) { + required_subgroup_size = get_subgroup_size(name, device->architecture); + } + if (!pipeline) { pipeline = std::make_shared(); pipeline->name = name; @@ -1596,63 +1907,66 @@ static void ggml_vk_load_shaders(vk_device& device) { parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size)); }; + auto const &fa_wg_denoms = [&](bool scalar, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array { + return {fa_rows_cols(scalar, D, clamp, type, small_rows)[0], 1, 1}; + }; + + auto const &fa_spec_constants = [&](bool scalar, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector { + // For large number of rows, 128 invocations seems to work best. + // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we + // can't use 256 for D==80. + // For scalar, use 128 (arbitrary) + uint32_t wg_size = scalar ? 128 : ((small_rows && (D % 32) == 0) ? 256 : 128); + auto rows_cols = fa_rows_cols(scalar, D, clamp, type, small_rows); + + // D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it. + // D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader. + const uint32_t D_lsb = D ^ (D & (D-1)); + uint32_t D_split = std::min(std::min(device->subgroup_size, 8u), D_lsb / 4); + + // mask dim1 is padded to 64, we rely on this to avoid clamping mask loads + GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0); + return {wg_size, rows_cols[0], rows_cols[1], (D), clamp, D_split}; + }; + +#define CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, D) \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,false), fa_spec_constants(SCALAR, D,1,TYPE,false), 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,false), fa_spec_constants(SCALAR, D,0,TYPE,false), fa_rows_cols(SCALAR,D,0,TYPE,false)[1], true); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,false), fa_spec_constants(SCALAR, D,1,TYPE,false), 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,false), fa_spec_constants(SCALAR, D,0,TYPE,false), fa_rows_cols(SCALAR,D,0,TYPE,false)[1], true); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,true), fa_spec_constants(SCALAR, D,1,TYPE,true), 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,true), fa_spec_constants(SCALAR, D,0,TYPE,true), fa_rows_cols(SCALAR,D,0,TYPE,true)[1], true); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,true), fa_spec_constants(SCALAR, D,1,TYPE,true), 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,true), fa_spec_constants(SCALAR, D,0,TYPE,true), fa_rows_cols(SCALAR,D,0,TYPE,true)[1], true); \ + +#define CREATE_FA(TYPE, NAMELC, SCALAR, SUFFIX) \ + CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 64) \ + CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 80) \ + CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 96) \ + CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 112) \ + CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 128) \ + CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 256) + + CREATE_FA(GGML_TYPE_F16, f16, true, ) + CREATE_FA(GGML_TYPE_Q4_0, q4_0, true, ) + CREATE_FA(GGML_TYPE_Q8_0, q8_0, true, ) #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) if (device->coopmat2) { - - auto const &fa_wg_denoms = [&](uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array { - return {fa_rows_cols(D, clamp, type, small_rows)[0], 1, 1}; - }; - - auto const &fa_spec_constants = [&](uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector { - // For large number of rows, 128 invocations seems to work best. - // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we - // can't use 256 for D==80. - uint32_t wg_size = (small_rows && (D % 32) == 0) ? 256 : 128; - auto rows_cols = fa_rows_cols(D, clamp, type, small_rows); - return {wg_size, rows_cols[0], rows_cols[1], (D), clamp}; - }; - -#define CREATE_FA2(TYPE, NAMELC, D) \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,false), fa_spec_constants(D,1,TYPE,false), 1); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,false), fa_spec_constants(D,0,TYPE,false), fa_rows_cols(D,0,TYPE,false)[1]); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,false), fa_spec_constants(D,1,TYPE,false), 1); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,false), fa_spec_constants(D,0,TYPE,false), fa_rows_cols(D,0,TYPE,false)[1]); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,true), fa_spec_constants(D,1,TYPE,true), 1); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,true), fa_spec_constants(D,0,TYPE,true), fa_rows_cols(D,0,TYPE,true)[1]); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,true), fa_spec_constants(D,1,TYPE,true), 1); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,true), fa_spec_constants(D,0,TYPE,true), fa_rows_cols(D,0,TYPE,true)[1]); \ - -#define CREATE_FA(TYPE, NAMELC) \ - CREATE_FA2(TYPE, NAMELC, 64) \ - CREATE_FA2(TYPE, NAMELC, 80) \ - CREATE_FA2(TYPE, NAMELC, 96) \ - CREATE_FA2(TYPE, NAMELC, 112) \ - CREATE_FA2(TYPE, NAMELC, 128) \ - CREATE_FA2(TYPE, NAMELC, 256) - - CREATE_FA(GGML_TYPE_F16, f16) - CREATE_FA(GGML_TYPE_Q4_0, q4_0) - CREATE_FA(GGML_TYPE_Q4_1, q4_1) - CREATE_FA(GGML_TYPE_Q5_0, q5_0) - CREATE_FA(GGML_TYPE_Q5_1, q5_1) - CREATE_FA(GGML_TYPE_Q8_0, q8_0) - // K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently - //CREATE_FA(GGML_TYPE_Q2_K, q2_k) - //CREATE_FA(GGML_TYPE_Q3_K, q3_k) - //CREATE_FA(GGML_TYPE_Q4_K, q4_k) - //CREATE_FA(GGML_TYPE_Q5_K, q5_k) - //CREATE_FA(GGML_TYPE_Q6_K, q6_k) - //CREATE_FA(GGML_TYPE_IQ1_S, iq1_s) - //CREATE_FA(GGML_TYPE_IQ1_M, iq1_m) - //CREATE_FA(GGML_TYPE_IQ2_XXS, iq2_xxs) - //CREATE_FA(GGML_TYPE_IQ2_XS, iq2_xs) - //CREATE_FA(GGML_TYPE_IQ2_S, iq2_s) - //CREATE_FA(GGML_TYPE_IQ3_XXS, iq3_xxs) - //CREATE_FA(GGML_TYPE_IQ3_S, iq3_s) - //CREATE_FA(GGML_TYPE_IQ4_XS, iq4_xs) - CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl) + CREATE_FA(GGML_TYPE_F16, f16, false, _cm2) + CREATE_FA(GGML_TYPE_Q4_0, q4_0, false, _cm2) + CREATE_FA(GGML_TYPE_Q4_1, q4_1, false, _cm2) + CREATE_FA(GGML_TYPE_Q5_0, q5_0, false, _cm2) + CREATE_FA(GGML_TYPE_Q5_1, q5_1, false, _cm2) + CREATE_FA(GGML_TYPE_Q8_0, q8_0, false, _cm2) + CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, false, _cm2) + } +#endif +#undef CREATE_FA2 #undef CREATE_FA +#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + if (device->coopmat2) { + // Create 6 variants, {s,m,l}x{unaligned,aligned} #define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ @@ -1668,6 +1982,11 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3) +#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + if (device->coopmat_bf16_support) { + CREATE_MM(pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3) + } +#endif CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) @@ -1689,6 +2008,11 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4) +#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + if (device->coopmat_bf16_support) { + CREATE_MM(pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4) + } +#endif CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) @@ -1742,6 +2066,11 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); +#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + if (device->coopmat_bf16_support) { + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ) + } +#endif if (device->coopmat_acc_f16_support) { CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); @@ -1790,6 +2119,11 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); +#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + if (device->coopmat_bf16_support) { + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + } +#endif if (device->coopmat_acc_f16_support) { CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); @@ -1854,6 +2188,14 @@ static void ggml_vk_load_shaders(vk_device& device) { if (device->mul_mat ## ID ## _s[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \ +#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + if (device->mul_mat ## ID ## _l[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ + if (device->mul_mat ## ID ## _m[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ + if (device->mul_mat ## ID ## _s[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ + // Create 2 variants, {f16,f32} accumulator #define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ @@ -1864,6 +2206,8 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); @@ -1885,10 +2229,22 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + if (device->integer_dot_product) { + CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + } +#endif + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); @@ -1910,6 +2266,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); #undef CREATE_MM2 +#undef CREATE_MMQ #undef CREATE_MM } else { // Create 6 variants, {s,m,l}x{unaligned,aligned} @@ -1927,11 +2284,21 @@ static void ggml_vk_load_shaders(vk_device& device) { if (device->mul_mat ## ID ## _s[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \ +#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + if (device->mul_mat ## ID ## _l[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ + if (device->mul_mat ## ID ## _m[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ + if (device->mul_mat ## ID ## _s[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); @@ -1953,10 +2320,22 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + if (device->integer_dot_product) { + CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + } +#endif + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); @@ -1977,8 +2356,26 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); -#undef CREATE_MM } + // reusing CREATE_MM from the fp32 path + if ((device->coopmat2 || device->coopmat_support) +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + && !device->coopmat_bf16_support +#endif + ) { + // use scalar tile sizes + l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 }; + m_warptile = { 128, 64, 64, 16, subgroup_size_8, 32, 2, 4, 2, 1, subgroup_size_8 }; + s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, 2, 2, 1, subgroup_size_8 }; + + l_wg_denoms = {128, 128, 1 }; + m_wg_denoms = { 64, 64, 1 }; + s_wg_denoms = { 32, 32, 1 }; + + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id); + } +#undef CREATE_MM // mul mat vec @@ -1986,16 +2383,18 @@ static void ggml_vk_load_shaders(vk_device& device) { uint32_t rm_stdq = 1; uint32_t rm_kq = 2; if (device->vendor_id == VK_VENDOR_ID_AMD) { - if (device->subgroup_min_size == 64 && device->subgroup_max_size == 64) { // GCN + if (device->architecture == AMD_GCN) { rm_stdq = 2; rm_kq = 4; } } else if (device->vendor_id == VK_VENDOR_ID_INTEL) rm_stdq = 2; + uint32_t rm_iq = 2 * rm_kq; for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32_"+std::to_string(i+1), mul_mat_vec_f32_f32_f32_len, mul_mat_vec_f32_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f32_f32_"+std::to_string(i+1), mul_mat_vec_f16_f32_f32_len, mul_mat_vec_f16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f32_f32_"+std::to_string(i+1), mul_mat_vec_bf16_f32_f32_len, mul_mat_vec_bf16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); @@ -2006,18 +2405,19 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq1_s_f32_f32_len, mul_mat_vec_iq1_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq1_m_f32_f32_len, mul_mat_vec_iq1_m_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xxs_f32_f32_len, mul_mat_vec_iq2_xxs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xs_f32_f32_len, mul_mat_vec_iq2_xs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_s_f32_f32_len, mul_mat_vec_iq2_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_xxs_f32_f32_len, mul_mat_vec_iq3_xxs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f32_f32_len, mul_mat_vec_iq3_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f32_f32_len, mul_mat_vec_iq4_xs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq1_s_f32_f32_len, mul_mat_vec_iq1_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq1_m_f32_f32_len, mul_mat_vec_iq1_m_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xxs_f32_f32_len, mul_mat_vec_iq2_xxs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xs_f32_f32_len, mul_mat_vec_iq2_xs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_s_f32_f32_len, mul_mat_vec_iq2_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_xxs_f32_f32_len, mul_mat_vec_iq3_xxs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f32_f32_len, mul_mat_vec_iq3_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f32_f32_len, mul_mat_vec_iq4_xs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32_"+std::to_string(i+1), mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32_"+std::to_string(i+1), mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f16_f32_"+std::to_string(i+1), mul_mat_vec_bf16_f16_f32_len, mul_mat_vec_bf16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); @@ -2028,19 +2428,20 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq1_s_f16_f32_len, mul_mat_vec_iq1_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq1_m_f16_f32_len, mul_mat_vec_iq1_m_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xxs_f16_f32_len, mul_mat_vec_iq2_xxs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xs_f16_f32_len, mul_mat_vec_iq2_xs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_s_f16_f32_len, mul_mat_vec_iq2_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_xxs_f16_f32_len, mul_mat_vec_iq3_xxs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f16_f32_len, mul_mat_vec_iq3_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f16_f32_len, mul_mat_vec_iq4_xs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq1_s_f16_f32_len, mul_mat_vec_iq1_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq1_m_f16_f32_len, mul_mat_vec_iq1_m_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xxs_f16_f32_len, mul_mat_vec_iq2_xxs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xs_f16_f32_len, mul_mat_vec_iq2_xs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_s_f16_f32_len, mul_mat_vec_iq2_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_xxs_f16_f32_len, mul_mat_vec_iq3_xxs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f16_f32_len, mul_mat_vec_iq3_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f16_f32_len, mul_mat_vec_iq4_xs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); } ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_BF16], "mul_mat_vec_id_bf16_f32", mul_mat_vec_id_bf16_f32_len, mul_mat_vec_id_bf16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); @@ -2051,15 +2452,15 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_S], "mul_mat_vec_id_iq1_s_f32", mul_mat_vec_id_iq1_s_f32_len, mul_mat_vec_id_iq1_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_M], "mul_mat_vec_id_iq1_m_f32", mul_mat_vec_id_iq1_m_f32_len, mul_mat_vec_id_iq1_m_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XXS], "mul_mat_vec_id_iq2_xxs_f32", mul_mat_vec_id_iq2_xxs_f32_len, mul_mat_vec_id_iq2_xxs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XS], "mul_mat_vec_id_iq2_xs_f32", mul_mat_vec_id_iq2_xs_f32_len, mul_mat_vec_id_iq2_xs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_S], "mul_mat_vec_id_iq2_s_f32", mul_mat_vec_id_iq2_s_f32_len, mul_mat_vec_id_iq2_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_XXS], "mul_mat_vec_id_iq3_xxs_f32", mul_mat_vec_id_iq3_xxs_f32_len, mul_mat_vec_id_iq3_xxs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_S], "mul_mat_vec_id_iq3_s_f32", mul_mat_vec_id_iq3_s_f32_len, mul_mat_vec_id_iq3_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", mul_mat_vec_id_iq4_xs_f32_len, mul_mat_vec_id_iq4_xs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_S], "mul_mat_vec_id_iq1_s_f32", mul_mat_vec_id_iq1_s_f32_len, mul_mat_vec_id_iq1_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_M], "mul_mat_vec_id_iq1_m_f32", mul_mat_vec_id_iq1_m_f32_len, mul_mat_vec_id_iq1_m_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XXS], "mul_mat_vec_id_iq2_xxs_f32", mul_mat_vec_id_iq2_xxs_f32_len, mul_mat_vec_id_iq2_xxs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XS], "mul_mat_vec_id_iq2_xs_f32", mul_mat_vec_id_iq2_xs_f32_len, mul_mat_vec_id_iq2_xs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_S], "mul_mat_vec_id_iq2_s_f32", mul_mat_vec_id_iq2_s_f32_len, mul_mat_vec_id_iq2_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_XXS], "mul_mat_vec_id_iq3_xxs_f32", mul_mat_vec_id_iq3_xxs_f32_len, mul_mat_vec_id_iq3_xxs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_S], "mul_mat_vec_id_iq3_s_f32", mul_mat_vec_id_iq3_s_f32_len, mul_mat_vec_id_iq3_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", mul_mat_vec_id_iq4_xs_f32_len, mul_mat_vec_id_iq4_xs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); // dequant shaders ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); @@ -2086,6 +2487,7 @@ static void ggml_vk_load_shaders(vk_device& device) { // get_rows ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F16 ], "get_rows_f16", get_rows_f16_len, get_rows_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_BF16], "get_rows_bf16", get_rows_bf16_len, get_rows_bf16_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_0], "get_rows_q4_0", get_rows_q4_0_len, get_rows_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_1], "get_rows_q4_1", get_rows_q4_1_len, get_rows_q4_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_0], "get_rows_q5_0", get_rows_q5_0_len, get_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); @@ -2103,6 +2505,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_BF16], "get_rows_bf16_f32", get_rows_bf16_f32_len, get_rows_bf16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_0], "get_rows_q4_0_f32", get_rows_q4_0_f32_len, get_rows_q4_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_1], "get_rows_q4_1_f32", get_rows_q4_1_f32_len, get_rows_q4_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_0], "get_rows_q5_0_f32", get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); @@ -2119,29 +2522,51 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 3 * sizeof(uint32_t), {1, 1, 1}, {}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32, "mul_mat_vec_p021_f16_f32", mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 7 * sizeof(uint32_t), {1, 1, 1}, {}, 1); + for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) { + if (device->subgroup_add && device->subgroup_require_full_support) { + ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_subgroup_add_len, mul_mat_vec_p021_f16_f32_subgroup_add_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true, true); + } else { + ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true); + } + } + ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 9 * sizeof(uint32_t), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f32, "cpy_f16_f32", cpy_f16_f32_len, cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_bf16,"cpy_f32_bf16",cpy_f32_bf16_len,cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f32, "contig_cpy_f32_f32", contig_cpy_f32_f32_len, contig_cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f32, "contig_cpy_f16_f32", contig_cpy_f16_f32_len, contig_cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_bf16,"contig_cpy_f32_bf16",contig_cpy_f32_bf16_len,contig_cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1); + if (device->float_controls_rte_fp16) { + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_rte_len, cpy_f32_q5_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_rte_len, cpy_f32_q5_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_rte_len, cpy_f32_q8_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_rte_len, cpy_f32_iq4_nl_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1); + } else { + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1); + } ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_0], "cpy_q4_0_f32", cpy_q4_0_f32_len, cpy_q4_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_1], "cpy_q4_1_f32", cpy_q4_1_f32_len, cpy_q4_1_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1); @@ -2150,20 +2575,32 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q8_0], "cpy_q8_0_f32", cpy_q8_0_f32_len, cpy_q8_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_IQ4_NL], "cpy_iq4_nl_f32", cpy_iq4_nl_f32_len, cpy_iq4_nl_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_add_f32, "add_f32", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1); - ggml_vk_create_pipeline(device, device->pipeline_add_f32_norepeat, "add_f32_norepeat", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1); - ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16, "add_f16_f32_f16", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1); - ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16_norepeat, "add_f16_f32_f16_norepeat", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1); + auto get_suffix = [](bool src0_f16, bool src1_f16, bool dst_f16) { + std::string s; + s += std::string(src0_f16 ? "_f16" : "_f32"); + s += std::string(src1_f16 ? "_f16" : "_f32"); + s += std::string(dst_f16 ? "_f16" : "_f32"); + return s; + }; + +#define CREATE_BINARY(name, namemod, spec) \ + for (int s0 : {0,1}) for (int s1 : {0,1}) for (int d : {0,1}) \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name ## namemod[s0][s1][d], \ + #name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d], name ## _data[s0][s1][d], \ + "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1); + + CREATE_BINARY(add, , {0}) + CREATE_BINARY(add, _norepeat, {1}) + CREATE_BINARY(sub, , {0}) + CREATE_BINARY(sub, _norepeat, {1}) + CREATE_BINARY(mul, , {0}) + CREATE_BINARY(mul, _norepeat, {1}) + CREATE_BINARY(div, , {0}) + CREATE_BINARY(div, _norepeat, {1}) +#undef CREATE_BINARY ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_sub_f32, "sub_f32", sub_f32_len, sub_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1); - ggml_vk_create_pipeline(device, device->pipeline_sub_f32_norepeat, "sub_f32_norepeat", sub_f32_len, sub_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1); - ggml_vk_create_pipeline(device, device->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1); - ggml_vk_create_pipeline(device, device->pipeline_mul_f32_norepeat, "mul_f32_norepeat", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1); - ggml_vk_create_pipeline(device, device->pipeline_div_f32, "div_f32", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1); - ggml_vk_create_pipeline(device, device->pipeline_div_f32_norepeat, "div_f32_norepeat", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1); - ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_concat_i32, "concat_i32", concat_i32_len, concat_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); @@ -2183,14 +2620,20 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32", repeat_f32_len, repeat_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_repeat_back_f32, "repeat_back_f32", repeat_back_f32_len, repeat_back_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_gelu_f32, "gelu_f32", gelu_f32_len, gelu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_gelu_quick_f32, "gelu_quick_f32", gelu_quick_f32_len, gelu_quick_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_silu_f32, "silu_f32", silu_f32_len, silu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_relu_f32, "relu_f32", relu_f32_len, relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); +#define CREATE_UNARY(name) \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + + CREATE_UNARY(gelu) + CREATE_UNARY(gelu_quick) + CREATE_UNARY(silu) + CREATE_UNARY(relu) + CREATE_UNARY(tanh) + CREATE_UNARY(sigmoid) +#undef CREATE_UNARY + ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_tanh_f32, "tanh_f32", tanh_f32_len, tanh_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_sigmoid_f32, "sigmoid_f32", sigmoid_f32_len, sigmoid_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true); @@ -2238,15 +2681,20 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); + ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f32, "conv2d_dw_cwhn_f32", conv2d_dw_cwhn_f32_len, conv2d_dw_cwhn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); + for (auto &c : compiles) { c.wait(); } device->need_compiles = false; } -static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props); +static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch); static vk_device ggml_vk_get_device(size_t idx) { VK_LOG_DEBUG("ggml_vk_get_device(" << idx << ")"); @@ -2275,6 +2723,8 @@ static vk_device ggml_vk_get_device(size_t idx) { device->physical_device = physical_devices[dev_num]; const std::vector ext_props = device->physical_device.enumerateDeviceExtensionProperties(); + device->architecture = get_device_architecture(device->physical_device); + const char* GGML_VK_PREFER_HOST_MEMORY = getenv("GGML_VK_PREFER_HOST_MEMORY"); device->prefer_host_memory = GGML_VK_PREFER_HOST_MEMORY != nullptr; @@ -2286,8 +2736,9 @@ static vk_device ggml_vk_get_device(size_t idx) { bool pipeline_robustness = false; bool coopmat2_support = false; device->coopmat_support = false; + device->integer_dot_product = false; + bool bfloat16_support = false; - // Check if maintenance4 is supported for (const auto& properties : ext_props) { if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) { maintenance4_support = true; @@ -2312,6 +2763,14 @@ static vk_device ggml_vk_get_device(size_t idx) { } else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 && !getenv("GGML_VK_DISABLE_COOPMAT2")) { coopmat2_support = true; +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + } else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 && + !getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) { + device->integer_dot_product = true; +#endif + } else if (strcmp("VK_KHR_shader_bfloat16", properties.extensionName) == 0 && + !getenv("GGML_VK_DISABLE_BFLOAT16")) { + bfloat16_support = true; } } @@ -2322,13 +2781,16 @@ static vk_device ggml_vk_get_device(size_t idx) { vk::PhysicalDeviceDriverProperties driver_props; vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props; vk::PhysicalDeviceShaderCoreProperties2AMD amd_shader_core_properties2_props; + vk::PhysicalDeviceVulkan11Properties vk11_props; vk::PhysicalDeviceVulkan12Properties vk12_props; vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props; + vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR shader_integer_dot_product_props; props2.pNext = &props3; props3.pNext = &subgroup_props; subgroup_props.pNext = &driver_props; - driver_props.pNext = &vk12_props; + driver_props.pNext = &vk11_props; + vk11_props.pNext = &vk12_props; VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_props; @@ -2357,9 +2819,15 @@ static vk_device ggml_vk_get_device(size_t idx) { } #endif + if (device->integer_dot_product) { + last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_props; + last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_props; + } + device->physical_device.getProperties2(&props2); device->properties = props2.properties; device->vendor_id = device->properties.vendorID; + device->driver_id = driver_props.driverID; const char* GGML_VK_FORCE_MAX_ALLOCATION_SIZE = getenv("GGML_VK_FORCE_MAX_ALLOCATION_SIZE"); @@ -2375,13 +2843,9 @@ static vk_device ggml_vk_get_device(size_t idx) { if (GGML_VK_SUBALLOCATION_BLOCK_SIZE != nullptr) { device->suballocation_block_size = std::stoul(GGML_VK_SUBALLOCATION_BLOCK_SIZE); -#if defined(_WIN32) - } else if (device->vendor_id == VK_VENDOR_ID_NVIDIA) { + } else { // Limit batching of allocations to 1GB by default to avoid fragmentation issues device->suballocation_block_size = 1024*1024*1024; -#endif - } else { - device->suballocation_block_size = device->max_memory_allocation_size; } device->suballocation_block_size = std::min(device->suballocation_block_size, device->max_memory_allocation_size); @@ -2396,14 +2860,22 @@ static vk_device ggml_vk_get_device(size_t idx) { } device->float_controls_rte_fp16 = vk12_props.shaderRoundingModeRTEFloat16; + device->subgroup_add = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && + (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eArithmetic); + + device->subgroup_shuffle = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && + (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eShuffle); + const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr; device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute; - if (!ggml_vk_khr_cooperative_matrix_support(device->properties, driver_props)) { + if (!ggml_vk_khr_cooperative_matrix_support(device->properties, driver_props, device->architecture)) { device->coopmat_support = false; } + device->integer_dot_product = device->integer_dot_product && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated; + std::vector queue_family_props = device->physical_device.getQueueFamilyProperties(); // Try to find a non-graphics compute queue and transfer-focused queues @@ -2488,6 +2960,17 @@ static vk_device ggml_vk_get_device(size_t idx) { } #endif +#if defined(VK_KHR_shader_bfloat16) + VkPhysicalDeviceShaderBfloat16FeaturesKHR bfloat16_features {}; + bfloat16_features.pNext = nullptr; + bfloat16_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR; + if (bfloat16_support) { + last_struct->pNext = (VkBaseOutStructure *)&bfloat16_features; + last_struct = (VkBaseOutStructure *)&bfloat16_features; + device_extensions.push_back("VK_KHR_shader_bfloat16"); + } +#endif + VkPhysicalDeviceMaintenance4Features maint4_features {}; maint4_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MAINTENANCE_4_FEATURES; if (maintenance4_support) { @@ -2496,6 +2979,14 @@ static vk_device ggml_vk_get_device(size_t idx) { device_extensions.push_back("VK_KHR_maintenance4"); } + VkPhysicalDeviceShaderIntegerDotProductFeaturesKHR shader_integer_dot_product_features {}; + shader_integer_dot_product_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_INTEGER_DOT_PRODUCT_FEATURES_KHR; + if (device->integer_dot_product) { + last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_features; + last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_features; + device_extensions.push_back("VK_KHR_shader_integer_dot_product"); + } + vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2); device->fp16 = device->fp16 && vk12_features.shaderFloat16; @@ -2665,7 +3156,37 @@ static vk_device ggml_vk_get_device(size_t idx) { device->coopmat_acc_f16_support = true; } } + } else if ((vk::ComponentTypeKHR)prop.AType == vk::ComponentTypeKHR::eSint8 && + (vk::ComponentTypeKHR)prop.BType == vk::ComponentTypeKHR::eSint8 && + (vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eSint32 && + (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eSint32 && + (vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup && + device->coopmat_int_m == 0 + ) { + device->coopmat_int_support = true; + device->coopmat_int_m = prop.MSize; + device->coopmat_int_n = prop.NSize; + device->coopmat_int_k = prop.KSize; } +#if defined(VK_KHR_shader_bfloat16) && defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + if (prop.AType == VK_COMPONENT_TYPE_BFLOAT16_KHR && + prop.BType == VK_COMPONENT_TYPE_BFLOAT16_KHR && + prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR && + prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR && + (vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup + ) { + // coopmat sizes not set yet + if (device->coopmat_m == 0) { + device->coopmat_bf16_support = true; + device->coopmat_m = prop.MSize; + device->coopmat_n = prop.NSize; + device->coopmat_k = prop.KSize; + } else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) { + // Only enable if shape is identical + device->coopmat_bf16_support = true; + } + } +#endif } if (device->coopmat_m == 0 || !device->coopmat_acc_f32_support) { @@ -2673,11 +3194,19 @@ static vk_device ggml_vk_get_device(size_t idx) { GGML_LOG_DEBUG("ggml_vulkan: WARNING: No suitable matrix core mode found. Disabling matrix cores.\n"); device->coopmat_support = false; } + if (getenv("GGML_VK_DISABLE_BFLOAT16")) { + device->coopmat_bf16_support = false; + } } if (device->coopmat_support) { device_extensions.push_back("VK_KHR_cooperative_matrix"); } +#if defined(VK_KHR_shader_bfloat16) + if (device->coopmat_bf16_support) { + device_extensions.push_back("VK_KHR_shader_bfloat16"); + } +#endif #endif device->name = GGML_VK_NAME + std::to_string(idx); @@ -2769,22 +3298,11 @@ static void ggml_vk_print_gpu_info(size_t idx) { vk::PhysicalDevice physical_device = devices[dev_num]; std::vector ext_props = physical_device.enumerateDeviceExtensionProperties(); - vk::PhysicalDeviceProperties2 props2; - vk::PhysicalDeviceMaintenance3Properties props3; - vk::PhysicalDeviceSubgroupProperties subgroup_props; - vk::PhysicalDeviceDriverProperties driver_props; - props2.pNext = &props3; - props3.pNext = &subgroup_props; - subgroup_props.pNext = &driver_props; - physical_device.getProperties2(&props2); - - const size_t subgroup_size = subgroup_props.subgroupSize; - const bool uma = props2.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu; - bool fp16_storage = false; bool fp16_compute = false; bool coopmat_support = false; bool coopmat2_support = false; + bool integer_dot_product = false; for (auto properties : ext_props) { if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) { @@ -2800,25 +3318,44 @@ static void ggml_vk_print_gpu_info(size_t idx) { } else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 && !getenv("GGML_VK_DISABLE_COOPMAT2")) { coopmat2_support = true; +#endif +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + } else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 && + !getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) { + integer_dot_product = true; #endif } } - if (!ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props)) { - coopmat_support = false; - } + const vk_device_architecture device_architecture = get_device_architecture(physical_device); const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16"); bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr; bool fp16 = !force_disable_f16 && fp16_storage && fp16_compute; - vk::PhysicalDeviceFeatures device_features = physical_device.getFeatures(); + vk::PhysicalDeviceProperties2 props2; + vk::PhysicalDeviceMaintenance3Properties props3; + vk::PhysicalDeviceSubgroupProperties subgroup_props; + vk::PhysicalDeviceDriverProperties driver_props; + vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR shader_integer_dot_product_props; + props2.pNext = &props3; + props3.pNext = &subgroup_props; + subgroup_props.pNext = &driver_props; + + // Pointer to the last chain element + VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&driver_props; + + if (integer_dot_product) { + last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_props; + last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_props; + } + + physical_device.getProperties2(&props2); VkPhysicalDeviceFeatures2 device_features2; device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2; device_features2.pNext = nullptr; - device_features2.features = (VkPhysicalDeviceFeatures)device_features; VkPhysicalDeviceVulkan11Features vk11_features; vk11_features.pNext = nullptr; @@ -2831,7 +3368,7 @@ static void ggml_vk_print_gpu_info(size_t idx) { vk11_features.pNext = &vk12_features; // Pointer to the last chain element - VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_features; + last_struct = (VkBaseOutStructure *)&vk12_features; #if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features; @@ -2843,20 +3380,39 @@ static void ggml_vk_print_gpu_info(size_t idx) { last_struct->pNext = (VkBaseOutStructure *)&coopmat_features; last_struct = (VkBaseOutStructure *)&coopmat_features; } +#endif + + VkPhysicalDeviceShaderIntegerDotProductFeaturesKHR shader_integer_dot_product_features {}; + shader_integer_dot_product_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_INTEGER_DOT_PRODUCT_FEATURES_KHR; + if (integer_dot_product) { + last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_features; + last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_features; + } vkGetPhysicalDeviceFeatures2(physical_device, &device_features2); fp16 = fp16 && vk12_features.shaderFloat16; - coopmat_support = coopmat_support && coopmat_features.cooperativeMatrix; + uint32_t default_subgroup_size = get_subgroup_size("", device_architecture); + const size_t subgroup_size = (default_subgroup_size != 0) ? default_subgroup_size : subgroup_props.subgroupSize; + const bool uma = props2.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu; + + integer_dot_product = integer_dot_product + && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated + && shader_integer_dot_product_features.shaderIntegerDotProduct; + + coopmat_support = coopmat_support +#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + && coopmat_features.cooperativeMatrix #endif + && ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props, device_architecture); std::string matrix_cores = coopmat2_support ? "NV_coopmat2" : coopmat_support ? "KHR_coopmat" : "none"; std::string device_name = props2.properties.deviceName.data(); - GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | shared memory: %d | matrix cores: %s\n", + GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | shared memory: %d | int dot: %d | matrix cores: %s\n", idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, subgroup_size, - props2.properties.limits.maxComputeSharedMemorySize, matrix_cores.c_str()); + props2.properties.limits.maxComputeSharedMemorySize, integer_dot_product, matrix_cores.c_str()); if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) { GGML_LOG_DEBUG("ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want.\n"); @@ -3058,6 +3614,7 @@ static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) { ctx->prealloc_size_split_k = 0; ctx->fence = ctx->device->device.createFence({}); + ctx->almost_ready_fence = ctx->device->device.createFence({}); #ifdef GGML_VULKAN_CHECK_RESULTS const char* skip_checks = getenv("GGML_VULKAN_SKIP_CHECKS"); @@ -3106,6 +3663,9 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) { return ctx->device->pipeline_matmul_f32_f16; } + if (src0_type == GGML_TYPE_BF16 && src1_type == GGML_TYPE_BF16) { + return ctx->device->pipeline_matmul_bf16; + } if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) { if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) { return ctx->device->pipeline_matmul_f16_f32.f16acc; @@ -3122,6 +3682,17 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte } } + // MMQ + if (src1_type == GGML_TYPE_Q8_1) { + vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f16acc; + + if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) { + return nullptr; + } + + return pipelines; + } + if (src1_type != GGML_TYPE_F32 && !ctx->device->coopmat2) { return nullptr; } @@ -3166,6 +3737,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * switch (a_type) { case GGML_TYPE_F32: case GGML_TYPE_F16: + case GGML_TYPE_BF16: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -3198,6 +3770,9 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) { return ctx->device->pipeline_matmul_id_f32; } + if (src0_type == GGML_TYPE_BF16 && src1_type == GGML_TYPE_BF16) { + return ctx->device->pipeline_matmul_id_bf16; + } if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) { if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) { return ctx->device->pipeline_matmul_id_f16_f32.f16acc; @@ -3251,6 +3826,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context switch (a_type) { case GGML_TYPE_F32: case GGML_TYPE_F16: + case GGML_TYPE_BF16: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -3414,8 +3990,6 @@ static vk_submission ggml_vk_begin_submission(vk_device& device, vk_queue& q, bo return s; } - - static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& subctx, vk_pipeline& pipeline, std::initializer_list const& descriptor_buffer_infos, size_t push_constant_size, const void* push_constants, std::array elements) { const uint32_t wg0 = CEIL_DIV(elements[0], pipeline->wg_denoms[0]); const uint32_t wg1 = CEIL_DIV(elements[1], pipeline->wg_denoms[1]); @@ -3839,20 +4413,30 @@ static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, int m, int if (split_k == 3) { split_k = 2; } + if (ctx->device->coopmat2) { + // coopmat2 shader expects splits to be aligned to 256 + while (split_k > 1 && ((k / split_k) % 256) != 0) { + split_k /= 2; + } + } } } return split_k; } -static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned, ggml_type src0_type) { - VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")"); +static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type, ggml_type src1_type) { + VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")"); if (ctx->device->coopmat2) { - if ((ctx->device->mul_mat_l[src0_type] && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) { + // Use large shader when the N dimension is greater than the medium shader's tile size + uint32_t crossover_large = mmp->m->wg_denoms[1]; + if ((ctx->device->mul_mat_l[src0_type] && (n > crossover_large)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) { return aligned ? mmp->a_l : mmp->l; } - if ((ctx->device->mul_mat_m[src0_type] && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_s[src0_type]) { + // Use medium shader when the N dimension is greater than the small shader's tile size + uint32_t crossover_medium = mmp->s->wg_denoms[1]; + if ((ctx->device->mul_mat_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_s[src0_type]) { return aligned ? mmp->a_m : mmp->m; } return aligned ? mmp->a_s : mmp->s; @@ -3867,9 +4451,9 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, return aligned ? mmp->a_l : mmp->l; } -static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type) { - VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ")"); - return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true, src0_type)->align; +static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type, ggml_type src1_type) { + VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")"); + return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true, src0_type, src1_type)->align; } static void ggml_vk_matmul( @@ -3877,18 +4461,19 @@ static void ggml_vk_matmul( vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& split_k_buffer, uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d, uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d, - uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3) { - VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ")"); + uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3, + uint32_t padded_n) { + VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", padded_n: " << padded_n << ")"); ggml_vk_sync_buffers(subctx); if (split_k == 1) { - const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3 }; + const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n }; ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, sizeof(vk_mat_mat_push_constants), &pc, { m, n, batch }); return; } GGML_ASSERT(batch_stride_d == m * n); - const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, CEIL_DIV(k, split_k), ne02, ne12, broadcast2, broadcast3 }; + const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, CEIL_DIV(k, split_k), ne02, ne12, broadcast2, broadcast3, padded_n }; // Make sure enough workgroups get assigned for split k to work ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, sizeof(vk_mat_mat_push_constants), &pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch }); ggml_vk_sync_buffers(subctx); @@ -3896,14 +4481,18 @@ static void ggml_vk_matmul( ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 }); } -static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned, ggml_type src0_type) { - VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")"); +static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type) { + VK_LOG_DEBUG("ggml_vk_guess_matmul_id_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")"); if (ctx->device->coopmat2) { - if ((ctx->device->mul_mat_id_l[src0_type] && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_s[src0_type])) { + // Use large shader when the N dimension is greater than the medium shader's tile size + uint32_t crossover_large = mmp->m->wg_denoms[1]; + if ((ctx->device->mul_mat_id_l[src0_type] && (n > crossover_large)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_s[src0_type])) { return aligned ? mmp->a_l : mmp->l; } - if ((ctx->device->mul_mat_id_m[src0_type] && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_id_s[src0_type]) { + // Use medium shader when the N dimension is greater than the small shader's tile size + uint32_t crossover_medium = mmp->s->wg_denoms[1]; + if ((ctx->device->mul_mat_id_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_id_s[src0_type]) { return aligned ? mmp->a_m : mmp->m; } return aligned ? mmp->a_s : mmp->s; @@ -3928,14 +4517,15 @@ static void ggml_vk_matmul_id( vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& ids, uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d, uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d, - uint32_t n_as, uint32_t nei0, uint32_t nei1, uint32_t nbi1, uint32_t ne11) { + uint32_t n_as, uint32_t nei0, uint32_t nei1, uint32_t nbi1, uint32_t ne11, + uint32_t padded_n) { VK_LOG_DEBUG("ggml_vk_matmul_id(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), ids: (" << ids.buffer->buffer << ", " << ids.offset << ", " << ids.size << "), " << "m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", " << "batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", " << "n_as: " << n_as << ", nei0: " << nei0 << ", nei1: " << nei1 << ", nbi1: " << nbi1 << ", ne11: " << ne11 << ")"); ggml_vk_sync_buffers(subctx); const vk_mat_mat_id_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, - nei0, nei1, nbi1, ne11 }; + nei0, nei1, nbi1, ne11, padded_n }; ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids }, sizeof(vk_mat_mat_id_push_constants), &pc, { m, nei1, n_as }); } @@ -3972,6 +4562,20 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_cpy_f16_f16; } } + if (src->type == GGML_TYPE_F16 && to == GGML_TYPE_F32) { + if (contig) { + return ctx->device->pipeline_contig_cpy_f16_f32; + } else { + return ctx->device->pipeline_cpy_f16_f32; + } + } + if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_BF16) { + if (contig) { + return ctx->device->pipeline_contig_cpy_f32_bf16; + } else { + return ctx->device->pipeline_cpy_f32_bf16; + } + } if (src->type == GGML_TYPE_F32) { switch (to) { case GGML_TYPE_Q4_0: @@ -4033,6 +4637,25 @@ static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context& ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(vk_op_unary_push_constants), &pc, elements); } +static vk_pipeline ggml_vk_get_quantize_pipeline(ggml_backend_vk_context * ctx, ggml_type type) { + switch(type) { + case GGML_TYPE_Q8_1: + return ctx->device->pipeline_quantize_q8_1; + default: + std::cerr << "Missing quantize pipeline for type: " << ggml_type_name(type) << std::endl; + GGML_ABORT("fatal error"); + } +} + +static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& subctx, vk_subbuffer&& in, vk_subbuffer&& out, uint32_t ne) { + VK_LOG_DEBUG("ggml_vk_quantize_q8_1(" << "buffer in size=" << in.buffer->size << ", buffer out size=" << out.buffer->size << ", " << ne << ")"); + + vk_pipeline pipeline = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1); + + ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(uint32_t), &ne, { ne, 1, 1 }); +} + static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { VK_LOG_DEBUG("ggml_vk_mul_mat_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; @@ -4080,56 +4703,76 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) || !ggml_vk_dim01_contiguous(src0); const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) || + (src0->type == GGML_TYPE_BF16 && src1->type != GGML_TYPE_BF16) || !ggml_vk_dim01_contiguous(src1); + // If src0 is BF16, try to use a BF16 x BF16 multiply + ggml_type f16_type = src0->type == GGML_TYPE_BF16 ? GGML_TYPE_BF16 : GGML_TYPE_F16; + const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig; - vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type, (ggml_prec)dst->op_params[0]); + bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0; + + // Check for mmq first + vk_matmul_pipeline mmp = quantize_y ? ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, GGML_TYPE_Q8_1, (ggml_prec)dst->op_params[0]) : nullptr; + + if (mmp == nullptr) { + // Fall back to f16 dequant mul mat + mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? f16_type : src1->type, (ggml_prec)dst->op_params[0]); + quantize_y = false; + } const bool qx_needs_dequant = mmp == nullptr || x_non_contig; - const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig; + const bool qy_needs_dequant = !quantize_y && ((src1->type != f16_type && !y_f32_kernel) || y_non_contig); if (qx_needs_dequant) { // Fall back to dequant + f16 mulmat - mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16, (ggml_prec)dst->op_params[0]); + mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, f16_type, y_f32_kernel ? GGML_TYPE_F32 : f16_type, (ggml_prec)dst->op_params[0]); } // Not implemented GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT + const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type))); + const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && ne11 > 8; + + vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type)); + + // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking + uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) : ne11; const int x_ne = ne01 * ne00; - const int y_ne = ne11 * ne10; + const int y_ne = padded_n * ne10; const int d_ne = ne11 * ne01; - const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? GGML_TYPE_F16 : src0->type)); - const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8; - - vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type); - const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, pipeline); const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type); const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne; - const uint64_t y_sz = y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne; + const uint64_t y_sz = quantize_y ? (y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : (y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne); const uint64_t d_sz = sizeof(float) * d_ne; vk_pipeline to_fp16_vk_0 = nullptr; vk_pipeline to_fp16_vk_1 = nullptr; + vk_pipeline to_q8_1 = nullptr; if (x_non_contig) { - to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, GGML_TYPE_F16); + to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, f16_type); } else { to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type); } if (y_non_contig) { - to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, GGML_TYPE_F16); + to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, f16_type); } else { to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type); } GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT + if (quantize_y) { + to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1); + } + if (dryrun) { const uint64_t x_sz_upd = x_sz * ne02 * ne03; const uint64_t y_sz_upd = y_sz * ne12 * ne13; @@ -4143,7 +4786,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) { ctx->prealloc_size_x = x_sz_upd; } - if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) { + if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz_upd) { ctx->prealloc_size_y = y_sz_upd; } if (split_k > 1 && ctx->prealloc_size_split_k < split_k_size) { @@ -4158,6 +4801,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub if (qy_needs_dequant) { ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1); } + if (quantize_y) { + ggml_pipeline_request_descriptor_sets(ctx->device, to_q8_1, 1); + } if (split_k > 1) { ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, 1); } @@ -4193,6 +4839,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub if (qy_needs_dequant) { d_Y = ctx->prealloc_y; GGML_ASSERT(d_Y->size >= y_sz * ne12 * ne13); + } else if (quantize_y) { + d_Y = ctx->prealloc_y; + GGML_ASSERT(d_Y->size >= y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)); } else { d_Y = d_Qy; y_buf_offset = qy_buf_offset; @@ -4209,6 +4858,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub if (y_non_contig) { ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }); } + if (quantize_y) { + ggml_vk_quantize_q8_1(ctx, subctx, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, y_ne * ne12 * ne13); + } uint32_t stride_batch_x = ne00*ne01; uint32_t stride_batch_y = ne10*ne11; @@ -4217,7 +4869,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub stride_batch_x = src0->nb[0] / ggml_type_size(src0->type); } - if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) { + if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant && !quantize_y) { stride_batch_y = src1->nb[0] / ggml_type_size(src1->type); } @@ -4228,7 +4880,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub { d_D, d_buf_offset, d_sz * ne12 * ne13 }, { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k }, ne01, ne11, ne10, ne10, ne10, ne01, stride_batch_x, stride_batch_y, ne20*ne21, - split_k, ne12*ne13, ne02, ne12, r2, r3 + split_k, ne12*ne13, ne02, ne12, r2, r3, padded_n ); // NOLINT } @@ -4465,9 +5117,15 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); const uint64_t d_sz = sizeof(float) * d_ne; + // With grouped query attention there are > 1 Q matrices per K, V matrix. + uint32_t gqa_ratio = (uint32_t)ne12 / (uint32_t)ne02; + if (gqa_ratio > 8 || gqa_ratio == 0 || ne12 != ne02 * gqa_ratio) { + gqa_ratio = 1; + } + if (dryrun) { // Request descriptor sets - ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_mul_mat_vec_p021_f16_f32, 1); + ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], 1); return; } @@ -4491,8 +5149,15 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c // compute const std::array pc = { (uint32_t)ne00, (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) }; + + uint32_t workgroups_z = (uint32_t)ne12; + // When gqa_ratio > 1, each invocation does multiple rows and we can launch fewer workgroups + if (gqa_ratio > 1) { + workgroups_z /= gqa_ratio; + } + ggml_vk_sync_buffers(subctx); - ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 6 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 }); + ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 6 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, workgroups_z }); } static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { @@ -4514,6 +5179,8 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con const uint64_t nb01 = src0->nb[1]; const uint64_t nb02 = src0->nb[2]; + const uint64_t nb12 = src1->nb[2]; + // const uint64_t ne10 = src1->ne[0]; const uint64_t ne11 = src1->ne[1]; const uint64_t ne12 = src1->ne[2]; @@ -4539,6 +5206,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con const uint32_t row_stride_x = nb01 / sizeof(ggml_fp16_t); const uint32_t channel_stride_x = nb02 / sizeof(ggml_fp16_t); + const uint32_t channel_stride_y = nb12 / sizeof(float); const uint64_t qx_sz = ggml_nbytes(src0); const uint64_t qy_sz = ggml_nbytes(src1); @@ -4569,7 +5237,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con const uint64_t d_shader_offset = d_buf_offset - d_buffer_offset; // compute - const std::array pc = { (uint32_t)ne00, (uint32_t)ne01, row_stride_x, channel_stride_x, (uint32_t)(ne12 / ne02), (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) }; + const std::array pc = { (uint32_t)ne00, (uint32_t)ne01, row_stride_x, channel_stride_x, channel_stride_y, (uint32_t)(ne12 / ne02), (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) }; ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 7 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 }); @@ -4594,7 +5262,7 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c // mul_mat_vec supports batching ne12*ne13 when ne11==1, or treating ne11 as the batch size (up to four) // when ne12 and ne13 are one. } else if ((dst->ne[1] == 1 || (dst->ne[1] <= mul_mat_vec_max_cols && src1->ne[2] * src1->ne[3] == 1)) && - (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) { + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || ggml_is_quantized(src0->type))) { ggml_vk_mul_mat_vec_q_f16(ctx, subctx, src0, src1, dst, dryrun); } else { ggml_vk_mul_mat_q_f16(ctx, subctx, src0, src1, dst, dryrun); @@ -4621,7 +5289,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& const uint64_t nei0 = ids->ne[0]; const uint64_t nei1 = ids->ne[1]; - GGML_ASSERT(nei0 * nei1 <= 3072); + GGML_ASSERT(nei0 * nei1 <= 4096); const uint32_t nbi1 = ids->nb[1]; const uint32_t nbi2 = ids->nb[2]; @@ -4662,31 +5330,37 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) || !ggml_vk_dim01_contiguous(src0); const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) || + (src0->type == GGML_TYPE_BF16 && src1->type != GGML_TYPE_BF16) || !ggml_vk_dim01_contiguous(src1); + // If src0 is BF16, try to use a BF16 x BF16 multiply + ggml_type f16_type = src0->type == GGML_TYPE_BF16 ? GGML_TYPE_BF16 : GGML_TYPE_F16; + const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig; - vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type, (ggml_prec)dst->op_params[0]); + vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? f16_type : src1->type, (ggml_prec)dst->op_params[0]); const bool qx_needs_dequant = mmp == nullptr || x_non_contig; - const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig; + const bool qy_needs_dequant = (src1->type != f16_type && !y_f32_kernel) || y_non_contig; if (qx_needs_dequant) { // Fall back to dequant + f16 mulmat - mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16, (ggml_prec)dst->op_params[0]); + mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, f16_type, y_f32_kernel ? GGML_TYPE_F32 : f16_type, (ggml_prec)dst->op_params[0]); } // Not implemented GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT - const uint64_t x_ne = ne01 * ne00; - const uint64_t y_ne = ne11 * ne10; - const uint64_t d_ne = ne21 * ne20; - - const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? GGML_TYPE_F16 : src0->type)); + const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type)); const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8; - vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type); + vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? f16_type : src0->type); + + // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking + uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11; + const uint64_t x_ne = ne01 * ne00; + const uint64_t y_ne = padded_n * ne10; + const uint64_t d_ne = ne21 * ne20; const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type); const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); @@ -4699,12 +5373,12 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& vk_pipeline to_fp16_vk_1 = nullptr; if (x_non_contig) { - to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, GGML_TYPE_F16); + to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, f16_type); } else { to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type); } if (y_non_contig) { - to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, GGML_TYPE_F16); + to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, f16_type); } else { to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type); } @@ -4806,7 +5480,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& { d_D, d_buf_offset, d_sz * ne22 * ne23 }, { d_ids, ids_buf_offset, ids_sz }, ne01, ne21, ne10, ne10, ne10, ne01, stride_batch_x, stride_batch_y, ne20*ne21, - n_as, nei0, nei1, nbi1 / ggml_type_size(ids->type), ne11 + n_as, nei0, nei1, nbi1 / ggml_type_size(ids->type), ne11, padded_n ); // NOLINT } @@ -5034,7 +5708,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx const uint32_t nbm1 = mask ? mask->nb[1] : 0; const uint32_t D = neq0; - const uint32_t N = neq1; + uint32_t N = neq1; const uint32_t KV = nek1; GGML_ASSERT(ne0 == D); @@ -5064,20 +5738,57 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx assert(q->type == GGML_TYPE_F32); assert(k->type == v->type); + bool scalar = !ctx->device->coopmat2; + + uint32_t gqa_ratio = 1; + uint32_t qk_ratio = neq2 / nek2; + uint32_t workgroups_x = (uint32_t)neq1; + uint32_t workgroups_y = (uint32_t)neq2; + uint32_t workgroups_z = (uint32_t)neq3; + + // For scalar FA, we can use the "large" size to accommodate qga. + // For coopmat FA, we always use the small size (which is still pretty large for gqa). + const uint32_t max_gqa = scalar ? scalar_flash_attention_num_large_rows : get_fa_num_small_rows(false); + + if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa && + qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) { + // grouped query attention - make the N dimension equal to gqa_ratio, reduce + // workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1 + // and change addressing calculations to index Q's dimension 2. + gqa_ratio = qk_ratio; + N = gqa_ratio; + workgroups_y /= N; + } + vk_pipeline *pipelines; // XXX TODO other backends may be changing accumulator precision to default to f32 soon - bool f32acc = dst->op_params[3] == GGML_PREC_F32; - bool small_rows = N <= flash_attention_num_small_rows; - switch (D) { - case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break; - case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break; - case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96[k->type][f32acc][small_rows][0]; break; - case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112[k->type][f32acc][small_rows][0]; break; - case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128[k->type][f32acc][small_rows][0]; break; - case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256[k->type][f32acc][small_rows][0]; break; - default: - assert(!"unsupported D value"); - return; + bool f32acc = scalar || dst->op_params[3] == GGML_PREC_F32; + bool small_rows = N <= get_fa_num_small_rows(scalar); + + if (scalar) { + switch (D) { + case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break; + case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break; + case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96[k->type][f32acc][small_rows][0]; break; + case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112[k->type][f32acc][small_rows][0]; break; + case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128[k->type][f32acc][small_rows][0]; break; + case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256[k->type][f32acc][small_rows][0]; break; + default: + GGML_ASSERT(!"unsupported D value"); + return; + } + } else { + switch (D) { + case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm2[k->type][f32acc][small_rows][0]; break; + case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm2[k->type][f32acc][small_rows][0]; break; + case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm2[k->type][f32acc][small_rows][0]; break; + case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm2[k->type][f32acc][small_rows][0]; break; + case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm2[k->type][f32acc][small_rows][0]; break; + case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm2[k->type][f32acc][small_rows][0]; break; + default: + GGML_ASSERT(!"unsupported D value"); + return; + } } assert(pipelines); @@ -5089,12 +5800,47 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx // the "aligned" shader variant will forcibly align strides, for performance (q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0; + // mask dim1 is padded to 64, we rely on this to avoid clamping mask loads + GGML_ASSERT((nem1 % GGML_KQ_MASK_PAD) == 0); + vk_pipeline pipeline = pipelines[aligned]; assert(pipeline); + uint32_t split_kv = KV; + uint32_t split_k = 1; + + // Use a placeholder core count if one isn't available. split_k is a big help for perf. + const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16; + + // Try to use split_k when KV is large enough to be worth the overhead + if (workgroups_x == 1 && shader_core_count > 0 && KV >= 512) { + // Try to run two workgroups per SM. + split_k = ctx->device->shader_core_count * 2 / workgroups_y; + if (split_k > 1) { + // Try to evenly split KV into split_k chunks, but it needs to be a multiple + // of "align", so recompute split_k based on that. + split_kv = ROUNDUP_POW2(KV / split_k, pipelines[1]->align); + split_k = CEIL_DIV(KV, split_kv); + workgroups_x = split_k; + } + } + + // Reserve space for split_k temporaries. For each split, we need to store the O matrix (D x ne1) + // and the per-row m and L values (ne1 rows). + const uint64_t split_k_size = split_k > 1 ? (D * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k : 0; + if (split_k_size > ctx->device->max_memory_allocation_size) { + GGML_ABORT("Requested preallocation size is too large"); + } + if (ctx->prealloc_size_split_k < split_k_size) { + ctx->prealloc_size_split_k = split_k_size; + } + if (dryrun) { // Request descriptor sets ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); + if (split_k > 1) { + ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_flash_attn_split_k_reduce, 1); + } return; } @@ -5115,8 +5861,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - ggml_vk_sync_buffers(subctx); - vk_buffer d_Q = nullptr, d_K = nullptr, d_V = nullptr, d_D = nullptr, d_M = nullptr; size_t q_buf_offset = 0, k_buf_offset = 0, v_buf_offset = 0, d_buf_offset = 0, m_buf_offset = 0; @@ -5181,16 +5925,45 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx v_stride, (uint32_t)nbv2, (uint32_t)nbv3, nbm1, scale, max_bias, logit_softcap, - mask != nullptr, n_head_log2, m0, m1 }; - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, - { - vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE}, - vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE}, - vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE}, - vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE}, - vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE}, - }, - sizeof(vk_flash_attn_push_constants), &pc, { (uint32_t)neq1, (uint32_t)neq2, (uint32_t)neq3 }); + mask != nullptr, n_head_log2, m0, m1, + gqa_ratio, split_kv, split_k }; + + ggml_vk_sync_buffers(subctx); + + if (split_k > 1) { + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, + { + vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE}, + vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE}, + vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE}, + vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE}, + vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE}, + }, + // We only use split_k when group query attention is enabled, which means + // there's no more than one tile of rows (i.e. workgroups_x would have been + // one). We reuse workgroups_x to mean the number of splits, so we need to + // cancel out the divide by wg_denoms[0]. + sizeof(vk_flash_attn_push_constants), &pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z }); + + ggml_vk_sync_buffers(subctx); + const std::array pc2 = { D, (uint32_t)ne1, split_k }; + ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce, + { + vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE}, + vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE}, + }, + pc2.size() * uint32_t{sizeof(uint32_t)}, pc2.data(), { (uint32_t)ne1, 1, 1 }); + } else { + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, + { + vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE}, + vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE}, + vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE}, + vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE}, + vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE}, + }, + sizeof(vk_flash_attn_push_constants), &pc, { workgroups_x, workgroups_y, workgroups_z }); + } } static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) { @@ -5210,26 +5983,37 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const } return nullptr; case GGML_OP_ADD: - if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_f32_norepeat : ctx->device->pipeline_add_f32; - } - if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) { - return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_f16_f32_f16_norepeat : ctx->device->pipeline_add_f16_f32_f16; - } - return nullptr; case GGML_OP_SUB: - if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_sub_f32_norepeat : ctx->device->pipeline_sub_f32; - } - return nullptr; case GGML_OP_MUL: - if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_mul_f32_norepeat : ctx->device->pipeline_mul_f32; - } - return nullptr; case GGML_OP_DIV: - if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_div_f32_norepeat : ctx->device->pipeline_div_f32; + if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) || + (src1->type != GGML_TYPE_F32 && src1->type != GGML_TYPE_F16) || + (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16)) { + return nullptr; + } + switch (op) { + case GGML_OP_ADD: + { + auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_norepeat : ctx->device->pipeline_add; + return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16]; + } + case GGML_OP_SUB: + { + auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_sub_norepeat : ctx->device->pipeline_sub; + return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16]; + } + case GGML_OP_MUL: + { + auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_mul_norepeat : ctx->device->pipeline_mul; + return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16]; + } + case GGML_OP_DIV: + { + auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_div_norepeat : ctx->device->pipeline_div; + return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16]; + } + default: + break; } return nullptr; case GGML_OP_CONCAT: @@ -5244,7 +6028,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const } return nullptr; case GGML_OP_UPSCALE: - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && dst->op_params[0] == GGML_SCALE_MODE_NEAREST) { return ctx->device->pipeline_upscale_f32; } return nullptr; @@ -5317,38 +6101,31 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_rms_norm_back_f32; } return nullptr; + case GGML_OP_L2_NORM: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_l2_norm_f32; + } + return nullptr; case GGML_OP_UNARY: + if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) || + (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) || + (src0->type != dst->type)) { + return nullptr; + } + switch (ggml_get_unary_op(dst)) { case GGML_UNARY_OP_SILU: - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_silu_f32; - } - break; + return ctx->device->pipeline_silu[dst->type == GGML_TYPE_F16]; case GGML_UNARY_OP_GELU: - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_gelu_f32; - } - break; + return ctx->device->pipeline_gelu[dst->type == GGML_TYPE_F16]; case GGML_UNARY_OP_GELU_QUICK: - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_gelu_quick_f32; - } - break; + return ctx->device->pipeline_gelu_quick[dst->type == GGML_TYPE_F16]; case GGML_UNARY_OP_RELU: - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_relu_f32; - } - break; + return ctx->device->pipeline_relu[dst->type == GGML_TYPE_F16]; case GGML_UNARY_OP_TANH: - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_tanh_f32; - } - break; + return ctx->device->pipeline_tanh[dst->type == GGML_TYPE_F16]; case GGML_UNARY_OP_SIGMOID: - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_sigmoid_f32; - } - break; + return ctx->device->pipeline_sigmoid[dst->type == GGML_TYPE_F16]; default: break; } @@ -5456,6 +6233,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_rwkv_wkv6_f32; } return nullptr; + case GGML_OP_RWKV_WKV7: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_rwkv_wkv7_f32; + } + return nullptr; case GGML_OP_OPT_STEP_ADAMW: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_opt_step_adamw_f32; @@ -5466,6 +6248,15 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_leaky_relu_f32; } return nullptr; + case GGML_OP_CONV_2D_DW: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + if (ggml_is_contiguous(src1)) { + return ctx->device->pipeline_conv2d_dw_whcn_f32; + } else if (ggml_is_contiguous_channels(src1)) { + return ctx->device->pipeline_conv2d_dw_cwhn_f32; + } + } + return nullptr; default: return nullptr; } @@ -5491,6 +6282,8 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) { case GGML_OP_REPEAT: case GGML_OP_REPEAT_BACK: case GGML_OP_ROPE: + case GGML_OP_RMS_NORM: + case GGML_OP_CONV_2D_DW: return true; default: return false; @@ -5701,8 +6494,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co switch (op) { case GGML_OP_NORM: - case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM_BACK: + case GGML_OP_L2_NORM: case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX_BACK: case GGML_OP_SUM_ROWS: @@ -5717,6 +6510,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co elements = { nr, 1, 1 }; } } break; + case GGML_OP_RMS_NORM: + elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 }; + break; + case GGML_OP_SUM: // We use GGML_OP_SUM_ROWS with 1 row. elements = { 1, 1, 1 }; @@ -5783,6 +6580,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co case GGML_OP_CONCAT: case GGML_OP_UPSCALE: case GGML_OP_UNARY: + case GGML_OP_CONV_2D_DW: { const uint32_t ne = ggml_nelements(dst); if (ne > 262144) { @@ -5952,23 +6750,17 @@ static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const }, dryrun); } -static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, bool dryrun = false) { - const ggml_tensor * k = dst->src[0]; - const ggml_tensor * v = dst->src[1]; - const ggml_tensor * r = dst->src[2]; - const ggml_tensor * tf = dst->src[3]; - const ggml_tensor * td = dst->src[4]; - const ggml_tensor * state = dst->src[5]; +static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, int version, bool dryrun = false) { + GGML_ASSERT(version == 6 || version == 7); + int num_srcs = version == 6 ? 6 : 7; + + for (int i = 0; i < num_srcs; i++) { + GGML_ASSERT(!ggml_is_quantized(dst->src[i]->type)); + } - GGML_ASSERT(!ggml_is_quantized(k->type)); - GGML_ASSERT(!ggml_is_quantized(v->type)); - GGML_ASSERT(!ggml_is_quantized(r->type)); - GGML_ASSERT(!ggml_is_quantized(tf->type)); - GGML_ASSERT(!ggml_is_quantized(td->type)); - GGML_ASSERT(!ggml_is_quantized(state->type)); GGML_ASSERT(dst->buffer != nullptr); - vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, k, v, r, dst, GGML_OP_RWKV_WKV6); + vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op); GGML_ASSERT(pipeline != nullptr); if (dryrun) { @@ -5977,89 +6769,73 @@ static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subc } ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; - ggml_backend_vk_buffer_context * k_buf_ctx = (ggml_backend_vk_buffer_context *)k->buffer->context; - ggml_backend_vk_buffer_context * v_buf_ctx = (ggml_backend_vk_buffer_context *)v->buffer->context; - ggml_backend_vk_buffer_context * r_buf_ctx = (ggml_backend_vk_buffer_context *)r->buffer->context; - ggml_backend_vk_buffer_context * tf_buf_ctx = (ggml_backend_vk_buffer_context *)tf->buffer->context; - ggml_backend_vk_buffer_context * td_buf_ctx = (ggml_backend_vk_buffer_context *)td->buffer->context; - ggml_backend_vk_buffer_context * state_buf_ctx = (ggml_backend_vk_buffer_context *)state->buffer->context; + ggml_backend_vk_buffer_context * src_buf_ctxs[7] = { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr }; + for (int i = 0; i < num_srcs; i++) { + src_buf_ctxs[i] = (ggml_backend_vk_buffer_context *)dst->src[i]->buffer->context; + } ggml_vk_sync_buffers(subctx); - vk_buffer d_D = nullptr, d_K = nullptr, d_V = nullptr, d_R = nullptr, d_TF = nullptr, d_TD = nullptr, d_State = nullptr; - size_t k_offset = 0, v_offset = 0, r_offset = 0, tf_offset = 0, td_offset = 0, state_offset = 0, dst_offset = 0; - bool K_uma = false, V_uma = false, R_uma = false, TF_uma = false, TD_uma = false, STATE_uma = false, DST_uma = false; + vk_buffer d_D = nullptr, d_srcs[7] = { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr }; + size_t dst_offset = 0, src_offsets[7] = { 0, 0, 0, 0, 0, 0, 0 }; + bool dst_uma = false, srcs_uma[7] = { false, false, false, false, false, false, false }; if (ctx->device->uma) { - ggml_vk_host_get(ctx->device, k->data, d_K, k_offset); - ggml_vk_host_get(ctx->device, v->data, d_V, v_offset); - ggml_vk_host_get(ctx->device, r->data, d_R, r_offset); - ggml_vk_host_get(ctx->device, tf->data, d_TF, tf_offset); - ggml_vk_host_get(ctx->device, td->data, d_TD, td_offset); - ggml_vk_host_get(ctx->device, state->data, d_State, state_offset); + for (int i = 0; i < num_srcs; i++) { + ggml_vk_host_get(ctx->device, dst->src[i]->data, d_srcs[i], src_offsets[i]); + srcs_uma[i] = d_srcs[i] != nullptr; + } + ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset); - - K_uma = d_K != nullptr; - V_uma = d_V != nullptr; - R_uma = d_R != nullptr; - TF_uma = d_TF != nullptr; - TD_uma = d_TD != nullptr; - STATE_uma = d_State != nullptr; - DST_uma = d_D != nullptr; + dst_uma = d_D != nullptr; } - if (!K_uma) { - d_K = k_buf_ctx->dev_buffer; - k_offset = vk_tensor_offset(k) + k->view_offs; + uint64_t src_sizes[7] = { 0, 0, 0, 0, 0, 0, 0 }; + for (int i = 0; i < num_srcs; i++) { + src_sizes[i] = ggml_nbytes(dst->src[i]); + if (!srcs_uma[i]) { + d_srcs[i] = src_buf_ctxs[i]->dev_buffer; + src_offsets[i] = vk_tensor_offset(dst->src[i]) + dst->src[i]->view_offs; + } } - if (!V_uma) { - d_V = v_buf_ctx->dev_buffer; - v_offset = vk_tensor_offset(v) + v->view_offs; - } - if (!R_uma) { - d_R = r_buf_ctx->dev_buffer; - r_offset = vk_tensor_offset(r) + r->view_offs; - } - if (!TF_uma) { - d_TF = tf_buf_ctx->dev_buffer; - tf_offset = vk_tensor_offset(tf) + tf->view_offs; - } - if (!TD_uma) { - d_TD = td_buf_ctx->dev_buffer; - td_offset = vk_tensor_offset(td) + td->view_offs; - } - if (!STATE_uma) { - d_State = state_buf_ctx->dev_buffer; - state_offset = vk_tensor_offset(state) + state->view_offs; - } - if (!DST_uma) { + + const uint64_t dst_size = ggml_nbytes(dst); + if (!dst_uma) { d_D = dst_buf_ctx->dev_buffer; dst_offset = vk_tensor_offset(dst) + dst->view_offs; } - const uint64_t k_size = ggml_nbytes(k); - const uint64_t v_size = ggml_nbytes(v); - const uint64_t r_size = ggml_nbytes(r); - const uint64_t tf_size = ggml_nbytes(tf); - const uint64_t td_size = ggml_nbytes(td); - const uint64_t state_size = ggml_nbytes(state); - const uint64_t dst_size = ggml_nbytes(dst); - std::array elements = { (uint32_t)(pc.B * pc.H), 1, 1 }; - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { - vk_subbuffer{ d_K, k_offset, k_size }, - vk_subbuffer{ d_V, v_offset, v_size }, - vk_subbuffer{ d_R, r_offset, r_size }, - vk_subbuffer{ d_TF, tf_offset, tf_size }, - vk_subbuffer{ d_TD, td_offset, td_size }, - vk_subbuffer{ d_State, state_offset, state_size }, - vk_subbuffer{ d_D, dst_offset, dst_size } - }, sizeof(vk_op_rwkv_wkv6_push_constants), &pc, elements); + if (version == 6) { + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { + vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] }, + vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] }, + vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] }, + vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] }, + vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] }, + vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] }, + vk_subbuffer{ d_D, dst_offset, dst_size } + }, sizeof(vk_op_rwkv_wkv6_push_constants), &pc, elements); + } else if (version == 7) { + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { + vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] }, + vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] }, + vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] }, + vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] }, + vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] }, + vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] }, + vk_subbuffer{ d_srcs[6], src_offsets[6], src_sizes[6] }, + vk_subbuffer{ d_D, dst_offset, dst_size } + }, sizeof(vk_op_rwkv_wkv7_push_constants), &pc, elements); + } else { + // shouldn't happen + GGML_ASSERT(false); + } } static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) { @@ -6068,7 +6844,7 @@ static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, const size_t n_heads = dst->src[0]->ne[1]; const size_t n_seqs = dst->src[5]->ne[1]; - ggml_vk_op_f32_rwkv6( + ggml_vk_op_f32_wkv( ctx, subctx, dst, { (uint32_t)n_seqs, @@ -6076,6 +6852,26 @@ static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, (uint32_t)n_embed, (uint32_t)n_heads, }, + 6, + dryrun + ); +} + +static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) { + const size_t seq_length = dst->src[0]->ne[2]; + const size_t n_embed = dst->ne[0]; + const size_t n_heads = dst->src[0]->ne[1]; + const size_t n_seqs = dst->src[6]->ne[1]; + + ggml_vk_op_f32_wkv( + ctx, subctx, dst, + { + (uint32_t)n_seqs, + (uint32_t)seq_length, + (uint32_t)n_embed, + (uint32_t)n_heads, + }, + 7, dryrun ); } @@ -6369,7 +7165,17 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { float * op_params = (float *)dst->op_params; - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun); + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + op_params[0], 0.0f, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }, dryrun); } static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { @@ -6377,6 +7183,11 @@ static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& sub ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun); } +static void ggml_vk_l2_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + float * op_params = (float *)dst->op_params; + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun); +} + static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun); } @@ -6556,6 +7367,30 @@ static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, c }, dryrun); } +static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + vk_op_conv2d_dw_push_constants p{}; + p.ne = ggml_nelements(dst); + p.channels = dst->ne[2]; + p.batches = dst->ne[3]; + p.dst_w = dst->ne[0]; + p.dst_h = dst->ne[1]; + p.src_w = src1->ne[0]; + p.src_h = src1->ne[1]; + p.knl_w = src0->ne[0]; + p.knl_h = src0->ne[1]; + p.stride_x = dst->op_params[0]; + p.stride_y = dst->op_params[1]; + p.pad_x = dst->op_params[2]; + p.pad_y = dst->op_params[3]; + p.dilation_x = dst->op_params[4]; + p.dilation_y = dst->op_params[5]; + + GGML_ASSERT(src0->ne[3] == p.channels); + GGML_ASSERT(src1->ne[3] == p.batches); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_2D_DW, std::move(p), dryrun); +} + static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { const float * op_params = (const float *)dst->op_params; ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f }, dryrun); @@ -6717,6 +7552,10 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t } } + if (ctx->device->need_compiles) { + ggml_vk_load_shaders(ctx->device); + } + ggml_pipeline_allocate_descriptor_sets(ctx->device); vk_buffer d_X = ggml_vk_create_buffer_check(ctx->device, sizeof(X_TYPE) * x_ne, vk::MemoryPropertyFlagBits::eDeviceLocal); @@ -6766,7 +7605,7 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t ctx, subctx, p, ggml_vk_subbuffer(d_X), ggml_vk_subbuffer(d_Y), ggml_vk_subbuffer(d_D), ggml_vk_subbuffer(ctx->prealloc_split_k), m, n, k, k, k, m, k*m, k*n, m*n, - split_k, batch, batch, batch, 1, 1 + split_k, batch, batch, batch, 1, 1, n ); } ggml_vk_ctx_end(subctx); @@ -6965,6 +7804,10 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_ ggml_pipeline_request_descriptor_sets(ctx->device, p, 1); + if (ctx->device->need_compiles) { + ggml_vk_load_shaders(ctx->device); + } + ggml_pipeline_allocate_descriptor_sets(ctx->device); ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz); @@ -7024,66 +7867,198 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_ free(x_chk); } -static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, size_t split_k, size_t shader_size, ggml_type quant) { +// This does not work without ggml q8_1 quantization support +// +// typedef uint16_t ggml_half; +// typedef uint32_t ggml_half2; +// +// #define QK8_1 32 +// typedef struct { +// union { +// struct { +// ggml_half d; // delta +// ggml_half s; // d * sum(qs[i]) +// } GGML_COMMON_AGGR_S; +// ggml_half2 ds; +// } GGML_COMMON_AGGR_U; +// int8_t qs[QK8_1]; // quants +// } block_q8_1; +// +// static void ggml_vk_test_quantize(ggml_backend_vk_context * ctx, size_t ne, ggml_type quant) { +// VK_LOG_DEBUG("ggml_vk_test_quantize(" << ne << ")"); +// GGML_ASSERT(quant == GGML_TYPE_Q8_1); +// +// const size_t x_sz = sizeof(float) * ne; +// const size_t qx_sz = ne * ggml_type_size(quant)/ggml_blck_size(quant); +// float * x = (float *) malloc(x_sz); +// block_q8_1 * qx = (block_q8_1 *)malloc(qx_sz); +// block_q8_1 * qx_res = (block_q8_1 *)malloc(qx_sz); +// vk_buffer x_buf = ggml_vk_create_buffer_check(ctx->device, x_sz, vk::MemoryPropertyFlagBits::eDeviceLocal); +// vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal); +// +// for (size_t i = 0; i < ne; i++) { +// x[i] = rand() / (float)RAND_MAX; +// } +// +// vk_pipeline p = ggml_vk_get_quantize_pipeline(ctx, quant); +// +// ggml_pipeline_request_descriptor_sets(ctx->device, p, 1); +// +// if (ctx->device->need_compiles) { +// ggml_vk_load_shaders(ctx->device); +// } +// +// ggml_pipeline_allocate_descriptor_sets(ctx->device); +// +// ggml_vk_buffer_write(x_buf, 0, x, x_sz); +// +// vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue); +// ggml_vk_ctx_begin(ctx->device, subctx); +// ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(x_buf), ggml_vk_subbuffer(qx_buf), ne); +// ggml_vk_ctx_end(subctx); +// +// auto begin = std::chrono::high_resolution_clock::now(); +// +// ggml_vk_submit(subctx, ctx->fence); +// VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_quantize waitForFences"); +// ctx->device->device.resetFences({ ctx->fence }); +// +// auto end = std::chrono::high_resolution_clock::now(); +// +// double ms_quant = std::chrono::duration_cast(end-begin).count() / 1000.0; +// ggml_vk_buffer_read(qx_buf, 0, qx, qx_sz); +// +// ggml_vk_quantize_data(x, qx_res, ne, quant); +// +// int first_err = -1; +// +// for (size_t i = 0; i < ne / 32; i++) { +// double error = std::fabs(ggml_fp16_to_fp32(qx_res[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) - ggml_fp16_to_fp32(qx[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d)); +// +// if (first_err < 0 && error > 0.1) { +// first_err = i; +// } +// +// error = std::fabs(ggml_fp16_to_fp32(qx_res[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) - ggml_fp16_to_fp32(qx[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s)); +// +// if (first_err < 0 && error > 0.1) { +// first_err = i; +// } +// +// for (size_t j = 0; j < 32; j++) { +// uint64_t error = std::abs(qx_res[i].qs[j] - qx[i].qs[j]); +// +// if (first_err < 0 && error > 1) { +// first_err = i; +// } +// } +// } +// +// std::cerr << "TEST QUANTIZE " << ggml_type_name(quant) << " time=" << ms_quant << "ms " << (first_err == -1 ? "CORRECT" : "INCORRECT") << std::endl; +// +// if (first_err != -1) { +// std::cerr << "first_error = " << first_err << std::endl; +// std::cerr << "Actual result: " << std::endl << std::endl; +// std::cout << "d=" << ggml_fp16_to_fp32(qx[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) << " s=" << ggml_fp16_to_fp32(qx[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) << " "; +// for (size_t j = 0; j < 32; j++) { +// std::cout << " qs" << j << "=" << (uint32_t)qx[first_err].qs[j] << " "; +// } +// std::cerr << std::endl << std::endl << "Expected result: " << std::endl << std::endl; +// std::cout << "d=" << ggml_fp16_to_fp32(qx_res[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) << " s=" << ggml_fp16_to_fp32(qx_res[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) << " "; +// for (size_t j = 0; j < 32; j++) { +// std::cout << " qs" << j << "=" << (uint32_t)qx_res[first_err].qs[j] << " "; +// } +// std::cerr << std::endl; +// } +// +// ggml_vk_destroy_buffer(x_buf); +// ggml_vk_destroy_buffer(qx_buf); +// +// free(x); +// free(qx); +// free(qx_res); +// } + +static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, size_t split_k, size_t shader_size, ggml_type quant, bool mmq = false) { VK_LOG_DEBUG("ggml_vk_test_dequant_matmul(" << m << ", " << n << ", " << k << ", " << batch << ", " << num_it << ", " << split_k << ", " << ggml_type_name(quant) << ")"); const size_t x_ne = m * k * batch; const size_t y_ne = k * n * batch; const size_t d_ne = m * n * batch; + vk_matmul_pipeline2 * pipelines; + + if (mmq) { + pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1; + } else { + pipelines = ctx->device->pipeline_dequant_mul_mat_mat; + } + + const bool fp16acc = ctx->device->fp16; + vk_pipeline p; std::string shname; if (shader_size == 0) { - p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_s : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_s; + p = fp16acc ? pipelines[quant].f16acc->a_s : pipelines[quant].f32acc->a_s; shname = std::string(ggml_type_name(quant)) + "_ALIGNED_S"; } else if (shader_size == 1) { - p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_m : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_m; + p = fp16acc ? pipelines[quant].f16acc->a_m : pipelines[quant].f32acc->a_m; shname = std::string(ggml_type_name(quant)) + "_ALIGNED_M"; } else if (shader_size == 2) { - p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_l : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_l; + p = fp16acc ? pipelines[quant].f16acc->a_l : pipelines[quant].f32acc->a_l; shname = std::string(ggml_type_name(quant)) + "_ALIGNED_L"; } else { GGML_ASSERT(0); } - const size_t kpad = ggml_vk_align_size(k, p->align); + const size_t kpad = mmq ? 0 : ggml_vk_align_size(k, p->align); - if (k != kpad) { + if (mmq || k != kpad) { if (shader_size == 0) { - p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->s : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->s; + p = fp16acc ? pipelines[quant].f16acc->s : pipelines[quant].f32acc->s; shname = std::string(ggml_type_name(quant)) + "_S"; } else if (shader_size == 1) { - p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->m : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->m; + p = fp16acc ? pipelines[quant].f16acc->m : pipelines[quant].f32acc->m; shname = std::string(ggml_type_name(quant)) + "_M"; } else if (shader_size == 2) { - p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->l : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->l; + p = fp16acc ? pipelines[quant].f16acc->l : pipelines[quant].f32acc->l; shname = std::string(ggml_type_name(quant)) + "_L"; } else { GGML_ASSERT(0); } } + if (p == nullptr) { + std::cerr << "error: no pipeline for ggml_vk_test_dequant_matmul " << ggml_type_name(quant) << std::endl; + return; + } + const size_t x_sz = sizeof(float) * x_ne; const size_t y_sz = sizeof(float) * y_ne; const size_t qx_sz = x_ne * ggml_type_size(quant)/ggml_blck_size(quant); + const size_t qy_sz = mmq ? y_ne * ggml_type_size(GGML_TYPE_Q8_1)/ggml_blck_size(GGML_TYPE_Q8_1) : y_sz; const size_t d_sz = sizeof(float) * d_ne; float * x = (float *) malloc(x_sz); float * y = (float *) malloc(y_sz); void * qx = malloc(qx_sz); vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal); vk_buffer y_buf = ggml_vk_create_buffer_check(ctx->device, y_sz, vk::MemoryPropertyFlagBits::eDeviceLocal); + vk_buffer qy_buf = ggml_vk_create_buffer_check(ctx->device, qy_sz, vk::MemoryPropertyFlagBits::eDeviceLocal); vk_buffer d_buf = ggml_vk_create_buffer_check(ctx->device, d_sz, vk::MemoryPropertyFlagBits::eDeviceLocal); float * d = (float *) malloc(d_sz); float * d_chk = (float *) malloc(d_sz); for (size_t i = 0; i < x_ne; i++) { x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f; + // x[i] = (i % k == i / k) ? 1.0f : 0.0f; + // x[i] = i % k; } ggml_vk_quantize_data(x, qx, x_ne, quant); for (size_t i = 0; i < y_ne; i++) { - // y[i] = rand() / (float)RAND_MAX; - y[i] = (i % k == i / k) ? 1.0f : 0.0f; + y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f; + // y[i] = (i % k == i / k) ? 1.0f : 0.0f; + // y[i] = i % k; } ggml_pipeline_request_descriptor_sets(ctx->device, p, num_it); @@ -7098,6 +8073,13 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, vk::MemoryPropertyFlagBits::eDeviceLocal); } } + if (mmq) { + ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_quantize_q8_1, num_it); + } + + if (ctx->device->need_compiles) { + ggml_vk_load_shaders(ctx->device); + } ggml_pipeline_allocate_descriptor_sets(ctx->device); @@ -7106,13 +8088,25 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue); ggml_vk_ctx_begin(ctx->device, subctx); - for (size_t i = 0; i < num_it; i++) { - ggml_vk_matmul( - ctx, subctx, p, ggml_vk_subbuffer(qx_buf), ggml_vk_subbuffer(y_buf), ggml_vk_subbuffer(d_buf), ggml_vk_subbuffer(ctx->prealloc_split_k), - m, n, k, - k, k, m, k*m, k*n, m*n, - split_k, batch, batch, batch, 1, 1 - ); + if (mmq) { + for (size_t i = 0; i < num_it; i++) { + ggml_vk_quantize_q8_1(ctx, subctx, { y_buf, 0, y_sz }, { qy_buf, 0, qy_sz }, y_ne); + ggml_vk_matmul( + ctx, subctx, p, { qx_buf, 0, qx_sz }, { qy_buf, 0, qy_sz }, { d_buf, 0, d_sz }, { ctx->prealloc_split_k, 0, ctx->prealloc_size_split_k }, + m, n, k, + k, k, m, k*m, k*n, m*n, + split_k, batch, batch, batch, 1, 1, n + ); + } + } else { + for (size_t i = 0; i < num_it; i++) { + ggml_vk_matmul( + ctx, subctx, p, { qx_buf, 0, qx_sz }, { y_buf, 0, y_sz }, { d_buf, 0, d_sz }, { ctx->prealloc_split_k, 0, ctx->prealloc_size_split_k }, + m, n, k, + k, k, m, k*m, k*n, m*n, + split_k, batch, batch, batch, 1, 1, n + ); + } } ggml_vk_ctx_end(subctx); @@ -7170,7 +8164,11 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, double tflops = 2.0*m*n*k*batch*num_it / (time_ms / 1000.0) / (1000.0*1000.0*1000.0*1000.0); - std::cerr << "TEST MMQ " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time_ms / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl; + std::cerr << "TEST dequant matmul " << shname; + if (mmq) { + std::cerr << " mmq"; + } + std::cerr << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time_ms / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl; if (avg_err > 0.01 || std::isnan(avg_err)) { std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl; @@ -7180,6 +8178,12 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, std::cerr << "Expected result: " << std::endl << std::endl; ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); + std::cerr << "src0: " << std::endl << std::endl; + ggml_vk_print_matrix_area(x, GGML_TYPE_F32, k, m, first_err_m, first_err_n, first_err_b); + std::cerr << std::endl; + std::cerr << "src1: " << std::endl << std::endl; + ggml_vk_print_matrix_area(y, GGML_TYPE_F32, k, n, first_err_m, first_err_n, first_err_b); + if (split_k > 1) { float * split_k_buf = (float *) malloc(sizeof(float) * d_ne * split_k); ggml_vk_buffer_read(ctx->prealloc_split_k, 0, split_k_buf, sizeof(float) * d_ne * split_k); @@ -7202,6 +8206,7 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, ggml_vk_destroy_buffer(qx_buf); ggml_vk_destroy_buffer(y_buf); + ggml_vk_destroy_buffer(qy_buf); ggml_vk_destroy_buffer(d_buf); free(x); @@ -7236,6 +8241,24 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) { }; const size_t num_it = 100; + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q4_0); + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q4_0); + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q4_0); + + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q4_0, true); + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q4_0, true); + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q4_0, true); + + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q8_0); + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q8_0); + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q8_0); + + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q8_0, true); + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q8_0, true); + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q8_0, true); + + abort(); + for (size_t i = 0; i < vals.size(); i += 3) { ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0); ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1); @@ -7310,11 +8333,11 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) { } } -static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_tensor* tensor, int tensor_idx, bool use_fence); +static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready); // Returns true if node has enqueued work into the queue, false otherwise // If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution. -static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * node, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool submit){ +static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * node, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool almost_ready, bool submit){ if (ggml_is_empty(node) || !node->buffer) { return false; } @@ -7372,6 +8395,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_GROUP_NORM: case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM_BACK: + case GGML_OP_L2_NORM: case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX_BACK: @@ -7387,7 +8411,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_IM2COL: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_POOL_2D: + case GGML_OP_CONV_2D_DW: case GGML_OP_RWKV_WKV6: + case GGML_OP_RWKV_WKV7: case GGML_OP_LEAKY_RELU: case GGML_OP_FLASH_ATTN_EXT: case GGML_OP_OPT_STEP_ADAMW: @@ -7434,6 +8460,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_GROUP_NORM: case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM_BACK: + case GGML_OP_L2_NORM: case GGML_OP_UNARY: case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: @@ -7448,6 +8475,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_IM2COL: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_POOL_2D: + case GGML_OP_CONV_2D_DW: case GGML_OP_LEAKY_RELU: { // These operations all go through ggml_vk_op_f32, so short-circuit and @@ -7551,6 +8579,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_RMS_NORM_BACK: ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun); + break; + case GGML_OP_L2_NORM: + ggml_vk_l2_norm(ctx, compute_ctx, src0, node, dryrun); + break; case GGML_OP_UNARY: switch (ggml_get_unary_op(node)) { @@ -7617,6 +8649,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_POOL_2D: ggml_vk_pool_2d(ctx, compute_ctx, src0, node, dryrun); + break; + case GGML_OP_CONV_2D_DW: + ggml_vk_conv_2d_dw(ctx, compute_ctx, src0, src1, node, dryrun); + break; case GGML_OP_LEAKY_RELU: ggml_vk_leaky_relu(ctx, compute_ctx, src0, node, dryrun); @@ -7641,6 +8677,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod break; + case GGML_OP_RWKV_WKV7: + ggml_vk_rwkv_wkv7(ctx, compute_ctx, node, dryrun); + + break; + case GGML_OP_OPT_STEP_ADAMW: ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun); @@ -7674,7 +8715,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod ctx->compute_ctx.reset(); - bool ok = ggml_vk_compute_forward(ctx, node_begin, node_idx_begin, false); + bool ok = ggml_vk_compute_forward(ctx, node_begin, node_idx_begin, false, almost_ready); if (!ok) { if (node->op == GGML_OP_UNARY) { std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast(node->op_params[0])) << ")" << std::endl; @@ -7688,7 +8729,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod return true; } -static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * tensor, int tensor_idx, bool use_fence = true){ +static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * tensor, int tensor_idx, bool use_fence = true, bool almost_ready = false) { ggml_backend_buffer * buf = nullptr; switch (tensor->op) { @@ -7714,6 +8755,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * case GGML_OP_GROUP_NORM: case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM_BACK: + case GGML_OP_L2_NORM: case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX_BACK: @@ -7732,7 +8774,9 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * case GGML_OP_IM2COL: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_POOL_2D: + case GGML_OP_CONV_2D_DW: case GGML_OP_RWKV_WKV6: + case GGML_OP_RWKV_WKV7: case GGML_OP_LEAKY_RELU: case GGML_OP_REPEAT: case GGML_OP_REPEAT_BACK: @@ -7789,12 +8833,15 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * memcpy(cpy.dst, cpy.src, cpy.n); } - ggml_vk_submit(subctx, use_fence ? ctx->fence : vk::Fence{}); + if (almost_ready && !ctx->almost_ready_fence_pending && !use_fence) { + ggml_vk_submit(subctx, ctx->almost_ready_fence); + ctx->almost_ready_fence_pending = true; + } else { + ggml_vk_submit(subctx, use_fence ? ctx->fence : vk::Fence{}); + } if (use_fence) { - VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_compute_forward waitForFences"); - - ctx->device->device.resetFences({ ctx->fence }); + ggml_vk_wait_for_fence(ctx); } #ifdef GGML_VULKAN_CHECK_RESULTS ggml_vk_check_results_1(tensor); @@ -7880,6 +8927,7 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) { ctx->gc.events.clear(); ctx->device->device.destroyFence(ctx->fence); + ctx->device->device.destroyFence(ctx->almost_ready_fence); } static int ggml_vk_get_device_count() { @@ -7922,11 +8970,12 @@ static void * ggml_backend_vk_buffer_get_base(ggml_backend_buffer_t buffer) { UNUSED(buffer); } -static void ggml_backend_vk_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { +static enum ggml_status ggml_backend_vk_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { VK_LOG_DEBUG("ggml_backend_vk_buffer_init_tensor(" << buffer << " (" << buffer->context << "), " << tensor << ")"); if (tensor->view_src != nullptr) { GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft); } + return GGML_STATUS_SUCCESS; } static void ggml_backend_vk_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { @@ -8225,8 +9274,7 @@ static void ggml_backend_vk_synchronize(ggml_backend_t backend) { } ggml_vk_submit(transfer_ctx, ctx->fence); - VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_backend_vk_synchronize waitForFences"); - ctx->device->device.resetFences({ ctx->fence }); + ggml_vk_wait_for_fence(ctx); for (auto& cpy : transfer_ctx->out_memcpys) { memcpy(cpy.dst, cpy.src, cpy.n); @@ -8243,8 +9291,12 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)"); ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; + uint64_t total_mat_mul_bytes = 0; for (int i = 0; i < cgraph->n_nodes; i++) { - ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false); + ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false, false); + if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) { + total_mat_mul_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]); + } } if (ctx->device->need_compiles) { ggml_vk_load_shaders(ctx->device); @@ -8265,19 +9317,32 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg bool first_node_in_batch = true; // true if next node will be first node in a batch int submit_node_idx = 0; // index to first node in a batch - // Submit work every nodes_per_submit nodes to overlap CPU cmdbuffer generation with GPU execution. - // Start with a smaller count to get work submitted right away, and increase it after each submit. - int nodes_per_submit = 20; + // Submit after enough work has accumulated, to overlap CPU cmdbuffer generation with GPU execution. + // Estimate the amount of matmul work by looking at the weight matrix size, and submit every 100MB + // (and scaled down based on model size, so smaller models submit earlier). + // Also submit at least every 100 nodes, in case there are workloads without as much matmul. + int nodes_per_submit = 100; int submitted_nodes = 0; int submit_count = 0; + uint64_t mul_mat_bytes = 0; + uint64_t mul_mat_bytes_per_submit = std::min(uint64_t(100*1000*1000), total_mat_mul_bytes / 40u); for (int i = 0; i < cgraph->n_nodes; i++) { if (first_node_in_batch) { submit_node_idx = i; } - bool submit = (submitted_nodes >= nodes_per_submit) || (i == last_node); + if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) { + mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]); + } - bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, submit); + // Signal the almost_ready fence when the graph is mostly complete (< 20% remaining) + bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5; + bool submit = (submitted_nodes >= nodes_per_submit) || + (mul_mat_bytes >= mul_mat_bytes_per_submit) || + (i == last_node) || + (almost_ready && !ctx->almost_ready_fence_pending); + + bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, almost_ready, submit); if (enqueued) { ++submitted_nodes; @@ -8289,16 +9354,12 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg #endif } - if (submit) { + if (submit && enqueued) { first_node_in_batch = true; submitted_nodes = 0; - switch (submit_count) { - case 0: - nodes_per_submit = 50; - break; - default: - nodes_per_submit = 100; - break; + mul_mat_bytes = 0; + if (submit_count < 3) { + mul_mat_bytes_per_submit *= 2; } submit_count++; } @@ -8450,7 +9511,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_TANH: case GGML_UNARY_OP_SIGMOID: - return ggml_is_contiguous(op->src[0]); + return ggml_is_contiguous(op->src[0]) && + (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && + (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && + (op->src[0]->type == op->type); default: return false; } @@ -8468,6 +9532,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm switch (src0_type) { case GGML_TYPE_F32: case GGML_TYPE_F16: + case GGML_TYPE_BF16: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -8503,19 +9568,23 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm if (a->ne[3] != b->ne[3]) { return false; } - if (!(ggml_vk_dim01_contiguous(op->src[0]) || op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) || + if (!(ggml_vk_dim01_contiguous(op->src[0]) || op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_BF16) || !(ggml_vk_dim01_contiguous(op->src[1]) || op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16)) { return false; } + if (op->src[0]->type == GGML_TYPE_BF16 && op->src[1]->type == GGML_TYPE_F16) { + // We currently don't have a bf16 x f16 shader, or an fp16->bf16 copy shader. + // So don't support this combination for now. + return false; + } return true; } break; case GGML_OP_FLASH_ATTN_EXT: { ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; - if (!ggml_vk_get_device(ctx->device)->coopmat2) { - return false; - } + auto device = ggml_vk_get_device(ctx->device); + bool coopmat2 = device->coopmat2; switch (op->src[0]->ne[0]) { case 64: case 80: @@ -8527,6 +9596,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm default: return false; } + if (op->src[1]->ne[0] != op->src[2]->ne[0]) { + // different head sizes of K and V are not supported yet + return false; + } if (op->src[0]->type != GGML_TYPE_F32) { return false; } @@ -8544,10 +9617,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm switch (op->src[1]->type) { case GGML_TYPE_F16: case GGML_TYPE_Q4_0: + case GGML_TYPE_Q8_0: + // supported in scalar and coopmat2 paths + break; case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: - case GGML_TYPE_Q8_0: // K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently //case GGML_TYPE_Q2_K: //case GGML_TYPE_Q3_K: @@ -8563,10 +9638,18 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm //case GGML_TYPE_IQ3_S: //case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: + // currently supported only in coopmat2 path + if (!coopmat2) { + return false; + } break; default: return false; } + if (!coopmat2 && !device->subgroup_shuffle) { + // scalar FA uses subgroupShuffle + return false; + } return true; } case GGML_OP_GET_ROWS: @@ -8574,6 +9657,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm switch (op->src[0]->type) { case GGML_TYPE_F32: case GGML_TYPE_F16: + case GGML_TYPE_BF16: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -8604,6 +9688,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm switch (src1_type) { case GGML_TYPE_F32: case GGML_TYPE_F16: + case GGML_TYPE_BF16: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -8617,6 +9702,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm } if (src1_type == GGML_TYPE_F32) { switch (src0_type) { + case GGML_TYPE_F16: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -8645,25 +9731,31 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_VIEW: case GGML_OP_PERMUTE: case GGML_OP_TRANSPOSE: + case GGML_OP_RMS_NORM: return true; case GGML_OP_NORM: case GGML_OP_GROUP_NORM: - case GGML_OP_RMS_NORM: + case GGML_OP_L2_NORM: return ggml_is_contiguous(op->src[0]); case GGML_OP_ADD: - case GGML_OP_ACC: case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: - case GGML_OP_CONCAT: + return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && + (op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16) && + (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16); case GGML_OP_SILU_BACK: case GGML_OP_RMS_NORM_BACK: - case GGML_OP_UPSCALE: - case GGML_OP_SCALE: case GGML_OP_SQR: case GGML_OP_SIN: case GGML_OP_COS: case GGML_OP_CLAMP: + return op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_UPSCALE: + return op->op_params[0] == GGML_SCALE_MODE_NEAREST; + case GGML_OP_ACC: + case GGML_OP_CONCAT: + case GGML_OP_SCALE: case GGML_OP_PAD: case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: @@ -8675,8 +9767,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_COUNT_EQUAL: case GGML_OP_IM2COL: case GGML_OP_TIMESTEP_EMBEDDING: + case GGML_OP_CONV_2D_DW: case GGML_OP_POOL_2D: case GGML_OP_RWKV_WKV6: + case GGML_OP_RWKV_WKV7: case GGML_OP_LEAKY_RELU: case GGML_OP_OPT_STEP_ADAMW: return true; @@ -8823,7 +9917,7 @@ static bool ggml_vk_instance_portability_enumeration_ext_available(const std::ve UNUSED(instance_extensions); } -static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props) { +static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch) { switch (props.vendorID) { case VK_VENDOR_ID_INTEL: // Intel drivers don't support coopmat properly yet @@ -8831,10 +9925,7 @@ static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDevicePrope case VK_VENDOR_ID_AMD: if (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource) { // Workaround for AMD proprietary driver reporting support on all GPUs - const std::string name = props.deviceName; - return name.rfind("AMD Radeon RX 7", 0) == 0 || name.rfind("AMD Radeon(TM) RX 7", 0) == 0 || // RDNA 3 consumer GPUs - name.rfind("AMD Radeon PRO W7", 0) == 0 || name.rfind("AMD Radeon(TM) PRO W7", 0) == 0 || // RDNA 3 workstation GPUs - name.rfind("AMD Radeon 7", 0) == 0 || name.rfind("AMD Radeon(TM) 7", 0) == 0; // RDNA 3 APUs + return arch == vk_device_architecture::AMD_RDNA3; } return true; default: @@ -9017,7 +10108,7 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { } if (tensor->op == GGML_OP_FLASH_ATTN_EXT) { - const float *params = (const float *)tensor->op_params; + const float * params = (const float *)tensor->op_params; tensor_clone = ggml_flash_attn_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], params[0], params[1], params[2]); } else if (tensor->op == GGML_OP_MUL_MAT) { tensor_clone = ggml_mul_mat(ggml_ctx, src_clone[0], src_clone[1]); @@ -9032,9 +10123,10 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { } else if (tensor->op == GGML_OP_CONCAT) { tensor_clone = ggml_concat(ggml_ctx, src_clone[0], src_clone[1], *(int *)tensor->op_params); } else if (tensor->op == GGML_OP_UPSCALE) { - tensor_clone = ggml_upscale_ext(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); + tensor_clone = ggml_upscale_ext(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], tensor->op_params[0], tensor->op_params[1], (ggml_scale_mode) tensor->op_params[0]); } else if (tensor->op == GGML_OP_SCALE) { - tensor_clone = ggml_scale(ggml_ctx, src_clone[0], ((float *)tensor->op_params)[0]); + const float * params = (const float *)tensor->op_params; + tensor_clone = ggml_scale(ggml_ctx, src_clone[0], params[0]); } else if (tensor->op == GGML_OP_SQR) { tensor_clone = ggml_sqr(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_SIN) { @@ -9042,7 +10134,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { } else if (tensor->op == GGML_OP_COS) { tensor_clone = ggml_cos(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_CLAMP) { - tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]); + const float * params = (const float *)tensor->op_params; + tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], params[0], params[1]); } else if (tensor->op == GGML_OP_PAD) { tensor_clone = ggml_pad(ggml_ctx, src_clone[0], tensor->ne[0] - src_clone[0]->ne[0], tensor->ne[1] - src_clone[0]->ne[1], tensor->ne[2] - src_clone[0]->ne[2], tensor->ne[3] - src_clone[0]->ne[3]); } else if (tensor->op == GGML_OP_REPEAT) { @@ -9056,7 +10149,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { } else if (tensor->op == GGML_OP_NORM) { tensor_clone = ggml_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params); } else if (tensor->op == GGML_OP_GROUP_NORM) { - tensor_clone = ggml_group_norm(ggml_ctx, src_clone[0], *(int *)tensor->op_params, ((float *)tensor->op_params)[1]); + const float * float_params = (const float *)tensor->op_params; + tensor_clone = ggml_group_norm(ggml_ctx, src_clone[0], tensor->op_params[0], float_params[1]); } else if (tensor->op == GGML_OP_RMS_NORM) { tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params); } else if (tensor->op == GGML_OP_RMS_NORM_BACK) { @@ -9064,16 +10158,20 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { tensor_clone = ggml_rms_norm_back(ggml_ctx, src_clone[0], src_clone[1], eps); } else if (tensor->op == GGML_OP_SILU_BACK) { tensor_clone = ggml_silu_back(ggml_ctx, src_clone[0], src_clone[1]); + } else if (tensor->op == GGML_OP_L2_NORM) { + const float eps = ((float *) tensor->op_params)[0]; + tensor_clone = ggml_l2_norm(ggml_ctx, src_clone[0], eps); } else if (tensor->op == GGML_OP_SOFT_MAX) { if (src1 != nullptr) { - tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]); + const float * params = (const float *)tensor->op_params; + tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], params[0], params[1]); } else { tensor_clone = ggml_soft_max(ggml_ctx, src_clone[0]); } } else if (tensor->op == GGML_OP_SOFT_MAX_BACK) { tensor_clone = ggml_soft_max_ext_back(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]); } else if (tensor->op == GGML_OP_DIAG_MASK_INF) { - tensor_clone = ggml_diag_mask_inf(ggml_ctx, src_clone[0], *(int *)tensor->op_params); + tensor_clone = ggml_diag_mask_inf(ggml_ctx, src_clone[0], tensor->op_params[0]); } else if (tensor->op == GGML_OP_ROPE || tensor->op == GGML_OP_ROPE_BACK) { const int n_dims = ((int32_t *) tensor->op_params)[1]; const int mode = ((int32_t *) tensor->op_params)[2]; @@ -9183,6 +10281,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { } else if (tensor->op == GGML_OP_RWKV_WKV6) { tensor_clone = ggml_rwkv_wkv6(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], src_clone[4], src_clone[5]); + } else if (tensor->op == GGML_OP_RWKV_WKV7) { + tensor_clone = ggml_rwkv_wkv7(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], + src_clone[4], src_clone[5], src_clone[6]); } else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) { src_clone[0]->flags = src0->flags; tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1], diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt index 074031087..ad13f69b3 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt @@ -1,9 +1,20 @@ -find_package (Threads REQUIRED) -find_program(GLSLC_EXECUTABLE glslc) -if(NOT GLSLC_EXECUTABLE) - message(FATAL_ERROR "glslc not found.") -endif() +cmake_minimum_required(VERSION 3.19) +project("vulkan-shaders-gen" C CXX) +find_package (Threads REQUIRED) + +if (GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + add_compile_definitions(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) +endif() +if (GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) +endif() +if (GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) +endif() +if (GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + add_compile_definitions(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) +endif() set(TARGET vulkan-shaders-gen) add_executable(${TARGET} vulkan-shaders-gen.cpp) install(TARGETS ${TARGET} RUNTIME) diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp index dd828c232..6567a8c54 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp @@ -18,7 +18,11 @@ void main() { // fast path for when all four iterations are in-bounds if (idx + (num_iter-1)*num_threads < p.ne) { [[unroll]] for (uint i = 0; i < num_iter; ++i) { -#ifndef OPTIMIZATION_ERROR_WORKAROUND + +#if defined(DATA_D_BF16) + float f = float(data_a[get_aoffset() + idx]); + data_d[get_doffset() + idx] = D_TYPE(fp32_to_bf16(f)); +#elif !defined(OPTIMIZATION_ERROR_WORKAROUND) data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]); #else data_d[get_doffset() + idx] = data_a[get_aoffset() + idx]; @@ -31,7 +35,10 @@ void main() { continue; } -#ifndef OPTIMIZATION_ERROR_WORKAROUND +#if defined(DATA_D_BF16) + float f = float(data_a[get_aoffset() + idx]); + data_d[get_doffset() + idx] = D_TYPE(fp32_to_bf16(f)); +#elif !defined(OPTIMIZATION_ERROR_WORKAROUND) data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]); #else data_d[get_doffset() + idx] = data_a[get_aoffset() + idx]; diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp new file mode 100644 index 000000000..938c74da5 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp @@ -0,0 +1,105 @@ +#version 450 + +#include "types.comp" + +layout (push_constant) uniform parameter +{ + uint ne; + uint batches; + uint channels; + uint dst_w; + uint dst_h; + uint src_w; + uint src_h; + uint knl_w; + uint knl_h; + int stride_x; + int stride_y; + int pad_x; + int pad_y; + int dilation_x; + int dilation_y; +} p; + +layout (binding = 0) readonly buffer A {A_TYPE knl_data[];}; +layout (binding = 1) readonly buffer B {B_TYPE src_data[];}; +layout (binding = 2) writeonly buffer D {D_TYPE dst_data[];}; + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +FLOAT_TYPE conv_2d_dw_whcn(uint idx) { + uint i0 = idx / p.dst_w; + uint dst_x = idx - i0 * p.dst_w; + uint i1 = i0 / p.dst_h; + uint dst_y = i0 - i1 * p.dst_h; + uint n = i1 / p.channels; + uint c = i1 - n * p.channels; + + uint src_i = n * p.channels * p.src_h * p.src_w + c * p.src_h * p.src_w; + uint knl_i = c * p.knl_h * p.knl_w; + + FLOAT_TYPE sum = 0.0; + for (uint knl_y = 0; knl_y < p.knl_h; ++knl_y) { + uint src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y; + if (src_y >= p.src_h) { // src_y < 0 will wrap to a large unsigned int + continue; + } + for (uint knl_x = 0; knl_x < p.knl_w; ++knl_x) { + uint src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x; + if (src_x >= p.src_w) { // src_x < 0 will wrap to a large unsigned int + continue; + } + FLOAT_TYPE v = FLOAT_TYPE(src_data[src_i + src_y * p.src_w + src_x]); + FLOAT_TYPE k = FLOAT_TYPE(knl_data[knl_i + knl_y * p.knl_w + knl_x]); + sum = fma(v, k, sum); + } + } + return sum; +} + +FLOAT_TYPE conv_2d_dw_cwhn(uint idx) { + uint i0 = idx / p.channels; + uint c = idx - i0 * p.channels; + uint i1 = i0 / p.dst_w; + uint dst_x = i0 - i1 * p.dst_w; + uint n = i1 / p.dst_h; + uint dst_y = i1 - n * p.dst_h; + + uint src_i = n * p.channels * p.src_h * p.src_w; + uint src_row = p.src_w * p.channels; + uint knl_row = p.knl_w * p.channels; + + FLOAT_TYPE sum = 0.0; + for (uint knl_y = 0; knl_y < p.knl_h; ++knl_y) { + uint src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y; + if (src_y >= p.src_h) { // src_y < 0 will wrap to a large unsigned int + continue; + } + for (uint knl_x = 0; knl_x < p.knl_w; ++knl_x) { + uint src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x; + if (src_x >= p.src_w) { // src_x < 0 will wrap to a large unsigned int + continue; + } + FLOAT_TYPE v = FLOAT_TYPE(src_data[src_i + src_y * src_row + src_x * p.channels + c]); + FLOAT_TYPE k = FLOAT_TYPE(knl_data[ knl_y * knl_row + knl_x * p.channels + c]); + sum = fma(v, k, sum); + } + } + return sum; +} + +void main() { + uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + if (idx >= p.ne) { + return; + } + + FLOAT_TYPE result = +#ifdef WHCN + conv_2d_dw_whcn(idx); +#else + conv_2d_dw_cwhn(idx); +#endif + dst_data[idx] = D_TYPE(result); +} + diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp index 29c906494..f476a2e3d 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp @@ -12,7 +12,10 @@ void main() { return; } -#ifndef OPTIMIZATION_ERROR_WORKAROUND +#if defined(DATA_D_BF16) + float f = float(data_a[get_aoffset() + src0_idx(idx)]); + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(fp32_to_bf16(f)); +#elif !defined(OPTIMIZATION_ERROR_WORKAROUND) data_d[get_doffset() + dst_idx(idx)] = D_TYPE(data_a[get_aoffset() + src0_idx(idx)]); #else data_d[get_doffset() + dst_idx(idx)] = data_a[get_aoffset() + src0_idx(idx)]; diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp index c813f1404..9c76437d9 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp @@ -1,5 +1,10 @@ #version 450 +#if RTE16 +#extension GL_EXT_spirv_intrinsics : enable +spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits +#endif // RTE16 + #include "types.comp" #include "generic_unary_head.comp" diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp index 10318e876..0d9739d40 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp @@ -23,6 +23,12 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) { } #endif +#if defined(DATA_A_BF16) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + return vec2(bf16_to_fp32(data_a[a_offset + ib]), bf16_to_fp32(data_a[a_offset + ib + 1])); +} +#endif + #if defined(DATA_A_Q4_0) vec2 dequantize(uint ib, uint iqs, uint a_offset) { const uint vui = uint(data_a[a_offset + ib].qs[iqs]); @@ -82,9 +88,9 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) { return vec2(int(data_a[a_offset + ib].qs[iqs]), int(data_a[a_offset + ib].qs[iqs + 1])); } vec4 dequantize4(uint ib, uint iqs, uint a_offset) { - uint32_t v0 = data_a_packed16[a_offset + ib].qs[iqs/2]; - uint32_t v1 = data_a_packed16[a_offset + ib].qs[iqs/2 + 1]; - return vec4(int8_t(v0 & 0xFF), int8_t(v0 >> 8), int8_t(v1 & 0xFF), int8_t(v1 >> 8)); + const i8vec2 v0 = unpack8(int32_t(data_a_packed16[a_offset + ib].qs[iqs/2])).xy; // vec4 used due to #12147 + const i8vec2 v1 = unpack8(int32_t(data_a_packed16[a_offset + ib].qs[iqs/2 + 1])).xy; + return vec4(v0.x, v0.y, v1.x, v1.y); } #endif @@ -428,7 +434,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) { } #endif -#if defined(DATA_A_F32) || defined(DATA_A_F16) +#if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16) vec2 get_dm(uint ib, uint a_offset) { return vec2(0, 0); } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp index 4770469ed..9cb7da2da 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp @@ -92,7 +92,7 @@ float16_t dequantFuncQ8_0(const in decodeBufQ8_0 bl, const in uint blockCoords[2 const uint iqs = idx; // Load 16b and select the byte for this element - int32_t qs = unpack8(int32_t(bl.block.qs[(iqs & 0x1E) >> 1]))[iqs & 1]; + int32_t qs = unpack8(bl.block.qs[(iqs & 0x1E) >> 1])[iqs & 1]; float16_t ret = float16_t(qs) * d; return ret; } @@ -167,6 +167,101 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4 block_q4_K_packed128 block; }; +#if defined(IS_MUL_MM2) + +// For Q4_K and Q5_K in the mat-mul shader, we decode a tile's worth of scales +// into shared memory and then process the whole tile using those scales. +// There is a fetch function that loads into private variables and then a store +// function that stores into shared memory. +// Q4_K and Q5_K have the same encoding of scales, so everything is shared except +// the part that fetches from the structure (which has a different block layout). +#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K) +const uint shAscales_stride = (BM + 2); +// 1 scale per 32 elements -> 8 scales per block, per row +shared vec2 shAscales[8 * shAscales_stride]; +uvec4 row_v; +#endif + +#if defined(DATA_A_Q4_K) +layout (binding = 0) readonly buffer A_Q4_K_128 {block_q4_K_packed128 data_a_q4_k_packed128[];}; + +void fetch_scalesQ4_K(uint ir_BM, uint pos_a, uint stride_a, uint block_k, uint tid, bool in_bounds) +{ + uint tids_per_row = BLOCK_SIZE / BM; + uint is_per_tid = 8 / tids_per_row; + uint is_start = is_per_tid * (tid % tids_per_row); + uint tid_row = tid / tids_per_row; + + uint row = ir_BM + tid_row; + uint block_index = pos_a + row * stride_a + (block_k / QUANT_K); + if (in_bounds || row < p.M) { + row_v = data_a_q4_k_packed128[block_index].q4k[0]; + } +} +#endif +#if defined(DATA_A_Q5_K) +layout (binding = 0) readonly buffer A_Q5_K_128 {block_q5_K_packed128 data_a_q5_k_packed128[];}; + +void fetch_scalesQ5_K(uint ir_BM, uint pos_a, uint stride_a, uint block_k, uint tid, bool in_bounds) +{ + uint tids_per_row = BLOCK_SIZE / BM; + uint is_per_tid = 8 / tids_per_row; + uint is_start = is_per_tid * (tid % tids_per_row); + uint tid_row = tid / tids_per_row; + + uint row = ir_BM + tid_row; + uint block_index = pos_a + row * stride_a + (block_k / QUANT_K); + if (in_bounds || row < p.M) { + row_v = data_a_q5_k_packed128[block_index].q5k[0]; + } +} +#endif + +#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K) +void store_scalesQ4_K(uint tid) +{ + barrier(); + + uint tids_per_row = BLOCK_SIZE / BM; + uint is_per_tid = 8 / tids_per_row; + uint is_start = is_per_tid * (tid % tids_per_row); + uint tid_row = tid / tids_per_row; + + [[unroll]] for (uint idx = 0; idx < is_per_tid; ++idx) { + uint is = idx + is_start; + uvec4 v = row_v; + const vec2 loadd = vec2(unpackFloat2x16(v.x)); + + uint32_t sc; + uint32_t mbyte; + + uint32_t scale0 = v.y; + uint32_t scale4 = v.z; + uint32_t scale8 = v.w; + + uint32_t sc_lo = scale0; + uint32_t mb_lo = scale4; + uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2); + uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2); + + sc = is < 4 ? sc_lo : sc_hi; + mbyte = is < 4 ? mb_lo : mb_hi; + sc = sc >> (8 * (is & 3)); + mbyte = mbyte >> (8 * (is & 3)); + sc &= 0x3F; + mbyte &= 0x3F; + + const float d = loadd.x * float(sc); + const float m = loadd.y * float(mbyte); + shAscales[is * shAscales_stride + tid_row] = vec2(d,m); + } + + barrier(); +} +#endif + +#endif + float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { decodeBufQ4_K_packed16 bl16 = decodeBufQ4_K_packed16(bl); @@ -176,9 +271,13 @@ float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2 const uint b = (idx & 0x20) >> 5; // 0,1 const uint is = (idx & 0xE0) >> 5; // 0..7 +#if defined(IS_MUL_MM2) && defined(DATA_A_Q4_K) + vec2 v = shAscales[is * shAscales_stride + (blockCoords[0] % BM)]; + float d = v.x; + float m = v.y; +#else uvec4 v = bl128.block.q4k[0]; - - const f16vec2 loadd = unpackFloat2x16(v.x); + const vec2 loadd = vec2(unpackFloat2x16(v.x)); uint32_t sc; uint32_t mbyte; @@ -199,15 +298,16 @@ float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2 sc &= 0x3F; mbyte &= 0x3F; - const float16_t d = loadd.x * float16_t(sc); - const float16_t m = loadd.y * float16_t(mbyte); + const float d = loadd.x * float(sc); + const float m = loadd.y * float(mbyte); +#endif uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]); qs = (qs >> (b * 4 + 8 * (idx & 1))) & 0xF; - float16_t ret = d * float16_t(qs) - m; + float ret = d * float(qs) - m; - return ret; + return float16_t(ret); } layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K { @@ -231,6 +331,11 @@ float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2 const uint b = (idx & 0x20) >> 5; // 0,1 const uint is = (idx & 0xE0) >> 5; // 0..7 +#if defined(IS_MUL_MM2) && defined(DATA_A_Q5_K) + vec2 v = shAscales[is * shAscales_stride + (blockCoords[0] % BM)]; + float d = v.x; + float m = v.y; +#else uvec4 v = bl128.block.q5k[0]; const f16vec2 loadd = unpackFloat2x16(v.x); @@ -256,6 +361,7 @@ float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2 const float16_t d = loadd.x * float16_t(sc); const float16_t m = loadd.y * float16_t(mbyte); +#endif uint qh = uint32_t(bl16.block.qh[(idx & 0x1E) >> 1]); qh = ((qh >> is) & 0x101) << 4; @@ -264,9 +370,9 @@ float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2 qs = (qs >> (b * 4)) & 0x0F0F; qs = unpack8(qs | qh)[idx & 1]; - float16_t ret = d * (float16_t(qs)) - m; + float ret = d * float(qs) - m; - return ret; + return float16_t(ret); } layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ6_K { @@ -311,8 +417,8 @@ float16_t dequantFuncIQ1_S(const in decodeBufIQ1_S bl, const in uint blockCoords const float16_t d = bl.block.d; const uint idx = coordInBlock[1]; - const uint ib32 = idx / 32; - const uint ib8 = idx / 8; + const uint ib32 = (idx & 0xE0) >> 5; + const uint ib8 = (idx & 0xF8) >> 3; const uint qh = bl.block.qh[ib32]; const uint qs = bl.block.qs[ib8]; @@ -330,14 +436,20 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ1 block_iq1_m block; }; +layout(buffer_reference, std430, buffer_reference_align = 8) buffer decodeBufIQ1_M_packed64 { + block_iq1_m_packed64 block; +}; + float16_t dequantFuncIQ1_M(const in decodeBufIQ1_M bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { - const u16vec4 scales = u16vec4(bl.block.scales[0], bl.block.scales[1], bl.block.scales[2], bl.block.scales[3]) >> 12; - const float16_t d = uint16BitsToHalf(scales.x | (scales.y << 4) | (scales.z << 8) | (scales.w << 12)); + decodeBufIQ1_M_packed64 bl64 = decodeBufIQ1_M_packed64(bl); const uint idx = coordInBlock[1]; - const uint ib8 = idx / 8; - const uint ib16 = idx / 16; + uvec2 scales = unpack32(bl64.block.scales); + const float16_t d = uint16BitsToHalf(uint16_t(((scales.x & 0xF000) >> 12) | ((scales.x & 0xF0000000) >> 24) | ((scales.y & 0xF000) >> 4) | ((scales.y & 0xF0000000) >> 16))); + + const uint ib8 = (idx & 0xF8) >> 3; + const uint ib16 = (idx & 0xF0) >> 4; const int i8 = int(idx % 8); const uint sc = bl.block.scales[ib8 / 8]; const uint qs = bl.block.qs[ib8]; @@ -370,7 +482,7 @@ float16_t dequantFuncIQ2_XXS(const in decodeBufIQ2_XXS bl, const in uint blockCo const uint ib8 = (idx & 0x18) >> 3; // 0..3 const uint iqs = 8 * ib32 + ib8; - const uint8_t qs = bl.block.qs[iqs]; + const uint qs = bl.block.qs[iqs]; const uint signscale = pack32(u16vec2(bl16.block.qs[4*ib32+2], bl16.block.qs[4*ib32+3])); const float dscale = float(bl.block.d) * 0.25 * (0.5 + float(signscale >> 28)); @@ -558,8 +670,12 @@ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoor #define dequantFuncA dequantFuncQ3_K #elif defined(DATA_A_Q4_K) #define dequantFuncA dequantFuncQ4_K +#define fetch_scales fetch_scalesQ4_K +#define store_scales store_scalesQ4_K #elif defined(DATA_A_Q5_K) #define dequantFuncA dequantFuncQ5_K +#define fetch_scales fetch_scalesQ5_K +#define store_scales store_scalesQ4_K #elif defined(DATA_A_Q6_K) #define dequantFuncA dequantFuncQ6_K #elif defined(DATA_A_IQ1_S) diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp new file mode 100644 index 000000000..e6545160d --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -0,0 +1,483 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : require + +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#extension GL_KHR_shader_subgroup_shuffle : enable + +#include "types.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (constant_id = 1) const uint32_t Br = 1; +layout (constant_id = 2) const uint32_t Bc = 32; +layout (constant_id = 3) const uint32_t D = 32; + +layout (constant_id = 5) const uint32_t D_split = 16; +const uint32_t D_per_thread = D / D_split; + +const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split; +const uint32_t cols_per_thread = Bc / cols_per_iter; + +layout (push_constant) uniform parameter { + uint32_t N; + uint32_t KV; + + uint32_t ne1; + uint32_t ne2; + uint32_t ne3; + + uint32_t neq2; + uint32_t neq3; + uint32_t nek2; + uint32_t nek3; + uint32_t nev2; + uint32_t nev3; + uint32_t nem1; + + uint32_t nb01; + uint32_t nb02; + uint32_t nb03; + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + uint32_t nb21; + uint32_t nb22; + uint32_t nb23; + uint32_t nb31; + + float scale; + float max_bias; + float logit_softcap; + + uint32_t mask; + uint32_t n_head_log2; + float m0; + float m1; + + uint32_t gqa_ratio; + uint32_t split_kv; + uint32_t k_num; +} p; + +layout (binding = 0) readonly buffer Q {float data_q[];}; +layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];}; +layout (binding = 1) readonly buffer K {float16_t data_k[];}; +layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];}; +layout (binding = 2) readonly buffer V {float16_t data_v[];}; +layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];}; +layout (binding = 3) readonly buffer M {float16_t data_m[];}; +layout (binding = 4) writeonly buffer O {D_TYPE data_o[];}; + +#if defined(A_TYPE_PACKED16) +#define BINDING_IDX_K 0 +#define BINDING_IDX_V 1 +layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2]; +#endif + +#if defined(DATA_A_Q4_0) +#define BLOCK_BYTE_SIZE 18 + +vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { + uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); + uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); + uint shift = (iqs & 0x10) >> 2; + vui_lo >>= shift; + vui_hi >>= shift; + + return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f); +} +#endif + +#if defined(DATA_A_Q8_0) +#define BLOCK_BYTE_SIZE 34 +vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { + const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147 + const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy; + + return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y); +} +#endif + +#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) + +// Store the output when doing grouped query attention. +// Rows index by Q's dimension 2, and the first N rows are valid. +D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) +{ + uint32_t offset = (iq2 + r) * D + c; + data_o[o_offset + offset] = D_TYPE(elem); + return elem; +} + +// Store column zero. This is used to save per-row m and L values for split_k. +ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) +{ + if (r < N && c == 0) { + uint32_t offset = iq2 + r; + data_o[o_offset + offset] = D_TYPE(elem); + } + return elem; +} + +// Load the slope matrix, indexed by Q's dimension 2. +ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2) +{ + const uint32_t h = iq2 + (r % p.gqa_ratio); + + const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1); + const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1); + + return ACC_TYPE(pow(base, ACC_TYPE(exph))); +} + +shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x]; +shared vec4 tmpshv4[gl_WorkGroupSize.x]; + +shared float masksh[Bc][Br]; +shared vec4 Qf[Br][D / 4]; + +void main() { +#ifdef NEEDS_INIT_IQ_SHMEM + init_iq_shmem(gl_WorkGroupSize); +#endif + + const uint32_t tid = gl_LocalInvocationIndex; + const uint32_t N = p.N; + const uint32_t KV = p.KV; + + const uint32_t d_tid = gl_LocalInvocationIndex % D_split; + const uint32_t col_tid = gl_LocalInvocationIndex / D_split; + + uint32_t i = gl_WorkGroupID.x; + uint32_t split_k_index = 0; + + if (p.k_num > 1) { + i = 0; + split_k_index = gl_WorkGroupID.x; + } + + const uint32_t Tr = CEIL_DIV(N, Br); + + const uint32_t start_j = split_k_index * p.split_kv / Bc; + const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc); + + // When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y. + // When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2. + const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio; + const uint32_t iq3 = gl_WorkGroupID.z; + + // broadcast factors + const uint32_t rk2 = p.neq2/p.nek2; + const uint32_t rk3 = p.neq3/p.nek3; + + const uint32_t rv2 = p.neq2/p.nev2; + const uint32_t rv3 = p.neq3/p.nev3; + + // k indices + const uint32_t ik3 = iq3 / rk3; + const uint32_t ik2 = iq2 / rk2; + + // v indices + const uint32_t iv3 = iq3 / rv3; + const uint32_t iv2 = iq2 / rv2; + + // nb?1 are already divided by the type size and are in units of elements. + // When using grouped query attention, Q is indexed by iq2, so the stride + // should be nb02 (which is in bytes). + uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01; + uint32_t k_stride = p.nb11; + uint32_t v_stride = p.nb21; + // When using grouped query attention, all rows use the same mask (stride 0). + // "p.gqa_ratio >> 16" is just a roundabout way of writing zero + // that prevents the compiler from folding the "&" through the select + // and breaking the alignment detection. + uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV; + + uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4; + + [[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) { + uint32_t d = (idx + tid) % (D / 4); + uint32_t r = (idx + tid) / (D / 4); + if (r < Br && d < D / 4 && + i * Br + r < N) { + Qf[r][d] = vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d]) * p.scale; + } + } + barrier(); + + vec4 Of[Br][D_per_thread / 4]; + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Of[r][d] = vec4(0.0); + } + } + + float Lf[Br], Mf[Br]; + + // Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M. + const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF); + + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Lf[r] = 0; + Mf[r] = NEG_FLT_MAX_OVER_2; + } + + float slope[Br]; + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + slope[r] = 1.0; + } + + // ALiBi + if (p.max_bias > 0.0f) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2); + } + } + +#if BLOCK_SIZE > 1 + uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE; + uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE; +#else + uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2; + uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2; +#endif + + [[dont_unroll]] + for (uint32_t j = start_j; j < end_j; ++j) { + + float Sf[Br][cols_per_thread]; + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + Sf[r][c] = 0.0; + } + } + + + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { +#if BLOCK_SIZE > 1 + uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + vec4 K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); +#else + vec4 K_Tf = vec4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]); +#endif + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Sf[r][c] += dot(Qf[r][d * D_split + d_tid], K_Tf); + } + } + } + + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + // Compute sum across the D_split + [[unroll]] for (uint s = D_split / 2; s > 0; s >>= 1) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Sf[r][c] += subgroupShuffleXor(Sf[r][c], s); + } + } + } + + if (p.logit_softcap != 0.0f) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + Sf[r][c] = p.logit_softcap * tanh(Sf[r][c]); + } + } + } + + if (p.mask != 0) { + + [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { + uint32_t c = (idx + tid) % Bc; + uint32_t r = (idx + tid) / Bc; + if (idx + tid < Bc * Br) { + masksh[c][r] = float(data_m[(i * Br + r) * m_stride + (j * Bc + c)]); + } + } + barrier(); + + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + float mvf = masksh[c * cols_per_iter + col_tid][r]; + + Sf[r][c] += slope[r]*mvf; + } + } + barrier(); + } + + float rowmaxf[Br], Pf[Br][cols_per_thread], rowsumf[Br], eMf[Br], Moldf[Br]; + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + rowmaxf[r] = Sf[r][0]; + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + rowmaxf[r] = max(rowmaxf[r], Sf[r][c]); + } + Moldf[r] = Mf[r]; + + // M = max(rowmax, Mold) + // P = e^(S - M) + // eM = e^(Mold - M) + Mf[r] = max(rowmaxf[r], Moldf[r]); + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + Pf[r][c] = exp(Sf[r][c] - Mf[r]); + } + eMf[r] = exp(Moldf[r] - Mf[r]); + + // Compute sum across row of P + rowsumf[r] = 0.0; + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + rowsumf[r] += Pf[r][c]; + } + + Lf[r] = eMf[r]*Lf[r] + rowsumf[r]; + } + + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Of[r][d] = eMf[r] * Of[r][d]; + } + } + + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { +#if BLOCK_SIZE > 1 + uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + vec4 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); +#else + vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]); +#endif + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Of[r][d] += Pf[r][c] * Vf; + } + } + } + + barrier(); + } + + // reduce across threads + + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + float rowmaxf, eMf; + + tmpsh[tid] = Mf[r]; + // Compute max across the row + barrier(); + [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) { + if (tid < s) { + tmpsh[tid] = max(tmpsh[tid], tmpsh[tid + s]); + } + barrier(); + } + rowmaxf = tmpsh[d_tid]; + barrier(); + + float Moldf = Mf[r]; + + // M = max(rowmax, Mold) + // eM = e^(Mold - M) + Mf[r] = max(rowmaxf, Moldf); + eMf = exp(Moldf - Mf[r]); + + Lf[r] = eMf*Lf[r]; + + tmpsh[tid] = Lf[r]; + + // Compute sum across the row + barrier(); + [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) { + if (tid < s) { + tmpsh[tid] = tmpsh[tid] + tmpsh[tid + s]; + } + barrier(); + } + Lf[r] = tmpsh[d_tid]; + barrier(); + + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + + Of[r][d] = eMf * Of[r][d]; + tmpshv4[tid] = Of[r][d]; + + barrier(); + [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) { + if (tid < s) { + Of[r][d] += tmpshv4[tid + s]; + tmpshv4[tid] = Of[r][d]; + } + barrier(); + } + Of[r][d] = tmpshv4[d_tid]; + barrier(); + } + } + + + // If there is split_k, then the split_k resolve shader does the final + // division by L. Store the intermediate O value and per-row m and L values. + if (p.k_num > 1) { + uint32_t o_offset = D * p.ne1 * split_k_index; + + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + if (r < N) { + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { + perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N); + } + } + } + } + + o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2; + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + if (r < N) { + perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N); + perElemOpStoreCol0(r, 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N); + } + } + + return; + } + + float Lfrcp[Br]; + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Lfrcp[r] = 1.0 / Lf[r]; + } + + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Of[r][d] *= Lfrcp[r]; + } + } + + uint32_t o_offset = iq3*p.ne2*p.ne1; + + if (p.gqa_ratio > 1) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + if (r < N) { + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { + perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N); + } + } + } + } + } else { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + if (i * Br + r < N) { + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { + data_o[o_offset + iq2 * D + (i * Br + r) * p.ne1 * D + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]); + } + } + } + } + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index df30355f6..b926a578a 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -61,6 +61,10 @@ layout (push_constant) uniform parameter { uint32_t n_head_log2; float m0; float m1; + + uint32_t gqa_ratio; + uint32_t split_kv; + uint32_t k_num; } p; layout (binding = 0) readonly buffer Q {uint8_t data_q[];}; @@ -103,6 +107,38 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele #define DECODEFUNC #endif +// Store the output when doing grouped query attention. +// Rows index by Q's dimension 2, and the first N rows are valid. +D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) +{ + if (r < N && c < D) { + uint32_t offset = (iq2 + r) * D + c; + data_o[o_offset + offset] = D_TYPE(elem); + } + return elem; +} + +// Store column zero. This is used to save per-row m and L values for split_k. +ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) +{ + if (r < N && c == 0) { + uint32_t offset = iq2 + r; + data_o[o_offset + offset] = D_TYPE(elem); + } + return elem; +} + +// Load the slope matrix, indexed by Q's dimension 2. +ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2) +{ + const uint32_t h = iq2 + (r % p.gqa_ratio); + + const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1); + const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1); + + return ACC_TYPE(pow(base, ACC_TYPE(exph))); +} + void main() { #ifdef NEEDS_INIT_IQ_SHMEM init_iq_shmem(gl_WorkGroupSize); @@ -111,12 +147,22 @@ void main() { const uint32_t N = p.N; const uint32_t KV = p.KV; + uint32_t i = gl_WorkGroupID.x; + uint32_t split_k_index = 0; + + if (p.k_num > 1) { + i = 0; + split_k_index = gl_WorkGroupID.x; + } + const uint32_t Tr = CEIL_DIV(N, Br); - const uint32_t Tc = CEIL_DIV(KV, Bc); - const uint32_t i = gl_WorkGroupID.x; + const uint32_t start_j = split_k_index * p.split_kv / Bc; + const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc); - const uint32_t iq2 = gl_WorkGroupID.y; + // When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y. + // When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2. + const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio; const uint32_t iq3 = gl_WorkGroupID.z; // broadcast factors @@ -149,10 +195,17 @@ void main() { tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D); tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D); - // nb?1 are already divided by the type size and are in units of elements - uint32_t q_stride = p.nb01; + // nb?1 are already divided by the type size and are in units of elements. + // When using grouped query attention, Q is indexed by iq2, so the stride + // should be nb02 (which is in bytes). + uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01; uint32_t k_stride = p.nb11; uint32_t v_stride = p.nb21; + // When using grouped query attention, all rows use the same mask (stride 0). + // "p.gqa_ratio >> 16" is just a roundabout way of writing zero + // that prevents the compiler from folding the "&" through the select + // and breaking the alignment detection. + uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV; // hint to the compiler that strides are aligned for the aligned variant of the shader if (Clamp != gl_CooperativeMatrixClampModeConstantNV) { @@ -161,6 +214,7 @@ void main() { k_stride &= ~7; v_stride &= ~7; #endif + m_stride &= ~7; } tensorLayoutQ = setTensorLayoutStrideNV(tensorLayoutQ, q_stride, 1); tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1); @@ -179,23 +233,21 @@ void main() { coopmat L, M; - L = coopmat(0); - M = coopmat(-1.0/0.0); + // Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M. + const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF); - ACC_TYPE slope = ACC_TYPE(1.0); + L = coopmat(0); + M = coopmat(NEG_FLT_MAX_OVER_2); + + coopmat slopeMat = coopmat(1.0); // ALiBi if (p.max_bias > 0.0f) { - const uint32_t h = iq2; - - const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1); - const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1); - - slope = pow(base, ACC_TYPE(exph)); + coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2); } [[dont_unroll]] - for (uint32_t j = 0; j < Tc; ++j) { + for (uint32_t j = start_j; j < end_j; ++j) { coopmat S = coopmat(0); @@ -213,14 +265,15 @@ void main() { } if (p.mask != 0) { - tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); + tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp); tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV); + tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1); coopmat mv; coopMatLoadTensorNV(mv, data_m, 0, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); - S += slope*coopmat(mv); + S += slopeMat*coopmat(mv); } // Clear padding elements to -inf, so they don't contribute to rowmax @@ -231,7 +284,7 @@ void main() { uint R = ((i + 1) * Br > N) ? (N % Br) : Br; uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc; - coopMatPerElementNV(S, S, replacePadding, ACC_TYPE(-1.0/0.0), R, C); + coopMatPerElementNV(S, S, replacePadding, ACC_TYPE(NEG_FLT_MAX_OVER_2), R, C); } coopmat rowmax, P, rowsum, eM; @@ -280,9 +333,25 @@ void main() { // resize eM by using smear/reduce coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce); - O = eMdiag * O; + // multiply with fp16 accumulation, then add to O. + coopmat PV = coopmat(0); + PV = coopMatMulAdd(P_A, V, PV); - O = coopMatMulAdd(P_A, V, O); + O = eMdiag * O + coopmat(PV); + } + + // If there is split_k, then the split_k resolve shader does the final + // division by L. Store the intermediate O value and per-row m and L values. + if (p.k_num > 1) { + coopmat O_D = coopmat(O); + + uint32_t o_offset = D * p.ne1 * split_k_index; + coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N); + + o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2; + coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N); + coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N); + return; } coopmat Ldiag; @@ -297,13 +366,18 @@ void main() { O = Ldiag*O; - tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV); - tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, D); - - // permute dimensions - tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2); uint32_t o_offset = iq3*p.ne2*p.ne1; coopmat O_D = coopmat(O); - coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, 1, 0, D), tensorViewPermute); + if (p.gqa_ratio > 1) { + coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N); + } else { + tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV); + tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, D); + + // permute dimensions + tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2); + + coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, D), tensorViewPermute); + } } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp new file mode 100644 index 000000000..a7e395685 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp @@ -0,0 +1,59 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable + +#define BLOCK_SIZE 32 + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {float data_a[];}; +layout (binding = 1) writeonly buffer D {float data_d[];}; + +layout (push_constant) uniform parameter { + uint D; + uint N; + uint k_num; +} p; + +void main() { + // Each workgroup handles a row + const uint n = gl_WorkGroupID.x; + const uint tid = gl_LocalInvocationID.x; + + uint D = p.D; + uint N = p.N; + uint k_num = p.k_num; + + uint l_offset = D * N * k_num + n; + uint m_offset = D * N * k_num + N + n; + uint lm_stride = N * 2; + + // Compute the max m value for the row + float m_max = -1.0/0.0; + [[unroll]] for (uint k = 0; k < k_num; ++k) { + float m = data_a[m_offset + k * lm_stride]; + m_max = max(m_max, m); + } + + // Compute L based on m_max + float L = 0; + [[unroll]] for (uint k = 0; k < k_num; ++k) { + float l = data_a[l_offset + k * lm_stride]; + float m = data_a[m_offset + k * lm_stride]; + L += exp(m - m_max) * l; + } + + L = 1.0 / L; + + // Scale and sum the O contributions based on m_max and store the result to memory + for (uint d = tid; d < D; d += BLOCK_SIZE) { + float O = 0.0; + [[unroll]] for (uint k = 0; k < k_num; ++k) { + uint o_offset = D * N * k + D * n + d; + float m = data_a[m_offset + k * lm_stride]; + O += exp(m - m_max) * data_a[o_offset]; + } + O *= L; + data_d[D * n + d] = O; + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp index e877ed779..ee6b86a18 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp @@ -20,9 +20,14 @@ void main() { const uint a_offset = get_aoffset() + i01*p.nb01 + i11*p.nb02 + i12*p.nb03; const uint d_offset = get_doffset() + i10*p.nb21 + i11*p.nb22 + i12*p.nb23; -#ifndef OPTIMIZATION_ERROR_WORKAROUND - data_d[d_offset + i00] = D_TYPE(data_a[a_offset + i00]); +#if defined(DATA_A_BF16) + FLOAT_TYPE v = FLOAT_TYPE(bf16_to_fp32(data_a[a_offset + i00])); #else - data_d[d_offset + i00] = data_a[a_offset + i00]; + FLOAT_TYPE v = FLOAT_TYPE(data_a[a_offset + i00]); +#endif +#ifndef OPTIMIZATION_ERROR_WORKAROUND + data_d[d_offset + i00] = D_TYPE(v); +#else + data_d[d_offset + i00] = D_TYPE(v); #endif } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp index c9f855687..cfd645a38 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp @@ -1,5 +1,7 @@ #version 450 +#extension GL_EXT_control_flow_attributes : enable + #include "types.comp" #include "generic_binary_head.comp" #include "dequant_funcs.comp" diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp index 122b1e93f..09aa849e8 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp @@ -40,6 +40,20 @@ void main() { const uint batch = gl_GlobalInvocationID.z / p.IC; const uint ic = gl_GlobalInvocationID.z % p.IC; + const uint src_base = ic * p.offset_delta + batch * p.batch_offset; + const uint dst_base = ((batch * p.OH + oh) * p.OW) * p.CHW + ic * (p.KW * p.KH); + const int oh_s1 = int(oh) * p.s1; + const uint ksize = p.OW * (p.KH > 1 ? p.KW : 1); + + const uint base_linear_idx = gidx * NUM_ITER; + + const uint max_ky = ksize / p.OW; + + uint current_kx = base_linear_idx / ksize; + const uint rem = base_linear_idx - (current_kx * ksize); + uint current_ky = rem / p.OW; + uint current_ix = rem % p.OW; + A_TYPE values[NUM_ITER]; uint offset_dst[NUM_ITER]; [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { @@ -48,36 +62,35 @@ void main() { [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { - const uint i = gidx * NUM_ITER + idx; + const uint linear_idx = base_linear_idx + idx; - const uint ksize = p.OW * (p.KH > 1 ? p.KW : 1); - const uint kx = i / ksize; - const uint kd = kx * ksize; - const uint ky = (i - kd) / p.OW; - const uint ix = i % p.OW; - - const uint iiw = ix * p.s0 + kx * p.d0 - p.p0; - const uint iih = oh * p.s1 + ky * p.d1 - p.p1; - - offset_dst[idx] = - ((batch * p.OH + oh) * p.OW + ix) * p.CHW + - (ic * (p.KW * p.KH) + ky * p.KW + kx); - - if (i >= p.pelements) { + if (linear_idx >= p.pelements) { continue; } - if (iih < p.IH && iiw < p.IW) { - const uint offset_src = ic * p.offset_delta + batch * p.batch_offset; - values[idx] = data_a[offset_src + iih * p.IW + iiw]; + const uint iiw = current_ix * p.s0 + current_kx * p.d0 - p.p0; + const uint iih = oh_s1 + current_ky * p.d1 - p.p1; + + offset_dst[idx] = dst_base + current_ix * p.CHW + current_ky * p.KW + current_kx; + + if ((iih < p.IH) && (iiw < p.IW)) { + values[idx] = data_a[src_base + iih * p.IW + iiw]; + } + + if (++current_ix == p.OW) { + current_ix = 0; + if (++current_ky == max_ky) { + current_ky = 0; + current_kx++; + } } } [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { - const uint i = gidx * NUM_ITER + idx; + const uint linear_idx = base_linear_idx + idx; - if (i >= p.pelements) { + if (linear_idx >= p.pelements) { continue; } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp new file mode 100644 index 000000000..deba8c398 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp @@ -0,0 +1,41 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable +#define BLOCK_SIZE 512 + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +shared FLOAT_TYPE sum[BLOCK_SIZE]; + +void main() { + const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; + const uint tid = gl_LocalInvocationID.x; + + sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp + + [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { + const FLOAT_TYPE xi = FLOAT_TYPE(data_a[row*p.KX + col]); + sum[tid] += xi * xi; + } + + // sum up partial sums and write back result + barrier(); + [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + sum[tid] += sum[tid + s]; + } + barrier(); + } + + const FLOAT_TYPE scale = inversesqrt(max(sum[0], FLOAT_TYPE(p.param1))); + + [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { + data_d[row*p.KX + col] = D_TYPE(scale * FLOAT_TYPE(data_a[row*p.KX + col])); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp index 31ecd9f81..bb429dd59 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp @@ -6,7 +6,7 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; -#if !defined(DATA_A_F32) && !defined(DATA_A_F16) +#if !defined(DATA_A_F32) && !defined(DATA_A_F16) && !defined(DATA_A_BF16) #define K_PER_ITER 8 #else #define K_PER_ITER 2 @@ -105,6 +105,16 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { int unroll_count = 4; uint unrolled_iters = num_iters & ~(unroll_count - 1); +#if K_PER_ITER == 2 + // If the K dimension is odd, we need lastiter==true on the last iteration + // so OOB is computed correctly. Skip some unrolling to make that happen. + if ((p.ncols & 1) != 0 && + unrolled_iters == num_iters && + unrolled_iters > 0) { + unrolled_iters -= unroll_count; + } +#endif + uint i = 0; while (i < unrolled_iters) { // Manually partially unroll the loop @@ -113,8 +123,18 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { i++; } } + unroll_count = 2; unrolled_iters = num_iters & ~(unroll_count - 1); + +#if K_PER_ITER == 2 + if ((p.ncols & 1) != 0 && + unrolled_iters == num_iters && + unrolled_iters > 0) { + unrolled_iters -= unroll_count; + } +#endif + while (i < unrolled_iters) { // Manually partially unroll the loop [[unroll]] for (uint k = 0; k < unroll_count; ++k) { diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp new file mode 100644 index 000000000..8d01536fa --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp @@ -0,0 +1,90 @@ +#version 450 +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { + const uint y_idx = i * QUANT_K + 16 * itid; + const uint nibble_shift = 4 * (itid & 1); + const uint ib32 = itid / 2; // 0..7 + + uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const float d = float(data_a[ibi].d); + const uint scale = (data_a[ibi].scales[ib32] >> nibble_shift) & 0xF; + const float db = d * (0.5 + scale) * 0.25; + + const uint qh = data_a[ibi].qh[ib32]; + const u8vec2 qs16 = unpack8(uint32_t(data_a_packed16[ibi].qs[itid])).xy; // vec4 used due to #12147 + const u8vec2 sign16 = unpack8(uint32_t(data_a_packed16[ibi].qs[QUANT_K / 16 + itid])).xy; + [[unroll]] for (uint l = 0; l < 2; ++l) { + const uint8_t sign = sign16[l]; + const uint qs = qs16[l] | ((qh << (8 - nibble_shift - 2 * l)) & 0x300); + const uvec2 grid = iq2s_grid[qs]; + const vec4 grid0 = vec4(unpack8(grid.x)); + const vec4 grid1 = vec4(unpack8(grid.y)); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]); + vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]); + + FLOAT_TYPE sum = + fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x), + fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y), + fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z), + fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w), + fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x), + fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y), + fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z), + fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w), + FLOAT_TYPE(0.0))))))))); + temp[j][n] = fma(db, sum, temp[j][n]); + } + } + ibi += num_blocks_per_row; + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 16 threads are used to process each block + const uint blocks_per_wg = gl_WorkGroupSize.x/16; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid % 16; // 0...15 + const uint ix = tid / 16; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg) + calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows); + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + init_iq_shmem(gl_WorkGroupSize); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp new file mode 100644 index 000000000..c49604324 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp @@ -0,0 +1,87 @@ +#version 450 +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { + const uint y_idx = i * QUANT_K + 16 * itid; + const uint nibble_shift = 4 * (itid & 1); + const uint ib32 = itid / 2; // 0..7 + + uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const float d = float(data_a[ibi].d); + const uint scale = (data_a[ibi].scales[ib32] >> nibble_shift) & 0xF; + const float db = d * (0.5 + scale) * 0.25; + + [[unroll]] for (uint l = 0; l < 2; ++l) { + const uint qs = data_a[ibi].qs[2 * itid + l]; + const uint sign = qs >> 9; + const uint sign7 = bitCount(sign); + const vec4 grid0 = vec4(unpack8(iq2xs_grid[qs & 511].x)); + const vec4 grid1 = vec4(unpack8(iq2xs_grid[qs & 511].y)); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]); + vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]); + + FLOAT_TYPE sum = + fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x), + fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y), + fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z), + fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w), + fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x), + fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y), + fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z), + fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign7 & 1) != 0 ? -grid1.w : grid1.w), + FLOAT_TYPE(0.0))))))))); + temp[j][n] = fma(db, sum, temp[j][n]); + } + } + ibi += num_blocks_per_row; + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 16 threads are used to process each block + const uint blocks_per_wg = gl_WorkGroupSize.x/16; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid % 16; // 0...15 + const uint ix = tid / 16; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg) + calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows); + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + init_iq_shmem(gl_WorkGroupSize); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp new file mode 100644 index 000000000..94d4b92e1 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp @@ -0,0 +1,87 @@ +#version 450 +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { + const uint y_idx = i * QUANT_K + 16 * itid; + const uint ib32 = itid / 2; // 0..7 + + uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const float d = float(data_a[ibi].d); + const uint signscale = pack32(u16vec2( + data_a_packed16[ibi].qs[4 * ib32 + 2], + data_a_packed16[ibi].qs[4 * ib32 + 3])); + const float db = d * 0.25 * (0.5 + (signscale >> 28)); + [[unroll]] for (uint l = 0; l < 2; ++l) { + const uint qs = data_a[ibi].qs[8 * ib32 + 2 * (itid & 1) + l]; + const uint sign = bitfieldExtract(signscale, 7 * int(2 * (itid & 1) + l), 7); + const uint sign7 = bitCount(sign); + const vec4 grid0 = vec4(unpack8(iq2xxs_grid[qs].x)); + const vec4 grid1 = vec4(unpack8(iq2xxs_grid[qs].y)); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + const vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]); + const vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]); + + FLOAT_TYPE sum = + fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x), + fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y), + fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z), + fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w), + fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x), + fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y), + fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z), + fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign7 & 1) != 0 ? -grid1.w : grid1.w), + FLOAT_TYPE(0.0))))))))); + temp[j][n] = fma(db, sum, temp[j][n]); + } + } + ibi += num_blocks_per_row; + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 16 threads are used to process each block + const uint blocks_per_wg = gl_WorkGroupSize.x/16; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid % 16; // 0...15 + const uint ix = tid / 16; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg) + calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows); + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + init_iq_shmem(gl_WorkGroupSize); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp new file mode 100644 index 000000000..f021e4047 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp @@ -0,0 +1,90 @@ +#version 450 +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { + const uint y_idx = i * QUANT_K + 32 * ib32; + + uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const float d = float(data_a[ibi].d); + const uint scale = (data_a[ibi].scales[ib32/2] >> (4 * (ib32 & 1))) & 0xF; + const float dscale = d * (1 + 2 * scale); + const uint qh = data_a[ibi].qh[ib32]; + FLOAT_TYPE sum[NUM_COLS]; + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + sum[j] = 0.0; + } + [[unroll]] for (uint l = 0; l < 4; ++l) { + const u8vec2 qs = unpack8(uint32_t(data_a_packed16[ibi].qs[4 * ib32 + l])).xy; // vec4 used due to #12147 + const uint sign = data_a[ibi].signs[4 * ib32 + l]; + const vec4 grid0 = vec4(unpack8(iq3s_grid[qs.x | ((qh << (8 - 2*l)) & 0x100)])); + const vec4 grid1 = vec4(unpack8(iq3s_grid[qs.y | ((qh << (7 - 2*l)) & 0x100)])); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + const vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]); + const vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]); + + sum[j] = + fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x), + fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y), + fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z), + fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w), + fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x), + fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y), + fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z), + fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w), + sum[j])))))))); + } + } + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + temp[j][n] = fma(dscale, sum[j], temp[j][n]); + } + ibi += num_blocks_per_row; + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 8 threads are used to process each block + const uint blocks_per_wg = gl_WorkGroupSize.x/8; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid % 8; // 0...7 + const uint ix = tid / 8; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg) + calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows); + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + init_iq_shmem(gl_WorkGroupSize); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp new file mode 100644 index 000000000..3fe9dc3a4 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp @@ -0,0 +1,88 @@ +#version 450 +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { + const uint y_idx = i * QUANT_K + 16 * itid; + const uint ib32 = itid / 2; // 0..7 + + uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const float d = float(data_a[ibi].d); + const uint signscale = pack32(u16vec2( + data_a_packed16[ibi].qs[QUANT_K / 8 + 2 * ib32], + data_a_packed16[ibi].qs[QUANT_K / 8 + 2 * ib32 + 1])); + const float db = d * 0.5 * (0.5 + (signscale >> 28)); + [[unroll]] for (uint l = 0; l < 2; ++l) { + const uint qs0 = data_a[ibi].qs[8 * ib32 + 4 * (itid & 1) + 2 * l]; + const uint qs1 = data_a[ibi].qs[8 * ib32 + 4 * (itid & 1) + 2 * l + 1]; + const uint sign = bitfieldExtract(signscale, 7 * int(2 * (itid & 1) + l), 7); + const uint sign7 = bitCount(sign); + const vec4 grid0 = vec4(unpack8(iq3xxs_grid[qs0])); + const vec4 grid1 = vec4(unpack8(iq3xxs_grid[qs1])); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + const vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]); + const vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]); + + FLOAT_TYPE sum = + fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x), + fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y), + fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z), + fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w), + fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x), + fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y), + fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z), + fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign7 & 1) != 0 ? -grid1.w : grid1.w), + FLOAT_TYPE(0.0))))))))); + temp[j][n] = fma(db, sum, temp[j][n]); + } + } + ibi += num_blocks_per_row; + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 16 threads are used to process each block + const uint blocks_per_wg = gl_WorkGroupSize.x/16; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid % 16; // 0...15 + const uint ix = tid / 16; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg) + calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows); + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + init_iq_shmem(gl_WorkGroupSize); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp index 1cc4996d3..bc633369f 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp @@ -12,13 +12,18 @@ layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; layout (binding = 2) writeonly buffer D {D_TYPE dst[];}; +layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];}; +layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];}; + layout (push_constant) uniform parameter { uint ncols_x; uint nrows_x; uint row_stride_x; uint channel_stride_x; + uint channel_stride_y; uint channel_x_divisor; + uint ne12; uint b_offset; uint d_offset; } p; @@ -30,6 +35,7 @@ void main() { const uint row_x = gl_GlobalInvocationID.y; const uint channel = gl_GlobalInvocationID.z; const uint channel_x = channel / p.channel_x_divisor; + const uint channel_y = channel % p.ne12; const uint nrows_y = p.ncols_x; const uint nrows_dst = p.nrows_x; @@ -37,25 +43,66 @@ void main() { const uint idst = channel*nrows_dst + row_dst; - tmp[tid] = 0.0f; + FLOAT_TYPE temp = 0.0f; - for (uint col_x0 = 0; col_x0 < p.ncols_x; col_x0 += BLOCK_SIZE) { - const uint col_x = col_x0 + tid; + // Detect alignment for vector loads + bool is_aligned = (p.ncols_x % 4) == 0 && (p.row_stride_x % 4) == 0 && (p.channel_stride_x % 4) == 0; - if (col_x >= p.ncols_x) { - break; + for (uint col_x0 = 0; col_x0 < p.ncols_x;) { + + // Unroll 2x and do vec4 loads if aligned + const uint unroll_count = 2; + if (col_x0 + unroll_count * 4 * BLOCK_SIZE <= p.ncols_x && is_aligned) { + [[unroll]] for (uint i = 0; i < unroll_count; ++i) { + const uint col_x = col_x0 + 4*tid; + + const uint row_y = col_x; + + const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x; + const uint iy = channel_y*p.channel_stride_y + row_y; + + const vec4 av4 = vec4(data_a_v4[ix / 4]); + const vec4 bv4 = vec4(data_b_v4[iy / 4]); + + temp += dot(av4, bv4); + + col_x0 += 4*BLOCK_SIZE; + } + // do vec4 loads if aligned + } else if (col_x0 + 4*BLOCK_SIZE <= p.ncols_x && is_aligned) { + const uint col_x = col_x0 + 4*tid; + + const uint row_y = col_x; + + const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x; + const uint iy = channel_y*p.channel_stride_y + row_y; + + const vec4 av4 = vec4(data_a_v4[ix / 4]); + const vec4 bv4 = vec4(data_b_v4[iy / 4]); + + temp += dot(av4, bv4); + + col_x0 += 4*BLOCK_SIZE; + } else { + const uint col_x = col_x0 + tid; + if (col_x >= p.ncols_x) { + break; + } + + const uint row_y = col_x; + + const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x; + const uint iy = channel_y*p.channel_stride_y + row_y; + + const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]); + + temp = fma(xi, FLOAT_TYPE(data_b[iy]), temp); + col_x0 += BLOCK_SIZE; } - - const uint row_y = col_x; - - const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x; - const uint iy = channel*nrows_y + row_y; - - const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]); - - tmp[tid] = fma(xi, FLOAT_TYPE(data_b[iy]), tmp[tid]); } + tmp[tid] = temp; + // sum up partial sums and write back result barrier(); [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp index 9b443807d..7aa070eeb 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp @@ -2,16 +2,25 @@ #extension GL_EXT_control_flow_attributes : enable #extension GL_EXT_shader_16bit_storage : require +#if USE_SUBGROUP_ADD +#extension GL_KHR_shader_subgroup_arithmetic : enable +#endif -#define BLOCK_SIZE 32 #define FLOAT_TYPE float -layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; layout (binding = 2) writeonly buffer D {D_TYPE dst[];}; +layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];}; +layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];}; + +layout(constant_id = 0) const int BLOCK_SIZE = 32; +// gqa_ratio is in the range [1,8] +layout(constant_id = 1) const uint gqa_ratio = 1; + layout (push_constant) uniform parameter { uint ncols_x; @@ -22,52 +31,124 @@ layout (push_constant) uniform parameter uint d_offset; } p; -shared FLOAT_TYPE tmp[BLOCK_SIZE]; +#if !USE_SUBGROUP_ADD +shared FLOAT_TYPE tmp[8][BLOCK_SIZE]; +#endif void main() { const uint tid = gl_LocalInvocationID.x; const uint row_x = gl_GlobalInvocationID.y; - const uint channel = gl_GlobalInvocationID.z; - const uint channel_x = channel / (p.nchannels_y / p.nchannels_x); + + uint channel, channel_x; + + // When gqa_ratio > 1, each invocation does multiple rows. + // The row in the A matrix is starting from channel / gqa_ratio and the + // rows in the B matrix are [channel, channel+gqa_ratio). + // When gpa_ratio is 1, each invocation does one row. + if (gqa_ratio > 1) { + channel_x = gl_GlobalInvocationID.z; + channel = channel_x * gqa_ratio; + } else { + channel = gl_GlobalInvocationID.z; + channel_x = channel / (p.nchannels_y / p.nchannels_x);; + } const uint nrows_y = p.ncols_x; const uint nrows_dst = p.nrows_x; const uint row_dst = row_x; - tmp[tid] = FLOAT_TYPE(0.0f); - - for (uint col_x0 = 0; col_x0 < p.ncols_x; col_x0 += BLOCK_SIZE) { - const uint col_x = col_x0 + tid; - - if (col_x >= p.ncols_x) { - break; - } - - // x is transposed and permuted - const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x; - const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]); - - const uint row_y = col_x; - - // y is not transposed but permuted - const uint iy = channel*nrows_y + row_y; - - tmp[tid] = fma(xi, FLOAT_TYPE(data_b[iy]), tmp[tid]); + FLOAT_TYPE temp[8]; + [[unroll]] for (uint i = 0; i < 8; ++i) { + temp[i] = FLOAT_TYPE(0.0f); } - // dst is not transposed and not permuted - const uint idst = channel*nrows_dst + row_dst; + // Detect alignment for vector loads + bool is_aligned = (p.ncols_x % 4) == 0 && (p.nchannels_x % 4) == 0 && (nrows_y % 4) == 0; + for (uint col_x0 = 0; col_x0 < p.ncols_x; col_x0 += BLOCK_SIZE) { + + // Use vec4 loads if aligned + if (col_x0 + 4*BLOCK_SIZE <= p.ncols_x && is_aligned) { + + uint col_x = col_x0 + 4*tid; + const uint row_y = col_x; + + // x is transposed and permuted + const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x; + const vec4 av4 = vec4(data_a_v4[ix / 4]); + + [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) { + // y is not transposed but permuted + const uint iy = (channel + c)*nrows_y + row_y; + + vec4 bv4 = data_b_v4[iy / 4]; + temp[c] += dot(av4, bv4); + } + + col_x0 += 3*BLOCK_SIZE; + } else { + const uint col_x = col_x0 + tid; + + if (col_x >= p.ncols_x) { + break; + } + + // x is transposed and permuted + const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x; + const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]); + + const uint row_y = col_x; + + [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) { + // y is not transposed but permuted + const uint iy = (channel + c)*nrows_y + row_y; + + temp[c] = fma(xi, FLOAT_TYPE(data_b[iy]), temp[c]); + } + } + } + +#if USE_SUBGROUP_ADD + // reduce vec4 at a time + vec4 t = vec4(temp[0], temp[1], temp[2], temp[3]); + t = subgroupAdd(t); + temp[0] = t[0]; + temp[1] = t[1]; + temp[2] = t[2]; + temp[3] = t[3]; + if (gqa_ratio > 4) { + t = vec4(temp[4], temp[5], temp[6], temp[7]); + t = subgroupAdd(t); + temp[4] = t[0]; + temp[5] = t[1]; + temp[6] = t[2]; + temp[7] = t[3]; + } +#else + [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) { + tmp[c][tid] = temp[c]; + } // sum up partial sums and write back result barrier(); [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { if (tid < s) { - tmp[tid] += tmp[tid + s]; + [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) { + temp[c] += tmp[c][tid + s]; + tmp[c][tid] = temp[c]; + } } barrier(); } + [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) { + temp[c] = tmp[c][tid]; + } +#endif if (tid == 0) { - dst[idst] = tmp[0]; + [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) { + // dst is not transposed and not permuted + const uint idst = (channel + c)*nrows_dst + row_dst; + dst[idst] = temp[c]; + } } } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp index 8cdc640e8..423ceb8a3 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp @@ -5,23 +5,24 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; -shared FLOAT_TYPE sccache1[BLOCK_SIZE/16][16]; -shared FLOAT_TYPE sccache2[BLOCK_SIZE/16][16]; +shared FLOAT_TYPE sccache1[2][BLOCK_SIZE/16][16]; +shared FLOAT_TYPE sccache2[2][BLOCK_SIZE/16][16]; FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; +uint csel = 0; void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint v_im, const uint ix, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) { const uint y_idx = i * QUANT_K + y_offset; [[unroll]] for (uint n = 0; n < num_rows; ++n) { const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + csel ^= 1; - barrier(); if (!all_threads) { // when we don't have enough blocks to use all threads if (i < num_blocks_per_row) { const uint32_t scale = uint32_t(data_a[ib0 + i].scales[itid]); - sccache1[ix][itid] = FLOAT_TYPE(scale & 0xF); - sccache2[ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF); + sccache1[csel][ix][itid] = FLOAT_TYPE(scale & 0xF); + sccache2[csel][ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF); } barrier(); @@ -29,8 +30,8 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, continue; } else { const uint32_t scale = uint32_t(data_a[ib0 + i].scales[itid]); - sccache1[ix][itid] = FLOAT_TYPE(scale & 0xF); - sccache2[ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF); + sccache1[csel][ix][itid] = FLOAT_TYPE(scale & 0xF); + sccache2[csel][ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF); barrier(); } @@ -57,22 +58,22 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, FLOAT_TYPE sum1 = FLOAT_TYPE(0.0); FLOAT_TYPE sum2 = FLOAT_TYPE(0.0); [[unroll]] for (int l = 0; l < 2; ++l) { - sum1 = fma(FLOAT_TYPE(b0[l]), sccache1[ix][ 8*v_im] * qs_u32_0[l ], - fma(FLOAT_TYPE(b16[l]), sccache1[ix][1 + 8*v_im] * qs_u32_0[l+2], - fma(FLOAT_TYPE(b32[l]), sccache1[ix][2 + 8*v_im] * qs_u32_2[l ], - fma(FLOAT_TYPE(b48[l]), sccache1[ix][3 + 8*v_im] * qs_u32_2[l+2], - fma(FLOAT_TYPE(b64[l]), sccache1[ix][4 + 8*v_im] * qs_u32_4[l ], - fma(FLOAT_TYPE(b80[l]), sccache1[ix][5 + 8*v_im] * qs_u32_4[l+2], - fma(FLOAT_TYPE(b96[l]), sccache1[ix][6 + 8*v_im] * qs_u32_6[l ], - fma(FLOAT_TYPE(b112[l]), sccache1[ix][7 + 8*v_im] * qs_u32_6[l+2], sum1)))))))); - sum2 = fma(FLOAT_TYPE(b0[l]), sccache2[ix][ 8*v_im], - fma(FLOAT_TYPE(b16[l]), sccache2[ix][1 + 8*v_im], - fma(FLOAT_TYPE(b32[l]), sccache2[ix][2 + 8*v_im], - fma(FLOAT_TYPE(b48[l]), sccache2[ix][3 + 8*v_im], - fma(FLOAT_TYPE(b64[l]), sccache2[ix][4 + 8*v_im], - fma(FLOAT_TYPE(b80[l]), sccache2[ix][5 + 8*v_im], - fma(FLOAT_TYPE(b96[l]), sccache2[ix][6 + 8*v_im], - fma(FLOAT_TYPE(b112[l]), sccache2[ix][7 + 8*v_im], sum2)))))))); + sum1 = fma(FLOAT_TYPE(b0[l]), sccache1[csel][ix][ 8*v_im] * qs_u32_0[l ], + fma(FLOAT_TYPE(b16[l]), sccache1[csel][ix][1 + 8*v_im] * qs_u32_0[l+2], + fma(FLOAT_TYPE(b32[l]), sccache1[csel][ix][2 + 8*v_im] * qs_u32_2[l ], + fma(FLOAT_TYPE(b48[l]), sccache1[csel][ix][3 + 8*v_im] * qs_u32_2[l+2], + fma(FLOAT_TYPE(b64[l]), sccache1[csel][ix][4 + 8*v_im] * qs_u32_4[l ], + fma(FLOAT_TYPE(b80[l]), sccache1[csel][ix][5 + 8*v_im] * qs_u32_4[l+2], + fma(FLOAT_TYPE(b96[l]), sccache1[csel][ix][6 + 8*v_im] * qs_u32_6[l ], + fma(FLOAT_TYPE(b112[l]), sccache1[csel][ix][7 + 8*v_im] * qs_u32_6[l+2], sum1)))))))); + sum2 = fma(FLOAT_TYPE(b0[l]), sccache2[csel][ix][ 8*v_im], + fma(FLOAT_TYPE(b16[l]), sccache2[csel][ix][1 + 8*v_im], + fma(FLOAT_TYPE(b32[l]), sccache2[csel][ix][2 + 8*v_im], + fma(FLOAT_TYPE(b48[l]), sccache2[csel][ix][3 + 8*v_im], + fma(FLOAT_TYPE(b64[l]), sccache2[csel][ix][4 + 8*v_im], + fma(FLOAT_TYPE(b80[l]), sccache2[csel][ix][5 + 8*v_im], + fma(FLOAT_TYPE(b96[l]), sccache2[csel][ix][6 + 8*v_im], + fma(FLOAT_TYPE(b112[l]), sccache2[csel][ix][7 + 8*v_im], sum2)))))))); } temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n])); } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp index 3116fad16..e91724a28 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp @@ -5,20 +5,21 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; -shared FLOAT_TYPE sccache[BLOCK_SIZE/16][2][8]; +shared FLOAT_TYPE sccache[2][BLOCK_SIZE/16][2][8]; FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; +uint csel = 0; void calc_superblock(const uint a_offset, const uint b_offset, const uint ix, const uint itid8, const uint v_im, const uint v_im4, const uint v_in, const uint32_t hm_m[4], const uint q_offset, const uint y_offset, const uint s_shift, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) { const uint y_idx = i * QUANT_K + y_offset; [[unroll]] for (uint n = 0; n < num_rows; ++n) { const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + csel ^= 1; if (!all_threads) { // when we don't have enough blocks to use all threads - barrier(); if (i < num_blocks_per_row) - sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32); + sccache[csel][ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32); barrier(); if (i >= num_blocks_per_row) @@ -40,8 +41,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint ix, co const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303)); if (all_threads) { - barrier(); - sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32); + sccache[csel][ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32); barrier(); } @@ -59,14 +59,14 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint ix, co FLOAT_TYPE sum = FLOAT_TYPE(0.0); [[unroll]] for (int l = 0; l < 2; ++l) { - sum = fma(FLOAT_TYPE( b0[l]) * sccache[ix][v_im][0], qs_u32_0[l ] - hmk_0[l ], - fma(FLOAT_TYPE( b16[l]) * sccache[ix][v_im][1], qs_u32_0[l+2] - hmk_0[l+2], - fma(FLOAT_TYPE( b32[l]) * sccache[ix][v_im][2], qs_u32_2[l ] - hmk_1[l ], - fma(FLOAT_TYPE( b48[l]) * sccache[ix][v_im][3], qs_u32_2[l+2] - hmk_1[l+2], - fma(FLOAT_TYPE( b64[l]) * sccache[ix][v_im][4], qs_u32_4[l ] - hmk_2[l ], - fma(FLOAT_TYPE( b80[l]) * sccache[ix][v_im][5], qs_u32_4[l+2] - hmk_2[l+2], - fma(FLOAT_TYPE( b96[l]) * sccache[ix][v_im][6], qs_u32_6[l ] - hmk_3[l ], - fma(FLOAT_TYPE(b112[l]) * sccache[ix][v_im][7], qs_u32_6[l+2] - hmk_3[l+2], sum)))))))); + sum = fma(FLOAT_TYPE( b0[l]) * sccache[csel][ix][v_im][0], qs_u32_0[l ] - hmk_0[l ], + fma(FLOAT_TYPE( b16[l]) * sccache[csel][ix][v_im][1], qs_u32_0[l+2] - hmk_0[l+2], + fma(FLOAT_TYPE( b32[l]) * sccache[csel][ix][v_im][2], qs_u32_2[l ] - hmk_1[l ], + fma(FLOAT_TYPE( b48[l]) * sccache[csel][ix][v_im][3], qs_u32_2[l+2] - hmk_1[l+2], + fma(FLOAT_TYPE( b64[l]) * sccache[csel][ix][v_im][4], qs_u32_4[l ] - hmk_2[l ], + fma(FLOAT_TYPE( b80[l]) * sccache[csel][ix][v_im][5], qs_u32_4[l+2] - hmk_2[l+2], + fma(FLOAT_TYPE( b96[l]) * sccache[csel][ix][v_im][6], qs_u32_6[l ] - hmk_3[l ], + fma(FLOAT_TYPE(b112[l]) * sccache[csel][ix][v_im][7], qs_u32_6[l+2] - hmk_3[l+2], sum)))))))); } temp[j][n] = fma(d, sum, temp[j][n]); } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp index f05f96b5e..d53d9ee0a 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp @@ -6,20 +6,21 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; -shared FLOAT_TYPE sccache[BLOCK_SIZE/16][16]; +shared FLOAT_TYPE sccache[2][BLOCK_SIZE/16][16]; FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; +uint csel = 0; void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint ix, const uint ql_offset, const uint qh_offset, const uint s_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) { const uint y_idx = i * QUANT_K + y_offset; [[unroll]] for (uint n = 0; n < num_rows; ++n) { const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + csel ^= 1; if (!all_threads) { // when we don't have enough blocks to use all threads - barrier(); if (i < num_blocks_per_row) - sccache[ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]); + sccache[csel][ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]); barrier(); if (i >= num_blocks_per_row) @@ -51,8 +52,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const vec4 q3 = vec4(unpack8(q3_u32)) - 32; if (all_threads) { - barrier(); - sccache[ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]); + sccache[csel][ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]); barrier(); } @@ -71,7 +71,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, sum[2] = fma(FLOAT_TYPE(by64[l]), q2[l], sum[2]); sum[3] = fma(FLOAT_TYPE(by96[l]), q3[l], sum[3]); } - temp[j][n] = fma(fma(sum[0], sccache[ix][s_offset], fma(sum[1], sccache[ix][s_offset + 2], fma(sum[2], sccache[ix][s_offset + 4], sum[3] * sccache[ix][s_offset + 6]))), d, temp[j][n]); + temp[j][n] = fma(fma(sum[0], sccache[csel][ix][s_offset], fma(sum[1], sccache[csel][ix][s_offset + 2], fma(sum[2], sccache[csel][ix][s_offset + 4], sum[3] * sccache[csel][ix][s_offset + 6]))), d, temp[j][n]); } } } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index 39657195c..7859a1a60 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -10,6 +10,10 @@ #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require #endif +#if defined(DATA_A_BF16) && defined(COOPMAT) +#extension GL_EXT_bfloat16 : enable +#endif + #ifdef COOPMAT #extension GL_KHR_cooperative_matrix : enable #extension GL_KHR_memory_scope_semantics : enable @@ -29,9 +33,20 @@ #define LOAD_VEC_B 1 #endif +#if !defined(TO_FLOAT_TYPE) +#define TO_FLOAT_TYPE FLOAT_TYPE +#endif + layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +#if defined(A_TYPE_PACKED16) +layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];}; +#endif +#if defined(A_TYPE_PACKED32) +layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];}; +#endif + layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; @@ -88,7 +103,7 @@ shared FLOAT_TYPE buf_a[BM * SHMEM_STRIDE]; shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE]; #ifdef MUL_MAT_ID -shared u16vec2 row_ids[3072]; +shared u16vec2 row_ids[4096]; #endif // MUL_MAT_ID #define NUM_WARPS (BLOCK_SIZE / WARP) @@ -195,8 +210,8 @@ void main() { #endif #ifdef COOPMAT - coopmat cache_a; - coopmat cache_b; + coopmat cache_a; + coopmat cache_b; coopmat sums[cms_per_row * cms_per_col]; [[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) { @@ -205,7 +220,7 @@ void main() { #else ACC_TYPE sums[WMITER * TM * WNITER * TN]; FLOAT_TYPE cache_a[WMITER * TM]; - FLOAT_TYPE cache_b[WNITER * TN]; + FLOAT_TYPE cache_b[TN]; [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) { sums[i] = ACC_TYPE(0.0f); @@ -241,76 +256,117 @@ void main() { buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(0.0f); } #endif +#elif defined(DATA_A_BF16) +#if LOAD_VEC_A == 4 + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; + buf_a[buf_idx ] = TO_FLOAT_TYPE(data_a[idx].x); + buf_a[buf_idx + 1] = TO_FLOAT_TYPE(data_a[idx].y); + buf_a[buf_idx + 2] = TO_FLOAT_TYPE(data_a[idx].z); + buf_a[buf_idx + 3] = TO_FLOAT_TYPE(data_a[idx].w); +#else + if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) { + buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = TO_FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]); + } else { + buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = TO_FLOAT_TYPE(uint16_t(0)); + } +#endif #elif defined(DATA_A_Q4_0) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 4 * loadr_a; - const uint ib = idx / 16; - const uint iqs = idx & 0xF; + const uint ib = idx / 4; + const uint iqs = idx & 0x03; - const float d = float(data_a[ib].d); - const uint vui = uint(data_a[ib].qs[iqs]); - const vec2 v = (vec2(vui & 0xF, vui >> 4) - 8.0f) * d; + const float d = float(data_a_packed16[ib].d); + const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16); + const vec4 v0 = (vec4(unpack8(vui & 0x0F0F0F0F)) - 8.0f) * d; + const vec4 v1 = (vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) - 8.0f) * d; - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); + buf_a[buf_idx ] = FLOAT_TYPE(v0.x); + buf_a[buf_idx + 1 ] = FLOAT_TYPE(v0.y); + buf_a[buf_idx + 2 ] = FLOAT_TYPE(v0.z); + buf_a[buf_idx + 3 ] = FLOAT_TYPE(v0.w); + buf_a[buf_idx + 16] = FLOAT_TYPE(v1.x); + buf_a[buf_idx + 17] = FLOAT_TYPE(v1.y); + buf_a[buf_idx + 18] = FLOAT_TYPE(v1.z); + buf_a[buf_idx + 19] = FLOAT_TYPE(v1.w); #elif defined(DATA_A_Q4_1) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 4 * loadr_a; - const uint ib = idx / 16; - const uint iqs = idx & 0xF; + const uint ib = idx / 4; + const uint iqs = idx & 0x03; - const float d = float(data_a[ib].d); - const float m = float(data_a[ib].m); - const uint vui = uint(data_a[ib].qs[iqs]); - const vec2 v = vec2(vui & 0xF, vui >> 4) * d + m; + const float d = float(data_a_packed16[ib].d); + const float m = float(data_a_packed16[ib].m); + const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16); + const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * d + m; + const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * d + m; - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); + buf_a[buf_idx ] = FLOAT_TYPE(v0.x); + buf_a[buf_idx + 1 ] = FLOAT_TYPE(v0.y); + buf_a[buf_idx + 2 ] = FLOAT_TYPE(v0.z); + buf_a[buf_idx + 3 ] = FLOAT_TYPE(v0.w); + buf_a[buf_idx + 16] = FLOAT_TYPE(v1.x); + buf_a[buf_idx + 17] = FLOAT_TYPE(v1.y); + buf_a[buf_idx + 18] = FLOAT_TYPE(v1.z); + buf_a[buf_idx + 19] = FLOAT_TYPE(v1.w); #elif defined(DATA_A_Q5_0) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a; - const uint ib = idx / 16; - const uint iqs = idx & 0xF; + const uint ib = idx / 8; + const uint iqs = idx & 0x07; - const float d = float(data_a[ib].d); - const uint uint_qh = uint(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]; - const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); - const uint vui = uint(data_a[ib].qs[iqs]); - const vec2 v = (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f) * d; + const float d = float(data_a_packed16[ib].d); + const uint uint_qh = uint(data_a_packed16[ib].qh[1]) << 16 | uint(data_a_packed16[ib].qh[0]); + const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10); + const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10); + + const uint vui = uint(data_a_packed16[ib].qs[iqs]); + const vec4 v = (vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) - 16.0f) * d; buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 1 ] = FLOAT_TYPE(v.z); buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); + buf_a[buf_idx + 17] = FLOAT_TYPE(v.w); #elif defined(DATA_A_Q5_1) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a; - const uint ib = idx / 16; - const uint iqs = idx & 0xF; + const uint ib = idx / 8; + const uint iqs = idx & 0x07; - const float d = float(data_a[ib].d); - const float m = float(data_a[ib].m); - const uint uint_qh = data_a[ib].qh; - const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); - const uint vui = uint(data_a[ib].qs[iqs]); - const vec2 v = vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) * d + m; + const float d = float(data_a_packed16[ib].d); + const float m = float(data_a_packed16[ib].m); + const uint uint_qh = data_a_packed16[ib].qh; + const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10); + const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10); + + const uint vui = uint(data_a_packed16[ib].qs[iqs]); + const vec4 v = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) * d + m; buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 1 ] = FLOAT_TYPE(v.z); buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); + buf_a[buf_idx + 17] = FLOAT_TYPE(v.w); #elif defined(DATA_A_Q8_0) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - const uint ib = idx / 16; - const uint iqs = (idx & 0xF) * 2; + const uint ib = idx / 8; + const uint iqs = idx & 0x07; - const float d = float(data_a[ib].d); - const vec2 v = vec2(int(data_a[ib].qs[iqs]), int(data_a[ib].qs[iqs + 1])) * d; + const float d = float(data_a_packed16[ib].d); + const i8vec2 v0 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs])).xy; // vec4 used due to #12147 + const i8vec2 v1 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs + 1])).xy; + const vec4 v = vec4(v0.x, v0.y, v1.x, v1.y) * d; buf_a[buf_idx ] = FLOAT_TYPE(v.x); buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + buf_a[buf_idx + 2] = FLOAT_TYPE(v.z); + buf_a[buf_idx + 3] = FLOAT_TYPE(v.w); #elif defined(DATA_A_Q2_K) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; @@ -511,7 +567,7 @@ void main() { const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4)); const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); const uint grid = iq2xxs_grid[qs][(idx % 4) / 2] >> (16 * (idx & 1)); - const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); + const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147 buf_a[buf_idx ] = FLOAT_TYPE(v.x); buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); @@ -531,7 +587,7 @@ void main() { const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4)); const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); const uint grid = iq2xs_grid[qs & 511][(idx % 4) / 2] >> (16 * (idx & 1)); - const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); + const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147 buf_a[buf_idx ] = FLOAT_TYPE(v.x); buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); @@ -553,7 +609,7 @@ void main() { const float db = d * 0.25 * (0.5 + scale); const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); const uint16_t grid = unpack16(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 2) >> 1])[idx & 1]; - const vec2 v = db * vec2(sign01) * vec2(unpack8(grid)); + const vec2 v = db * vec2(sign01) * vec2(unpack8(uint32_t(grid)).xy); // vec4 used due to #12147 buf_a[buf_idx ] = FLOAT_TYPE(v.x); buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); @@ -578,7 +634,7 @@ void main() { const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4)); const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); const uint grid = iq3xxs_grid[qs] >> (16 * (idx & 1)); - const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); + const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147 buf_a[buf_idx ] = FLOAT_TYPE(v.x); buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); @@ -598,7 +654,7 @@ void main() { const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign))); const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf)); const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> (16 * (idx % 2)); - const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); + const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147 buf_a[buf_idx ] = FLOAT_TYPE(v.x); buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); @@ -623,17 +679,18 @@ void main() { buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); #elif defined(DATA_A_IQ4_NL) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a; - const uint ib = idx / 16; - const uint iqs = idx & 0xF; + const uint ib = idx / 8; + const uint iqs = idx & 0x07; - const float d = float(data_a[ib].d); - const uint vui = uint(data_a[ib].qs[iqs]); - const vec2 v = vec2(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[vui >> 4]) * d; + const FLOAT_TYPE d = FLOAT_TYPE(data_a_packed16[ib].d); + const uint vui = uint(data_a_packed16[ib].qs[iqs]); - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); + buf_a[buf_idx ] = FLOAT_TYPE(kvalues_iq4nl[vui & 0xF]) * d; + buf_a[buf_idx + 1 ] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 8, 4)]) * d; + buf_a[buf_idx + 16] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 4, 4)]) * d; + buf_a[buf_idx + 17] = FLOAT_TYPE(kvalues_iq4nl[vui >> 12]) * d; #endif } [[unroll]] for (uint l = 0; l < BN; l += loadstride_b) { @@ -661,13 +718,13 @@ void main() { const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b; #endif const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B; - buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx].x); - buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx].y); - buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx].z); - buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx].w); + buf_b[buf_idx + 0] = TO_FLOAT_TYPE(data_b[idx].x); + buf_b[buf_idx + 1] = TO_FLOAT_TYPE(data_b[idx].y); + buf_b[buf_idx + 2] = TO_FLOAT_TYPE(data_b[idx].z); + buf_b[buf_idx + 3] = TO_FLOAT_TYPE(data_b[idx].w); #elif !MUL_MAT_ID if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) { - buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]); + buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]); } else { buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f); } @@ -675,7 +732,7 @@ void main() { const uint row_i = ic * BN + loadc_b + l; if (row_i < _ne1) { const u16vec2 row_idx = row_ids[row_i]; - buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]); + buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]); } else { buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f); } @@ -710,16 +767,14 @@ void main() { } [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { [[unroll]] for (uint j = 0; j < TN; j++) { - cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * SHMEM_STRIDE + i]; + cache_b[j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * SHMEM_STRIDE + i]; } - } - [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { [[unroll]] for (uint cc = 0; cc < TN; cc++) { [[unroll]] for (uint cr = 0; cr < TM; cr++) { const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr; - sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr]), ACC_TYPE(cache_b[wsic * TN + cc]), sums[sums_idx]); + sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr]), ACC_TYPE(cache_b[cc]), sums[sums_idx]); } } } @@ -743,7 +798,7 @@ void main() { [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); - [[unroll]] for (uint col = 0; col < BN; col += storestride) { + [[unroll]] for (uint col = 0; col < TN; col += storestride) { const uint row_i = dc + cm_col * TN + col + store_c; if (row_i >= _ne1) break; diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp index 66dd2c860..918465757 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp @@ -14,15 +14,25 @@ #extension GL_EXT_buffer_reference : enable #extension GL_KHR_shader_subgroup_ballot : enable #extension GL_KHR_shader_subgroup_vote : enable +#ifdef DATA_A_BF16 +#extension GL_EXT_bfloat16 : enable +#endif #include "types.comp" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; +#define IS_MUL_MM2 1 + +layout (constant_id = 0) const uint BLOCK_SIZE = 256; layout (constant_id = 1) const uint BM = 64; layout (constant_id = 2) const uint BN = 64; layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant +layout (constant_id = 4) const bool enable_smaller_matrices = false; +const uint BNover2 = enable_smaller_matrices ? (BN / 2) : BN; +const uint BNover4 = enable_smaller_matrices ? (BN / 4) : BN; + layout (push_constant) uniform parameter { uint M; @@ -48,6 +58,8 @@ layout (push_constant) uniform parameter uint broadcast2; uint broadcast3; #endif + // N dimension for the B matrix can be >= p.N + uint padded_N; } p; @@ -64,10 +76,23 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; #define DECODEFUNCA #endif +#if !defined(fetch_scales) +#define fetch_scales(a, b, c, d, e, f) +#endif +#if !defined(store_scales) +#define store_scales(a) +#endif + +#if defined(DATA_A_BF16) +#define MAT_TYPE bfloat16_t +#else +#define MAT_TYPE FLOAT_TYPE +#endif + #ifdef MUL_MAT_ID layout (binding = 3) readonly buffer IDS {int data_ids[];}; -shared u16vec4 row_ids[3072]; +shared u16vec4 row_ids[4096]; layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufB { B_TYPE b[]; @@ -110,6 +135,8 @@ void main() { init_iq_shmem(gl_WorkGroupSize); #endif + const uint tid = gl_LocalInvocationIndex; + #ifdef MUL_MAT_ID const uint expert_idx = gl_GlobalInvocationID.z; #else @@ -166,15 +193,13 @@ void main() { const uint end_k = min(p.K, (ik + 1) * p.k_split); #endif - coopmat sum; - sum = coopmat(0.0); - #ifdef MUL_MAT_ID uint pos_a = (expert_idx * p.batch_stride_a) / QUANT_K; uint pos_b = 0; #else uint pos_a = (batch_idx_a * p.batch_stride_a) / QUANT_K; uint pos_b = batch_idx * p.batch_stride_b; + uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; #endif uint stride_a = p.stride_a / QUANT_K; @@ -195,6 +220,7 @@ void main() { tensorLayoutNV<2> tensorLayoutB = createTensorLayoutNV(2); tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutBClamp = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); + tensorLayoutD = setTensorLayoutStrideNV(tensorLayoutD, p.stride_d, 1); #if QUANT_K > 1 tensorLayoutA = setTensorLayoutBlockSizeNV(tensorLayoutA, 1, QUANT_K); @@ -202,24 +228,32 @@ void main() { #endif // Use end_k rather than p.K as the dimension because that's what - // we need to bound check against when using split_k + // we need to bound check against when using split_k. + // Bounds check B against padded_N, but bounds check D against N. tensorLayoutA = setTensorLayoutDimensionNV(tensorLayoutA, p.M, end_k); - tensorLayoutB = setTensorLayoutDimensionNV(tensorLayoutB, p.N, end_k); + tensorLayoutB = setTensorLayoutDimensionNV(tensorLayoutB, p.padded_N, end_k); tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.N, p.M); tensorLayoutAClamp = setTensorLayoutDimensionNV(tensorLayoutAClamp, p.M, end_k); - tensorLayoutBClamp = setTensorLayoutDimensionNV(tensorLayoutBClamp, p.N, end_k); + tensorLayoutBClamp = setTensorLayoutDimensionNV(tensorLayoutBClamp, p.padded_N, end_k); tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0); #if !defined(MUL_MAT_ID) + + const uint START_ALIGN_K = 256; + // For Qi_K (block size 256), unroll whole 256 element tiles. + // For legacy quants (block size 32), unroll 8x. + const uint UNROLL_K = (QUANT_K == 256) ? 256 : (BK * 8); + const uint unroll_count = UNROLL_K / BK; + // Detect a fast path where all loads are entirely in bounds and no clamping is required - if ((ir + 1) * BM <= p.M && (ic + 1) * BN <= p.N && (start_k % BK) == 0 && (end_k % BK) == 0 && + if ((ir + 1) * BM <= p.M && (ic + 1) * BN <= p.padded_N && (start_k % START_ALIGN_K) == 0 && (end_k % BK) == 0 && #if QUANT_K == 1 (stride_a % 8) == 0 && #endif - (stride_b % 8) == 0 && (start_k % 8) == 0) { + (stride_b % 8) == 0) { // Hint to the compiler that values are aligned (want 16B alignment) - start_k &= ~7; + start_k &= ~(START_ALIGN_K-1); stride_b &= ~7; #if QUANT_K == 1 stride_a &= ~7; @@ -228,17 +262,131 @@ void main() { tensorLayoutA = setTensorLayoutStrideNV(tensorLayoutA, stride_a, 1); tensorLayoutB = setTensorLayoutStrideNV(tensorLayoutB, stride_b, 1); - uint k_iters = (end_k - start_k + BK - 1) / BK; + uint k_iters = (end_k - start_k) / UNROLL_K; + uint block_k = start_k; - for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) { + // fetch scale values for a tile of quants. These will be copied into shared memory. + // The fetches and stores are pipelined to hide the latency. + fetch_scales(ir * BM, pos_a, stride_a, start_k, tid, true); - coopmat mat_a; - coopmat mat_b; + if (enable_smaller_matrices && ic * BN + BNover4 >= p.N) { + coopmat sum = coopmat(0.0); + for (uint i = 0; i < k_iters; ++i) { - coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose); + store_scales(tid); + if (block_k + UNROLL_K < end_k) { + fetch_scales(ir * BM, pos_a, stride_a, block_k + UNROLL_K, tid, true); + } - sum = coopMatMulAdd(mat_a, mat_b, sum); + // Manually partial unroll + [[unroll]] for (uint j = 0; j < unroll_count; ++j) { + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose); + + sum = coopMatMulAdd(mat_a, mat_b, sum); + block_k += BK; + } + } + // Do any remaining iterations that were not unrolled + if (block_k < end_k) { + store_scales(tid); + } + while (block_k < end_k) { + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose); + + sum = coopMatMulAdd(mat_a, mat_b, sum); + block_k += BK; + } + coopmat mat_d = coopmat(sum); + + coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover4, ir * BM, BM), tensorViewTranspose); + return; + } else if (enable_smaller_matrices && ic * BN + BNover2 >= p.N) { + coopmat sum = coopmat(0.0); + for (uint i = 0; i < k_iters; ++i) { + + store_scales(tid); + if (block_k + UNROLL_K < end_k) { + fetch_scales(ir * BM, pos_a, stride_a, block_k + UNROLL_K, tid, true); + } + + // Manually partial unroll + [[unroll]] for (uint j = 0; j < unroll_count; ++j) { + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose); + + sum = coopMatMulAdd(mat_a, mat_b, sum); + block_k += BK; + } + } + // Do any remaining iterations that were not unrolled + if (block_k < end_k) { + store_scales(tid); + } + while (block_k < end_k) { + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose); + + sum = coopMatMulAdd(mat_a, mat_b, sum); + block_k += BK; + } + coopmat mat_d = coopmat(sum); + + coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover2, ir * BM, BM), tensorViewTranspose); + return; + } else { + coopmat sum = coopmat(0.0); + + for (uint i = 0; i < k_iters; ++i) { + + store_scales(tid); + if (block_k + UNROLL_K < end_k) { + fetch_scales(ir * BM, pos_a, stride_a, block_k + UNROLL_K, tid, true); + } + + // Manually partial unroll + [[unroll]] for (uint j = 0; j < unroll_count; ++j) { + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose); + + sum = coopMatMulAdd(mat_a, mat_b, sum); + block_k += BK; + } + } + // Do any remaining iterations that were not unrolled + if (block_k < end_k) { + store_scales(tid); + } + while (block_k < end_k) { + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose); + + sum = coopMatMulAdd(mat_a, mat_b, sum); + block_k += BK; + } + coopmat mat_d = coopmat(sum); + + coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose); + return; } } else #endif // !defined(MUL_MAT_ID) @@ -251,61 +399,43 @@ void main() { tensorLayoutBClamp = setTensorLayoutStrideNV(tensorLayoutBClamp, stride_b, 1); + coopmat sum; + sum = coopmat(0.0); + + uint k_iters = (end_k - start_k + BK - 1) / BK; + + fetch_scales(ir * BM, pos_a, stride_a, start_k, tid, false); + [[dont_unroll]] - for (uint block_k = start_k; block_k < end_k; block_k += BK) { + for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) { - coopmat mat_a; - coopmat mat_b; - - // Clamping is expensive, so detect different code paths for each combination - // of A and B needing clamping. - bool unclampedA = (ir + 1) * BM <= p.M && block_k + BK <= end_k && (block_k % 8) == 0; -#ifdef MUL_MAT_ID - bool unclampedB = true; -#else - bool unclampedB = (ic + 1) * BN <= p.N && block_k + BK <= end_k && (block_k % 8) == 0; -#endif - if (unclampedA && unclampedB) { - coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, (block_k & ~7), BK) DECODEFUNCA); -#ifdef MUL_MAT_ID - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB); -#else - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, (block_k & ~7), BK), tensorViewTranspose); -#endif - sum = coopMatMulAdd(mat_a, mat_b, sum); - } else if (unclampedA && !unclampedB) { - coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, (block_k & ~7), BK) DECODEFUNCA); - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose); - - sum = coopMatMulAdd(mat_a, mat_b, sum); - } else if (!unclampedA && unclampedB) { - coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); -#ifdef MUL_MAT_ID - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB); -#else - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, (block_k & ~7), BK), tensorViewTranspose); -#endif - sum = coopMatMulAdd(mat_a, mat_b, sum); - } else if (!unclampedA && !unclampedB) { - coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose); - - sum = coopMatMulAdd(mat_a, mat_b, sum); + store_scales(tid); + if (block_k + BK < end_k) { + fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false); } - } - } - // Convert from ACC_TYPE to D_TYPE - coopmat mat_d; - mat_d = coopmat(sum); + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); +#ifdef MUL_MAT_ID + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB); +#else + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose); +#endif + + sum = coopMatMulAdd(mat_a, mat_b, sum); + } + + // Convert from ACC_TYPE to D_TYPE + coopmat mat_d; + mat_d = coopmat(sum); #ifdef MUL_MAT_ID - // Call callback to store each element, remapping row through shared memory - coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic); + // Call callback to store each element, remapping row through shared memory + coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic); #else - tensorLayoutD = setTensorLayoutStrideNV(tensorLayoutD, p.stride_d, 1); - - uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; - coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose); + coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose); #endif + } } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp new file mode 100644 index 000000000..83de90eb7 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp @@ -0,0 +1,442 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require + +#extension GL_EXT_integer_dot_product : require + +#ifdef FLOAT16 +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#endif + +#ifdef COOPMAT +#extension GL_KHR_cooperative_matrix : enable +#extension GL_KHR_memory_scope_semantics : enable +#extension GL_KHR_shader_subgroup_basic : enable +#endif + +#ifdef MUL_MAT_ID +#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require +#endif + +#include "types.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];}; +#if defined(A_TYPE_PACKED32) +layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];}; +#endif +layout (binding = 1) readonly buffer B {block_q8_1_packed32 data_b[];}; +layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; + +#ifdef MUL_MAT_ID +layout (binding = 3) readonly buffer IDS {int data_ids[];}; +#endif + +layout (push_constant) uniform parameter +{ + uint M; + uint N; + uint K; + uint stride_a; + uint stride_b; + uint stride_d; + + uint batch_stride_a; + uint batch_stride_b; + uint batch_stride_d; + +#ifdef MUL_MAT_ID + uint nei0; + uint nei1; + uint nbi1; + uint ne11; +#else + uint k_split; + uint ne02; + uint ne12; + uint broadcast2; + uint broadcast3; +#endif +} p; + +layout (constant_id = 0) const uint BLOCK_SIZE = 64; +layout (constant_id = 1) const uint BM = 64; +layout (constant_id = 2) const uint BN = 64; +// layout (constant_id = 3) const uint BK = 32; +layout (constant_id = 4) const uint WM = 32; +layout (constant_id = 5) const uint WN = 32; +layout (constant_id = 6) const uint WMITER = 2; +layout (constant_id = 7) const uint TM = 4; +layout (constant_id = 8) const uint TN = 2; +layout (constant_id = 9) const uint TK = 1; // Only needed for coopmat +layout (constant_id = 10) const uint WARP = 32; + +#define BK 32 + +#ifdef COOPMAT +#define SHMEM_STRIDE (BK / 4 + 4) +#else +#define SHMEM_STRIDE (BK / 4 + 1) +#endif + +shared int32_t buf_a_qs[BM * SHMEM_STRIDE]; + +#ifndef COOPMAT +#if QUANT_AUXF == 1 +shared FLOAT_TYPE buf_a_dm[BM]; +#else +shared FLOAT_TYPE_VEC2 buf_a_dm[BM]; +#endif +#endif + +shared int32_t buf_b_qs[BN * SHMEM_STRIDE]; +#ifndef COOPMAT +shared FLOAT_TYPE_VEC2 buf_b_ds[BN]; +#endif + +#define LOAD_VEC_A (4 * QUANT_R) +#define LOAD_VEC_B 4 + +#ifdef MUL_MAT_ID +shared u16vec2 row_ids[4096]; +#endif // MUL_MAT_ID + +#define NUM_WARPS (BLOCK_SIZE / WARP) + +#ifdef COOPMAT +shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS]; +#endif + +#include "mul_mmq_funcs.comp" + +void main() { +#ifdef NEEDS_INIT_IQ_SHMEM + init_iq_shmem(gl_WorkGroupSize); +#endif + +#ifdef MUL_MAT_ID + const uint expert_idx = gl_GlobalInvocationID.z; +#else + const uint batch_idx = gl_GlobalInvocationID.z; + + const uint i13 = batch_idx / p.ne12; + const uint i12 = batch_idx % p.ne12; + + const uint i03 = i13 / p.broadcast3; + const uint i02 = i12 / p.broadcast2; + + const uint batch_idx_a = i03 * p.ne02 + i02; +#endif + + const uint blocks_m = (p.M + BM - 1) / BM; + const uint ir = gl_WorkGroupID.x % blocks_m; + const uint ik = gl_WorkGroupID.x / blocks_m; + const uint ic = gl_WorkGroupID.y; + + const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER); + const uint WSUBM = WM / WMITER; + const uint WSUBN = WN / WNITER; + +#ifdef COOPMAT + const uint warp_i = gl_SubgroupID; + + const uint tiw = gl_SubgroupInvocationID; + + const uint cms_per_row = WM / TM; + const uint cms_per_col = WN / TN; + + const uint storestride = WARP / TM; + const uint store_r = tiw % TM; + const uint store_c = tiw / TM; +#else + const uint warp_i = gl_LocalInvocationID.x / WARP; + + const uint tiw = gl_LocalInvocationID.x % WARP; + + const uint tiwr = tiw % (WSUBM / TM); + const uint tiwc = tiw / (WSUBM / TM); +#endif + + const uint warp_r = warp_i % (BM / WM); + const uint warp_c = warp_i / (BM / WM); + + const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A); + const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A); + const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B); + const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B); + + const uint loadstride_a = BLOCK_SIZE * LOAD_VEC_A / BK; + const uint loadstride_b = BLOCK_SIZE * LOAD_VEC_B / BK; + +#ifdef MUL_MAT_ID + uint _ne1 = 0; + for (uint ii1 = 0; ii1 < p.nei1; ii1++) { + for (uint ii0 = 0; ii0 < p.nei0; ii0++) { + if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) { + row_ids[_ne1] = u16vec2(ii0, ii1); + _ne1++; + } + } + } + + barrier(); + + // Workgroup has no work + if (ic * BN >= _ne1) return; +#endif + +#ifdef MUL_MAT_ID + const uint start_k = 0; + const uint end_k = p.K; +#else + const uint start_k = ik * p.k_split; + const uint end_k = min(p.K, (ik + 1) * p.k_split); +#endif + + uint pos_a_ib = ( +#ifdef MUL_MAT_ID + expert_idx * p.batch_stride_a + +#else + batch_idx_a * p.batch_stride_a + +#endif + ir * BM * p.stride_a + start_k) / BK; +#ifdef MUL_MAT_ID + uint pos_b_ib = 0; +#else + uint pos_b_ib = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / BK; +#endif + +#ifdef COOPMAT + coopmat cache_a; + coopmat cache_b; + coopmat cm_result; + + coopmat factors[cms_per_row * cms_per_col]; + + coopmat sums[cms_per_row * cms_per_col]; + + [[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) { + sums[i] = coopmat(0.0f); + } +#else + int32_t cache_a_qs[WMITER * TM * BK / 4]; + + int32_t cache_b_qs[TN * BK / 4]; + + ACC_TYPE sums[WMITER * TM * WNITER * TN]; + + [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) { + sums[i] = ACC_TYPE(0.0f); + } +#endif + +#if QUANT_AUXF == 1 + FLOAT_TYPE cache_a_dm[WMITER * TM]; +#else + FLOAT_TYPE_VEC2 cache_a_dm[WMITER * TM]; +#endif + + FLOAT_TYPE_VEC2 cache_b_ds[TN]; + + for (uint block = start_k; block < end_k; block += BK) { + [[unroll]] for (uint l = 0; loadc_a + l < BM; l += loadstride_a) { + const uint ib = pos_a_ib + (loadc_a + l) * p.stride_a / BK; + const uint iqs = loadr_a; + const uint buf_ib = loadc_a + l; + + if (iqs == 0) { +#if QUANT_AUXF == 1 + buf_a_dm[buf_ib] = get_d(ib); +#else + buf_a_dm[buf_ib] = get_dm(ib); +#endif + } +#if QUANT_R == 1 + buf_a_qs[buf_ib * SHMEM_STRIDE + iqs] = repack(ib, iqs); +#else + const i32vec2 vals = repack(ib, iqs); + buf_a_qs[buf_ib * SHMEM_STRIDE + iqs ] = vals.x; + buf_a_qs[buf_ib * SHMEM_STRIDE + iqs + 4] = vals.y; +#endif + } + [[unroll]] for (uint l = 0; loadc_b + l < BN; l += loadstride_b) { +#ifdef MUL_MAT_ID + const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l]; + const uint idx = pos_b_ib + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b; + const uint ib = idx / 8; + const uint iqs = idx & 0x7; +#else + const uint ib = pos_b_ib + (loadc_b + l) * p.stride_b / BK; + const uint iqs = loadr_b; +#endif + + const uint buf_ib = loadc_b + l; + + if (iqs == 0) { + buf_b_ds[buf_ib] = FLOAT_TYPE_VEC2(data_b[ib].ds); + } + buf_b_qs[buf_ib * SHMEM_STRIDE + iqs] = data_b[ib].qs[iqs]; + } + + barrier(); + + pos_a_ib += 1; + pos_b_ib += 1; + +#ifdef COOPMAT + [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { + const uint ib_a = warp_r * WM + cm_row * TM; + // Load from shared into cache + coopMatLoad(cache_a, buf_a_qs, ib_a * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor); + + // TODO: only cache values that are actually needed + [[unroll]] for (uint t_idx = 0; t_idx < TM; t_idx++) { + cache_a_dm[t_idx] = buf_a_dm[ib_a + t_idx]; + } + + [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { + const uint ib_b = warp_c * WN + cm_col * TN; + coopMatLoad(cache_b, buf_b_qs, ib_b * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor); + + // TODO: only cache values that are actually needed + [[unroll]] for (uint t_idx = 0; t_idx < TN; t_idx++) { + cache_b_dm[t_idx] = buf_b_d[ib_b + t_idx]; + } + + cm_result = coopmat(0); + cm_result = coopMatMulAdd(cache_a, cache_b, cm_result); + + [[unroll]] for (uint col = 0; col < TN; col += storestride) { + coopmat_stage[warp_i * TM * TN + (store_c + col) * TM + store_r] = ACC_TYPE(float(cache_a_d[store_r]) * float(cache_b_d[store_c + col])); + } + + coopMatLoad(factors, coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); + sums[cm_col * cms_per_row + cm_row] += factors * coopmat(cm_result); + } + } +#else + // Load from shared into cache + [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { + [[unroll]] for (uint cr = 0; cr < TM; cr++) { + const uint ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr; + cache_a_dm[wsir * TM + cr] = buf_a_dm[ib]; + [[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) { + cache_a_qs[(wsir * TM + cr) * (BK / 4) + idx_k] = buf_a_qs[ib * SHMEM_STRIDE + idx_k]; + } + } + } + + [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { + [[unroll]] for (uint cc = 0; cc < TN; cc++) { + const uint ib = warp_c * WN + wsic * WSUBN + tiwc * TN + cc; + cache_b_ds[cc] = buf_b_ds[ib]; + [[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) { + cache_b_qs[cc * (BK / 4) + idx_k] = buf_b_qs[ib * SHMEM_STRIDE + idx_k]; + } + } + + [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { + [[unroll]] for (uint cc = 0; cc < TN; cc++) { + [[unroll]] for (uint cr = 0; cr < TM; cr++) { + const uint cache_a_idx = wsir * TM + cr; + const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr; + int32_t q_sum = 0; + [[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) { + q_sum += dotPacked4x8EXT(cache_a_qs[cache_a_idx * (BK / 4) + idx_k], + cache_b_qs[cc * (BK / 4) + idx_k]); + } + + sums[sums_idx] += mul_q8_1(q_sum, cache_a_dm[cache_a_idx], cache_b_ds[cc]); + } + } + } + } +#endif + + barrier(); + } + + const uint dr = ir * BM + warp_r * WM; + const uint dc = ic * BN + warp_c * WN; + +#ifndef MUL_MAT_ID + const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; +#endif + +#ifdef COOPMAT +#ifdef MUL_MAT_ID + [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { + [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { + coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); + + [[unroll]] for (uint col = 0; col < BN; col += storestride) { + const uint row_i = dc + cm_col * TN + col + store_c; + if (row_i >= _ne1) break; + + const u16vec2 row_idx = row_ids[row_i]; + + data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); + } + } + } +#else + const bool is_aligned = p.stride_d % 4 == 0; // Assumption: D_TYPE == float + + [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { + [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { + const bool is_in_bounds = dr + (cm_row + 1) * TM <= p.M && dc + (cm_col + 1) * TN <= p.N; + + if (is_aligned && is_in_bounds) { + // Full coopMat is within bounds and stride_d is aligned with 16B + coopmat cm_dtype = coopmat(sums[cm_col * cms_per_row + cm_row]); + coopMatStore(cm_dtype, data_d, offsets + (dc + cm_col * TN) * p.stride_d + dr + cm_row * TM, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor); + } else if (is_in_bounds) { + // Full coopMat is within bounds, but stride_d is not aligned + coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); + + [[unroll]] for (uint col = 0; col < TN; col += storestride) { + data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); + } + } else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) { + // Partial coopMat is within bounds + coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); + + [[unroll]] for (uint col = 0; col < TN; col += storestride) { + if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) { + data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); + } + } + } + } + } +#endif // MUL_MAT_ID +#else + [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { + [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { + + const uint dr_warp = dr + wsir * WSUBM + tiwr * TM; + const uint dc_warp = dc + wsic * WSUBN + tiwc * TN; + [[unroll]] for (uint cc = 0; cc < TN; cc++) { +#ifdef MUL_MAT_ID + const uint row_i = dc_warp + cc; + if (row_i >= _ne1) break; + + const u16vec2 row_idx = row_ids[row_i]; +#endif // MUL_MAT_ID + [[unroll]] for (uint cr = 0; cr < TM; cr++) { +#ifdef MUL_MAT_ID + data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); +#else + if (dr_warp + cr < p.M && dc_warp + cc < p.N) { + data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); + } +#endif // MUL_MAT_ID + } + } + } + } +#endif // COOPMAT +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp new file mode 100644 index 000000000..63b15471b --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp @@ -0,0 +1,99 @@ +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require + +#include "types.comp" + +// Each iqs value maps to a 32-bit integer + +#if defined(DATA_A_Q4_0) +i32vec2 repack(uint ib, uint iqs) { + // Use 2-byte loads since a q4_0 block (18 bytes) is not divisible by 4 + const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2 ], + data_a[ib].qs[iqs * 2 + 1]); + const uint32_t vui = pack32(quants); + return i32vec2( vui & 0x0F0F0F0F, + (vui >> 4) & 0x0F0F0F0F); +} + +ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) { + return ACC_TYPE(da * (float(q_sum) * dsb.x - 8.0f * dsb.y)); +} +#endif + +#if defined(DATA_A_Q4_1) +i32vec2 repack(uint ib, uint iqs) { + // Use 4-byte loads since a q4_1 block (20 bytes) is divisible by 4 + const uint32_t vui = data_a_packed32[ib].qs[iqs]; + return i32vec2( vui & 0x0F0F0F0F, + (vui >> 4) & 0x0F0F0F0F); +} + +ACC_TYPE mul_q8_1(int32_t q_sum, vec2 dma, vec2 dsb) { + return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y); +} +#endif + +#if defined(DATA_A_Q5_0) +i32vec2 repack(uint ib, uint iqs) { + // Use 2-byte loads since a q5_0 block (22 bytes) is not divisible by 4 + const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2 ], + data_a[ib].qs[iqs * 2 + 1]); + const uint32_t vui = pack32(quants); + const int32_t qh = int32_t((uint32_t(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]) >> (4 * iqs)); + const int32_t v0 = int32_t(vui & 0x0F0F0F0F) + | ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28) + + const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F) + | (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28) + + return i32vec2(v0, v1); +} + +ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) { + return ACC_TYPE(da * (float(q_sum) * dsb.x - 16.0f * dsb.y)); +} +#endif + +#if defined(DATA_A_Q5_1) +i32vec2 repack(uint ib, uint iqs) { + // Use 4-byte loads since a q5_1 block (24 bytes) is divisible by 4 + const uint32_t vui = data_a_packed32[ib].qs[iqs]; + const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs)); + const int32_t v0 = int32_t(vui & 0x0F0F0F0F) + | ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28) + + const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F) + | (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28) + + return i32vec2(v0, v1); +} + +ACC_TYPE mul_q8_1(int32_t q_sum, vec2 dma, vec2 dsb) { + return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y); +} +#endif + +#if defined(DATA_A_Q8_0) +int32_t repack(uint ib, uint iqs) { + // Use 2-byte loads since a q8_0 block (34 bytes) is not divisible by 4 + return pack32(i16vec2(data_a[ib].qs[iqs * 2 ], + data_a[ib].qs[iqs * 2 + 1])); +} + +ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) { + return ACC_TYPE(float(q_sum) * da * dsb.x); +} +#endif + +#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL) +FLOAT_TYPE get_d(uint ib) { + return FLOAT_TYPE(data_a[ib].d); +} +#endif + +#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1) +FLOAT_TYPE_VEC2 get_dm(uint ib) { + return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm); +} +#endif diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp new file mode 100644 index 000000000..e2e020fec --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp @@ -0,0 +1,77 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : require +#extension GL_EXT_shader_16bit_storage : require + +layout (push_constant) uniform parameter +{ + uint ne; +} p; + +#include "types.comp" + +layout(constant_id = 0) const uint GROUP_SIZE = 32; +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {vec4 data_a[];}; +layout (binding = 1) writeonly buffer D {block_q8_1_packed32 data_b[];}; + +shared float shmem[GROUP_SIZE]; + +void quantize() { + const uint wgid = gl_WorkGroupID.x; + const uint tid = gl_LocalInvocationID.x; + + // Each thread handles a vec4, so 8 threads handle a block + const uint blocks_per_group = GROUP_SIZE / 8; + + const uint block_in_wg = tid / 8; + + const uint ib = wgid * blocks_per_group + block_in_wg; + const uint iqs = tid % 8; + + if (ib >= gl_NumWorkGroups.x * blocks_per_group) { + return; + } + + const uint a_idx = ib * 8 + iqs; + + vec4 vals = a_idx < p.ne ? data_a[a_idx] : vec4(0.0f); + const vec4 abs_vals = abs(vals); + + // Find absolute max for each block + shmem[tid] = max(max(abs_vals.x, abs_vals.y), max(abs_vals.z, abs_vals.w)); + barrier(); + [[unroll]] for (uint s = 4; s > 0; s >>= 1) { + if (iqs < s) { + shmem[tid] = max(shmem[tid], shmem[tid + s]); + } + barrier(); + } + + const float amax = shmem[block_in_wg * 8]; + const float d = amax / 127.0; + const float d_inv = d != 0.0 ? 1.0 / d : 0.0; + vals = round(vals * d_inv); + data_b[ib].qs[iqs] = pack32(i8vec4(round(vals))); + barrier(); + + // Calculate the sum for each block + shmem[tid] = vals.x + vals.y + vals.z + vals.w; + barrier(); + [[unroll]] for (uint s = 4; s > 0; s >>= 1) { + if (iqs < s) { + shmem[tid] += shmem[tid + s]; + } + barrier(); + } + if (iqs == 0) { + const float sum = shmem[tid]; + + data_b[ib].ds = f16vec2(vec2(d, sum * d)); + } +} + +void main() { + quantize(); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp index 52a19b62a..4f806270c 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp @@ -17,5 +17,5 @@ void main() { return; } - data_d[i] = max(float(data_a[i]), 0); + data_d[i] = D_TYPE(max(float(data_a[i]), 0)); } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp index b554400ba..deb8ee996 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp @@ -1,6 +1,6 @@ #version 450 -#include "generic_head.comp" +#include "generic_unary_head.comp" #include "types.comp" #extension GL_EXT_control_flow_attributes : enable @@ -8,19 +8,29 @@ layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; -layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; -layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; - shared FLOAT_TYPE sum[BLOCK_SIZE]; void main() { - const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; - const uint tid = gl_LocalInvocationID.x; + const uint ncols = p.ne00; + const uint nrows = gl_NumWorkGroups.x; + const uint nchannels = gl_NumWorkGroups.y; + + const uint row = gl_WorkGroupID.x; + const uint channel = gl_WorkGroupID.y; + const uint samp = gl_WorkGroupID.z; + const uint tid = gl_LocalInvocationID.x; + + const uint stride_row = p.nb01; + const uint stride_channel = p.nb02; + const uint stride_sample = p.nb03; + + uint32_t a_offset = samp*stride_sample + channel*stride_channel + row*stride_row + get_aoffset(); + uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset(); sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp - [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { - const FLOAT_TYPE xi = FLOAT_TYPE(data_a[row*p.KX + col]); + [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) { + const FLOAT_TYPE xi = FLOAT_TYPE(data_a[a_offset + col]); sum[tid] += xi * xi; } @@ -33,10 +43,10 @@ void main() { barrier(); } - const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(p.KX); + const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(ncols); const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1)); - [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { - data_d[row*p.KX + col] = D_TYPE(scale * FLOAT_TYPE(data_a[row*p.KX + col])); + [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) { + data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col])); } } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp index 776581e2c..5c9e5c350 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp @@ -16,5 +16,5 @@ void main() { if (i >= p.KX) { return; } - data_d[i] = D_TYPE(1. / (1 + exp(-1. *data_a[i]))); + data_d[i] = D_TYPE(1. / (1 + exp(-1. * float(data_a[i])))); } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp index 495f966bd..8a6f868f5 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp @@ -16,5 +16,5 @@ void main() { if (i >= p.KX) { return; } - data_d[i] = D_TYPE(1. - 2. / (exp(2.*data_a[i]) + 1.)); + data_d[i] = D_TYPE(1. - 2. / (exp(2.*float(data_a[i])) + 1.)); } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/test_bfloat16_support.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/test_bfloat16_support.comp new file mode 100644 index 000000000..fd0ba401f --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/test_bfloat16_support.comp @@ -0,0 +1,7 @@ +#version 460 + +#extension GL_EXT_bfloat16 : require + +void main() +{ +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/test_integer_dot_support.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/test_integer_dot_support.comp new file mode 100644 index 000000000..470e3074d --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/test_integer_dot_support.comp @@ -0,0 +1,7 @@ +#version 460 + +#extension GL_EXT_integer_dot_product : require + +void main() +{ +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/types.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/types.comp index dfa16cda5..3bde71783 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/types.comp @@ -1,7 +1,7 @@ - #if !defined(GGML_TYPES_COMP) #define GGML_TYPES_COMP +#extension GL_EXT_shader_explicit_arithmetic_types_int64 : require #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require #extension GL_EXT_shader_explicit_arithmetic_types_int8 : require @@ -33,6 +33,19 @@ #endif #endif +#if defined(DATA_A_BF16) +#define QUANT_K 1 +#define QUANT_R 1 + +#if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1 +#define A_TYPE uint16_t +#elif LOAD_VEC_A == 4 +#define A_TYPE u16vec4 +#elif LOAD_VEC_A == 8 +#error unsupported +#endif +#endif + #define QUANT_K_Q4_0 32 #define QUANT_R_Q4_0 2 @@ -50,6 +63,7 @@ struct block_q4_0_packed16 #if defined(DATA_A_Q4_0) #define QUANT_K QUANT_K_Q4_0 #define QUANT_R QUANT_R_Q4_0 +#define QUANT_AUXF 1 #define A_TYPE block_q4_0 #define A_TYPE_PACKED16 block_q4_0_packed16 #endif @@ -71,11 +85,19 @@ struct block_q4_1_packed16 uint16_t qs[16/2]; }; +struct block_q4_1_packed32 +{ + f16vec2 dm; + uint32_t qs[16/4]; +}; + #if defined(DATA_A_Q4_1) #define QUANT_K QUANT_K_Q4_1 #define QUANT_R QUANT_R_Q4_1 +#define QUANT_AUXF 2 #define A_TYPE block_q4_1 #define A_TYPE_PACKED16 block_q4_1_packed16 +#define A_TYPE_PACKED32 block_q4_1_packed32 #endif #define QUANT_K_Q5_0 32 @@ -98,6 +120,7 @@ struct block_q5_0_packed16 #if defined(DATA_A_Q5_0) #define QUANT_K QUANT_K_Q5_0 #define QUANT_R QUANT_R_Q5_0 +#define QUANT_AUXF 1 #define A_TYPE block_q5_0 #define A_TYPE_PACKED16 block_q5_0_packed16 #endif @@ -121,11 +144,20 @@ struct block_q5_1_packed16 uint16_t qs[16/2]; }; +struct block_q5_1_packed32 +{ + f16vec2 dm; + uint qh; + uint32_t qs[16/4]; +}; + #if defined(DATA_A_Q5_1) #define QUANT_K QUANT_K_Q5_1 #define QUANT_R QUANT_R_Q5_1 +#define QUANT_AUXF 2 #define A_TYPE block_q5_1 #define A_TYPE_PACKED16 block_q5_1_packed16 +#define A_TYPE_PACKED32 block_q5_1_packed32 #endif #define QUANT_K_Q8_0 32 @@ -139,16 +171,42 @@ struct block_q8_0 struct block_q8_0_packed16 { float16_t d; - uint16_t qs[32/2]; + int16_t qs[32/2]; +}; +struct block_q8_0_packed32 +{ + float16_t d; + int32_t qs[32/4]; }; #if defined(DATA_A_Q8_0) #define QUANT_K QUANT_K_Q8_0 #define QUANT_R QUANT_R_Q8_0 +#define QUANT_AUXF 1 #define A_TYPE block_q8_0 #define A_TYPE_PACKED16 block_q8_0_packed16 +#define A_TYPE_PACKED32 block_q8_0_packed32 #endif +#define QUANT_K_Q8_1 32 +#define QUANT_R_Q8_1 1 + +struct block_q8_1 +{ + f16vec2 ds; + int8_t qs[32]; +}; +struct block_q8_1_packed16 +{ + f16vec2 ds; + int16_t qs[16]; +}; +struct block_q8_1_packed32 +{ + f16vec2 ds; + int32_t qs[8]; +}; + // K-quants #define QUANT_K_Q2_K 256 @@ -312,6 +370,12 @@ struct block_iq1_m { uint16_t scales[QUANT_K_IQ1_M/64]; }; +struct block_iq1_m_packed64 { + uint64_t qs[QUANT_K_IQ1_M/8/8]; + uint64_t qh[QUANT_K_IQ1_M/16/8]; + uint64_t scales; +}; + #if defined(DATA_A_IQ1_S) #define QUANT_K QUANT_K_IQ1_S #define QUANT_R QUANT_R_IQ1_S @@ -466,10 +530,13 @@ shared uint16_t iq1s_grid[2048]; void init_iq_shmem(uvec3 wgsize) { // copy the table into shared memory and sync - for (uint i = gl_LocalInvocationIndex.x; i < iq1s_grid_const.length(); i += wgsize.x) { - u16vec2 g = unpack16(iq1s_grid_const[i]); - iq1s_grid[2*i+0] = g.x; - iq1s_grid[2*i+1] = g.y; + [[unroll]] for (uint i = 0; i < iq1s_grid_const.length(); i += wgsize.x) { + uint idx = i + gl_LocalInvocationIndex.x; + if (iq1s_grid_const.length() % wgsize.x == 0 || idx < iq1s_grid_const.length()) { + u16vec2 g = unpack16(iq1s_grid_const[idx]); + iq1s_grid[2*idx+0] = g.x; + iq1s_grid[2*idx+1] = g.y; + } } barrier(); } @@ -565,8 +632,10 @@ shared uvec2 iq2xxs_grid[256]; void init_iq_shmem(uvec3 wgsize) { // copy the table into shared memory and sync - for (uint i = gl_LocalInvocationIndex.x; i < iq2xxs_grid.length(); i += wgsize.x) { - iq2xxs_grid[i] = iq2xxs_grid_const[i]; + [[unroll]] for (uint i = 0; i < iq2xxs_grid.length(); i += wgsize.x) { + if (iq2xxs_grid_const.length() % wgsize.x == 0 || i + gl_LocalInvocationIndex.x < iq2xxs_grid_const.length()) { + iq2xxs_grid[i + gl_LocalInvocationIndex.x] = iq2xxs_grid_const[i + gl_LocalInvocationIndex.x]; + } } barrier(); } @@ -733,8 +802,10 @@ shared uvec2 iq2xs_grid[512]; void init_iq_shmem(uvec3 wgsize) { // copy the table into shared memory and sync - for (uint i = gl_LocalInvocationIndex.x; i < iq2xs_grid.length(); i += wgsize.x) { - iq2xs_grid[i] = iq2xs_grid_const[i]; + [[unroll]] for (uint i = 0; i < iq2xs_grid.length(); i += wgsize.x) { + if (iq2xs_grid.length() % wgsize.x == 0 || i + gl_LocalInvocationIndex.x < iq2xs_grid_const.length()) { + iq2xs_grid[i + gl_LocalInvocationIndex.x] = iq2xs_grid_const[i + gl_LocalInvocationIndex.x]; + } } barrier(); } @@ -756,6 +827,14 @@ struct block_iq2_s uint8_t scales[QUANT_K_IQ2_S/32]; }; +struct block_iq2_s_packed16 +{ + float16_t d; + uint16_t qs[QUANT_K_IQ2_S/8]; + uint16_t qh[QUANT_K_IQ2_S/64]; + uint16_t scales[QUANT_K_IQ2_S/64]; +}; + #if defined(DATA_A_IQ2_S) const uvec2 iq2s_grid_const[1024] = { @@ -1023,8 +1102,10 @@ shared uvec2 iq2s_grid[1024]; void init_iq_shmem(uvec3 wgsize) { // copy the table into shared memory and sync - for (uint i = gl_LocalInvocationIndex.x; i < iq2s_grid.length(); i += wgsize.x) { - iq2s_grid[i] = iq2s_grid_const[i]; + [[unroll]] for (uint i = 0; i < iq2s_grid.length(); i += wgsize.x) { + if (iq2s_grid.length() % wgsize.x == 0 || i + gl_LocalInvocationIndex.x < iq2s_grid_const.length()) { + iq2s_grid[i + gl_LocalInvocationIndex.x] = iq2s_grid_const[i + gl_LocalInvocationIndex.x]; + } } barrier(); } @@ -1032,6 +1113,7 @@ void init_iq_shmem(uvec3 wgsize) #define QUANT_K QUANT_K_IQ2_S #define QUANT_R QUANT_R_IQ2_S #define A_TYPE block_iq2_s +#define A_TYPE_PACKED16 block_iq2_s_packed16 #endif #define QUANT_K_IQ3_XXS 256 @@ -1092,8 +1174,10 @@ shared uint32_t iq3xxs_grid[256]; void init_iq_shmem(uvec3 wgsize) { // copy the table into shared memory and sync - for (uint i = gl_LocalInvocationIndex.x; i < iq3xxs_grid.length(); i += wgsize.x) { - iq3xxs_grid[i] = iq3xxs_grid_const[i]; + [[unroll]] for (uint i = 0; i < iq3xxs_grid.length(); i += wgsize.x) { + if (iq3xxs_grid.length() % wgsize.x == 0 || i + gl_LocalInvocationIndex.x < iq3xxs_grid.length()) { + iq3xxs_grid[i + gl_LocalInvocationIndex.x] = iq3xxs_grid_const[i + gl_LocalInvocationIndex.x]; + } } barrier(); } @@ -1200,8 +1284,10 @@ shared uint32_t iq3s_grid[512]; void init_iq_shmem(uvec3 wgsize) { // copy the table into shared memory and sync - for (uint i = gl_LocalInvocationIndex.x; i < iq3s_grid.length(); i += wgsize.x) { - iq3s_grid[i] = iq3s_grid_const[i]; + [[unroll]] for (uint i = 0; i < iq3s_grid.length(); i += wgsize.x) { + if (iq3s_grid.length() % wgsize.x == 0 || i + gl_LocalInvocationIndex.x < iq3s_grid.length()) { + iq3s_grid[i + gl_LocalInvocationIndex.x] = iq3s_grid_const[i + gl_LocalInvocationIndex.x]; + } } barrier(); } @@ -1270,4 +1356,18 @@ void init_iq_shmem(uvec3 wgsize) } #endif +// returns the bfloat value in the low 16b. +// See ggml_compute_fp32_to_bf16 +uint32_t fp32_to_bf16(float f) +{ + uint32_t u = floatBitsToUint(f); + u = (u + (0x7fff + ((u >> 16) & 1))) >> 16; + return u; +} + +float bf16_to_fp32(uint32_t u) +{ + return uintBitsToFloat(u << 16); +} + #endif // !defined(GGML_TYPES_COMP) diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index c5e0bba82..d196137eb 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -63,7 +63,8 @@ const std::vector type_names = { "iq3_xxs", "iq3_s", "iq4_xs", - "iq4_nl" + "iq4_nl", + "bf16", }; namespace { @@ -295,7 +296,9 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4"; std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4"; - std::map base_dict = {{"FLOAT_TYPE", (coopmat2 || fp16) ? "float16_t" : "float"}}; + std::map base_dict = { + {"FLOAT_TYPE_VEC2", (coopmat2 || fp16) ? "f16vec2" : "vec2"}, + }; std::string shader_name = "matmul"; if (matmul_id) { @@ -313,34 +316,81 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool base_dict["COOPMAT"] = "1"; } - base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float"; + const std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp"; - std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp"; + auto const &FLOAT_TYPE = [&](const std::string &t) -> std::string { + if (t == "bf16") { + // scalar path promotes to float + if (!coopmat && !coopmat2) { + return "float"; + } + return "bfloat16_t"; + } + if (coopmat2 || fp16) { + return "float16_t"; + } + return "float"; + }; // Shaders with f16 B_TYPE - string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); + + // bf16 + { + std::string load_vec_a_unaligned = "1"; + // For aligned matmul loads + std::string load_vec_a = coopmat2 ? "1" : "4"; + + // scalar path promotes to float + std::string to_float_type = (coopmat || coopmat2) ? "uintBitsToBFloat16EXT" : "bf16_to_fp32"; + + // If bfloat16 is not supported, then only compile the scalar (promote to fp32) shader +#if !defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + if (!(coopmat || coopmat2)) +#endif + { + string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_bf16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc); + } + } for (const auto& tname : type_names) { + std::string load_vec_quant = "2"; + if ((tname == "q4_0") || (tname == "q4_1")) + load_vec_quant = "8"; + else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq4_nl")) + load_vec_quant = "4"; + + if (tname == "bf16") { + continue; + } + std::string data_a_key = "DATA_A_" + to_uppercase(tname); // For unaligned, load one at a time for f32/f16, or two at a time for quants - std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16") ? "1" : "2"; + std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? "1" : load_vec_quant; // For aligned matmul loads - std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : "2"; + std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? load_vec : load_vec_quant; // don't generate f32 variants for coopmat2 if (!coopmat2) { - string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); } if (tname != "f16" && tname != "f32") { - string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); } + +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + if (!coopmat && !coopmat2 && !matmul_id && (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "q8_0")) { + string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc); + } +#endif } } @@ -371,7 +421,6 @@ void process_shaders() { #endif } -#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) // flash attention for (const auto& f16acc : {false, true}) { std::string acctype = f16acc ? "float16_t" : "float"; @@ -380,7 +429,9 @@ void process_shaders() { if (tname == "f32") { continue; } + if (tname == "bf16") continue; +#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) if (tname == "f16") { string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, true, f16acc); @@ -389,14 +440,22 @@ void process_shaders() { string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc); } +#endif + if (tname == "f16") { + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", + merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, false, f16acc); + } else if (tname == "q4_0" || tname == "q8_0") { + std::string data_a_key = "DATA_A_" + to_uppercase(tname); + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", + merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc); + } } } -#endif for (const auto& tname : type_names) { // mul mat vec std::string data_a_key = "DATA_A_" + to_uppercase(tname); - std::string shader = (string_ends_with(tname, "_k") || string_starts_with(tname, "iq1_")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp"; + std::string shader = (string_ends_with(tname, "_k") || string_starts_with(tname, "iq1_") || string_starts_with(tname, "iq2_") || string_starts_with(tname, "iq3_")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp"; string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}})); string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}})); @@ -404,12 +463,12 @@ void process_shaders() { string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}})); // Dequant shaders - if (tname != "f16") { + if (tname != "f16" && tname != "bf16") { string_to_spv("dequant_" + tname, "dequant_" + tname + ".comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float16_t"}})); } if (!string_ends_with(tname, "_k")) { - shader = (tname == "f32" || tname == "f16") ? "get_rows.comp" : "get_rows_quant.comp"; + shader = (tname == "f32" || tname == "f16" || tname == "bf16") ? "get_rows.comp" : "get_rows_quant.comp"; if (tname == "f16") { string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}})); @@ -420,35 +479,62 @@ void process_shaders() { } } - string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("mul_mat_vec_p021_f16_f32_subgroup_add", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}); + string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}); + string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}); // Norms string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}); string_to_spv("cpy_f16_f16", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); + string_to_spv("cpy_f16_f32", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); + string_to_spv("cpy_f32_bf16","copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}}); string_to_spv("contig_cpy_f32_f32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("contig_cpy_f32_f16", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}); string_to_spv("contig_cpy_f16_f16", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); + string_to_spv("contig_cpy_f16_f32", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); + string_to_spv("contig_cpy_f32_bf16","contig_copy.comp",{{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}}); for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) { string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + string_to_spv("cpy_f32_" + t + "_rte", "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}}); string_to_spv("cpy_" + t + "_f32", "copy_from_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); } - string_to_spv("add_f32", "add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); - string_to_spv("add_f16_f32_f16", "add.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}}); + auto get_type_str = [](bool f16) { + return f16 ? "float16_t" : "float"; + }; + auto get_suffix = [](bool src0_f16, bool src1_f16, bool dst_f16) { + std::string s; + s += std::string(src0_f16 ? "_f16" : "_f32"); + s += std::string(src1_f16 ? "_f16" : "_f32"); + s += std::string(dst_f16 ? "_f16" : "_f32"); + return s; + }; + for (std::string op : {"add", "sub", "mul", "div"}) { + for (auto src0_f16 : {false, true}) { + for (auto src1_f16 : {false, true}) { + for (auto dst_f16 : {false, true}) { + auto name = op + get_suffix(src0_f16, src1_f16, dst_f16); + string_to_spv(name.c_str(), op + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}}); + } + } + } + } string_to_spv("sub_f32", "sub.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {}); + string_to_spv("fa_split_k_reduce", "flash_attn_split_k_reduce.comp", {}); + string_to_spv("quantize_q8_1", "quantize_q8_1.comp", {}); string_to_spv("mul_f32", "mul.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); @@ -475,14 +561,21 @@ void process_shaders() { string_to_spv("upscale_f32", "upscale.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("gelu_quick_f32", "gelu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("sigmoid_f32", "sigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("gelu_f16", "gelu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("gelu_quick_f16", "gelu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("gelu_quick_f32", "gelu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("silu_f16", "silu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("relu_f16", "relu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("tanh_f16", "tanh.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("sigmoid_f16", "sigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("sigmoid_f32", "sigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + + string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); @@ -522,8 +615,13 @@ void process_shaders() { string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); + string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); + string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); + string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}})); + string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}})); + for (auto &c : compiles) { c.wait(); } @@ -578,7 +676,12 @@ void write_output_files() { std::remove(path.c_str()); } } - + for (const char *op : {"add", "sub", "mul", "div"}) { + fprintf(hdr, "extern unsigned char *%s_data[2][2][2];\n", op); + fprintf(hdr, "extern uint64_t %s_len[2][2][2];\n", op); + fprintf(src, "unsigned char *%s_data[2][2][2] = {{{%s_f32_f32_f32_data, %s_f32_f32_f16_data}, {%s_f32_f16_f32_data, %s_f32_f16_f16_data}}, {{%s_f16_f32_f32_data, %s_f16_f32_f16_data}, {%s_f16_f16_f32_data, %s_f16_f16_f16_data}}};\n", op, op, op, op, op, op, op, op, op); + fprintf(src, "uint64_t %s_len[2][2][2] = {{{%s_f32_f32_f32_len, %s_f32_f32_f16_len}, {%s_f32_f16_f32_len, %s_f32_f16_f16_len}}, {{%s_f16_f32_f32_len, %s_f16_f32_f16_len}, {%s_f16_f16_f32_len, %s_f16_f16_f16_len}}};\n", op, op, op, op, op, op, op, op, op); + } fclose(hdr); fclose(src); } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp new file mode 100644 index 000000000..88c1c02b3 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp @@ -0,0 +1,91 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : require + +#define BLOCK_SIZE 64 +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout(push_constant) uniform Parameters { + uint B; + uint T; + uint C; + uint H; +}; + +layout(binding = 0) readonly buffer RBuf { A_TYPE r[]; }; +layout(binding = 1) readonly buffer WBuf { A_TYPE w[]; }; +layout(binding = 2) readonly buffer KBuf { A_TYPE k[]; }; +layout(binding = 3) readonly buffer VBuf { A_TYPE v[]; }; +layout(binding = 4) readonly buffer ABuf { A_TYPE a[]; }; +layout(binding = 5) readonly buffer BBuf { A_TYPE b[]; }; +layout(binding = 6) readonly buffer StateBuf { A_TYPE state_in[]; }; +layout(binding = 7) buffer DstBuf { A_TYPE dst[]; }; + +shared A_TYPE _r[BLOCK_SIZE], _w[BLOCK_SIZE], _k[BLOCK_SIZE], _a[BLOCK_SIZE], _b[BLOCK_SIZE]; + +void main() { + const uint head_size = BLOCK_SIZE; + const uint batch_id = gl_WorkGroupID.x / H; + const uint head_id = gl_WorkGroupID.x % H; + const uint tid = gl_LocalInvocationID.x; + + const uint state_size = C * head_size; + const uint n_seq_tokens = T / B; + + if (batch_id >= B || head_id >= H) { + return; + } + + A_TYPE state[BLOCK_SIZE]; + [[unroll]] for (uint i = 0; i < head_size; i++) { + state[i] = state_in[batch_id * state_size + head_id * head_size * head_size + + tid * head_size + i]; + } + + const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid; + const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid; + + for (uint t = start_t; t < end_t; t += C) { + barrier(); + _r[tid] = r[t]; + _w[tid] = w[t]; + _k[tid] = k[t]; + _a[tid] = a[t]; + _b[tid] = b[t]; + barrier(); + + A_TYPE sa = 0.0; + [[unroll]] for (uint j = 0; j < head_size; j += 4) { + vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]); + vec4 a_vec = vec4(_a[j], _a[j+1], _a[j+2], _a[j+3]); + sa += dot(s_vec, a_vec); + } + + const A_TYPE v_val = v[t]; + A_TYPE y = 0.0; + + [[unroll]] for (uint j = 0; j < head_size; j += 4) { + vec4 r_vec = vec4(_r[j], _r[j+1], _r[j+2], _r[j+3]); + vec4 w_vec = vec4(_w[j], _w[j+1], _w[j+2], _w[j+3]); + vec4 k_vec = vec4(_k[j], _k[j+1], _k[j+2], _k[j+3]); + vec4 b_vec = vec4(_b[j], _b[j+1], _b[j+2], _b[j+3]); + vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]); + + vec4 kv = k_vec * v_val; + s_vec = s_vec * w_vec + kv + sa * b_vec; + y += dot(r_vec, s_vec); + + state[j] = s_vec.x; + state[j+1] = s_vec.y; + state[j+2] = s_vec.z; + state[j+3] = s_vec.w; + } + + dst[t] = y; + } + + [[unroll]] for (uint i = 0; i < head_size; i++) { + dst[T * C + batch_id * state_size + head_id * head_size * head_size + + tid * head_size + i] = state[i]; + } +} From fd4480a8480471d214d0b02b979286c9cd5b236d Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sun, 10 Aug 2025 16:05:09 +0200 Subject: [PATCH 046/172] Fixed duplicate sync in ggml.go --- ml/backend/ggml/ggml.go | 1 - 1 file changed, 1 deletion(-) diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 591e29ccb..d1d3c40e6 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -23,7 +23,6 @@ import ( "sync/atomic" "unicode" "unsafe" - "sync" "github.com/ollama/ollama/format" "github.com/ollama/ollama/fs" From 1edbfd0559ed203a82e2596ef8e5917bde3d8584 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sun, 10 Aug 2025 16:07:24 +0200 Subject: [PATCH 047/172] Revert changes in ggml.go --- ml/backend/ggml/ggml.go | 36 ------------------------------------ 1 file changed, 36 deletions(-) diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index d1d3c40e6..ee653df8c 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -404,42 +404,6 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) { } } - // concurrently read in tensor data. uses a section reader which is safe for concurrent reads - sr := io.NewSectionReader(r, int64(meta.Tensors().Offset), n-int64(meta.Tensors().Offset)) - var tensorSetMutex sync.Mutex - var g errgroup.Group - for _, t := range meta.Tensors().Items() { - for _, target := range targets[t.Name] { - g.Go(func() error { - if target == "" { - target = t.Name - } - - tt, ok := tensors[target] - if !ok { - return fmt.Errorf("unassigned tensor: %s", t.Name) - } - - bts := C.malloc(C.size_t(t.Size())) - if bts == nil { - return errors.New("failed to allocate tensor buffer") - } - defer C.free(bts) - - buf := unsafe.Slice((*byte)(bts), t.Size()) - n, err := io.ReadFull(io.NewSectionReader(sr, int64(t.Offset), int64(t.Size())), buf) - if err != nil || n != len(buf) { - return errors.New("read failed") - } - - tensorSetMutex.Lock() - C.ggml_backend_tensor_set(tt, bts, 0, C.size_t(t.Size())) - tensorSetMutex.Unlock() - return nil - }) - } - } - if g.Wait() != nil { return nil, err } From 60a015e8c3c0b51095eda6043c21cc30341f46b7 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sun, 10 Aug 2025 16:09:44 +0200 Subject: [PATCH 048/172] Revert chnages in ggml.go --- ml/backend/ggml/ggml.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index ee653df8c..aa241e9b6 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -404,9 +404,6 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) { } } - if g.Wait() != nil { - return nil, err - } // map devices to backend buffer types so new tensors can be assigned to the correct device deviceBufferTypes := make(map[C.ggml_backend_dev_t]C.ggml_backend_buffer_type_t) From 5270c4c5f77e8226aa2f3fa1461a22e11fe1fd8f Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sun, 10 Aug 2025 16:53:13 +0200 Subject: [PATCH 049/172] enable falsh attention on vulkan --- discover/gpu.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/discover/gpu.go b/discover/gpu.go index 9048631a8..53b3f75cc 100644 --- a/discover/gpu.go +++ b/discover/gpu.go @@ -422,7 +422,8 @@ func GetGPUInfo() GpuInfoList { C.free(unsafe.Pointer(memInfo.err)) continue } - + + gpuInfo.FlashAttention = C.vk_check_flash_attention(*vHandles.vulkan, C.int(i)) == 0 // 0 means supported gpuInfo.TotalMemory = uint64(memInfo.total) gpuInfo.FreeMemory = uint64(memInfo.free) gpuInfo.ID = C.GoString(&memInfo.gpu_id[0]) From a1393414ced0aa6b6413a65a92219e664035faee Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sun, 10 Aug 2025 17:54:13 +0200 Subject: [PATCH 050/172] revert remove parenthesis --- discover/gpu.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/discover/gpu.go b/discover/gpu.go index 53b3f75cc..13c607b18 100644 --- a/discover/gpu.go +++ b/discover/gpu.go @@ -423,7 +423,7 @@ func GetGPUInfo() GpuInfoList { continue } - gpuInfo.FlashAttention = C.vk_check_flash_attention(*vHandles.vulkan, C.int(i)) == 0 // 0 means supported + gpuInfo.FlashAttention = (C.vk_check_flash_attention(*vHandles.vulkan, C.int(i)) == 0) // 0 means supported gpuInfo.TotalMemory = uint64(memInfo.total) gpuInfo.FreeMemory = uint64(memInfo.free) gpuInfo.ID = C.GoString(&memInfo.gpu_id[0]) From ee24b967f1b43ac0ba90f5e2986aa969d62f3a2c Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sun, 10 Aug 2025 19:57:14 +0200 Subject: [PATCH 051/172] fixed flash attention logic enabling --- discover/gpu.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/discover/gpu.go b/discover/gpu.go index 13c607b18..a77f8896b 100644 --- a/discover/gpu.go +++ b/discover/gpu.go @@ -423,7 +423,7 @@ func GetGPUInfo() GpuInfoList { continue } - gpuInfo.FlashAttention = (C.vk_check_flash_attention(*vHandles.vulkan, C.int(i)) == 0) // 0 means supported + gpuInfo.FlashAttention = (C.vk_check_flash_attention(*vHandles.vulkan, C.int(i)) == 1) // 1 means supported gpuInfo.TotalMemory = uint64(memInfo.total) gpuInfo.FreeMemory = uint64(memInfo.free) gpuInfo.ID = C.GoString(&memInfo.gpu_id[0]) From f6dd7070deba990c560accebf431140c970d4cb6 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sun, 10 Aug 2025 21:22:26 +0200 Subject: [PATCH 052/172] vk_check_flash_attention 0 means supported --- discover/gpu.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/discover/gpu.go b/discover/gpu.go index a77f8896b..13c607b18 100644 --- a/discover/gpu.go +++ b/discover/gpu.go @@ -423,7 +423,7 @@ func GetGPUInfo() GpuInfoList { continue } - gpuInfo.FlashAttention = (C.vk_check_flash_attention(*vHandles.vulkan, C.int(i)) == 1) // 1 means supported + gpuInfo.FlashAttention = (C.vk_check_flash_attention(*vHandles.vulkan, C.int(i)) == 0) // 0 means supported gpuInfo.TotalMemory = uint64(memInfo.total) gpuInfo.FreeMemory = uint64(memInfo.free) gpuInfo.ID = C.GoString(&memInfo.gpu_id[0]) From d1f74e17d47c822a33d9e80360c340dc35a9c961 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sun, 10 Aug 2025 21:28:59 +0200 Subject: [PATCH 053/172] Update gpu.go --- discover/gpu.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/discover/gpu.go b/discover/gpu.go index 13c607b18..bd3bd20df 100644 --- a/discover/gpu.go +++ b/discover/gpu.go @@ -423,7 +423,8 @@ func GetGPUInfo() GpuInfoList { continue } - gpuInfo.FlashAttention = (C.vk_check_flash_attention(*vHandles.vulkan, C.int(i)) == 0) // 0 means supported + // gpuInfo.FlashAttention = (C.vk_check_flash_attention(*vHandles.vulkan, C.int(i)) == 0) // 0 means supported + gpuInfo.FlashAttention = true gpuInfo.TotalMemory = uint64(memInfo.total) gpuInfo.FreeMemory = uint64(memInfo.free) gpuInfo.ID = C.GoString(&memInfo.gpu_id[0]) From e3627b2832ba96fc0b919a0c2b5a46c8c78c8042 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Mon, 11 Aug 2025 18:39:10 +0200 Subject: [PATCH 054/172] Add vulkan to Windows Build script --- scripts/build_windows.ps1 | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/scripts/build_windows.ps1 b/scripts/build_windows.ps1 index 27f3eb9d4..057b64c90 100644 --- a/scripts/build_windows.ps1 +++ b/scripts/build_windows.ps1 @@ -128,6 +128,15 @@ function buildOllama() { & cmake --install build --component "HIP" --strip if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} } + if ($env:VULKAN_SDK) { + write-host "Building Vulkan backend libraries" + & cmake --fresh --preset Vulkan --install-prefix $script:DIST_DIR + if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} + & cmake --build --preset Vulkan --config Release --parallel $script:JOBS + if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} + & cmake --install build --component Vulkan --strip + if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} + } } write-host "Building ollama CLI" & go build -trimpath -ldflags "-s -w -X=github.com/ollama/ollama/version.Version=$script:VERSION -X=github.com/ollama/ollama/server.mode=release" . From 0c27f472e7668e3773df22d4468989523ea97670 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Mon, 11 Aug 2025 18:52:43 +0200 Subject: [PATCH 055/172] Remove commented out code --- discover/gpu.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/discover/gpu.go b/discover/gpu.go index bd3bd20df..8a36db854 100644 --- a/discover/gpu.go +++ b/discover/gpu.go @@ -422,8 +422,7 @@ func GetGPUInfo() GpuInfoList { C.free(unsafe.Pointer(memInfo.err)) continue } - - // gpuInfo.FlashAttention = (C.vk_check_flash_attention(*vHandles.vulkan, C.int(i)) == 0) // 0 means supported + gpuInfo.FlashAttention = true gpuInfo.TotalMemory = uint64(memInfo.total) gpuInfo.FreeMemory = uint64(memInfo.free) From 49c4d154ae1f258d260a898fc1bd23013995deff Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Tue, 12 Aug 2025 21:55:19 +0200 Subject: [PATCH 056/172] Enable Vulkan Flash attention in FlashAttentionSupported --- discover/gpu.go | 3 +-- discover/types.go | 3 ++- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/discover/gpu.go b/discover/gpu.go index 8a36db854..123177d3a 100644 --- a/discover/gpu.go +++ b/discover/gpu.go @@ -422,8 +422,7 @@ func GetGPUInfo() GpuInfoList { C.free(unsafe.Pointer(memInfo.err)) continue } - - gpuInfo.FlashAttention = true + gpuInfo.TotalMemory = uint64(memInfo.total) gpuInfo.FreeMemory = uint64(memInfo.free) gpuInfo.ID = C.GoString(&memInfo.gpu_id[0]) diff --git a/discover/types.go b/discover/types.go index 183c51ae2..39830dc5a 100644 --- a/discover/types.go +++ b/discover/types.go @@ -182,7 +182,8 @@ func (l GpuInfoList) FlashAttentionSupported() bool { supportsFA := gpu.Library == "cpu" || gpu.Library == "metal" || (gpu.Library == "cuda" && gpu.DriverMajor >= 7) || - gpu.Library == "rocm" + gpu.Library == "rocm" || + gpu.Library == "vulkan" if !supportsFA { return false From 56050ad8ea86eaa410948689c297bfbf85bc5198 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Thu, 14 Aug 2025 22:42:30 +0200 Subject: [PATCH 057/172] Fix logging --- discover/gpu.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/discover/gpu.go b/discover/gpu.go index 0ca0dde8d..72011e699 100644 --- a/discover/gpu.go +++ b/discover/gpu.go @@ -772,7 +772,12 @@ func LoadVulkanMgmt(vulkanLibPaths []string, capLibPaths []string) (int, *C.vk_h C.vk_init(vkLib, capLib, &resp) if resp.err != nil { - slog.Error("Unable to load vulkan", "library", vkLibPath, capLibPath, "error", C.GoString(resp.err)) + slog.Error( + "Unable to load vulkan", + "vulkan_library", vkLibPath, + "cap_library", capLibPath, + "error", C.GoString(resp.err), + ) C.free(unsafe.Pointer(resp.err)) } else { return int(resp.num_devices), &resp.ch, vkLibPath, capLibPath From 834a66689e365fa3f40569269c616c5bd1dd3938 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Fri, 15 Aug 2025 00:18:18 +0200 Subject: [PATCH 058/172] Update Vulkan backend to e54d41befcc1575f4c898c5ff4ef43970cead75f --- .../ggml/ggml/src/ggml-vulkan/CMakeLists.txt | 202 +- .../ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp | 2919 +++++++++++++---- .../ggml-vulkan/vulkan-shaders/CMakeLists.txt | 9 + .../ggml-vulkan/vulkan-shaders/add_id.comp | 42 + .../ggml-vulkan/vulkan-shaders/conv2d_mm.comp | 329 ++ .../vulkan-shaders/conv_transpose_1d.comp | 98 + .../vulkan-shaders/copy_from_quant.comp | 4 +- .../vulkan-shaders/copy_to_quant.comp | 75 +- .../vulkan-shaders/dequant_funcs.comp | 18 + .../vulkan-shaders/dequant_funcs_cm2.comp | 21 + .../vulkan-shaders/dequant_iq1_m.comp | 2 +- .../vulkan-shaders/dequant_mxfp4.comp | 32 + .../vulkan-shaders/dequant_q2_k.comp | 2 +- .../vulkan-shaders/dequant_q3_k.comp | 2 +- .../vulkan-shaders/dequant_q4_k.comp | 2 +- .../vulkan-shaders/dequant_q5_k.comp | 2 +- .../vulkan-shaders/dequant_q6_k.comp | 2 +- .../vulkan-shaders/flash_attn.comp | 230 +- .../vulkan-shaders/flash_attn_base.comp | 178 + .../vulkan-shaders/flash_attn_cm1.comp | 387 +++ .../vulkan-shaders/flash_attn_cm2.comp | 207 +- .../flash_attn_split_k_reduce.comp | 83 +- .../src/ggml-vulkan/vulkan-shaders/geglu.comp | 13 + .../ggml-vulkan/vulkan-shaders/geglu_erf.comp | 27 + .../vulkan-shaders/geglu_quick.comp | 11 + .../ggml-vulkan/vulkan-shaders/gelu_erf.comp | 39 + .../vulkan-shaders/generic_binary_head.comp | 2 + .../ggml-vulkan/vulkan-shaders/glu_head.comp | 19 + .../ggml-vulkan/vulkan-shaders/glu_main.comp | 29 + .../ggml-vulkan/vulkan-shaders/im2col.comp | 11 +- .../vulkan-shaders/mul_mat_vec_nc.comp | 18 +- .../ggml-vulkan/vulkan-shaders/mul_mm.comp | 209 +- .../vulkan-shaders/mul_mm_cm2.comp | 47 +- .../vulkan-shaders/mul_mmq_funcs.comp | 6 + .../src/ggml-vulkan/vulkan-shaders/reglu.comp | 9 + .../ggml-vulkan/vulkan-shaders/rms_norm.comp | 21 +- .../src/ggml-vulkan/vulkan-shaders/roll.comp | 46 + .../ggml-vulkan/vulkan-shaders/rope_head.comp | 5 +- .../vulkan-shaders/rope_multi.comp | 16 +- .../ggml-vulkan/vulkan-shaders/rope_neox.comp | 16 +- .../ggml-vulkan/vulkan-shaders/rope_norm.comp | 16 +- .../src/ggml-vulkan/vulkan-shaders/rte.comp | 5 + .../src/ggml-vulkan/vulkan-shaders/scale.comp | 2 +- .../ggml-vulkan/vulkan-shaders/soft_max.comp | 36 +- .../ggml-vulkan/vulkan-shaders/swiglu.comp | 9 + .../vulkan-shaders/swiglu_oai.comp | 14 + .../src/ggml-vulkan/vulkan-shaders/types.comp | 55 + .../ggml-vulkan/vulkan-shaders/upscale.comp | 74 +- .../vulkan-shaders/vulkan-shaders-gen.cpp | 124 +- 49 files changed, 4407 insertions(+), 1318 deletions(-) create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/CMakeLists.txt b/ml/backend/ggml/ggml/src/ggml-vulkan/CMakeLists.txt index 31816219c..b97e7bf99 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/CMakeLists.txt +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/CMakeLists.txt @@ -15,6 +15,32 @@ function(detect_host_compiler) set(HOST_CXX_COMPILER "${HOST_CXX_COMPILER}" PARENT_SCOPE) endfunction() +# Function to test shader extension support +# Parameters: +# EXTENSION_NAME - Name of the extension to test (e.g., "GL_EXT_integer_dot_product") +# TEST_SHADER_FILE - Path to the test shader file +# RESULT_VARIABLE - Name of the variable to set (ON/OFF) based on test result +function(test_shader_extension_support EXTENSION_NAME TEST_SHADER_FILE RESULT_VARIABLE) + execute_process( + COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${TEST_SHADER_FILE}" + OUTPUT_VARIABLE glslc_output + ERROR_VARIABLE glslc_error + ) + + if (${glslc_error} MATCHES ".*extension not supported: ${EXTENSION_NAME}.*") + message(STATUS "${EXTENSION_NAME} not supported by glslc") + set(${RESULT_VARIABLE} OFF PARENT_SCOPE) + else() + message(STATUS "${EXTENSION_NAME} supported by glslc") + set(${RESULT_VARIABLE} ON PARENT_SCOPE) + add_compile_definitions(${RESULT_VARIABLE}) + + # Ensure the extension support is forwarded to vulkan-shaders-gen + list(APPEND VULKAN_SHADER_GEN_CMAKE_ARGS -D${RESULT_VARIABLE}=ON) + set(VULKAN_SHADER_GEN_CMAKE_ARGS "${VULKAN_SHADER_GEN_CMAKE_ARGS}" PARENT_SCOPE) + endif() +endfunction() + if (Vulkan_FOUND) message(STATUS "Vulkan found") @@ -23,69 +49,32 @@ if (Vulkan_FOUND) ../../include/ggml-vulkan.h ) - # Compile a test shader to determine whether GL_KHR_cooperative_matrix is supported. - # If it's not, there will be an error to stderr. - # If it's supported, set a define to indicate that we should compile those shaders - execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat_support.comp" - OUTPUT_VARIABLE glslc_output - ERROR_VARIABLE glslc_error) + set(VULKAN_SHADER_GEN_CMAKE_ARGS "") - if (${glslc_error} MATCHES ".*extension not supported: GL_KHR_cooperative_matrix.*") - message(STATUS "GL_KHR_cooperative_matrix not supported by glslc") - set(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT OFF) - else() - message(STATUS "GL_KHR_cooperative_matrix supported by glslc") - set(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT ON) - add_compile_definitions(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) - endif() + # Test all shader extensions + test_shader_extension_support( + "GL_KHR_cooperative_matrix" + "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat_support.comp" + "GGML_VULKAN_COOPMAT_GLSLC_SUPPORT" + ) - # Compile a test shader to determine whether GL_NV_cooperative_matrix2 is supported. - # If it's not, there will be an error to stderr. - # If it's supported, set a define to indicate that we should compile those shaders - execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat2_support.comp" - OUTPUT_VARIABLE glslc_output - ERROR_VARIABLE glslc_error) + test_shader_extension_support( + "GL_NV_cooperative_matrix2" + "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat2_support.comp" + "GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT" + ) - if (${glslc_error} MATCHES ".*extension not supported: GL_NV_cooperative_matrix2.*") - message(STATUS "GL_NV_cooperative_matrix2 not supported by glslc") - set(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT OFF) - else() - message(STATUS "GL_NV_cooperative_matrix2 supported by glslc") - set(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT ON) - add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) - endif() + test_shader_extension_support( + "GL_EXT_integer_dot_product" + "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_integer_dot_support.comp" + "GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT" + ) - # Compile a test shader to determine whether GL_EXT_integer_dot_product is supported. - # If it's not, there will be an error to stderr. - # If it's supported, set a define to indicate that we should compile those shaders - execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_integer_dot_support.comp" - OUTPUT_VARIABLE glslc_output - ERROR_VARIABLE glslc_error) - - if (${glslc_error} MATCHES ".*extension not supported: GL_EXT_integer_dot_product.*") - message(STATUS "GL_EXT_integer_dot_product not supported by glslc") - set(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT OFF) - else() - message(STATUS "GL_EXT_integer_dot_product supported by glslc") - set(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT ON) - add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) - endif() - - # Compile a test shader to determine whether GL_EXT_bfloat16 is supported. - # If it's not, there will be an error to stderr. - # If it's supported, set a define to indicate that we should compile those shaders - execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_bfloat16_support.comp" - OUTPUT_VARIABLE glslc_output - ERROR_VARIABLE glslc_error) - - if (${glslc_error} MATCHES ".*extension not supported: GL_EXT_bfloat16.*") - message(STATUS "GL_EXT_bfloat16 not supported by glslc") - set(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT OFF) - else() - message(STATUS "GL_EXT_bfloat16 supported by glslc") - set(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT ON) - add_compile_definitions(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) - endif() + test_shader_extension_support( + "GL_EXT_bfloat16" + "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_bfloat16_support.comp" + "GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT" + ) target_link_libraries(ggml-vulkan PRIVATE Vulkan::Vulkan) target_include_directories(ggml-vulkan PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) @@ -110,10 +99,7 @@ if (Vulkan_FOUND) if (GGML_VULKAN_SHADER_DEBUG_INFO) add_compile_definitions(GGML_VULKAN_SHADER_DEBUG_INFO) - endif() - - if (GGML_VULKAN_PERF) - add_compile_definitions(GGML_VULKAN_PERF) + list(APPEND VULKAN_SHADER_GEN_CMAKE_ARGS -DGGML_VULKAN_SHADER_DEBUG_INFO=ON) endif() if (GGML_VULKAN_VALIDATE) @@ -124,16 +110,8 @@ if (Vulkan_FOUND) add_compile_definitions(GGML_VULKAN_RUN_TESTS) endif() - if (NOT CMAKE_CROSSCOMPILING) - add_subdirectory(vulkan-shaders) - if (MSVC) - foreach(CONFIG ${CMAKE_CONFIGURATION_TYPES}) - string(TOUPPER ${CONFIG} CONFIG) - set_target_properties(vulkan-shaders-gen PROPERTIES - RUNTIME_OUTPUT_DIRECTORY_${CONFIG} ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}) - endforeach() - endif() - else() + # Set up toolchain for host compilation whether cross-compiling or not + if (CMAKE_CROSSCOMPILING) if (GGML_VULKAN_SHADERS_GEN_TOOLCHAIN) set(HOST_CMAKE_TOOLCHAIN_FILE ${GGML_VULKAN_SHADERS_GEN_TOOLCHAIN}) else() @@ -146,42 +124,59 @@ if (Vulkan_FOUND) configure_file(${CMAKE_CURRENT_SOURCE_DIR}/cmake/host-toolchain.cmake.in ${CMAKE_BINARY_DIR}/host-toolchain.cmake @ONLY) set(HOST_CMAKE_TOOLCHAIN_FILE ${CMAKE_BINARY_DIR}/host-toolchain.cmake) endif() - message(STATUS "vulkan-shaders-gen toolchain file: ${HOST_CMAKE_TOOLCHAIN_FILE}") - - include(ExternalProject) - # Native build through ExternalProject_Add - ExternalProject_Add( - vulkan-shaders-gen - SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders - CMAKE_ARGS -DCMAKE_TOOLCHAIN_FILE=${HOST_CMAKE_TOOLCHAIN_FILE} - -DCMAKE_INSTALL_PREFIX=${CMAKE_BINARY_DIR} - -DGGML_VULKAN_COOPMAT_GLSLC_SUPPORT=${GGML_VULKAN_COOPMAT_GLSLC_SUPPORT} - -DGGML_VULKAN_COOPMAT2_GLSLC_SUPPORT=${GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT} - -DGGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT=${GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT} - -DGGML_VULKAN_BFLOAT16_GLSLC_SUPPORT=${GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT} - BUILD_COMMAND ${CMAKE_COMMAND} --build . - INSTALL_COMMAND ${CMAKE_COMMAND} --install . - INSTALL_DIR ${CMAKE_BINARY_DIR} - ) - ExternalProject_Add_StepTargets(vulkan-shaders-gen build install) + else() + # For non-cross-compiling, use empty toolchain (use host compiler) + set(HOST_CMAKE_TOOLCHAIN_FILE "") endif() - set (_ggml_vk_host_suffix $,.exe,>) - set (_ggml_vk_genshaders_cmd ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/vulkan-shaders-gen${_ggml_vk_host_suffix}) - set (_ggml_vk_header ${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.hpp) - set (_ggml_vk_source ${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.cpp) - set (_ggml_vk_input_dir ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders) - set (_ggml_vk_output_dir ${CMAKE_CURRENT_BINARY_DIR}/vulkan-shaders.spv) - file(GLOB _ggml_vk_shader_deps "${_ggml_vk_input_dir}/*.comp") - set (_ggml_vk_shader_deps ${_ggml_vk_shader_deps} vulkan-shaders-gen) + include(ExternalProject) if (CMAKE_CROSSCOMPILING) - set(_ggml_vk_shader_deps ${_ggml_vk_shader_deps} vulkan-shaders-gen-build vulkan-shaders-gen-install) + list(APPEND VULKAN_SHADER_GEN_CMAKE_ARGS -DCMAKE_TOOLCHAIN_FILE=${HOST_CMAKE_TOOLCHAIN_FILE}) + message(STATUS "vulkan-shaders-gen toolchain file: ${HOST_CMAKE_TOOLCHAIN_FILE}") endif() + ExternalProject_Add( + vulkan-shaders-gen + SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders + CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${CMAKE_BINARY_DIR}/$ + -DCMAKE_INSTALL_BINDIR=. + -DCMAKE_BUILD_TYPE=$ + ${VULKAN_SHADER_GEN_CMAKE_ARGS} + + BUILD_COMMAND ${CMAKE_COMMAND} --build . --config $ + BUILD_ALWAYS TRUE + + # NOTE: When DESTDIR is set using Makefile generators and + # "make install" triggers the build step, vulkan-shaders-gen + # would be installed into the DESTDIR prefix, so it is unset + # to ensure that does not happen. + + INSTALL_COMMAND ${CMAKE_COMMAND} -E env --unset=DESTDIR + ${CMAKE_COMMAND} --install . --config $ + ) + + set (_ggml_vk_host_suffix $,.exe,>) + set (_ggml_vk_genshaders_dir "${CMAKE_BINARY_DIR}/$") + set (_ggml_vk_genshaders_cmd "${_ggml_vk_genshaders_dir}/vulkan-shaders-gen${_ggml_vk_host_suffix}") + set (_ggml_vk_header "${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.hpp") + set (_ggml_vk_source "${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.cpp") + set (_ggml_vk_input_dir "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders") + set (_ggml_vk_output_dir "${CMAKE_CURRENT_BINARY_DIR}/vulkan-shaders.spv") + + file(GLOB _ggml_vk_shader_files CONFIGURE_DEPENDS "${_ggml_vk_input_dir}/*.comp") + + # Because external projects do not provide source-level tracking, + # the vulkan-shaders-gen sources need to be explicitly added to + # ensure that changes will cascade into shader re-generation. + + file(GLOB _ggml_vk_shaders_gen_sources + CONFIGURE_DEPENDS "${_ggml_vk_input_dir}/*.cpp" + "${_ggml_vk_input_dir}/*.h") + add_custom_command( OUTPUT ${_ggml_vk_header} - ${_ggml_vk_source} + ${_ggml_vk_source} COMMAND ${_ggml_vk_genshaders_cmd} --glslc ${Vulkan_GLSLC_EXECUTABLE} @@ -191,7 +186,10 @@ if (Vulkan_FOUND) --target-cpp ${_ggml_vk_source} --no-clean - DEPENDS ${_ggml_vk_shader_deps} + DEPENDS ${_ggml_vk_shader_files} + ${_ggml_vk_shaders_gen_sources} + vulkan-shaders-gen + COMMENT "Generate vulkan shaders" ) diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp index e2b357fdc..4070e248b 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1,6 +1,6 @@ #include "ggml-vulkan.h" #include -#if defined(GGML_VULKAN_RUN_TESTS) || defined(GGML_VULKAN_PERF) || defined(GGML_VULKAN_CHECK_RESULTS) +#if defined(GGML_VULKAN_RUN_TESTS) || defined(GGML_VULKAN_CHECK_RESULTS) #include #include "ggml-cpu.h" #endif @@ -78,7 +78,7 @@ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; } #define VK_VENDOR_ID_INTEL 0x8086 #define VK_VENDOR_ID_NVIDIA 0x10de -#define VK_DEVICE_DESCRIPTOR_POOL_SIZE 32 +#define VK_DEVICE_DESCRIPTOR_POOL_SIZE 256 #define GGML_VK_MAX_NODES 8192 @@ -102,25 +102,11 @@ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; } struct ggml_backend_vk_context; -struct vk_queue { - uint32_t queue_family_index; - vk::Queue queue; - vk::CommandPool pool; - uint32_t cmd_buffer_idx; - std::vector cmd_buffers; - - vk::PipelineStageFlags stage_flags; - - bool transfer_only; -}; +#define MAX_PARAMETER_COUNT 8 struct vk_pipeline_struct { std::string name; vk::ShaderModule shader_module; - vk::DescriptorSetLayout dsl; - std::vector descriptor_pools; - std::vector descriptor_sets; - uint32_t descriptor_set_idx; vk::PipelineLayout layout; vk::Pipeline pipeline; uint32_t push_constant_size; @@ -167,6 +153,45 @@ struct ggml_backend_vk_buffer_type_context { vk_device device; }; +struct vk_queue; + +// Stores command pool/buffers. There's an instance of this +// for each (context,queue) pair and for each (device,queue) pair. +struct vk_command_pool { + void init(vk_device& device, vk_queue *q_); + void destroy(vk::Device& device); + + vk::CommandPool pool; + uint32_t cmd_buffer_idx; + std::vector cmd_buffers; + + vk_queue *q; +}; + +// Prevent simultaneous submissions to the same queue. +// This could be per vk_queue if we stopped having two vk_queue structures +// sharing the same vk::Queue. +static std::mutex queue_mutex; + +struct vk_queue { + uint32_t queue_family_index; + vk::Queue queue; + + vk_command_pool cmd_pool; + + vk::PipelineStageFlags stage_flags; + + bool transfer_only; + + // copy everything except the cmd_pool + void copyFrom(vk_queue &other) { + queue_family_index = other.queue_family_index; + queue = other.queue; + stage_flags = other.stage_flags; + transfer_only = other.transfer_only; + } +}; + static const char * ggml_backend_vk_buffer_type_name(ggml_backend_buffer_type_t buft); static ggml_backend_buffer_t ggml_backend_vk_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size); static size_t ggml_backend_vk_buffer_type_get_alignment(ggml_backend_buffer_type_t buft); @@ -184,9 +209,7 @@ static ggml_backend_buffer_type_i ggml_backend_vk_buffer_type_interface = { #ifdef GGML_VULKAN_MEMORY_DEBUG class vk_memory_logger; #endif -#ifdef GGML_VULKAN_PERF class vk_perf_logger; -#endif static void ggml_vk_destroy_buffer(vk_buffer& buf); static constexpr uint32_t mul_mat_vec_max_cols = 8; @@ -198,6 +221,23 @@ enum vk_device_architecture { AMD_RDNA1, AMD_RDNA2, AMD_RDNA3, + INTEL_XE2, + NVIDIA_PRE_TURING, +}; + +// HSK x HSV +enum FaHeadSizes { + FA_HEAD_SIZE_64, + FA_HEAD_SIZE_80, + FA_HEAD_SIZE_96, + FA_HEAD_SIZE_112, + FA_HEAD_SIZE_128, + FA_HEAD_SIZE_192, + FA_HEAD_SIZE_192_128, + FA_HEAD_SIZE_256, + FA_HEAD_SIZE_576_512, + FA_HEAD_SIZE_UNSUPPORTED, + FA_HEAD_SIZE_COUNT = FA_HEAD_SIZE_UNSUPPORTED, }; static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) { @@ -248,12 +288,63 @@ static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& } return vk_device_architecture::AMD_RDNA2; } + } else if (props.vendorID == VK_VENDOR_ID_INTEL) { + const std::vector ext_props = device.enumerateDeviceExtensionProperties(); + + bool subgroup_size_control = false; + + for (const auto& properties : ext_props) { + if (strcmp("VK_EXT_subgroup_size_control", properties.extensionName) == 0) { + subgroup_size_control = true; + } + } + + if (!subgroup_size_control) { + return vk_device_architecture::OTHER; + } + + vk::PhysicalDeviceProperties2 props2; + vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props; + + props2.pNext = &subgroup_size_control_props; + device.getProperties2(&props2); + + if (subgroup_size_control_props.minSubgroupSize == 16) { + // Xe2 architecture uses SIMD16 while previous Xe and Gen architecture uses SIMD8. + // Minimum subgroup size matches the SIMD width so we distinguish architecture by checking this value. + // https://www.intel.com/content/www/us/en/content-details/824434/2024-intel-tech-tour-xe2-and-lunar-lake-s-gpu.html + // https://www.intel.com/content/www/us/en/docs/oneapi/optimization-guide-gpu/2025-0/intel-xe-gpu-architecture.html + return vk_device_architecture::INTEL_XE2; + } + } else if (props.vendorID == VK_VENDOR_ID_NVIDIA) { + const std::vector ext_props = device.enumerateDeviceExtensionProperties(); + + bool cooperative_matrix = false; + + // Detect "pre-turing" based on lack of coopmat support. + for (const auto& properties : ext_props) { + if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0) { + cooperative_matrix = true; + break; + } + } + + if (!cooperative_matrix) { + return vk_device_architecture::NVIDIA_PRE_TURING; + } } return vk_device_architecture::OTHER; } +enum vk_conv_shapes { + CONV_SHAPE_128x128, + CONV_SHAPE_64x32, + CONV_SHAPE_32x256, + CONV_SHAPE_COUNT, +}; + struct vk_device_struct { - std::mutex mutex; + std::recursive_mutex mutex; vk::PhysicalDevice physical_device; vk::PhysicalDeviceProperties properties; @@ -261,6 +352,7 @@ struct vk_device_struct { uint64_t max_memory_allocation_size; uint64_t suballocation_block_size; bool fp16; + bool bf16; bool pipeline_robustness; vk::Device device; uint32_t vendor_id; @@ -288,6 +380,9 @@ struct vk_device_struct { bool coopmat_acc_f32_support {}; bool coopmat_acc_f16_support {}; bool coopmat_bf16_support {}; + bool coopmat_support_16x16x16_f16acc {}; + bool coopmat_support_16x16x16_f32acc {}; + bool coopmat1_fa_support {}; uint32_t coopmat_m; uint32_t coopmat_n; uint32_t coopmat_k; @@ -311,6 +406,8 @@ struct vk_device_struct { // set to true to indicate that some shaders need to be compiled after the dryrun bool need_compiles {}; + vk::DescriptorSetLayout dsl; + vk_matmul_pipeline pipeline_matmul_f32 {}; vk_matmul_pipeline pipeline_matmul_f32_f16 {}; vk_matmul_pipeline pipeline_matmul_bf16 {}; @@ -352,33 +449,46 @@ struct vk_device_struct { vk_pipeline pipeline_div[2][2][2]; vk_pipeline pipeline_div_norepeat[2][2][2]; + vk_pipeline pipeline_add_id_f32; + vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32; - vk_pipeline pipeline_upscale_f32; + vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bilinear_ac_f32; vk_pipeline pipeline_scale_f32; vk_pipeline pipeline_sqr_f32; vk_pipeline pipeline_sin_f32; vk_pipeline pipeline_cos_f32; vk_pipeline pipeline_clamp_f32; vk_pipeline pipeline_pad_f32; + vk_pipeline pipeline_roll_f32; vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32; vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16; vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16; vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT]; vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT]; + vk_pipeline pipeline_set_rows[GGML_TYPE_COUNT]; vk_pipeline pipeline_norm_f32; vk_pipeline pipeline_group_norm_f32; vk_pipeline pipeline_rms_norm_f32; + vk_pipeline pipeline_rms_norm_mul_f32; vk_pipeline pipeline_rms_norm_back_f32; vk_pipeline pipeline_l2_norm_f32; // [src/dst 0=fp32,1=fp16] vk_pipeline pipeline_gelu[2]; + vk_pipeline pipeline_gelu_erf[2]; vk_pipeline pipeline_gelu_quick[2]; vk_pipeline pipeline_silu[2]; vk_pipeline pipeline_relu[2]; vk_pipeline pipeline_tanh[2]; vk_pipeline pipeline_sigmoid[2]; + vk_pipeline pipeline_geglu[2]; + vk_pipeline pipeline_reglu[2]; + vk_pipeline pipeline_swiglu[2]; + vk_pipeline pipeline_swiglu_oai[2]; + vk_pipeline pipeline_geglu_erf[2]; + vk_pipeline pipeline_geglu_quick[2]; + vk_pipeline pipeline_leaky_relu_f32; vk_pipeline pipeline_silu_back_f32; vk_pipeline pipeline_diag_mask_inf_f32; @@ -395,32 +505,26 @@ struct vk_device_struct { vk_pipeline pipeline_count_equal_i32; vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16; vk_pipeline pipeline_timestep_embedding_f32; + vk_pipeline pipeline_conv_transpose_1d_f32; vk_pipeline pipeline_pool2d_f32; vk_pipeline pipeline_rwkv_wkv6_f32; vk_pipeline pipeline_rwkv_wkv7_f32; vk_pipeline pipeline_opt_step_adamw_f32; + vk_pipeline pipeline_conv2d_f32[CONV_SHAPE_COUNT]; + vk_pipeline pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT]; vk_pipeline pipeline_conv2d_dw_whcn_f32; vk_pipeline pipeline_conv2d_dw_cwhn_f32; // [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned} - vk_pipeline pipeline_flash_attn_f32_f16_D64_cm2[GGML_TYPE_COUNT][2][2][2]; - vk_pipeline pipeline_flash_attn_f32_f16_D80_cm2[GGML_TYPE_COUNT][2][2][2]; - vk_pipeline pipeline_flash_attn_f32_f16_D96_cm2[GGML_TYPE_COUNT][2][2][2]; - vk_pipeline pipeline_flash_attn_f32_f16_D112_cm2[GGML_TYPE_COUNT][2][2][2]; - vk_pipeline pipeline_flash_attn_f32_f16_D128_cm2[GGML_TYPE_COUNT][2][2][2]; - vk_pipeline pipeline_flash_attn_f32_f16_D256_cm2[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_cm2[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2]; - vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2]; - vk_pipeline pipeline_flash_attn_f32_f16_D80[GGML_TYPE_COUNT][2][2][2]; - vk_pipeline pipeline_flash_attn_f32_f16_D96[GGML_TYPE_COUNT][2][2][2]; - vk_pipeline pipeline_flash_attn_f32_f16_D112[GGML_TYPE_COUNT][2][2][2]; - vk_pipeline pipeline_flash_attn_f32_f16_D128[GGML_TYPE_COUNT][2][2][2]; - vk_pipeline pipeline_flash_attn_f32_f16_D256[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_cm1[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2]; + + vk_pipeline pipeline_flash_attn_f32_f16[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2]; vk_pipeline pipeline_flash_attn_split_k_reduce; std::unordered_map pipelines; - std::unordered_map pipeline_descriptor_set_requirements; std::vector> pinned_memory; @@ -429,12 +533,17 @@ struct vk_device_struct { ggml_backend_buffer_type buffer_type; + bool disable_fusion; + bool disable_host_visible_vidmem; + #ifdef GGML_VULKAN_MEMORY_DEBUG std::unique_ptr memory_logger; #endif -#ifdef GGML_VULKAN_PERF + + // for GGML_VK_PERF_LOGGER std::unique_ptr perf_logger; -#endif + vk::QueryPool query_pool; + int32_t num_queries; ~vk_device_struct() { VK_LOG_DEBUG("destroy device " << name); @@ -443,10 +552,8 @@ struct vk_device_struct { ggml_vk_destroy_buffer(sync_staging); - device.destroyCommandPool(compute_queue.pool); - if (!single_queue) { - device.destroyCommandPool(transfer_queue.pool); - } + compute_queue.cmd_pool.destroy(device); + transfer_queue.cmd_pool.destroy(device); for (auto& pipeline : pipelines) { if (pipeline.second.expired()) { @@ -458,10 +565,26 @@ struct vk_device_struct { } pipelines.clear(); + device.destroyDescriptorSetLayout(dsl); + device.destroy(); } }; +void vk_command_pool::init(vk_device& device, vk_queue *q_) { + cmd_buffer_idx = 0; + q = q_; + + vk::CommandPoolCreateInfo command_pool_create_info(vk::CommandPoolCreateFlags(VK_COMMAND_POOL_CREATE_TRANSIENT_BIT), q->queue_family_index); + pool = device->device.createCommandPool(command_pool_create_info); +} + +void vk_command_pool::destroy(vk::Device& device) { + device.destroyCommandPool(pool); + pool = nullptr; + cmd_buffers.clear(); +} + struct vk_buffer_struct { vk::Buffer buffer = VK_NULL_HANDLE; vk::DeviceMemory device_memory = VK_NULL_HANDLE; @@ -547,6 +670,8 @@ struct vk_flash_attn_push_constants { uint32_t nev2; uint32_t nev3; uint32_t nem1; + uint32_t nem2; + uint32_t nem3; uint32_t nb01; uint32_t nb02; @@ -557,14 +682,12 @@ struct vk_flash_attn_push_constants { uint32_t nb21; uint32_t nb22; uint32_t nb23; - uint32_t nb31; float scale; float max_bias; float logit_softcap; - uint32_t mask; - uint32_t n_head_log2; + uint32_t mask_n_head_log2; float m0; float m1; @@ -572,6 +695,7 @@ struct vk_flash_attn_push_constants { uint32_t split_kv; uint32_t k_num; }; +static_assert(sizeof(vk_flash_attn_push_constants) <= 128, "sizeof(vk_flash_attn_push_constants) must be <= 128"); struct vk_op_push_constants { uint32_t KX; @@ -580,6 +704,15 @@ struct vk_op_push_constants { float param2; }; +struct vk_op_glu_push_constants { + uint32_t N; + uint32_t ne00; + uint32_t ne20; + uint32_t mode; // 0: default, 1: swapped, 2: split + float alpha; // for swiglu_oai + float limit; +}; + struct vk_op_unary_push_constants { uint32_t ne; uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03; @@ -595,6 +728,37 @@ struct vk_op_unary_push_constants { }; static_assert(sizeof(vk_op_unary_push_constants) <= 128, "sizeof(vk_op_unary_push_constants) must be <= 128"); +static vk_op_unary_push_constants vk_op_unary_push_constants_init(const ggml_tensor * src0, const ggml_tensor * dst, int64_t ne = 0) { + GGML_ASSERT(ne != 0 || (ggml_nelements(src0) == ggml_nelements(dst))); + ne = ne != 0 ? ne : ggml_nelements(dst); + GGML_ASSERT(ne <= (int64_t)std::numeric_limits::max()); + + vk_op_unary_push_constants p{}; + p.ne = (uint32_t)ne; + + size_t src0_tsize = ggml_type_size(src0->type); + p.ne00 = (uint32_t)src0->ne[0]; + p.ne01 = (uint32_t)src0->ne[1]; + p.ne02 = (uint32_t)src0->ne[2]; + p.ne03 = (uint32_t)src0->ne[3]; + p.nb00 = (uint32_t)(src0->nb[0] / src0_tsize); + p.nb01 = (uint32_t)(src0->nb[1] / src0_tsize); + p.nb02 = (uint32_t)(src0->nb[2] / src0_tsize); + p.nb03 = (uint32_t)(src0->nb[3] / src0_tsize); + + size_t dst_tsize = ggml_type_size(dst->type); + p.ne10 = (uint32_t)dst->ne[0]; + p.ne11 = (uint32_t)dst->ne[1]; + p.ne12 = (uint32_t)dst->ne[2]; + p.ne13 = (uint32_t)dst->ne[3]; + p.nb10 = (uint32_t)(dst->nb[0] / dst_tsize); + p.nb11 = (uint32_t)(dst->nb[1] / dst_tsize); + p.nb12 = (uint32_t)(dst->nb[2] / dst_tsize); + p.nb13 = (uint32_t)(dst->nb[3] / dst_tsize); + + return p; // fastdiv values and offsets are initialized later in ggml_vk_op +} + // See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1. // Precompute mp (m' in the paper) and L such that division // can be computed using a multiply (high 32b of 64b result) @@ -636,6 +800,15 @@ struct vk_op_binary_push_constants { float param1; float param2; int32_t param3; }; +struct vk_op_add_id_push_constants { + uint32_t ne0; + uint32_t ne1; + uint32_t s01; + uint32_t s02; + uint32_t s11; + uint32_t s21; +}; + struct vk_op_diag_mask_push_constants { uint32_t ncols; uint32_t rows_per_channel; @@ -663,12 +836,21 @@ struct vk_op_rope_push_constants { struct vk_op_soft_max_push_constants { uint32_t KX; uint32_t KY; + uint32_t ne00; + uint32_t ne01; + uint32_t ne02; + uint32_t ne12; + uint32_t ne13; + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; float scale; float max_bias; float m0; float m1; uint32_t n_head_log2; uint32_t nrows_x; + uint32_t has_sinks; }; struct vk_op_argsort_push_constants { @@ -696,6 +878,21 @@ struct vk_op_timestep_embedding_push_constants { uint32_t max_period; }; +struct vk_op_conv_transpose_1d_push_constants { + uint32_t Cout; + uint32_t Cin; + uint32_t K; + uint32_t L; + uint32_t KL; + + uint32_t nb01; + uint32_t nb02; + uint32_t nb11; + uint32_t nb1; + + int32_t s0; +}; + struct vk_op_pool2d_push_constants { uint32_t IW; uint32_t IH; uint32_t OW; uint32_t OH; @@ -721,6 +918,52 @@ struct vk_op_rwkv_wkv7_push_constants { uint32_t H; }; +struct vk_op_conv2d_push_constants { + uint32_t Cout; + uint32_t Cin; + uint32_t N; + + uint32_t KW; + uint32_t KH; + uint32_t W; + uint32_t H; + uint32_t OW; + uint32_t OH; + + uint32_t s0; + uint32_t s1; + uint32_t p0; + uint32_t p1; + uint32_t d0; + uint32_t d1; + + uint32_t nb01; + uint32_t nb02; + uint32_t nb03; + + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + + uint32_t nb1; + uint32_t nb2; + uint32_t nb3; + + // init_fastdiv_values constants for dividing by KW, KW*KH, OW, OW*OH + uint32_t KWmp; uint32_t KWL; + uint32_t KWKHmp; uint32_t KWKHL; + uint32_t OWmp; uint32_t OWL; + uint32_t OWOHmp; uint32_t OWOHL; +}; + +template <> void init_pushconst_fastdiv(vk_op_conv2d_push_constants &p) { + // Compute magic values to divide by KW, KW*KH, OW, OW*OH + init_fastdiv_values(p.KW, p.KWmp, p.KWL); + init_fastdiv_values(p.KW*p.KH, p.KWKHmp, p.KWKHL); + init_fastdiv_values(p.OW, p.OWmp, p.OWL); + init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL); +} + struct vk_op_conv2d_dw_push_constants { uint32_t ne; uint32_t batches; @@ -741,6 +984,7 @@ struct vk_op_conv2d_dw_push_constants { struct vk_op_upscale_push_constants { uint32_t ne; uint32_t a_offset; uint32_t d_offset; + uint32_t ne00; uint32_t ne01; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03; uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; float sf0; float sf1; float sf2; float sf3; @@ -764,7 +1008,7 @@ struct vk_context_struct { std::vector in_memcpys; std::vector out_memcpys; - vk_queue * q; + vk_command_pool * p {}; }; typedef std::shared_ptr vk_context; typedef std::weak_ptr vk_context_ref; @@ -818,21 +1062,46 @@ private: #define VK_LOG_MEMORY(msg) ((void) 0) #endif // GGML_VULKAN_MEMORY_DEBUG -#if defined(GGML_VULKAN_PERF) - class vk_perf_logger { -public: + public: void print_timings() { + if (timings.empty()) { + return; + } + uint64_t total_all_op_times = 0; std::cerr << "----------------\nVulkan Timings:" << std::endl; - for (const auto& t : timings) { - uint64_t total = 0; - for (const auto& time : t.second) { - total += time; + for (const auto & t : timings) { + uint64_t total_op_times = 0; + for (const auto & time : t.second) { + total_op_times += time; } - std::cerr << t.first << ": " << t.second.size() << " x " << (total / t.second.size() / 1000.0) << " ms" << std::endl; + std::cerr << t.first << ": " << t.second.size() << " x " << (total_op_times / t.second.size() / 1000.0) + << " us"; + + // If we have as many flops entries as timing entries for the op, then compute and log the flops/S. + auto it = flops.find(t.first); + if (it != flops.end() && (it->second).size() == t.second.size()) { + uint64_t total_op_flops = 0; + for (const auto & elem : it->second) { + total_op_flops += elem; + } + std::cerr << " (" + << (double(total_op_flops) / (1000.0 * 1000.0 * 1000.0)) / + (double(total_op_times) / (1000.0 * 1000.0 * 1000.0)) + << " GFLOPS/s)"; + } + + total_all_op_times += total_op_times; + + std::cerr << std::endl; + } + + if (timings.size() > 0) { + std::cerr << "Total time: " << total_all_op_times / 1000.0 << " us." << std::endl; } timings.clear(); + flops.clear(); } void log_timing(const ggml_tensor * node, uint64_t time) { @@ -841,24 +1110,46 @@ public: return; } if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) { - const uint64_t m = node->src[0]->ne[1]; - const uint64_t n = node->src[1]->ne[1]; - const uint64_t k = node->src[1]->ne[0]; - std::string name = ggml_op_name(node->op); + const uint64_t m = node->src[0]->ne[1]; + const uint64_t n = node->src[1]->ne[1]; + const uint64_t k = node->src[1]->ne[0]; + std::string name = ggml_op_name(node->op); if (n == 1) { name += "_VEC m=" + std::to_string(m) + " k=" + std::to_string(k); } else { name += " m=" + std::to_string(m) + " n=" + std::to_string(n) + " k=" + std::to_string(k); } timings[name].push_back(time); + flops[name].push_back(m * n * (k + (k - 1))); + return; + } + if (node->op == GGML_OP_CONV_2D) { + std::string name = ggml_op_name(node->op); + ggml_tensor * knl = node->src[0]; + uint64_t OW = node->ne[0]; + uint64_t OH = node->ne[1]; + uint64_t N = node->ne[3]; + uint64_t Cout = node->ne[2]; + uint64_t KW = knl->ne[0]; + uint64_t KH = knl->ne[1]; + uint64_t Cin = knl->ne[2]; + // KxCRS @ CRSxNPQ = KxNPQ -> M=K, K=CRS, N=NPQ + uint64_t size_M = Cout; + uint64_t size_K = Cin * KW * KH; + uint64_t size_N = N * OW * OH; + uint64_t n_flops = size_M * size_N * (size_K + (size_K - 1)); + name += " M=Cout=" + std::to_string(size_M) + ", K=Cin*KW*KH=" + std::to_string(size_K) + + ", N=N*OW*OH=" + std::to_string(size_N); + flops[name].push_back(n_flops); + timings[name].push_back(time); return; } timings[ggml_op_name(node->op)].push_back(time); } -private: + private: std::map> timings; + std::map> flops; }; -#endif // GGML_VULKAN_PERF struct ggml_backend_vk_context { std::string name; @@ -878,6 +1169,18 @@ struct ggml_backend_vk_context { vk_context_ref transfer_ctx; std::vector tensor_ctxs; + + std::vector descriptor_pools; + std::vector descriptor_sets; + uint32_t descriptor_set_idx {}; + uint32_t pipeline_descriptor_set_requirements {}; + + vk_command_pool compute_cmd_pool; + vk_command_pool transfer_cmd_pool; + + // number of additional consecutive nodes that are being fused with the + // node currently being processed + int num_additional_fused_ops {}; }; static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT @@ -941,6 +1244,14 @@ void vk_memory_logger::log_deallocation(vk_buffer_ref buf_ref) { struct vk_instance_t { vk::Instance instance; + bool debug_utils_support = false; // VK_EXT_debug_utils enabled + PFN_vkSetDebugUtilsObjectNameEXT pfn_vkSetDebugUtilsObjectNameEXT = {}; + PFN_vkQueueBeginDebugUtilsLabelEXT pfn_vkQueueBeginDebugUtilsLabelEXT = {}; + PFN_vkQueueEndDebugUtilsLabelEXT pfn_vkQueueEndDebugUtilsLabelEXT = {}; + PFN_vkCmdBeginDebugUtilsLabelEXT pfn_vkCmdBeginDebugUtilsLabelEXT = {}; + PFN_vkCmdEndDebugUtilsLabelEXT pfn_vkCmdEndDebugUtilsLabelEXT = {}; + PFN_vkCmdInsertDebugUtilsLabelEXT pfn_vkCmdInsertDebugUtilsLabelEXT = {}; + std::vector device_indices; vk_device devices[GGML_VK_MAX_DEVICES]; }; @@ -948,13 +1259,15 @@ struct vk_instance_t { static bool vk_instance_initialized = false; static vk_instance_t vk_instance; +static bool vk_perf_logger_enabled = false; + #ifdef GGML_VULKAN_CHECK_RESULTS static size_t vk_skip_checks; static size_t vk_output_tensor; static void ggml_vk_print_tensor(const ggml_tensor * tensor, const char * name); -static void ggml_vk_check_results_0(ggml_tensor * tensor); -static void ggml_vk_check_results_1(ggml_tensor * tensor); +static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx); +static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx); #endif typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); @@ -1006,39 +1319,19 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << disable_robustness << ", " << require_full_subgroups << ", " << required_subgroup_size << ")"); GGML_ASSERT(parameter_count > 0); + GGML_ASSERT(parameter_count <= MAX_PARAMETER_COUNT); GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT vk::ShaderModuleCreateInfo shader_module_create_info({}, spv_size, reinterpret_cast(spv_data)); pipeline->shader_module = device->device.createShaderModule(shader_module_create_info); - std::vector dsl_binding; - std::vector dsl_binding_flags; - for (uint32_t i = 0; i < parameter_count; i++) { - dsl_binding.push_back({i, vk::DescriptorType::eStorageBuffer, 1, vk::ShaderStageFlagBits::eCompute}); - dsl_binding_flags.push_back({}); - } - - vk::DescriptorSetLayoutBindingFlagsCreateInfo dslbfci = { dsl_binding_flags }; - vk::PushConstantRange pcr( vk::ShaderStageFlagBits::eCompute, 0, pipeline->push_constant_size ); - vk::DescriptorSetLayoutCreateInfo descriptor_set_layout_create_info( - {}, - dsl_binding); - descriptor_set_layout_create_info.setPNext(&dslbfci); - pipeline->dsl = device->device.createDescriptorSetLayout(descriptor_set_layout_create_info); - - vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline->parameter_count * VK_DEVICE_DESCRIPTOR_POOL_SIZE); - vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, VK_DEVICE_DESCRIPTOR_POOL_SIZE, descriptor_pool_size); - pipeline->descriptor_pools.push_back(device->device.createDescriptorPool(descriptor_pool_create_info)); - - pipeline->descriptor_set_idx = 0; - - vk::PipelineLayoutCreateInfo pipeline_layout_create_info(vk::PipelineLayoutCreateFlags(), pipeline->dsl, pcr); + vk::PipelineLayoutCreateInfo pipeline_layout_create_info(vk::PipelineLayoutCreateFlags(), device->dsl, pcr); pipeline->layout = device->device.createPipelineLayout(pipeline_layout_create_info); std::vector specialization_entries(specialization_constants.size()); @@ -1098,8 +1391,16 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin } pipeline->compiled = true; + if (vk_instance.debug_utils_support) { + vk::DebugUtilsObjectNameInfoEXT duoni; + duoni.objectType = vk::ObjectType::ePipeline; + duoni.pObjectName = pipeline->name.c_str(); + duoni.objectHandle = /*reinterpret_cast*/(uint64_t)(static_cast(pipeline->pipeline)); + vk_instance.pfn_vkSetDebugUtilsObjectNameEXT(device->device, &static_cast(duoni)); + } + { - std::lock_guard guard(device->mutex); + std::lock_guard guard(device->mutex); device->pipelines.insert({ pipeline->name, pipeline }); } @@ -1113,15 +1414,6 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline) { VK_LOG_DEBUG("ggml_pipeline_destroy_pipeline(" << pipeline->name << ")"); - for (auto& pool : pipeline->descriptor_pools) { - device.destroyDescriptorPool(pool); - } - pipeline->descriptor_pools.clear(); - pipeline->descriptor_sets.clear(); - pipeline->descriptor_set_idx = 0; - - device.destroyDescriptorSetLayout(pipeline->dsl); - device.destroyPipelineLayout(pipeline->layout); device.destroyShaderModule(pipeline->shader_module); @@ -1129,97 +1421,77 @@ static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline) device.destroyPipeline(pipeline->pipeline); } -static void ggml_pipeline_request_descriptor_sets(vk_device& device, vk_pipeline& pipeline, uint32_t n) { +static void ggml_pipeline_request_descriptor_sets(ggml_backend_vk_context *ctx, vk_pipeline& pipeline, uint32_t n) { VK_LOG_DEBUG("ggml_pipeline_request_descriptor_sets(" << pipeline->name << ", " << n << ")"); - device->pipeline_descriptor_set_requirements[pipeline->name] += n; + ctx->pipeline_descriptor_set_requirements += n; if (!pipeline->compiled) { pipeline->needed = true; - device->need_compiles = true; + ctx->device->need_compiles = true; } } -static void ggml_pipeline_allocate_descriptor_sets(vk_device& device) { - std::lock_guard guard(device->mutex); +static void ggml_pipeline_allocate_descriptor_sets(ggml_backend_vk_context * ctx) { - for (auto& pair : device->pipeline_descriptor_set_requirements) { - vk_pipeline pipeline = device->pipelines.at(pair.first).lock(); - const uint64_t n = pair.second; + if (ctx->descriptor_sets.size() >= ctx->pipeline_descriptor_set_requirements) { + // Enough descriptors are available + return; + } - VK_LOG_DEBUG("ggml_pipeline_allocate_descriptor_sets(" << pipeline->name << ", " << n << ")"); + vk_device& device = ctx->device; - if (pipeline->descriptor_sets.size() >= pipeline->descriptor_set_idx + n) { - // Enough descriptors are available - continue; + uint32_t to_alloc = ctx->pipeline_descriptor_set_requirements - ctx->descriptor_sets.size(); + uint32_t pool_remaining = VK_DEVICE_DESCRIPTOR_POOL_SIZE - ctx->descriptor_sets.size() % VK_DEVICE_DESCRIPTOR_POOL_SIZE; + uint32_t pool_idx = ctx->descriptor_sets.size() / VK_DEVICE_DESCRIPTOR_POOL_SIZE; + + while (to_alloc > 0) { + const uint32_t alloc_count = std::min(pool_remaining, to_alloc); + to_alloc -= alloc_count; + pool_remaining = VK_DEVICE_DESCRIPTOR_POOL_SIZE; + + if (pool_idx >= ctx->descriptor_pools.size()) { + vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, MAX_PARAMETER_COUNT * VK_DEVICE_DESCRIPTOR_POOL_SIZE); + vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, VK_DEVICE_DESCRIPTOR_POOL_SIZE, descriptor_pool_size); + ctx->descriptor_pools.push_back(device->device.createDescriptorPool(descriptor_pool_create_info)); } - uint32_t to_alloc = pipeline->descriptor_set_idx + n - pipeline->descriptor_sets.size(); - uint32_t pool_remaining = VK_DEVICE_DESCRIPTOR_POOL_SIZE - pipeline->descriptor_sets.size() % VK_DEVICE_DESCRIPTOR_POOL_SIZE; - uint32_t pool_idx = pipeline->descriptor_sets.size() / VK_DEVICE_DESCRIPTOR_POOL_SIZE; - - while (to_alloc > 0) { - const uint32_t alloc_count = std::min(pool_remaining, to_alloc); - to_alloc -= alloc_count; - pool_remaining = VK_DEVICE_DESCRIPTOR_POOL_SIZE; - - if (pool_idx >= pipeline->descriptor_pools.size()) { - vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline->parameter_count * VK_DEVICE_DESCRIPTOR_POOL_SIZE); - vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, VK_DEVICE_DESCRIPTOR_POOL_SIZE, descriptor_pool_size); - pipeline->descriptor_pools.push_back(device->device.createDescriptorPool(descriptor_pool_create_info)); - } - - std::vector layouts(alloc_count); - for (uint32_t i = 0; i < alloc_count; i++) { - layouts[i] = pipeline->dsl; - } - vk::DescriptorSetAllocateInfo descriptor_set_alloc_info(pipeline->descriptor_pools[pool_idx], alloc_count, layouts.data()); - std::vector sets = device->device.allocateDescriptorSets(descriptor_set_alloc_info); - pipeline->descriptor_sets.insert(pipeline->descriptor_sets.end(), sets.begin(), sets.end()); - - pool_idx++; + std::vector layouts(alloc_count); + for (uint32_t i = 0; i < alloc_count; i++) { + layouts[i] = device->dsl; } + vk::DescriptorSetAllocateInfo descriptor_set_alloc_info(ctx->descriptor_pools[pool_idx], alloc_count, layouts.data()); + std::vector sets = device->device.allocateDescriptorSets(descriptor_set_alloc_info); + ctx->descriptor_sets.insert(ctx->descriptor_sets.end(), sets.begin(), sets.end()); + + pool_idx++; } } -static void ggml_pipeline_cleanup(vk_pipeline& pipeline) { - VK_LOG_DEBUG("ggml_pipeline_cleanup(" << pipeline->name << ")"); - pipeline->descriptor_set_idx = 0; -} - -static vk::CommandBuffer ggml_vk_create_cmd_buffer(vk_device& device, vk_queue& q) { +static vk::CommandBuffer ggml_vk_create_cmd_buffer(vk_device& device, vk_command_pool& p) { VK_LOG_DEBUG("ggml_vk_create_cmd_buffer()"); - std::lock_guard guard(device->mutex); - if (q.cmd_buffers.size() > q.cmd_buffer_idx) { + if (p.cmd_buffers.size() > p.cmd_buffer_idx) { // Reuse command buffer - return q.cmd_buffers[q.cmd_buffer_idx++]; + return p.cmd_buffers[p.cmd_buffer_idx++]; } vk::CommandBufferAllocateInfo command_buffer_alloc_info( - q.pool, + p.pool, vk::CommandBufferLevel::ePrimary, 1); const std::vector cmd_buffers = device->device.allocateCommandBuffers(command_buffer_alloc_info); auto buf = cmd_buffers.front(); - q.cmd_buffers.push_back(buf); - q.cmd_buffer_idx++; + p.cmd_buffers.push_back(buf); + p.cmd_buffer_idx++; return buf; } -static vk_submission ggml_vk_create_submission(vk_device& device, vk_queue& q, std::vector wait_semaphores, std::vector signal_semaphores) { - VK_LOG_DEBUG("ggml_vk_create_submission()"); - vk_submission s; - s.buffer = ggml_vk_create_cmd_buffer(device, q); - s.wait_semaphores = std::move(wait_semaphores); - s.signal_semaphores = std::move(signal_semaphores); - return s; -} - static void ggml_vk_submit(vk_context& ctx, vk::Fence fence) { if (ctx->seqs.empty()) { if (fence) { - ctx->q->queue.submit({}, fence); + std::lock_guard guard(queue_mutex); + ctx->p->q->queue.submit({}, fence); } return; } @@ -1258,7 +1530,7 @@ static void ggml_vk_submit(vk_context& ctx, vk::Fence fence) { tl_signal_vals.push_back({}); tl_signal_semaphores.push_back({}); for (size_t i = 0; i < submission.wait_semaphores.size(); i++) { - stage_flags[idx].push_back(ctx->q->stage_flags); + stage_flags[idx].push_back(ctx->p->q->stage_flags); tl_wait_vals[idx].push_back(submission.wait_semaphores[i].value); tl_wait_semaphores[idx].push_back(submission.wait_semaphores[i].s); } @@ -1288,7 +1560,8 @@ static void ggml_vk_submit(vk_context& ctx, vk::Fence fence) { } } - ctx->q->queue.submit(submit_infos, fence); + std::lock_guard guard(queue_mutex); + ctx->p->q->queue.submit(submit_infos, fence); ctx->seqs.clear(); } @@ -1341,33 +1614,30 @@ static uint32_t ggml_vk_find_queue_family_index(std::vector guard(device->mutex); + std::lock_guard guard(device->mutex); q.queue_family_index = queue_family_index; q.transfer_only = transfer_only; - vk::CommandPoolCreateInfo command_pool_create_info_compute(vk::CommandPoolCreateFlags(VK_COMMAND_POOL_CREATE_TRANSIENT_BIT), queue_family_index); - q.pool = device->device.createCommandPool(command_pool_create_info_compute); - - q.cmd_buffer_idx = 0; + q.cmd_pool.init(device, &q); q.queue = device->device.getQueue(queue_family_index, queue_index); q.stage_flags = stage_flags; } -static vk_context ggml_vk_create_context(ggml_backend_vk_context * ctx, vk_queue& q) { +static vk_context ggml_vk_create_context(ggml_backend_vk_context * ctx, vk_command_pool& p) { vk_context result = std::make_shared(); VK_LOG_DEBUG("ggml_vk_create_context(" << result << ")"); ctx->gc.contexts.emplace_back(result); - result->q = &q; + result->p = &p; return result; } -static vk_context ggml_vk_create_temporary_context(vk_queue& q) { +static vk_context ggml_vk_create_temporary_context(vk_command_pool& p) { vk_context result = std::make_shared(); VK_LOG_DEBUG("ggml_vk_create_temporary_context(" << result << ")"); - result->q = &q; + result->p = &p; return result; } @@ -1400,15 +1670,29 @@ static vk::Event ggml_vk_create_event(ggml_backend_vk_context * ctx) { return ctx->gc.events[ctx->event_idx++]; } -static void ggml_vk_queue_cleanup(vk_device& device, vk_queue& q) { - VK_LOG_DEBUG("ggml_vk_queue_cleanup()"); - std::lock_guard guard(device->mutex); +static void ggml_vk_command_pool_cleanup(vk_device& device, vk_command_pool& p) { + VK_LOG_DEBUG("ggml_vk_command_pool_cleanup()"); // Requires command buffers to be done - device->device.resetCommandPool(q.pool); - q.cmd_buffer_idx = 0; + device->device.resetCommandPool(p.pool); + p.cmd_buffer_idx = 0; } +static void ggml_vk_queue_command_pools_cleanup(vk_device& device) { + VK_LOG_DEBUG("ggml_vk_queue_command_pools_cleanup()"); + + // Arbitrary frequency to cleanup/reuse command buffers + static constexpr uint32_t cleanup_frequency = 10; + + if (device->compute_queue.cmd_pool.cmd_buffer_idx >= cleanup_frequency) { + ggml_vk_command_pool_cleanup(device, device->compute_queue.cmd_pool); + } + if (device->transfer_queue.cmd_pool.cmd_buffer_idx >= cleanup_frequency) { + ggml_vk_command_pool_cleanup(device, device->transfer_queue.cmd_pool); + } +} + + static uint32_t find_properties(const vk::PhysicalDeviceMemoryProperties* mem_props, vk::MemoryRequirements* mem_req, vk::MemoryPropertyFlags flags) { for (uint32_t i = 0; i < mem_props->memoryTypeCount; ++i) { vk::MemoryType memory_type = mem_props->memoryTypes[i]; @@ -1427,8 +1711,6 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, vk::Memor throw vk::OutOfDeviceMemoryError("Requested buffer size exceeds device memory allocation limit"); } - std::lock_guard guard(device->mutex); - vk_buffer buf = std::make_shared(); if (size == 0) { @@ -1523,6 +1805,8 @@ static vk_buffer ggml_vk_create_buffer_device(vk_device& device, size_t size) { } else if (device->uma) { // Fall back to host memory type buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent); + } else if (device->disable_host_visible_vidmem) { + buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal, vk::MemoryPropertyFlagBits::eDeviceLocal); } else { // use rebar if available, otherwise fallback to device only visible memory buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal | vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent, vk::MemoryPropertyFlagBits::eDeviceLocal); @@ -1557,11 +1841,11 @@ static vk_subbuffer ggml_vk_subbuffer(vk_buffer& buf) { static void ggml_vk_sync_buffers(vk_context& ctx) { VK_LOG_DEBUG("ggml_vk_sync_buffers()"); - const bool transfer_queue = ctx->q->transfer_only; + const bool transfer_queue = ctx->p->q->transfer_only; ctx->s->buffer.pipelineBarrier( - ctx->q->stage_flags, - ctx->q->stage_flags, + ctx->p->q->stage_flags, + ctx->p->q->stage_flags, {}, { { { !transfer_queue ? (vk::AccessFlagBits::eShaderRead | vk::AccessFlagBits::eShaderWrite | vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) : (vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) }, @@ -1580,45 +1864,111 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector&& events ctx->s->buffer.waitEvents( events, - ctx->q->stage_flags, - ctx->q->stage_flags, + ctx->p->q->stage_flags, + ctx->p->q->stage_flags, {}, {}, {} ); } +enum FaCodePath { + FA_SCALAR, + FA_COOPMAT1, + FA_COOPMAT2, +}; + +static FaHeadSizes fa_get_head_sizes(uint32_t hsk, uint32_t hsv) { + if (hsk != 192 && hsk != 576 && hsk != hsv) { + return FA_HEAD_SIZE_UNSUPPORTED; + } + switch (hsk) { + case 64: return FA_HEAD_SIZE_64; + case 80: return FA_HEAD_SIZE_80; + case 96: return FA_HEAD_SIZE_96; + case 112: return FA_HEAD_SIZE_112; + case 128: return FA_HEAD_SIZE_128; + case 192: + if (hsv == 192) { + return FA_HEAD_SIZE_192; + } else if (hsv == 128) { + return FA_HEAD_SIZE_192_128; + } else { + return FA_HEAD_SIZE_UNSUPPORTED; + } + case 256: return FA_HEAD_SIZE_256; + case 576: + if (hsv == 512) { + return FA_HEAD_SIZE_576_512; + } else { + return FA_HEAD_SIZE_UNSUPPORTED; + } + default: return FA_HEAD_SIZE_UNSUPPORTED; + } +} + // number of rows/cols for flash attention shader static constexpr uint32_t flash_attention_num_small_rows = 32; static constexpr uint32_t scalar_flash_attention_num_small_rows = 1; -static constexpr uint32_t scalar_flash_attention_num_large_rows = 8; -static uint32_t get_fa_num_small_rows(bool scalar) { - return scalar ? scalar_flash_attention_num_small_rows : flash_attention_num_small_rows; +static uint32_t get_fa_scalar_num_large_rows(uint32_t hsv) { + if (hsv >= 512) { + return 2; + } else { + return 8; + } } -static std::array fa_rows_cols(bool scalar, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) { - GGML_UNUSED(clamp); +// The FA coopmat1 shader assumes 16x16x16 matrix multiply support. +// 128 threads split into four subgroups, each subgroup does 1/4 +// of the Bc dimension. +static constexpr uint32_t coopmat1_flash_attention_num_large_rows = 16; +static constexpr uint32_t scalar_flash_attention_Bc = 64; +static constexpr uint32_t scalar_flash_attention_workgroup_size = 128; - if (scalar) { +static uint32_t get_fa_num_small_rows(FaCodePath path) { + if (path == FA_COOPMAT2) { + return flash_attention_num_small_rows; + } else { + return scalar_flash_attention_num_small_rows; + } +} + +static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) { + GGML_UNUSED(clamp); + GGML_UNUSED(hsv); + + if (path == FA_SCALAR) { if (small_rows) { return {scalar_flash_attention_num_small_rows, 64}; } else { - return {scalar_flash_attention_num_large_rows, 32}; + return {get_fa_scalar_num_large_rows(hsv), 32}; + } + } + + if (path == FA_COOPMAT1) { + if (small_rows) { + return {scalar_flash_attention_num_small_rows, scalar_flash_attention_Bc}; + } else { + return {coopmat1_flash_attention_num_large_rows, scalar_flash_attention_Bc}; } } // small rows, large cols if (small_rows) { - return {get_fa_num_small_rows(scalar), 32}; + return {get_fa_num_small_rows(FA_COOPMAT2), 32}; } // small cols to reduce register count - if (ggml_is_quantized(type) || D == 256) { - return {64, 32}; + if (ggml_is_quantized(type) || hsk >= 256) { + if (hsk >= 512) { + return {32, 32}; + } else { + return {64, 32}; + } } return {64, 64}; -}; +} static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector& warptile, bool mul_mat_id, ggml_type src0_type) { @@ -1645,6 +1995,7 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec break; case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: + case GGML_TYPE_MXFP4: lut_size = 4*16; break; default: @@ -1657,7 +2008,7 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec const uint32_t warps = warptile[0] / warptile[10]; const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size; - const uint32_t mmid_row_ids = mul_mat_id ? 4096 * sizeof(uint32_t) : 0; + const uint32_t mmid_row_ids = mul_mat_id ? (4096 * sizeof(uint32_t) + 4/*_ne1*/) : 0; const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0; const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size; @@ -1774,18 +2125,18 @@ static void ggml_vk_load_shaders(vk_device& device) { s_mmq_wg_denoms = { 32, 64, 1 }; // spec constants and tile sizes for quant matmul (Qi_K) - l_warptile_mmq_k = { 256, 64, 128, 64, 1 }; - m_warptile_mmq_k = { 256, 32, 64, 64, 0 }; - s_warptile_mmq_k = { 256, 32, 32, 128, 0 }; - l_mmq_wg_denoms_k = { 64, 128, 1 }; - m_mmq_wg_denoms_k = { 32, 64, 1 }; - s_mmq_wg_denoms_k = { 32, 32, 1 }; + l_warptile_mmq_k = { 256, 128, 256, 64, 1 }; + m_warptile_mmq_k = { 256, 128, 128, 64, 1 }; + s_warptile_mmq_k = { 256, 32, 64, 128, 0 }; + l_mmq_wg_denoms_k = { 128, 256, 1 }; + m_mmq_wg_denoms_k = { 128, 128, 1 }; + s_mmq_wg_denoms_k = { 32, 64, 1 }; // spec constants and tile sizes for quant matmul_id - l_warptile_mmqid = { 256, 128, 64, 16, 0 }; + l_warptile_mmqid = { 256, 128, 128, 16, 0 }; m_warptile_mmqid = { 256, 128, 64, 16, 0 }; s_warptile_mmqid = { 256, 128, 64, 16, 0 }; - l_mmqid_wg_denoms = { 128, 64, 1 }; + l_mmqid_wg_denoms = { 128, 128, 1 }; m_mmqid_wg_denoms = { 128, 64, 1 }; s_mmqid_wg_denoms = { 128, 64, 1 }; @@ -1903,21 +2254,26 @@ static void ggml_vk_load_shaders(vk_device& device) { } compile_count++; } + compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), spv_size, spv_data, entrypoint, parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size)); }; - auto const &fa_wg_denoms = [&](bool scalar, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array { - return {fa_rows_cols(scalar, D, clamp, type, small_rows)[0], 1, 1}; + auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::array { + return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows)[0], 1, 1}; }; - auto const &fa_spec_constants = [&](bool scalar, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector { + auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector { // For large number of rows, 128 invocations seems to work best. // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we // can't use 256 for D==80. // For scalar, use 128 (arbitrary) - uint32_t wg_size = scalar ? 128 : ((small_rows && (D % 32) == 0) ? 256 : 128); - auto rows_cols = fa_rows_cols(scalar, D, clamp, type, small_rows); + // The same D_split value is used for both HSK and HSV, so just base it on the union of the LSBs. + const uint32_t D = (hsk|hsv); + uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1) + ? scalar_flash_attention_workgroup_size + : ((small_rows && (D % 32) == 0) ? 256 : 128); + auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows); // D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it. // D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader. @@ -1926,39 +2282,49 @@ static void ggml_vk_load_shaders(vk_device& device) { // mask dim1 is padded to 64, we rely on this to avoid clamping mask loads GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0); - return {wg_size, rows_cols[0], rows_cols[1], (D), clamp, D_split}; + return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split}; }; -#define CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, D) \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,false), fa_spec_constants(SCALAR, D,1,TYPE,false), 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,false), fa_spec_constants(SCALAR, D,0,TYPE,false), fa_rows_cols(SCALAR,D,0,TYPE,false)[1], true); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,false), fa_spec_constants(SCALAR, D,1,TYPE,false), 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,false), fa_spec_constants(SCALAR, D,0,TYPE,false), fa_rows_cols(SCALAR,D,0,TYPE,false)[1], true); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,true), fa_spec_constants(SCALAR, D,1,TYPE,true), 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,true), fa_spec_constants(SCALAR, D,0,TYPE,true), fa_rows_cols(SCALAR,D,0,TYPE,true)[1], true); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,true), fa_spec_constants(SCALAR, D,1,TYPE,true), 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,true), fa_spec_constants(SCALAR, D,0,TYPE,true), fa_rows_cols(SCALAR,D,0,TYPE,true)[1], true); \ +#define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, HSK, HSV, HEAD_SIZES) \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ -#define CREATE_FA(TYPE, NAMELC, SCALAR, SUFFIX) \ - CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 64) \ - CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 80) \ - CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 96) \ - CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 112) \ - CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 128) \ - CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 256) +#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \ + CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64, 64, 64) \ + CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 80, 80, 80) \ + CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 96, 96, 96) \ + CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 112, 112, 112) \ + CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 128, 128, 128) \ + CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 192, 192, 192) \ + CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 192, 128, 192_128) \ + CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 256, 256, 256) \ + CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 576, 512, 576_512) - CREATE_FA(GGML_TYPE_F16, f16, true, ) - CREATE_FA(GGML_TYPE_Q4_0, q4_0, true, ) - CREATE_FA(GGML_TYPE_Q8_0, q8_0, true, ) + CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, ) + CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, ) + CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, ) +#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + if (device->coopmat1_fa_support) { + CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT1, _cm1) + CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT1, _cm1) + CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT1, _cm1) + } +#endif #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) if (device->coopmat2) { - CREATE_FA(GGML_TYPE_F16, f16, false, _cm2) - CREATE_FA(GGML_TYPE_Q4_0, q4_0, false, _cm2) - CREATE_FA(GGML_TYPE_Q4_1, q4_1, false, _cm2) - CREATE_FA(GGML_TYPE_Q5_0, q5_0, false, _cm2) - CREATE_FA(GGML_TYPE_Q5_1, q5_1, false, _cm2) - CREATE_FA(GGML_TYPE_Q8_0, q8_0, false, _cm2) - CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, false, _cm2) + CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT2, _cm2) + CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT2, _cm2) + CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT2, _cm2) + CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_COOPMAT2, _cm2) + CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_COOPMAT2, _cm2) + CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT2, _cm2) + CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_COOPMAT2, _cm2) } #endif #undef CREATE_FA2 @@ -1987,25 +2353,26 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3) } #endif - CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) - CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) - CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) - CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) - CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) - CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) - CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) - CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) - CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) - CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) - CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ1_S].f16acc, matmul_iq1_s_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) - CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ1_M].f16acc, matmul_iq1_m_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) - CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) - CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) - CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) - CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) - CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) - CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) - CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0], matmul_q4_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1], matmul_q4_1_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0], matmul_q5_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_1], matmul_q5_1_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q8_0], matmul_q8_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q2_K], matmul_q2_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q3_K], matmul_q3_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_K], matmul_q4_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_K], matmul_q5_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q6_K], matmul_q6_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ1_S], matmul_iq1_s_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ1_M], matmul_iq1_m_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XXS], matmul_iq2_xxs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XS], matmul_iq2_xs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_S], matmul_iq2_s_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_XXS], matmul_iq3_xxs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_S], matmul_iq3_s_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_MXFP4], matmul_mxfp4_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4) #if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) @@ -2032,6 +2399,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f16acc, matmul_id_mxfp4_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) #undef CREATE_MM #undef CREATE_MM2 } else @@ -2041,17 +2409,17 @@ static void ggml_vk_load_shaders(vk_device& device) { // Create 6 variants, {s,m,l}x{unaligned,aligned} #define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ if (device->mul_mat ## ID ## _l[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \ if (device->mul_mat ## ID ## _m[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \ if (device->mul_mat ## ID ## _s[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \ if (device->mul_mat ## ID ## _l[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \ if (device->mul_mat ## ID ## _m[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \ if (device->mul_mat ## ID ## _s[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \ // Create 2 variants, {f16,f32} accumulator #define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ @@ -2073,47 +2441,49 @@ static void ggml_vk_load_shaders(vk_device& device) { #endif if (device->coopmat_acc_f16_support) { - CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], matmul_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1], matmul_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1], matmul_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0], matmul_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f16acc, matmul_iq1_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f16acc, matmul_iq1_m_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K], matmul_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K], matmul_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K], matmul_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K], matmul_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K], matmul_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S], matmul_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M], matmul_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS], matmul_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS], matmul_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S], matmul_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS], matmul_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S], matmul_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); } else { - CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f16acc, matmul_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f16acc, matmul_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f32acc, matmul_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f32acc, matmul_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f32acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f32acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f32acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); } CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); @@ -2146,6 +2516,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f16acc, matmul_id_mxfp4_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); } else { CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); @@ -2167,6 +2538,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f16acc, matmul_id_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); } #undef CREATE_MM2 #undef CREATE_MM @@ -2188,13 +2560,19 @@ static void ggml_vk_load_shaders(vk_device& device) { if (device->mul_mat ## ID ## _s[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \ -#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ - if (device->mul_mat ## ID ## _l[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ - if (device->mul_mat ## ID ## _m[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ - if (device->mul_mat ## ID ## _s[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ +#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + if (device->mul_mat ## ID ## _l[TYPE]) { \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f16acc->l, #NAMELC "_f16acc_l", NAMELC ## _f16acc_len, NAMELC ## _f16acc_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->l, #NAMELC "_l", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ + } \ + if (device->mul_mat ## ID ## _m[TYPE]) { \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f16acc->m, #NAMELC "_f16acc_m", NAMELC ## _f16acc_len, NAMELC ## _f16acc_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->m, #NAMELC "_m", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ + } \ + if (device->mul_mat ## ID ## _s[TYPE]) { \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f16acc->s, #NAMELC "_f16acc_s", NAMELC ## _f16acc_len, NAMELC ## _f16acc_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->s, #NAMELC "_s", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ + } \ // Create 2 variants, {f16,f32} accumulator #define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ @@ -2208,34 +2586,35 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], matmul_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1], matmul_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1], matmul_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0], matmul_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f16acc, matmul_iq1_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f16acc, matmul_iq1_m_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K], matmul_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K], matmul_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K], matmul_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K], matmul_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K], matmul_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S], matmul_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M], matmul_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS], matmul_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS], matmul_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S], matmul_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS], matmul_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S], matmul_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) if (device->integer_dot_product) { - CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); - CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); - CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); - CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); - CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0], matmul_q4_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1], matmul_q4_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0], matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1], matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0], matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); } #endif @@ -2265,6 +2644,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f16acc, matmul_id_mxfp4_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); #undef CREATE_MM2 #undef CREATE_MMQ #undef CREATE_MM @@ -2284,13 +2664,13 @@ static void ggml_vk_load_shaders(vk_device& device) { if (device->mul_mat ## ID ## _s[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \ -#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ +#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ if (device->mul_mat ## ID ## _l[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC "_l", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ if (device->mul_mat ## ID ## _m[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC "_m", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ if (device->mul_mat ## ID ## _s[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC "_s", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); @@ -2319,14 +2699,15 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) if (device->integer_dot_product) { - CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); - CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); - CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); - CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); - CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); } #endif @@ -2356,6 +2737,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); } // reusing CREATE_MM from the fp32 path if ((device->coopmat2 || device->coopmat_support) @@ -2414,6 +2796,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f32_f32_len, mul_mat_vec_iq3_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f32_f32_len, mul_mat_vec_iq4_xs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f32_f32_"+std::to_string(i+1), mul_mat_vec_mxfp4_f32_f32_len, mul_mat_vec_mxfp4_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32_"+std::to_string(i+1), mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32_"+std::to_string(i+1), mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); @@ -2437,6 +2820,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f16_f32_len, mul_mat_vec_iq3_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f16_f32_len, mul_mat_vec_iq4_xs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f16_f32_"+std::to_string(i+1), mul_mat_vec_mxfp4_f16_f32_len, mul_mat_vec_mxfp4_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); } ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); @@ -2461,6 +2845,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_S], "mul_mat_vec_id_iq3_s_f32", mul_mat_vec_id_iq3_s_f32_len, mul_mat_vec_id_iq3_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", mul_mat_vec_id_iq4_xs_f32_len, mul_mat_vec_id_iq4_xs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_f32", mul_mat_vec_id_mxfp4_f32_len, mul_mat_vec_id_mxfp4_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); // dequant shaders ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); @@ -2483,6 +2868,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ3_S], "dequant_iq3_s", dequant_iq3_s_len, dequant_iq3_s_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_XS], "dequant_iq4_xs", dequant_iq4_xs_len, dequant_iq4_xs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_NL], "dequant_iq4_nl", dequant_iq4_nl_len, dequant_iq4_nl_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_MXFP4], "dequant_mxfp4", dequant_mxfp4_len, dequant_mxfp4_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); // get_rows ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); @@ -2502,6 +2888,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ3_S], "get_rows_iq3_s", get_rows_iq3_s_len, get_rows_iq3_s_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs", get_rows_iq4_xs_len, get_rows_iq4_xs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl", get_rows_iq4_nl_len, get_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_MXFP4], "get_rows_mxfp4", get_rows_mxfp4_len, get_rows_mxfp4_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); @@ -2520,9 +2907,10 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ3_S], "get_rows_iq3_s_f32", get_rows_iq3_s_f32_len, get_rows_iq3_s_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs_f32", get_rows_iq4_xs_f32_len, get_rows_iq4_xs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4], "get_rows_mxfp4_f32", get_rows_mxfp4_f32_len, get_rows_mxfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 3 * sizeof(uint32_t), {1, 1, 1}, {}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, 5 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1); for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) { @@ -2532,11 +2920,12 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true); } } - ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 9 * sizeof(uint32_t), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 12 * sizeof(uint32_t), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_f32, "rms_norm_mul_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1); ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); @@ -2553,19 +2942,41 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_bf16,"contig_cpy_f32_bf16",contig_cpy_f32_bf16_len,contig_cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); if (device->float_controls_rte_fp16) { - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_rte_len, cpy_f32_q5_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_rte_len, cpy_f32_q5_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_rte_len, cpy_f32_q8_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_rte_len, cpy_f32_iq4_nl_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_rte_len, cpy_f32_q5_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_rte_len, cpy_f32_q5_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_rte_len, cpy_f32_q8_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_rte_len, cpy_f32_iq4_nl_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); } else { - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + } + + if (device->float_controls_rte_fp16) { + ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F32], "set_rows_f32", set_rows_f32_rte_len, set_rows_f32_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F16], "set_rows_f16", set_rows_f16_rte_len, set_rows_f16_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_BF16], "set_rows_bf16", set_rows_bf16_rte_len, set_rows_bf16_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_0], "set_rows_q4_0", set_rows_q4_0_rte_len, set_rows_q4_0_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_1], "set_rows_q4_1", set_rows_q4_1_rte_len, set_rows_q4_1_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_0], "set_rows_q5_0", set_rows_q5_0_rte_len, set_rows_q5_0_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_1], "set_rows_q5_1", set_rows_q5_1_rte_len, set_rows_q5_1_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q8_0], "set_rows_q8_0", set_rows_q8_0_rte_len, set_rows_q8_0_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_IQ4_NL], "set_rows_iq4_nl", set_rows_iq4_nl_rte_len, set_rows_iq4_nl_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); + } else { + ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F32], "set_rows_f32", set_rows_f32_len, set_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F16], "set_rows_f16", set_rows_f16_len, set_rows_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_BF16], "set_rows_bf16", set_rows_bf16_len, set_rows_bf16_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_0], "set_rows_q4_0", set_rows_q4_0_len, set_rows_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_1], "set_rows_q4_1", set_rows_q4_1_len, set_rows_q4_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_0], "set_rows_q5_0", set_rows_q5_0_len, set_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_1], "set_rows_q5_1", set_rows_q5_1_len, set_rows_q5_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q8_0], "set_rows_q8_0", set_rows_q8_0_len, set_rows_q8_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_IQ4_NL], "set_rows_iq4_nl", set_rows_iq4_nl_len, set_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); } ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_0], "cpy_q4_0_f32", cpy_q4_0_f32_len, cpy_q4_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1); @@ -2583,10 +2994,11 @@ static void ggml_vk_load_shaders(vk_device& device) { return s; }; + bool rte = device->float_controls_rte_fp16; #define CREATE_BINARY(name, namemod, spec) \ for (int s0 : {0,1}) for (int s1 : {0,1}) for (int d : {0,1}) \ ggml_vk_create_pipeline(device, device->pipeline_ ## name ## namemod[s0][s1][d], \ - #name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d], name ## _data[s0][s1][d], \ + #name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d][rte], name ## _data[s0][s1][d][rte], \ "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1); CREATE_BINARY(add, , {0}) @@ -2599,13 +3011,17 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_BINARY(div, _norepeat, {1}) #undef CREATE_BINARY + ggml_vk_create_pipeline(device, device->pipeline_add_id_f32, "add_id_f32", add_id_f32_len, add_id_f32_data, "main", 4, sizeof(vk_op_add_id_push_constants), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_concat_i32, "concat_i32", concat_i32_len, concat_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_upscale_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_upscale_nearest_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_NEAREST}, 1); + ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR}, 1); + ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_ac_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS}, 1); ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); @@ -2617,6 +3033,8 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_roll_f32, "roll_f32", roll_f32_len, roll_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32", repeat_f32_len, repeat_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_repeat_back_f32, "repeat_back_f32", repeat_back_f32_len, repeat_back_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); @@ -2625,6 +3043,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); CREATE_UNARY(gelu) + CREATE_UNARY(gelu_erf) CREATE_UNARY(gelu_quick) CREATE_UNARY(silu) CREATE_UNARY(relu) @@ -2632,15 +3051,32 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_UNARY(sigmoid) #undef CREATE_UNARY +#define CREATE_GLU(name) \ + if (device->float_controls_rte_fp16) { \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16_rte", name ## _f16_rte_len, name ## _f16_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \ + } else { \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \ + } + + CREATE_GLU(geglu) + CREATE_GLU(reglu) + CREATE_GLU(swiglu) + CREATE_GLU(swiglu_oai) + CREATE_GLU(geglu_erf) + CREATE_GLU(geglu_quick) +#undef CREATE_GLU + ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1); - ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1); ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); @@ -2677,6 +3113,8 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_conv_transpose_1d_f32, "conv_transpose_1d_f32", conv_transpose_1d_f32_len, conv_transpose_1d_f32_data, "main", 3, sizeof(vk_op_conv_transpose_1d_push_constants), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_pool2d_f32, "pool2d_f32", pool2d_f32_len, pool2d_f32_data, "main", 2, sizeof(vk_op_pool2d_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); @@ -2685,6 +3123,108 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + // conv2d + for (uint32_t s = 0; s < CONV_SHAPE_COUNT; ++s) { + uint32_t conv2d_WG_SIZE = 256; + uint32_t conv2d_BS_K = 128; + uint32_t conv2d_BS_CRS = 16; + uint32_t use_collectives = 0; // Enables subgroup ops for preventing the re-calculation of indices. + uint32_t conv2d_BS_NPQ = 128; + uint32_t conv2d_TS_K = 8; + uint32_t conv2d_SHMEM_PAD = 4; + bool conv2d_UNROLL = true; + +#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + if (device->coopmat2) { + conv2d_SHMEM_PAD = 8; // 8 float16_t + } +#endif + + if (device->vendor_id == VK_VENDOR_ID_INTEL) { + conv2d_SHMEM_PAD = 0; + conv2d_UNROLL = false; + } else if (device->vendor_id == VK_VENDOR_ID_AMD) { + conv2d_SHMEM_PAD = device->architecture == vk_device_architecture::AMD_GCN ? 1 : 4; + } + + switch (s) { + default: + case CONV_SHAPE_128x128: + conv2d_BS_K = 128; + conv2d_BS_NPQ = 128; + conv2d_BS_CRS = 16; + if (device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != vk_device_architecture::AMD_GCN) { + conv2d_UNROLL = false; + } + break; + case CONV_SHAPE_64x32: + conv2d_BS_K = 64; + conv2d_BS_NPQ = 32; + conv2d_BS_CRS = 32; + conv2d_TS_K = 4; + break; + case CONV_SHAPE_32x256: + conv2d_BS_K = 32; + conv2d_BS_NPQ = 256; + conv2d_BS_CRS = 16; + break; + } + + // Use collectives on pre-Turing NVIDIA GPUs and GCN AMD cards, which had slower integer math. + bool allow_collectives_nv = device->vendor_id != VK_VENDOR_ID_NVIDIA || + device->architecture == vk_device_architecture::NVIDIA_PRE_TURING; + bool allow_collectives_amd = device->vendor_id != VK_VENDOR_ID_AMD || + device->architecture == vk_device_architecture::AMD_GCN; + + if (device->subgroup_shuffle && + device->vendor_id != VK_VENDOR_ID_INTEL && // Do not enable collectives on Intel, see PR 14316. + allow_collectives_nv && + allow_collectives_amd) { + use_collectives = 1; + conv2d_BS_CRS = std::min( + device->subgroup_size, + conv2d_BS_CRS); // CRS block size should be capped at subgroup size for correctness when shuffle is used. + } + + uint32_t conv2d_shmem_req = + (conv2d_BS_K * (conv2d_BS_CRS + conv2d_SHMEM_PAD) + conv2d_BS_CRS * (conv2d_BS_NPQ + conv2d_SHMEM_PAD)) * sizeof(float); + if (device->properties.limits.maxComputeSharedMemorySize < conv2d_shmem_req) { + conv2d_BS_CRS = 8; + if (use_collectives) { + conv2d_BS_CRS = std::min(device->subgroup_size, conv2d_BS_CRS); + } + } + + std::array wg_denoms = { conv2d_BS_K, conv2d_BS_NPQ, 1 }; + std::vector spec_constants = { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives, conv2d_SHMEM_PAD }; + +#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + if (device->coopmat2) { + ggml_vk_create_pipeline( + device, device->pipeline_conv2d_f32[s], "conv2d_f32", conv2d_f32_cm2_len, conv2d_f32_cm2_data, "main", 3, + sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives); + ggml_vk_create_pipeline( + device, device->pipeline_conv2d_f16_f32[s], "conv2d_f16_f32", conv2d_f16_f32_cm2_len, conv2d_f16_f32_cm2_data, "main", 3, + sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives); + } else +#endif + if (conv2d_UNROLL) { + ggml_vk_create_pipeline( + device, device->pipeline_conv2d_f32[s], "conv2d_f32", conv2d_f32_unroll_len, conv2d_f32_unroll_data, "main", 3, + sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives); + ggml_vk_create_pipeline( + device, device->pipeline_conv2d_f16_f32[s], "conv2d_f16_f32", conv2d_f16_f32_unroll_len, conv2d_f16_f32_unroll_data, "main", 3, + sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives); + } else { + ggml_vk_create_pipeline( + device, device->pipeline_conv2d_f32[s], "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3, + sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives); + ggml_vk_create_pipeline( + device, device->pipeline_conv2d_f16_f32[s], "conv2d_f16_f32", conv2d_f16_f32_len, conv2d_f16_f32_data, "main", 3, + sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives); + } + } + ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f32, "conv2d_dw_cwhn_f32", conv2d_dw_cwhn_f32_len, conv2d_dw_cwhn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); @@ -2707,9 +3247,9 @@ static vk_device ggml_vk_get_device(size_t idx) { #ifdef GGML_VULKAN_MEMORY_DEBUG device->memory_logger = std::unique_ptr(new vk_memory_logger()); #endif -#ifdef GGML_VULKAN_PERF - device->perf_logger = std::unique_ptr(new vk_perf_logger()); -#endif + if (vk_perf_logger_enabled) { + device->perf_logger = std::unique_ptr(new vk_perf_logger()); + } size_t dev_num = vk_instance.device_indices[idx]; @@ -2728,6 +3268,9 @@ static vk_device ggml_vk_get_device(size_t idx) { const char* GGML_VK_PREFER_HOST_MEMORY = getenv("GGML_VK_PREFER_HOST_MEMORY"); device->prefer_host_memory = GGML_VK_PREFER_HOST_MEMORY != nullptr; + const char* GGML_VK_DISABLE_HOST_VISIBLE_VIDMEM = getenv("GGML_VK_DISABLE_HOST_VISIBLE_VIDMEM"); + device->disable_host_visible_vidmem = GGML_VK_DISABLE_HOST_VISIBLE_VIDMEM != nullptr; + bool fp16_storage = false; bool fp16_compute = false; bool maintenance4_support = false; @@ -2754,23 +3297,29 @@ static vk_device ggml_vk_get_device(size_t idx) { pipeline_robustness = true; } else if (strcmp("VK_EXT_subgroup_size_control", properties.extensionName) == 0) { device->subgroup_size_control = true; +#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) } else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0 && !getenv("GGML_VK_DISABLE_COOPMAT")) { device->coopmat_support = true; device->coopmat_m = 0; device->coopmat_n = 0; device->coopmat_k = 0; +#endif +#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) } else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 && !getenv("GGML_VK_DISABLE_COOPMAT2")) { coopmat2_support = true; +#endif #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) } else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 && !getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) { device->integer_dot_product = true; #endif +#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) } else if (strcmp("VK_KHR_shader_bfloat16", properties.extensionName) == 0 && !getenv("GGML_VK_DISABLE_BFLOAT16")) { bfloat16_support = true; +#endif } } @@ -2991,6 +3540,12 @@ static vk_device ggml_vk_get_device(size_t idx) { device->fp16 = device->fp16 && vk12_features.shaderFloat16; +#if defined(VK_KHR_shader_bfloat16) + device->bf16 = bfloat16_support && bfloat16_features.shaderBFloat16Type; +#else + device->bf16 = false; +#endif + device->pipeline_robustness = pl_robustness_features.pipelineRobustness; if (device->subgroup_size_control) { @@ -3009,6 +3564,11 @@ static vk_device ggml_vk_get_device(size_t idx) { #if defined(VK_KHR_cooperative_matrix) device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix; + + // coopmat1 fa shader currently assumes 32 invocations per subgroup + device->coopmat1_fa_support = device->coopmat_support && device->subgroup_require_full_support && + device->subgroup_size_control && device->subgroup_min_size <= 32 && + device->subgroup_max_size >= 32; #endif if (coopmat2_support) { @@ -3143,6 +3703,9 @@ static vk_device ggml_vk_get_device(size_t idx) { // Only enable if shape is identical device->coopmat_acc_f32_support = true; } + if (prop.MSize == 16 && prop.NSize == 16 && prop.KSize == 16) { + device->coopmat_support_16x16x16_f32acc = true; + } } else if ((vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat16 && (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat16) { // coopmat sizes not set yet @@ -3155,6 +3718,9 @@ static vk_device ggml_vk_get_device(size_t idx) { // Only enable if shape is identical device->coopmat_acc_f16_support = true; } + if (prop.MSize == 16 && prop.NSize == 16 && prop.KSize == 16) { + device->coopmat_support_16x16x16_f16acc = true; + } } } else if ((vk::ComponentTypeKHR)prop.AType == vk::ComponentTypeKHR::eSint8 && (vk::ComponentTypeKHR)prop.BType == vk::ComponentTypeKHR::eSint8 && @@ -3256,6 +3822,22 @@ static vk_device ggml_vk_get_device(size_t idx) { } } + + std::vector dsl_binding; + std::vector dsl_binding_flags; + for (uint32_t i = 0; i < MAX_PARAMETER_COUNT; i++) { + dsl_binding.push_back({i, vk::DescriptorType::eStorageBuffer, 1, vk::ShaderStageFlagBits::eCompute}); + dsl_binding_flags.push_back({}); + } + + vk::DescriptorSetLayoutBindingFlagsCreateInfo dslbfci = { dsl_binding_flags }; + + vk::DescriptorSetLayoutCreateInfo descriptor_set_layout_create_info( + {}, + dsl_binding); + descriptor_set_layout_create_info.setPNext(&dslbfci); + device->dsl = device->device.createDescriptorSetLayout(descriptor_set_layout_create_info); + ggml_vk_load_shaders(device); if (!device->single_queue) { @@ -3263,7 +3845,8 @@ static vk_device ggml_vk_get_device(size_t idx) { ggml_vk_create_queue(device, device->transfer_queue, transfer_queue_family_index, transfer_queue_index, { vk::PipelineStageFlagBits::eTransfer }, true); } else { // TODO: Use pointer or reference to avoid copy - device->transfer_queue = device->compute_queue; + device->transfer_queue.copyFrom(device->compute_queue); + device->transfer_queue.cmd_pool.init(device, &device->transfer_queue); } device->buffer_type = { @@ -3276,6 +3859,8 @@ static vk_device ggml_vk_get_device(size_t idx) { device->idx = idx; + device->disable_fusion = getenv("GGML_VK_DISABLE_FUSION") != nullptr; + return device; } @@ -3303,6 +3888,7 @@ static void ggml_vk_print_gpu_info(size_t idx) { bool coopmat_support = false; bool coopmat2_support = false; bool integer_dot_product = false; + bool bfloat16_support = false; for (auto properties : ext_props) { if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) { @@ -3323,6 +3909,11 @@ static void ggml_vk_print_gpu_info(size_t idx) { } else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 && !getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) { integer_dot_product = true; +#endif +#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + } else if (strcmp("VK_KHR_shader_bfloat16", properties.extensionName) == 0 && + !getenv("GGML_VK_DISABLE_BFLOAT16")) { + bfloat16_support = true; #endif } } @@ -3389,10 +3980,25 @@ static void ggml_vk_print_gpu_info(size_t idx) { last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_features; } +#if defined(VK_KHR_shader_bfloat16) + VkPhysicalDeviceShaderBfloat16FeaturesKHR bfloat16_features {}; + bfloat16_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR; + if (bfloat16_support) { + last_struct->pNext = (VkBaseOutStructure *)&bfloat16_features; + last_struct = (VkBaseOutStructure *)&bfloat16_features; + } +#endif + vkGetPhysicalDeviceFeatures2(physical_device, &device_features2); fp16 = fp16 && vk12_features.shaderFloat16; +#if defined(VK_KHR_shader_bfloat16) + bool bf16 = bfloat16_support && bfloat16_features.shaderBFloat16Type; +#else + bool bf16 = false; +#endif + uint32_t default_subgroup_size = get_subgroup_size("", device_architecture); const size_t subgroup_size = (default_subgroup_size != 0) ? default_subgroup_size : subgroup_props.subgroupSize; const bool uma = props2.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu; @@ -3410,8 +4016,8 @@ static void ggml_vk_print_gpu_info(size_t idx) { std::string matrix_cores = coopmat2_support ? "NV_coopmat2" : coopmat_support ? "KHR_coopmat" : "none"; std::string device_name = props2.properties.deviceName.data(); - GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | shared memory: %d | int dot: %d | matrix cores: %s\n", - idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, subgroup_size, + GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | bf16: %d | warp size: %zu | shared memory: %d | int dot: %d | matrix cores: %s\n", + idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, bf16, subgroup_size, props2.properties.limits.maxComputeSharedMemorySize, integer_dot_product, matrix_cores.c_str()); if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) { @@ -3422,6 +4028,8 @@ static void ggml_vk_print_gpu_info(size_t idx) { static bool ggml_vk_instance_validation_ext_available(const std::vector& instance_extensions); static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector& instance_extensions); +static bool ggml_vk_instance_debug_utils_ext_available(const std::vector & instance_extensions); + static void ggml_vk_instance_init() { if (vk_instance_initialized) { return; @@ -3442,7 +4050,7 @@ static void ggml_vk_instance_init() { #ifdef __APPLE__ const bool portability_enumeration_ext = ggml_vk_instance_portability_enumeration_ext_available(instance_extensions); #endif - + const bool debug_utils_ext = ggml_vk_instance_debug_utils_ext_available(instance_extensions) && getenv("GGML_VK_DEBUG_MARKERS") != nullptr; std::vector layers; if (validation_ext) { @@ -3457,6 +4065,9 @@ static void ggml_vk_instance_init() { extensions.push_back("VK_KHR_portability_enumeration"); } #endif + if (debug_utils_ext) { + extensions.push_back("VK_EXT_debug_utils"); + } vk::InstanceCreateInfo instance_create_info(vk::InstanceCreateFlags{}, &app_info, layers, extensions); #ifdef __APPLE__ if (portability_enumeration_ext) { @@ -3480,11 +4091,24 @@ static void ggml_vk_instance_init() { vk_instance.instance = vk::createInstance(instance_create_info); vk_instance_initialized = true; - size_t num_available_devices = vk_instance.instance.enumeratePhysicalDevices().size(); + if (debug_utils_ext) { + vk_instance.debug_utils_support = true; + vk_instance.pfn_vkSetDebugUtilsObjectNameEXT = (PFN_vkSetDebugUtilsObjectNameEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkSetDebugUtilsObjectNameEXT"); + vk_instance.pfn_vkQueueBeginDebugUtilsLabelEXT = (PFN_vkQueueBeginDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkQueueBeginDebugUtilsLabelEXT"); + vk_instance.pfn_vkQueueEndDebugUtilsLabelEXT = (PFN_vkQueueEndDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkQueueEndDebugUtilsLabelEXT"); + vk_instance.pfn_vkCmdBeginDebugUtilsLabelEXT = (PFN_vkCmdBeginDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkCmdBeginDebugUtilsLabelEXT"); + vk_instance.pfn_vkCmdEndDebugUtilsLabelEXT = (PFN_vkCmdEndDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkCmdEndDebugUtilsLabelEXT"); + vk_instance.pfn_vkCmdInsertDebugUtilsLabelEXT = (PFN_vkCmdInsertDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkCmdInsertDebugUtilsLabelEXT"); + + } + + vk_perf_logger_enabled = getenv("GGML_VK_PERF_LOGGER") != nullptr; // Emulate behavior of CUDA_VISIBLE_DEVICES for Vulkan char * devices_env = getenv("GGML_VK_VISIBLE_DEVICES"); if (devices_env != nullptr) { + size_t num_available_devices = vk_instance.instance.enumeratePhysicalDevices().size(); + std::string devices(devices_env); std::replace(devices.begin(), devices.end(), ',', ' '); @@ -3500,9 +4124,9 @@ static void ggml_vk_instance_init() { } else { std::vector devices = vk_instance.instance.enumeratePhysicalDevices(); - // Make sure at least one device exists + // If no vulkan devices are found, return early if (devices.empty()) { - std::cerr << "ggml_vulkan: Error: No devices found." << std::endl; + GGML_LOG_INFO("ggml_vulkan: No devices found.\n"); return; } @@ -3585,9 +4209,20 @@ static void ggml_vk_instance_init() { } } - // If no dedicated GPUs found, fall back to GPU 0 + // If no dedicated GPUs found, fall back to the first non-CPU device. + // If only CPU devices are available, return without devices. if (vk_instance.device_indices.empty()) { - vk_instance.device_indices.push_back(0); + for (size_t i = 0; i < devices.size(); i++) { + if (devices[i].getProperties().deviceType != vk::PhysicalDeviceType::eCpu) { + vk_instance.device_indices.push_back(i); + break; + } + } + } + + if (vk_instance.device_indices.empty()) { + GGML_LOG_INFO("ggml_vulkan: No devices found.\n"); + return; } } GGML_LOG_DEBUG("ggml_vulkan: Found %zu Vulkan devices:\n", vk_instance.device_indices.size()); @@ -3616,6 +4251,9 @@ static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) { ctx->fence = ctx->device->device.createFence({}); ctx->almost_ready_fence = ctx->device->device.createFence({}); + ctx->compute_cmd_pool.init(ctx->device, &ctx->device->compute_queue); + ctx->transfer_cmd_pool.init(ctx->device, &ctx->device->transfer_queue); + #ifdef GGML_VULKAN_CHECK_RESULTS const char* skip_checks = getenv("GGML_VULKAN_SKIP_CHECKS"); vk_skip_checks = (skip_checks == NULL ? 0 : atoi(skip_checks)); @@ -3647,6 +4285,7 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: + case GGML_TYPE_MXFP4: break; default: return nullptr; @@ -3656,7 +4295,7 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type } static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type, ggml_prec prec) { - VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_pipeline(" << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")"); + VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_pipeline(" << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ", " << prec << ")"); if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) { return ctx->device->pipeline_matmul_f32; } @@ -3684,7 +4323,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte // MMQ if (src1_type == GGML_TYPE_Q8_1) { - vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f16acc; + vk_matmul_pipeline pipelines = (ctx->device->fp16 && prec == GGML_PREC_DEFAULT) ? ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f32acc; if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) { return nullptr; @@ -3717,6 +4356,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: + case GGML_TYPE_MXFP4: break; default: return nullptr; @@ -3724,9 +4364,12 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte if (ctx->device->coopmat2) { assert(src1_type == GGML_TYPE_F16); - return ctx->device->pipeline_dequant_mul_mat_mat_f16[src0_type].f16acc; + return prec == GGML_PREC_DEFAULT ? ctx->device->pipeline_dequant_mul_mat_mat_f16[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat_f16[src0_type].f32acc; } - return ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc; + if (ctx->device->coopmat_support) { + return (ctx->device->fp16 && ctx->device->coopmat_acc_f16_support && prec == GGML_PREC_DEFAULT) ? ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc; + } + return (ctx->device->fp16 && prec == GGML_PREC_DEFAULT) ? ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc; } static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type, uint32_t num_cols) { @@ -3757,6 +4400,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: + case GGML_TYPE_MXFP4: break; default: return nullptr; @@ -3811,6 +4455,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: + case GGML_TYPE_MXFP4: break; default: return nullptr; @@ -3846,6 +4491,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: + case GGML_TYPE_MXFP4: break; default: return nullptr; @@ -3933,6 +4579,7 @@ static void * ggml_vk_host_malloc(vk_device& device, size_t size) { return nullptr; } + std::lock_guard guard(device->mutex); device->pinned_memory.push_back(std::make_tuple(buf->ptr, size, buf)); return buf->ptr; @@ -3943,6 +4590,8 @@ static void ggml_vk_host_free(vk_device& device, void* ptr) { return; } VK_LOG_MEMORY("ggml_vk_host_free(" << ptr << ")"); + std::lock_guard guard(device->mutex); + vk_buffer buf; size_t index; for (size_t i = 0; i < device->pinned_memory.size(); i++) { @@ -3965,6 +4614,7 @@ static void ggml_vk_host_free(vk_device& device, void* ptr) { } static void ggml_vk_host_get(vk_device& device, const void * ptr, vk_buffer& buf, size_t& buf_offset) { + std::lock_guard guard(device->mutex); buf = nullptr; buf_offset = 0; for (size_t i = 0; i < device->pinned_memory.size(); i++) { @@ -3978,9 +4628,9 @@ static void ggml_vk_host_get(vk_device& device, const void * ptr, vk_buffer& buf } } -static vk_submission ggml_vk_begin_submission(vk_device& device, vk_queue& q, bool one_time = true) { +static vk_submission ggml_vk_begin_submission(vk_device& device, vk_command_pool& p, bool one_time = true) { vk_submission s; - s.buffer = ggml_vk_create_cmd_buffer(device, q); + s.buffer = ggml_vk_create_cmd_buffer(device, p); if (one_time) { s.buffer.begin({ vk::CommandBufferUsageFlagBits::eOneTimeSubmit }); } else { @@ -3990,7 +4640,33 @@ static vk_submission ggml_vk_begin_submission(vk_device& device, vk_queue& q, bo return s; } -static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& subctx, vk_pipeline& pipeline, std::initializer_list const& descriptor_buffer_infos, size_t push_constant_size, const void* push_constants, std::array elements) { +template size_t push_constant_size(const T &t) { + static_assert(std::is_class::value, "T must be a struct/class"); + GGML_UNUSED(t); + return sizeof(T); +} +template size_t push_constant_size(const std::vector &t) { + GGML_UNUSED(t); + return sizeof(T) * t.size(); +} +template size_t push_constant_size(const std::array &t) { + GGML_UNUSED(t); + return sizeof(T) * N; +} + +template const T *push_constant_data(const T &t) { + static_assert(std::is_class::value, "T must be a struct/class"); + return &t; +} +template const T *push_constant_data(const std::vector &t) { + return t.data(); +} +template const T *push_constant_data(const std::array &t) { + return t.data(); +} + +template +static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& subctx, vk_pipeline& pipeline, std::initializer_list const& descriptor_buffer_infos, const T &push_constants, std::array elements) { const uint32_t wg0 = CEIL_DIV(elements[0], pipeline->wg_denoms[0]); const uint32_t wg1 = CEIL_DIV(elements[1], pipeline->wg_denoms[1]); const uint32_t wg2 = CEIL_DIV(elements[2], pipeline->wg_denoms[2]); @@ -3999,14 +4675,15 @@ static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& std::cerr << "(" << buffer.buffer << ", " << buffer.offset << ", " << buffer.range << "), "; } std::cerr << "}, (" << wg0 << "," << wg1 << "," << wg2 << "))"); - GGML_ASSERT(pipeline->descriptor_set_idx < pipeline->descriptor_sets.size()); - GGML_ASSERT(descriptor_buffer_infos.size() == pipeline->parameter_count); + GGML_ASSERT(ctx->descriptor_set_idx < ctx->descriptor_sets.size()); + GGML_ASSERT(descriptor_buffer_infos.size() <= MAX_PARAMETER_COUNT); + GGML_ASSERT(pipeline->parameter_count == descriptor_buffer_infos.size()); - vk::DescriptorSet& descriptor_set = pipeline->descriptor_sets[pipeline->descriptor_set_idx++]; + vk::DescriptorSet& descriptor_set = ctx->descriptor_sets[ctx->descriptor_set_idx++]; vk::WriteDescriptorSet write_descriptor_set{ descriptor_set, 0, 0, pipeline->parameter_count, vk::DescriptorType::eStorageBuffer, nullptr, descriptor_buffer_infos.begin() }; ctx->device->device.updateDescriptorSets({ write_descriptor_set }, {}); - subctx->s->buffer.pushConstants(pipeline->layout, vk::ShaderStageFlagBits::eCompute, 0, push_constant_size, push_constants); + subctx->s->buffer.pushConstants(pipeline->layout, vk::ShaderStageFlagBits::eCompute, 0, push_constant_size(push_constants), push_constant_data(push_constants)); subctx->s->buffer.bindPipeline(vk::PipelineBindPoint::eCompute, pipeline->pipeline); subctx->s->buffer.bindDescriptorSets(vk::PipelineBindPoint::eCompute, pipeline->layout, @@ -4039,7 +4716,7 @@ static void ggml_vk_ctx_begin(vk_device& device, vk_context& subctx) { ggml_vk_ctx_end(subctx); } - subctx->seqs.push_back({ ggml_vk_begin_submission(device, *subctx->q) }); + subctx->seqs.push_back({ ggml_vk_begin_submission(device, *subctx->p) }); subctx->s = subctx->seqs[subctx->seqs.size() - 1].data(); } @@ -4240,7 +4917,9 @@ static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void * memcpy((uint8_t *)dst->ptr + offset + i * width, (const uint8_t *) src + i * spitch, width); } } else { - vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue); + std::lock_guard guard(dst->device->mutex); + + vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool); ggml_vk_ctx_begin(dst->device, subctx); ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, spitch, width, height, true); ggml_vk_ctx_end(subctx); @@ -4252,6 +4931,7 @@ static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void * ggml_vk_submit(subctx, dst->device->fence); VK_CHECK(dst->device->device.waitForFences({ dst->device->fence }, true, UINT64_MAX), "vk_buffer_write_2d waitForFences"); dst->device->device.resetFences({ dst->device->fence }); + ggml_vk_queue_command_pools_cleanup(dst->device); } } @@ -4328,7 +5008,9 @@ static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_ memcpy(dst, (uint8_t *) src->ptr + offset, size); } else { - vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue); + std::lock_guard guard(src->device->mutex); + + vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue.cmd_pool); ggml_vk_ctx_begin(src->device, subctx); ggml_vk_buffer_read_async(subctx, src, offset, dst, size, true); ggml_vk_ctx_end(subctx); @@ -4336,6 +5018,7 @@ static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_ ggml_vk_submit(subctx, src->device->fence); VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX), "vk_buffer_read waitForFences"); src->device->device.resetFences({ src->device->fence }); + ggml_vk_queue_command_pools_cleanup(src->device); for (auto& cpy : subctx->out_memcpys) { memcpy(cpy.dst, cpy.src, cpy.n); @@ -4355,15 +5038,17 @@ static void ggml_vk_buffer_copy_async(vk_context& ctx, vk_buffer& dst, size_t ds static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) { if (src->device == dst->device) { + std::lock_guard guard(src->device->mutex); VK_LOG_DEBUG("ggml_vk_buffer_copy(SINGLE_DEVICE, " << size << ")"); // Copy within the device - vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue); + vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue.cmd_pool); ggml_vk_ctx_begin(src->device, subctx); ggml_vk_buffer_copy_async(subctx, dst, dst_offset, src, src_offset, size); ggml_vk_ctx_end(subctx); ggml_vk_submit(subctx, src->device->fence); VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX), "vk_buffer_copy waitForFences"); src->device->device.resetFences({ src->device->fence }); + ggml_vk_queue_command_pools_cleanup(src->device); } else { VK_LOG_DEBUG("ggml_vk_buffer_copy(MULTI_DEVICE, " << size << ")"); // Copy device to device @@ -4388,7 +5073,8 @@ static void ggml_vk_buffer_memset_async(vk_context& ctx, vk_buffer& dst, size_t static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, size_t size) { VK_LOG_DEBUG("ggml_vk_buffer_memset(" << offset << ", " << c << ", " << size << ")"); - vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue); + std::lock_guard guard(dst->device->mutex); + vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool); ggml_vk_ctx_begin(dst->device, subctx); subctx->s->buffer.fillBuffer(dst->buffer, offset, size, c); ggml_vk_ctx_end(subctx); @@ -4396,28 +5082,40 @@ static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, siz ggml_vk_submit(subctx, dst->device->fence); VK_CHECK(dst->device->device.waitForFences({ dst->device->fence }, true, UINT64_MAX), "vk_memset waitForFences"); dst->device->device.resetFences({ dst->device->fence }); + ggml_vk_queue_command_pools_cleanup(dst->device); } -static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, int m, int n, int k, const vk_pipeline& pipeline) { +static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, uint32_t m, uint32_t n, uint32_t k, const vk_pipeline& pipeline) { VK_LOG_DEBUG("ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ")"); uint32_t split_k = 1; - if (ctx->device->shader_core_count != 0 && m >= (int)pipeline->wg_denoms[0] && n >= (int)pipeline->wg_denoms[1]) { + if (ctx->device->shader_core_count != 0 && m >= pipeline->wg_denoms[0] && n >= pipeline->wg_denoms[1]) { // If k is 'large' and the SMs will fill less than halfway, use split_k. uint32_t m_tiles = CEIL_DIV(m, pipeline->wg_denoms[0]); uint32_t n_tiles = CEIL_DIV(n, pipeline->wg_denoms[1]); - if (k >= 2048 && m_tiles * n_tiles < ctx->device->shader_core_count / 2) { - split_k = ctx->device->shader_core_count / (m_tiles * n_tiles); - // Clamp to 2 or 4 - split_k = std::min(split_k, 4u); - if (split_k == 3) { - split_k = 2; + + if (k >= 2048) { + if (m_tiles * n_tiles <= ctx->device->shader_core_count / 2) { + split_k = ctx->device->shader_core_count / (m_tiles * n_tiles); + } else if (m_tiles * n_tiles <= ctx->device->shader_core_count * 2 / 3) { + split_k = 3; } - if (ctx->device->coopmat2) { - // coopmat2 shader expects splits to be aligned to 256 - while (split_k > 1 && ((k / split_k) % 256) != 0) { - split_k /= 2; + // Cap the split at 8x. Unless k is huge this is a lot of overhead. + split_k = std::min(split_k, 8u); + + // ggml_vk_matmul will align the splits to be a multiple of 256. + // If this rounded up size would cause the last split to be empty, + // then reduce the split count. + while (true) { + if (split_k == 1) { + break; } + uint32_t k_split = CEIL_DIV(k, split_k); + k_split = ROUNDUP_POW2(k_split, 256); + if (k_split * (split_k - 1) < k) { + break; + } + split_k--; } } } @@ -4429,9 +5127,22 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")"); if (ctx->device->coopmat2) { + const uint32_t shader_core_count = ctx->device->shader_core_count; + const uint32_t tiles_l = CEIL_DIV(m, mmp->a_l->wg_denoms[0]) * CEIL_DIV(n, mmp->a_l->wg_denoms[1]); + const uint32_t tiles_m = CEIL_DIV(m, mmp->a_m->wg_denoms[0]) * CEIL_DIV(n, mmp->a_m->wg_denoms[1]); + // Use large shader when the N dimension is greater than the medium shader's tile size uint32_t crossover_large = mmp->m->wg_denoms[1]; - if ((ctx->device->mul_mat_l[src0_type] && (n > crossover_large)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) { + + // Prefer large over medium if either: + // - medium or large tiles would overfill the GPU + // - large tiles with a split_k==3 fits in the GPU and medium tiles with split_k==2 does not + // (medium with split_k==2 is probably better if it fits - more workgroups running and less split_k overhead) + bool prefer_large = tiles_m > shader_core_count || tiles_l > shader_core_count || + // split_k==3 with large tiles likely better than medium tiles with no split_k. + (tiles_l <= shader_core_count / 3 && tiles_m > shader_core_count / 2); + + if ((ctx->device->mul_mat_l[src0_type] && (n > crossover_large && prefer_large)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) { return aligned ? mmp->a_l : mmp->l; } // Use medium shader when the N dimension is greater than the small shader's tile size @@ -4449,6 +5160,8 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, return aligned ? mmp->a_m : mmp->m; } return aligned ? mmp->a_l : mmp->l; + + GGML_UNUSED(src1_type); } static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type, ggml_type src1_type) { @@ -4467,18 +5180,22 @@ static void ggml_vk_matmul( ggml_vk_sync_buffers(subctx); if (split_k == 1) { const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n }; - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, sizeof(vk_mat_mat_push_constants), &pc, { m, n, batch }); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, pc, { m, n, batch }); return; } GGML_ASSERT(batch_stride_d == m * n); - const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, CEIL_DIV(k, split_k), ne02, ne12, broadcast2, broadcast3, padded_n }; + // Round the split size up to a multiple of 256 (k-quant alignment) + uint32_t k_split = CEIL_DIV(k, split_k); + k_split = ROUNDUP_POW2(k_split, 256); + + const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k_split, ne02, ne12, broadcast2, broadcast3, padded_n }; // Make sure enough workgroups get assigned for split k to work - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, sizeof(vk_mat_mat_push_constants), &pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch }); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch }); ggml_vk_sync_buffers(subctx); const std::array pc2 = { (uint32_t)(m * n * batch), split_k }; - ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 }); + ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2, { m * n * batch, 1, 1 }); } static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type) { @@ -4526,14 +5243,14 @@ static void ggml_vk_matmul_id( ggml_vk_sync_buffers(subctx); const vk_mat_mat_id_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, nei0, nei1, nbi1, ne11, padded_n }; - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids }, sizeof(vk_mat_mat_id_push_constants), &pc, { m, nei1, n_as }); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids }, pc, { m, nei1, n_as }); } static bool ggml_vk_dim01_contiguous(const ggml_tensor * tensor) { return tensor->nb[0] == ggml_type_size(tensor->type) && tensor->nb[1] == (tensor->nb[0]*tensor->ne[0])/ggml_blck_size(tensor->type) && - tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; + (tensor->ne[3] == 1 || tensor->nb[3] == tensor->nb[2]*tensor->ne[2]); } static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src, const ggml_tensor * dst, ggml_type to) { @@ -4604,6 +5321,27 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const } } + if (src->type == to) { + // Copy two or four bytes at a time, depending on block size. + // For quantized types, we scale by block size/type size. But + // this path is also used for bf16->bf16 for example, where the + // type size must be exactly 2 or 4. + GGML_ASSERT(ggml_is_quantized(to) || ggml_type_size(src->type) == 2 || ggml_type_size(src->type) == 4); + if ((ggml_type_size(src->type) % 4) == 0) { + if (contig) { + return ctx->device->pipeline_contig_cpy_f32_f32; + } else { + return ctx->device->pipeline_cpy_f32_f32; + } + } else { + if (contig) { + return ctx->device->pipeline_contig_cpy_f16_f16; + } else { + return ctx->device->pipeline_cpy_f16_f16; + } + } + } + std::cerr << "Missing CPY op for types: " << ggml_type_name(src->type) << " " << ggml_type_name(to) << std::endl; GGML_ABORT("fatal error"); } @@ -4634,7 +5372,7 @@ static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context& }; init_pushconst_fastdiv(pc); ggml_vk_sync_buffers(subctx); - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(vk_op_unary_push_constants), &pc, elements); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, pc, elements); } static vk_pipeline ggml_vk_get_quantize_pipeline(ggml_backend_vk_context * ctx, ggml_type type) { @@ -4653,15 +5391,15 @@ static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& sub vk_pipeline pipeline = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1); ggml_vk_sync_buffers(subctx); - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(uint32_t), &ne, { ne, 1, 1 }); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, std::array{ne}, { ne, 1, 1 }); } static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { - VK_LOG_DEBUG("ggml_vk_mul_mat_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; - std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; - std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; + VK_LOG_DEBUG("ggml_vk_mul_mat_q_f16((" << src0 << ", name=" << src0->name << ", type=" << ggml_type_name(src0->type) << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; + std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << ggml_type_name(src1->type) << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; + std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << ggml_type_name(dst->type) << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; std::cerr << "), " << (dryrun ? "dryrun" : "") << ")"); - GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); // NOLINT + GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16); // NOLINT GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT const uint64_t ne00 = src0->ne[0]; @@ -4794,18 +5532,18 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub } // Request descriptor sets - ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); if (qx_needs_dequant) { - ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_0, 1); + ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1); } if (qy_needs_dequant) { - ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1); + ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1); } if (quantize_y) { - ggml_pipeline_request_descriptor_sets(ctx->device, to_q8_1, 1); + ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1); } if (split_k > 1) { - ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, 1); + ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, 1); } return; } @@ -4853,7 +5591,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub } else if (qx_needs_dequant) { const std::vector pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) }; ggml_vk_sync_buffers(subctx); - ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, vk_subbuffer{ d_X, 0, x_sz * ne02 * ne03 } }, pc.size() * sizeof(uint32_t), pc.data(), { (uint32_t)(x_ne * ne02 * ne03), 1, 1}); + ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, vk_subbuffer{ d_X, 0, x_sz * ne02 * ne03 } }, pc, { (uint32_t)(x_ne * ne02 * ne03), 1, 1}); } if (y_non_contig) { ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }); @@ -4889,7 +5627,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; std::cerr << "), " << (dryrun ? "dryrun" : "") << "),)"); - GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); // NOLINT + GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16); // NOLINT GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT const uint64_t ne00 = src0->ne[0]; @@ -4987,12 +5725,12 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& // Request descriptor sets if (qx_needs_dequant) { - ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_0, 1); + ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1); } if (qy_needs_dequant) { - ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1); + ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1); } - ggml_pipeline_request_descriptor_sets(ctx->device, dmmv, 1); + ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1); return; } @@ -5069,7 +5807,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, { vk_subbuffer{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, vk_subbuffer{ d_Y, y_buf_offset, y_sz * ne12 * ne13 }, vk_subbuffer{ d_D, d_buf_offset, d_sz * ne22 * ne23} }, - sizeof(vk_mat_vec_push_constants), &pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z }); + pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z }); } static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { @@ -5125,7 +5863,7 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c if (dryrun) { // Request descriptor sets - ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], 1); + ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], 1); return; } @@ -5157,7 +5895,7 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c } ggml_vk_sync_buffers(subctx); - ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 6 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, workgroups_z }); + ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, pc, { 1, (uint32_t)ne01, workgroups_z }); } static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { @@ -5174,7 +5912,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con const uint64_t ne00 = src0->ne[0]; const uint64_t ne01 = src0->ne[1]; const uint64_t ne02 = src0->ne[2]; - // const uint64_t ne03 = src0->ne[3]; + const uint64_t ne03 = src0->ne[3]; const uint64_t nb01 = src0->nb[1]; const uint64_t nb02 = src0->nb[2]; @@ -5186,7 +5924,12 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con const uint64_t ne12 = src1->ne[2]; // const uint64_t ne13 = src1->ne[3]; + const uint32_t nb03 = (uint32_t)(src0->nb[3] / sizeof(ggml_fp16_t)); + const uint32_t nb13 = (uint32_t)(src1->nb[3] / sizeof(float)); + const uint32_t nb23 = (uint32_t)(dst->nb[3] / sizeof(float)); + GGML_ASSERT(ne11 == 1); + GGML_ASSERT(src0->ne[3] == src1->ne[3]); // checked in supports_op ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; @@ -5202,7 +5945,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con src1_uma = d_Qy != nullptr; } - const uint64_t d_ne = ne01 * ne11 * ne12; + const uint64_t d_ne = ne01 * ne11 * ne12 * ne03; const uint32_t row_stride_x = nb01 / sizeof(ggml_fp16_t); const uint32_t channel_stride_x = nb02 / sizeof(ggml_fp16_t); @@ -5214,7 +5957,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con if (dryrun) { // Request descriptor sets - ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, 1); + ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, 1); return; } @@ -5237,10 +5980,10 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con const uint64_t d_shader_offset = d_buf_offset - d_buffer_offset; // compute - const std::array pc = { (uint32_t)ne00, (uint32_t)ne01, row_stride_x, channel_stride_x, channel_stride_y, (uint32_t)(ne12 / ne02), (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) }; + const std::array pc = { (uint32_t)ne00, (uint32_t)ne01, row_stride_x, channel_stride_x, channel_stride_y, (uint32_t)(ne12 / ne02), (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)), nb03, nb13, nb23 }; ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, - { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 7 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 }); + { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, pc, { (uint32_t)ne03, (uint32_t)ne01, (uint32_t)ne12 }); } static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { @@ -5401,12 +6144,12 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& } // Request descriptor sets - ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); if (qx_needs_dequant) { - ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_0, 1); + ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1); } if (qy_needs_dequant) { - ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1); + ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1); } return; } @@ -5456,7 +6199,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& const std::vector pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) }; ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0, - { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, vk_subbuffer{ d_X, 0, x_sz * ne02 * ne03 } }, pc.size() * sizeof(uint32_t), pc.data(), { (uint32_t)(x_ne * ne02 * ne03), 1, 1}); + { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, vk_subbuffer{ d_X, 0, x_sz * ne02 * ne03 } }, pc, { (uint32_t)(x_ne * ne02 * ne03), 1, 1}); } if (y_non_contig) { ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }); @@ -5490,7 +6233,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte std::cerr << "), (" << ids << ", name=" << ids->name << ", type=" << ids->type << ", ne0=" << ids->ne[0] << ", ne1=" << ids->ne[1] << ", ne2=" << ids->ne[2] << ", ne3=" << ids->ne[3] << ", nb0=" << ids->nb[0] << ", nb1=" << ids->nb[1] << ", nb2=" << ids->nb[2] << ", nb3=" << ids->nb[3]; std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; std::cerr << "), " << (dryrun ? "dryrun" : "") << ")"); - GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); // NOLINT + GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16); // NOLINT GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT GGML_ASSERT(ids->type == GGML_TYPE_I32); @@ -5595,12 +6338,12 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte // Request descriptor sets if (qx_needs_dequant) { - ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_0, 1); + ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1); } if (qy_needs_dequant) { - ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1); + ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1); } - ggml_pipeline_request_descriptor_sets(ctx->device, dmmv, 1); + ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1); return; } @@ -5676,7 +6419,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, { vk_subbuffer{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, vk_subbuffer{ d_Y, y_buf_offset, y_sz * ne12 * ne13 }, vk_subbuffer{ d_D, d_buf_offset, d_sz * ne22 * ne23}, vk_subbuffer{ d_ids, ids_buf_offset, ids_sz } }, - sizeof(vk_mat_vec_id_push_constants), &pc, { groups_x, (uint32_t)nei0, groups_z }); + pc, { groups_x, (uint32_t)nei0, groups_z }); } static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) { @@ -5684,15 +6427,94 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx if (src2->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) { ggml_vk_mul_mat_vec_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun); } else { - ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun); + // Split based on number of ids, to fit in shared memory + const uint32_t nei0 = (uint32_t)src2->ne[0]; + const uint32_t nei1 = (uint32_t)src2->ne[1]; + + GGML_ASSERT(nei0 <= 4096); + const uint32_t split_size = std::min(nei1, 4096u / nei0); + + ggml_tensor src1_copy = *src1; + ggml_tensor src2_copy = *src2; + ggml_tensor dst_copy = *dst; + + for (uint32_t token_start = 0; token_start < nei1; token_start += split_size) { + const uint32_t n_tokens = std::min(split_size, nei1 - token_start); + + src1_copy.view_offs = src1->view_offs + token_start * src1_copy.nb[2]; + src2_copy.view_offs = src2->view_offs + token_start * src2_copy.nb[1]; + dst_copy.view_offs = dst->view_offs + token_start * dst_copy.nb[2]; + + src1_copy.ne[2] = n_tokens; + src2_copy.ne[1] = n_tokens; + dst_copy.ne[2] = n_tokens; + + ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, &src1_copy, &src2_copy, &dst_copy, dryrun); + } } } -static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * q, const ggml_tensor * k, const ggml_tensor * v, const ggml_tensor * mask, ggml_tensor * dst, bool dryrun = false) { +static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv) { + // Needs to be kept up to date on shader changes + GGML_UNUSED(hsv); + const uint32_t wg_size = scalar_flash_attention_workgroup_size; + const uint32_t Br = get_fa_scalar_num_large_rows(hsv); + const uint32_t Bc = scalar_flash_attention_Bc; + + const uint32_t tmpsh = wg_size * sizeof(float); + const uint32_t tmpshv4 = wg_size * 4 * sizeof(float); + + const uint32_t masksh = Bc * Br * sizeof(float); + + const uint32_t Qf = Br * (hsk / 4 + 2) * 4 * sizeof(float); + + const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf; + const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; + + VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported); + + return supported; +} + +static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool f32acc) { + // Needs to be kept up to date on shader changes + GGML_UNUSED(hsv); + const uint32_t wg_size = scalar_flash_attention_workgroup_size; + const uint32_t Br = coopmat1_flash_attention_num_large_rows; + const uint32_t Bc = scalar_flash_attention_Bc; + + const uint32_t acctype = f32acc ? 4 : 2; + const uint32_t f16vec4 = 8; + + const uint32_t tmpsh = wg_size * sizeof(float); + const uint32_t tmpshv4 = wg_size * 4 * acctype; + + const uint32_t Qf = Br * (hsk / 4 + 2) * f16vec4; + + const uint32_t sfshstride = (hsk <= 128) ? (Br + 8) : Br; + const uint32_t sfsh = Bc * sfshstride * acctype; + + const uint32_t kshstride = hsk / 4 + 2; + const uint32_t ksh = Bc * kshstride * f16vec4; + + const uint32_t slope = Br * sizeof(float); + + const uint32_t total_size = tmpsh + tmpshv4 + Qf + sfsh + ksh + slope; + const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; + + VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported); + + return supported; +} + +static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * q, const ggml_tensor * k, const ggml_tensor * v, const ggml_tensor * mask, const ggml_tensor * sinks, ggml_tensor * dst, bool dryrun = false) { VK_LOG_DEBUG("ggml_vk_flash_attn((" << q << ", name=" << q->name << ", type=" << q->type << ", ne0=" << q->ne[0] << ", ne1=" << q->ne[1] << ", ne2=" << q->ne[2] << ", ne3=" << q->ne[3] << ", nb0=" << q->nb[0] << ", nb1=" << q->nb[1] << ", nb2=" << q->nb[2] << ", nb3=" << q->nb[3]; std::cerr << "), (" << k << ", name=" << k->name << ", type=" << k->type << ", ne0=" << k->ne[0] << ", ne1=" << k->ne[1] << ", ne2=" << k->ne[2] << ", ne3=" << k->ne[3] << ", nb0=" << k->nb[0] << ", nb1=" << k->nb[1] << ", nb2=" << k->nb[2] << ", nb3=" << k->nb[3]; std::cerr << "), (" << v << ", name=" << v->name << ", type=" << v->type << ", ne0=" << v->ne[0] << ", ne1=" << v->ne[1] << ", ne2=" << v->ne[2] << ", ne3=" << v->ne[3] << ", nb0=" << v->nb[0] << ", nb1=" << v->nb[1] << ", nb2=" << v->nb[2] << ", nb3=" << v->nb[3]; std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; + if (sinks) { + std::cerr << "), (" << sinks << ", name=" << sinks->name << ", type=" << sinks->type << ", ne0=" << sinks->ne[0] << ", ne1=" << sinks->ne[1] << ", ne2=" << sinks->ne[2] << ", ne3=" << sinks->ne[3] << ", nb0=" << sinks->nb[0] << ", nb1=" << sinks->nb[1] << ", nb2=" << sinks->nb[2] << ", nb3=" << sinks->nb[3]; + } std::cerr << "), " << (dryrun ? "dryrun" : "") << ")"); GGML_TENSOR_LOCALS(int64_t, neq, q, ne) @@ -5705,13 +6527,15 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx GGML_TENSOR_LOCALS(size_t, nb, dst, nb) const uint32_t nem1 = mask ? mask->ne[1] : 0; - const uint32_t nbm1 = mask ? mask->nb[1] : 0; + const uint32_t nem2 = mask ? mask->ne[2] : 0; + const uint32_t nem3 = mask ? mask->ne[3] : 0; - const uint32_t D = neq0; + const uint32_t HSK = nek0; + const uint32_t HSV = nev0; uint32_t N = neq1; const uint32_t KV = nek1; - GGML_ASSERT(ne0 == D); + GGML_ASSERT(ne0 == HSV); GGML_ASSERT(ne2 == N); // input tensor rows must be contiguous @@ -5719,12 +6543,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx GGML_ASSERT(nbk0 == ggml_type_size(k->type)); GGML_ASSERT(nbv0 == ggml_type_size(v->type)); - GGML_ASSERT(neq0 == D); - GGML_ASSERT(nek0 == D); - GGML_ASSERT(nev0 == D); + GGML_ASSERT(neq0 == HSK); GGML_ASSERT(neq1 == N); - GGML_ASSERT(nev0 == D); GGML_ASSERT(nev1 == nek1); @@ -5738,7 +6559,19 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx assert(q->type == GGML_TYPE_F32); assert(k->type == v->type); - bool scalar = !ctx->device->coopmat2; + FaCodePath path = ctx->device->coopmat2 ? FA_COOPMAT2 : + ctx->device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR; + + if (path == FA_COOPMAT1) { + const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) || + (dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc); + + const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, HSK, HSV, dst->op_params[3] == GGML_PREC_F32); + + if (!coopmat_shape_supported || !coopmat_shmem_supported) { + path = FA_SCALAR; + } + } uint32_t gqa_ratio = 1; uint32_t qk_ratio = neq2 / nek2; @@ -5746,12 +6579,24 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx uint32_t workgroups_y = (uint32_t)neq2; uint32_t workgroups_z = (uint32_t)neq3; - // For scalar FA, we can use the "large" size to accommodate qga. - // For coopmat FA, we always use the small size (which is still pretty large for gqa). - const uint32_t max_gqa = scalar ? scalar_flash_attention_num_large_rows : get_fa_num_small_rows(false); + // For scalar/coopmat1 FA, we can use the "large" size to accommodate qga. + // For coopmat2 FA, we always use the small size (which is still pretty large for gqa). + uint32_t max_gqa; + switch (path) { + case FA_SCALAR: + case FA_COOPMAT1: + // We may switch from coopmat1 to scalar, so use the scalar limit for both + max_gqa = get_fa_scalar_num_large_rows(HSV); + break; + case FA_COOPMAT2: + max_gqa = get_fa_num_small_rows(FA_COOPMAT2); + break; + default: + GGML_ASSERT(0); + } if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa && - qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) { + qk_ratio * nek2 == neq2 && nek2 == nev2 && nem2 <= 1) { // grouped query attention - make the N dimension equal to gqa_ratio, reduce // workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1 // and change addressing calculations to index Q's dimension 2. @@ -5761,34 +6606,41 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx } vk_pipeline *pipelines; - // XXX TODO other backends may be changing accumulator precision to default to f32 soon - bool f32acc = scalar || dst->op_params[3] == GGML_PREC_F32; - bool small_rows = N <= get_fa_num_small_rows(scalar); + bool small_rows = N <= get_fa_num_small_rows(path); - if (scalar) { - switch (D) { - case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break; - case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break; - case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96[k->type][f32acc][small_rows][0]; break; - case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112[k->type][f32acc][small_rows][0]; break; - case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128[k->type][f32acc][small_rows][0]; break; - case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256[k->type][f32acc][small_rows][0]; break; - default: - GGML_ASSERT(!"unsupported D value"); - return; - } - } else { - switch (D) { - case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm2[k->type][f32acc][small_rows][0]; break; - case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm2[k->type][f32acc][small_rows][0]; break; - case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm2[k->type][f32acc][small_rows][0]; break; - case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm2[k->type][f32acc][small_rows][0]; break; - case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm2[k->type][f32acc][small_rows][0]; break; - case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm2[k->type][f32acc][small_rows][0]; break; - default: - GGML_ASSERT(!"unsupported D value"); - return; - } + // coopmat1 does not actually support "small rows" (it needs 16 rows). + // So use scalar instead. + if (small_rows && path == FA_COOPMAT1) { + path = FA_SCALAR; + } + + // scalar is faster than coopmat2 when N==1 + if (N == 1 && path == FA_COOPMAT2) { + path = FA_SCALAR; + } + + // with large hsk/hsv, scalar path may need to use small_rows to fit in shared memory + if (path == FA_SCALAR && + !ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV)) { + small_rows = true; + } + + bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32; + + FaHeadSizes head_sizes = fa_get_head_sizes(k->ne[0], v->ne[0]); + + switch (path) { + case FA_SCALAR: + pipelines = &ctx->device->pipeline_flash_attn_f32_f16[k->type][head_sizes][f32acc][small_rows][0]; + break; + case FA_COOPMAT1: + pipelines = &ctx->device->pipeline_flash_attn_f32_f16_cm1[k->type][head_sizes][f32acc][small_rows][0]; + break; + case FA_COOPMAT2: + pipelines = &ctx->device->pipeline_flash_attn_f32_f16_cm2[k->type][head_sizes][f32acc][small_rows][0]; + break; + default: + GGML_ASSERT(0); } assert(pipelines); @@ -5813,21 +6665,21 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16; // Try to use split_k when KV is large enough to be worth the overhead - if (workgroups_x == 1 && shader_core_count > 0 && KV >= 512) { + if (workgroups_x == 1 && shader_core_count > 0) { // Try to run two workgroups per SM. - split_k = ctx->device->shader_core_count * 2 / workgroups_y; + split_k = shader_core_count * 2 / (workgroups_y * workgroups_z); if (split_k > 1) { // Try to evenly split KV into split_k chunks, but it needs to be a multiple // of "align", so recompute split_k based on that. - split_kv = ROUNDUP_POW2(KV / split_k, pipelines[1]->align); + split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), pipelines[1]->align); split_k = CEIL_DIV(KV, split_kv); workgroups_x = split_k; } } - // Reserve space for split_k temporaries. For each split, we need to store the O matrix (D x ne1) - // and the per-row m and L values (ne1 rows). - const uint64_t split_k_size = split_k > 1 ? (D * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k : 0; + // Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1) + // and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows. + const uint64_t split_k_size = split_k > 1 ? (HSV * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne3 : 0; if (split_k_size > ctx->device->max_memory_allocation_size) { GGML_ABORT("Requested preallocation size is too large"); } @@ -5837,9 +6689,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx if (dryrun) { // Request descriptor sets - ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); if (split_k > 1) { - ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_flash_attn_split_k_reduce, 1); + ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_flash_attn_split_k_reduce, 1); } return; } @@ -5861,10 +6713,10 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - vk_buffer d_Q = nullptr, d_K = nullptr, d_V = nullptr, d_D = nullptr, d_M = nullptr; - size_t q_buf_offset = 0, k_buf_offset = 0, v_buf_offset = 0, d_buf_offset = 0, m_buf_offset = 0; + vk_buffer d_Q = nullptr, d_K = nullptr, d_V = nullptr, d_D = nullptr, d_M = nullptr, d_S = nullptr; + size_t q_buf_offset = 0, k_buf_offset = 0, v_buf_offset = 0, d_buf_offset = 0, m_buf_offset = 0, s_buf_offset = 0; - bool Q_uma = false, K_uma = false, V_uma = false, D_uma = false, M_uma = false; + bool Q_uma = false, K_uma = false, V_uma = false, D_uma = false, M_uma = false, S_uma = false; if (ctx->device->uma) { ggml_vk_host_get(ctx->device, q->data, d_Q, q_buf_offset); @@ -5879,6 +6731,10 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx ggml_vk_host_get(ctx->device, mask->data, d_M, m_buf_offset); M_uma = d_M != nullptr; } + if (sinks) { + ggml_vk_host_get(ctx->device, sinks->data, d_S, s_buf_offset); + S_uma = d_S != nullptr; + } } @@ -5914,18 +6770,29 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx } } + if (!S_uma) { + d_S = d_Q; + s_buf_offset = q_buf_offset; + if (sinks) { + ggml_backend_vk_buffer_context * s_buf_ctx = (ggml_backend_vk_buffer_context*)sinks->buffer->context; + d_S = s_buf_ctx->dev_buffer; + s_buf_offset = vk_tensor_offset(sinks) + sinks->view_offs; + } + } + + uint32_t mask_n_head_log2 = ((sinks != nullptr) << 24) | ((mask != nullptr) << 16) | n_head_log2; + const vk_flash_attn_push_constants pc = { N, KV, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, (uint32_t)neq2, (uint32_t)neq3, (uint32_t)nek2, (uint32_t)nek3, (uint32_t)nev2, (uint32_t)nev3, - nem1, + nem1, nem2, nem3, q_stride, (uint32_t)nbq2, (uint32_t)nbq3, k_stride, (uint32_t)nbk2, (uint32_t)nbk3, v_stride, (uint32_t)nbv2, (uint32_t)nbv3, - nbm1, scale, max_bias, logit_softcap, - mask != nullptr, n_head_log2, m0, m1, + mask_n_head_log2, m0, m1, gqa_ratio, split_kv, split_k }; ggml_vk_sync_buffers(subctx); @@ -5937,22 +6804,24 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE}, vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE}, vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE}, + vk_subbuffer{d_S, s_buf_offset, VK_WHOLE_SIZE}, vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE}, }, // We only use split_k when group query attention is enabled, which means // there's no more than one tile of rows (i.e. workgroups_x would have been // one). We reuse workgroups_x to mean the number of splits, so we need to // cancel out the divide by wg_denoms[0]. - sizeof(vk_flash_attn_push_constants), &pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z }); + pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z }); ggml_vk_sync_buffers(subctx); - const std::array pc2 = { D, (uint32_t)ne1, split_k }; + const std::array pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k, (sinks != nullptr) }; ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce, { vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE}, + vk_subbuffer{d_S, s_buf_offset, VK_WHOLE_SIZE}, vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE}, }, - pc2.size() * uint32_t{sizeof(uint32_t)}, pc2.data(), { (uint32_t)ne1, 1, 1 }); + pc2, { (uint32_t)ne1, HSV, (uint32_t)ne3 }); } else { ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { @@ -5960,12 +6829,41 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE}, vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE}, vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE}, + vk_subbuffer{d_S, s_buf_offset, VK_WHOLE_SIZE}, vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE}, }, - sizeof(vk_flash_attn_push_constants), &pc, { workgroups_x, workgroups_y, workgroups_z }); + pc, { workgroups_x, workgroups_y, workgroups_z }); } } +static std::array ggml_vk_get_conv_elements(const ggml_tensor *dst) { + const ggml_tensor *src0 = dst->src[0]; + const ggml_tensor *src1 = dst->src[1]; + + // src0 - kernel: [KW, KH, Cin, Cout] + // src1 - input: [W, H, Cin, N] + // dst - result: [OW, OH, Cout, N] + + // Copied from ggml.c: int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d) + auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t { + return (ins + 2 * p - d * (ks - 1) - 1) / s + 1; + }; + // parallelize in {OW/BS_K, OH/BS_NPQ, 1} + int64_t W = src1->ne[0]; + int64_t H = src1->ne[1]; + int64_t KW = src0->ne[0]; + int64_t KH = src0->ne[1]; + int64_t Cout = src0->ne[3]; + int64_t N = src1->ne[3]; + int64_t OH = calc_conv_output_size(H, KH, dst->op_params[1], dst->op_params[3], dst->op_params[5]); + int64_t OW = calc_conv_output_size(W, KW, dst->op_params[0], dst->op_params[2], dst->op_params[4]); + int64_t NPQ = N * OW * OH; + + // Tile output matrix to (K/NB_K, NPQ/NB_NPQ, 1) workgroups + std::array elements = { static_cast(Cout), static_cast(NPQ), 1 }; + return elements; +} + static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) { switch (op) { case GGML_OP_GET_ROWS: @@ -6016,6 +6914,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const break; } return nullptr; + case GGML_OP_ADD_ID: + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && src2->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_add_id_f32; + } + return nullptr; case GGML_OP_CONCAT: if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_concat_f32; @@ -6028,8 +6931,16 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const } return nullptr; case GGML_OP_UPSCALE: - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && dst->op_params[0] == GGML_SCALE_MODE_NEAREST) { - return ctx->device->pipeline_upscale_f32; + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + int mode = ggml_get_op_params_i32(dst, 0); + switch (mode) { + case GGML_SCALE_MODE_NEAREST: + return ctx->device->pipeline_upscale_nearest_f32; + case GGML_SCALE_MODE_BILINEAR: + return ctx->device->pipeline_upscale_bilinear_f32; + case GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS: + return ctx->device->pipeline_upscale_bilinear_ac_f32; + } } return nullptr; case GGML_OP_SCALE: @@ -6062,6 +6973,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_pad_f32; } return nullptr; + case GGML_OP_ROLL: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_roll_f32; + } + return nullptr; case GGML_OP_REPEAT: if (ggml_type_size(src0->type) == sizeof(float) && ggml_type_size(dst->type) == sizeof(float)) { return ctx->device->pipeline_repeat_f32; @@ -6076,6 +6992,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const case GGML_OP_CONT: case GGML_OP_DUP: return ggml_vk_get_cpy_pipeline(ctx, src0, dst, dst->type); + case GGML_OP_SET_ROWS: + return ctx->device->pipeline_set_rows[dst->type]; case GGML_OP_SILU_BACK: if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_silu_back_f32; @@ -6093,7 +7011,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return nullptr; case GGML_OP_RMS_NORM: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_rms_norm_f32; + return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_f32 : ctx->device->pipeline_rms_norm_f32; } return nullptr; case GGML_OP_RMS_NORM_BACK: @@ -6118,6 +7036,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_silu[dst->type == GGML_TYPE_F16]; case GGML_UNARY_OP_GELU: return ctx->device->pipeline_gelu[dst->type == GGML_TYPE_F16]; + case GGML_UNARY_OP_GELU_ERF: + return ctx->device->pipeline_gelu_erf[dst->type == GGML_TYPE_F16]; case GGML_UNARY_OP_GELU_QUICK: return ctx->device->pipeline_gelu_quick[dst->type == GGML_TYPE_F16]; case GGML_UNARY_OP_RELU: @@ -6130,6 +7050,30 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const break; } return nullptr; + case GGML_OP_GLU: + if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) || + (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) || + (src0->type != dst->type)) { + return nullptr; + } + + switch (ggml_get_glu_op(dst)) { + case GGML_GLU_OP_GEGLU: + return ctx->device->pipeline_geglu[dst->type == GGML_TYPE_F16]; + case GGML_GLU_OP_REGLU: + return ctx->device->pipeline_reglu[dst->type == GGML_TYPE_F16]; + case GGML_GLU_OP_SWIGLU: + return ctx->device->pipeline_swiglu[dst->type == GGML_TYPE_F16]; + case GGML_GLU_OP_SWIGLU_OAI: + return ctx->device->pipeline_swiglu_oai[dst->type == GGML_TYPE_F16]; + case GGML_GLU_OP_GEGLU_ERF: + return ctx->device->pipeline_geglu_erf[dst->type == GGML_TYPE_F16]; + case GGML_GLU_OP_GEGLU_QUICK: + return ctx->device->pipeline_geglu_quick[dst->type == GGML_TYPE_F16]; + default: + break; + } + return nullptr; case GGML_OP_DIAG_MASK_INF: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_diag_mask_inf_f32; @@ -6137,6 +7081,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return nullptr; case GGML_OP_SOFT_MAX: GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); + GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32); if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) { return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_wg512 : ctx->device->pipeline_soft_max_f32; @@ -6223,6 +7168,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_timestep_embedding_f32; } return nullptr; + case GGML_OP_CONV_TRANSPOSE_1D: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_conv_transpose_1d_f32; + } + return nullptr; case GGML_OP_POOL_2D: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_pool2d_f32; @@ -6248,6 +7198,36 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_leaky_relu_f32; } return nullptr; + case GGML_OP_CONV_2D: + if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && + ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) { + auto elements = ggml_vk_get_conv_elements(dst); + vk_conv_shapes shape; + + uint32_t tiles[CONV_SHAPE_COUNT]; + for (uint32_t i = 0; i < CONV_SHAPE_COUNT; ++i) { + tiles[i] = CEIL_DIV(elements[0], ctx->device->pipeline_conv2d_f32[i]->wg_denoms[0]) * CEIL_DIV(elements[1], ctx->device->pipeline_conv2d_f32[i]->wg_denoms[1]); + } + + // We can't query number of shader cores on Intel, use 32 as a placeholder + // so small convolutions will still choose a smaller tile. + const uint32_t shader_core_count = ctx->device->shader_core_count > 0 ? ctx->device->shader_core_count : 32; + + if (elements[0] > 64 && tiles[CONV_SHAPE_128x128] >= shader_core_count * 2) { + shape = CONV_SHAPE_128x128; + } else if (elements[0] <= 32 && tiles[CONV_SHAPE_32x256] >= shader_core_count * 2) { + shape = CONV_SHAPE_32x256; + } else { + shape = CONV_SHAPE_64x32; + } + + if (src0->type == GGML_TYPE_F32) { + return ctx->device->pipeline_conv2d_f32[shape]; + } else if (src0->type == GGML_TYPE_F16) { + return ctx->device->pipeline_conv2d_f16_f32[shape]; + } + } + return nullptr; case GGML_OP_CONV_2D_DW: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { if (ggml_is_contiguous(src1)) { @@ -6272,6 +7252,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) { case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: + case GGML_OP_ADD_ID: case GGML_OP_CONCAT: case GGML_OP_UPSCALE: case GGML_OP_SQR: @@ -6284,6 +7265,8 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) { case GGML_OP_ROPE: case GGML_OP_RMS_NORM: case GGML_OP_CONV_2D_DW: + case GGML_OP_IM2COL: + case GGML_OP_SET_ROWS: return true; default: return false; @@ -6396,7 +7379,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co } if (dryrun) { - ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); return; } @@ -6556,6 +7539,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co uint32_t half_ceil = (dim + 1) / 2; elements = { half_ceil, (uint32_t)src0->ne[0], 1 }; } break; + case GGML_OP_CONV_TRANSPOSE_1D: + { + elements = {uint32_t(src0->ne[1]), 1, 1}; // parallelize in {Cout, 1, 1} + } break; case GGML_OP_POOL_2D: { const uint32_t N = dst->ne[3]; @@ -6564,6 +7551,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co const uint32_t OW = dst->ne[0]; elements = { N * OC * OH * OW, 1, 1}; } break; + case GGML_OP_CONV_2D: + { + elements = ggml_vk_get_conv_elements(dst); + } break; case GGML_OP_ADD: case GGML_OP_SUB: case GGML_OP_DIV: @@ -6574,15 +7565,32 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co case GGML_OP_COS: case GGML_OP_CLAMP: case GGML_OP_PAD: + case GGML_OP_ROLL: case GGML_OP_REPEAT: case GGML_OP_REPEAT_BACK: case GGML_OP_CPY: case GGML_OP_CONCAT: case GGML_OP_UPSCALE: case GGML_OP_UNARY: + case GGML_OP_GLU: case GGML_OP_CONV_2D_DW: { - const uint32_t ne = ggml_nelements(dst); + uint32_t ne = ggml_nelements(dst); + if (op == GGML_OP_CPY && ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) { + // Convert from number of logical elements to 2- or 4-byte units. + ne /= ggml_blck_size(src0->type); + if ((ggml_type_size(src0->type) % 4) == 0) { + ne *= ggml_type_size(src0->type) / 4; + } else { + ne *= ggml_type_size(src0->type) / 2; + } + } + // copy_to_quant has block size of 32, and each thread does QUANT_K elements. + // Splitting into 512x512xZ wouldn't work well since each workgroup does 1024 elements. + // So divide by block size here before splitting into 512x512 groups. + if (op == GGML_OP_CPY && !ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) { + ne = CEIL_DIV(ne, ggml_blck_size(dst->type)); + } if (ne > 262144) { elements = { 512, 512, CEIL_DIV(ne, 262144) }; } else if (ne > 512) { @@ -6591,6 +7599,29 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co elements = { ne, 1, 1 }; } } break; + case GGML_OP_ADD_ID: + { + elements = { (uint32_t)ne01, (uint32_t)ne02, 1 }; + } break; + case GGML_OP_SET_ROWS: + { + uint32_t ne = ggml_nelements(src0); + if (ggml_is_quantized(dst->type)) { + // quants run 32 threads each doing QUANT_K elements + ne = CEIL_DIV(ne, 32 * ggml_blck_size(dst->type)); + } else { + // scalar types do one element per thread, running 512 threads + ne = CEIL_DIV(ne, 512); + } + if (ne > 262144) { + elements = { 512, 512, CEIL_DIV(ne, 262144) }; + } else if (ne > 512) { + elements = { 512, CEIL_DIV(ne, 512), 1 }; + } else { + elements = { ne, 1, 1 }; + } + } + break; default: elements = { (uint32_t)ggml_nelements(src0), 1, 1 }; break; @@ -6611,8 +7642,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co } } - if (op == GGML_OP_SOFT_MAX) { - // Empty src1 is possible in soft_max, but the shader needs a buffer + if (op == GGML_OP_GLU) { + // Empty src1 is possible in glu, but the shader needs a buffer vk_subbuffer subbuf_y; if (use_src1) { subbuf_y = { d_Y, y_buf_offset, y_sz }; @@ -6621,7 +7652,25 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co } ggml_vk_sync_buffers(subctx); - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); + } else if (op == GGML_OP_SOFT_MAX) { + // Empty src1 and src2 is possible in soft_max, but the shader needs a buffer + vk_subbuffer subbuf_y; + if (use_src1) { + subbuf_y = { d_Y, y_buf_offset, y_sz }; + } else { + subbuf_y = { d_X, 0, x_sz }; + } + + vk_subbuffer subbuf_z; + if (use_src2) { + subbuf_z = { d_Z, z_buf_offset, z_sz }; + } else { + subbuf_z = { d_X, 0, x_sz }; + } + + ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); } else if (op == GGML_OP_ROPE || op == GGML_OP_ROPE_BACK) { // Empty src2 is possible in rope, but the shader needs a buffer vk_subbuffer subbuf_z; @@ -6632,26 +7681,26 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co } ggml_vk_sync_buffers(subctx); - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); } else if (op == GGML_OP_IM2COL) { // im2col uses only src1 and dst buffers ggml_vk_sync_buffers(subctx); - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); } else if (op == GGML_OP_COUNT_EQUAL) { ggml_vk_sync_buffers(subctx); // count_equal assumes that destination buffer is initialized with zeroes ggml_vk_buffer_memset_async(subctx, d_D, d_buf_offset, 0, d_sz); ggml_vk_sync_buffers(subctx); - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); } else if (use_src2) { ggml_vk_sync_buffers(subctx); - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); } else if (use_src1) { ggml_vk_sync_buffers(subctx); - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); } else { ggml_vk_sync_buffers(subctx); - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); } } @@ -6750,6 +7799,21 @@ static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const }, dryrun); } +static void ggml_vk_add_id(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t src2_type_size = ggml_type_size(src2->type); + + ggml_vk_op_f32(ctx, subctx, src0, src1, src2, dst, GGML_OP_ADD_ID, { + (uint32_t)dst->ne[0], + (uint32_t)dst->ne[1], + (uint32_t)src0->nb[1] / src0_type_size, + (uint32_t)src0->nb[2] / src0_type_size, + (uint32_t)src1->nb[1] / src1_type_size, + (uint32_t)src2->nb[1] / src2_type_size, + }, dryrun); +} + static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, int version, bool dryrun = false) { GGML_ASSERT(version == 6 || version == 7); int num_srcs = version == 6 ? 6 : 7; @@ -6764,7 +7828,7 @@ static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx GGML_ASSERT(pipeline != nullptr); if (dryrun) { - ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); return; } @@ -6820,7 +7884,7 @@ static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] }, vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] }, vk_subbuffer{ d_D, dst_offset, dst_size } - }, sizeof(vk_op_rwkv_wkv6_push_constants), &pc, elements); + }, pc, elements); } else if (version == 7) { ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] }, @@ -6831,7 +7895,7 @@ static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] }, vk_subbuffer{ d_srcs[6], src_offsets[6], src_sizes[6] }, vk_subbuffer{ d_D, dst_offset, dst_size } - }, sizeof(vk_op_rwkv_wkv7_push_constants), &pc, elements); + }, pc, elements); } else { // shouldn't happen GGML_ASSERT(false); @@ -6903,7 +7967,7 @@ static void ggml_vk_op_f32_opt_step_adamw(ggml_backend_vk_context * ctx, vk_cont GGML_ASSERT(pipeline != nullptr); if (dryrun) { - ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); return; } @@ -6968,7 +8032,7 @@ static void ggml_vk_op_f32_opt_step_adamw(ggml_backend_vk_context * ctx, vk_cont vk_subbuffer{ d_GM, gm_offset, gm_size }, vk_subbuffer{ d_GV, gv_offset, gv_size }, vk_subbuffer{ d_P, p_offset, p_size }, - }, sizeof(vk_op_push_constants), &pc, elements); + }, pc, elements); } static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) { @@ -7000,14 +8064,21 @@ static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, co static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t mode = (uint32_t)ggml_get_op_params_i32(dst, 0); - const float sf0 = (float)dst->ne[0] / src0->ne[0]; - const float sf1 = (float)dst->ne[1] / src0->ne[1]; - const float sf2 = (float)dst->ne[2] / src0->ne[2]; - const float sf3 = (float)dst->ne[3] / src0->ne[3]; + float sf0 = (float)dst->ne[0] / src0->ne[0]; + float sf1 = (float)dst->ne[1] / src0->ne[1]; + float sf2 = (float)dst->ne[2] / src0->ne[2]; + float sf3 = (float)dst->ne[3] / src0->ne[3]; + + if (mode & GGML_SCALE_FLAG_ALIGN_CORNERS) { + sf0 = (float)(dst->ne[0] - 1) / (src0->ne[0] - 1); + sf1 = (float)(dst->ne[1] - 1) / (src0->ne[1] - 1); + } ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UPSCALE, { (uint32_t)ggml_nelements(dst), 0, 0, + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, (uint32_t)dst->ne[0], (uint32_t)dst->ne[1], (uint32_t)dst->ne[2],(uint32_t)dst->ne[3], sf0, sf1, sf2, sf3, @@ -7015,130 +8086,98 @@ static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, c } static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { - float * op_params = (float *)dst->op_params; - const uint32_t src0_type_size = ggml_type_size(src0->type); - const uint32_t dst_type_size = ggml_type_size(dst->type); + vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst); + p.param1 = ggml_get_op_params_f32(dst, 0); + p.param2 = ggml_get_op_params_f32(dst, 1); - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SCALE, { - (uint32_t)ggml_nelements(src0), - (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, - (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, - 0, - op_params[0], 0.0f, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - }, dryrun); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SCALE, std::move(p), dryrun); } static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { - const uint32_t src0_type_size = ggml_type_size(src0->type); - const uint32_t dst_type_size = ggml_type_size(dst->type); - - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQR, { - (uint32_t)ggml_nelements(src0), - (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, - (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, - 0, - 0.0f, 0.0f, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - }, dryrun); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQR, vk_op_unary_push_constants_init(src0, dst), dryrun); } static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { - const uint32_t src0_type_size = ggml_type_size(src0->type); - const uint32_t dst_type_size = ggml_type_size(dst->type); - - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SIN, { - (uint32_t)ggml_nelements(src0), - (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, - (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, - 0, - 0.0f, 0.0f, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - }, dryrun); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SIN, vk_op_unary_push_constants_init(src0, dst), dryrun); } static void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { - const uint32_t src0_type_size = ggml_type_size(src0->type); - const uint32_t dst_type_size = ggml_type_size(dst->type); - - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_COS, { - (uint32_t)ggml_nelements(src0), - (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, - (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, - 0, - 0.0f, 0.0f, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - }, dryrun); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_COS, vk_op_unary_push_constants_init(src0, dst), dryrun); } static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { - float * op_params = (float *)dst->op_params; - const uint32_t src0_type_size = ggml_type_size(src0->type); - const uint32_t dst_type_size = ggml_type_size(dst->type); + vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst); + p.param1 = ggml_get_op_params_f32(dst, 0); + p.param2 = ggml_get_op_params_f32(dst, 1); - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CLAMP, { - (uint32_t)ggml_nelements(src0), - (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, - (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, - 0, - op_params[0], op_params[1], - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - }, dryrun); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CLAMP, std::move(p), dryrun); } static void ggml_vk_pad(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { - const uint32_t src0_type_size = ggml_type_size(src0->type); - const uint32_t dst_type_size = ggml_type_size(dst->type); + vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst)); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_PAD, std::move(p), dryrun); +} - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_PAD, { - (uint32_t)ggml_nelements(dst), - (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, - (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, - 0, - 0.0f, 0.0f, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - }, dryrun); +static void ggml_vk_roll(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + const int32_t s0 = ggml_get_op_params_i32(dst, 0); + const int32_t s1 = ggml_get_op_params_i32(dst, 1); + const int32_t s2 = ggml_get_op_params_i32(dst, 2); + const int32_t s3 = ggml_get_op_params_i32(dst, 3); + const uint32_t s01_packed = ((s0 + 0x8000) << 16) | (s1 + 0x8000); + const uint32_t s23_packed = ((s2 + 0x8000) << 16) | (s3 + 0x8000); + + vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst); + memcpy(&p.param1, &s01_packed, sizeof(float)); + memcpy(&p.param2, &s23_packed, sizeof(float)); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ROLL, std::move(p), dryrun); } static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { - const uint32_t src0_type_size = ggml_type_size(src0->type); - const uint32_t dst_type_size = ggml_type_size(dst->type); - - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT, { - (uint32_t)ggml_nelements(dst), - (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, - (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, - 0, - 0.0f, 0.0f, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - }, dryrun); + vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst)); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT, std::move(p), dryrun); } static void ggml_vk_repeat_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { - const uint32_t src0_type_size = ggml_type_size(src0->type); - const uint32_t dst_type_size = ggml_type_size(dst->type); - - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT_BACK, { - (uint32_t)ggml_nelements(dst), - (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, - (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, - 0, - 0.0f, 0.0f, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - }, dryrun); + vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst)); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT_BACK, std::move(p), dryrun); } static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + uint32_t ne = (uint32_t)ggml_nelements(src0); + if (ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) { + // Convert from number of logical elements to 2- or 4-byte units. + ne /= ggml_blck_size(src0->type); + if ((ggml_type_size(src0->type) % 4) == 0) { + ne *= ggml_type_size(src0->type) / 4; + } else { + ne *= ggml_type_size(src0->type) / 2; + } + } + + vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ne); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, std::move(p), dryrun); +} + +static void ggml_vk_set_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t src1_type_size = ggml_type_size(src1->type); const uint32_t dst_type_size = ggml_type_size(dst->type); - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, { + // Skip empty skip_rows operations. For most ops the empty check at the start + // of ggml_vk_build_graph is sufficient, but set_rows can have a nonempty dst + // with empty srcs. + if (ggml_is_empty(src0) || ggml_is_empty(src1)) { + return; + } + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SET_ROWS, { (uint32_t)ggml_nelements(src0), - (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, - (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, 0, - 0.0f, 0.0f, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.0f, 0.0f, 0, }, dryrun); } @@ -7163,18 +8202,18 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun); } -static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { - float * op_params = (float *)dst->op_params; +static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, float * op_params, bool dryrun = false) { const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t src1_type_size = ggml_type_size(src1->type); const uint32_t dst_type_size = ggml_type_size(dst->type); - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)ggml_nelements(src0), - (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, - (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, 0, - op_params[0], 0.0f, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + op_params[0], 0.0f, 0, }, dryrun); } @@ -7192,12 +8231,43 @@ static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, con ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun); } +static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + const float * op_params_f = (const float *)dst->op_params; + + const bool swapped = (bool)dst->op_params[1]; + const bool split = src1 != nullptr; + const float alpha = op_params_f[2]; + const float limit = op_params_f[3]; + + GGML_ASSERT(ggml_is_contiguous(src0)); + + if (!split) { + GGML_ASSERT(src0->ne[0] / 2 == dst->ne[0]); + } else { + GGML_ASSERT(src0->ne[0] == src1->ne[0]); + GGML_ASSERT(src0->ne[0] == dst->ne[0]); + GGML_ASSERT(src0->type == src1->type); + } + + const uint32_t mode = split ? 2 : (swapped ? 1 : 0); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GLU, + { + (uint32_t)ggml_nelements(dst), + (uint32_t)src0->ne[0], + (uint32_t)dst->ne[0], + mode, + alpha, + limit + }, dryrun); +} + static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { int32_t * op_params = (int32_t *)dst->op_params; ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }, dryrun); } -static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { +static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) { float * op_params = (float *)dst->op_params; float scale = op_params[0]; @@ -7207,19 +8277,29 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const uint32_t nrows_x = (uint32_t)ggml_nrows(src0); const uint32_t nrows_y = (uint32_t)src0->ne[1]; - const uint32_t n_head_kv = nrows_x/nrows_y; + const uint32_t ne12 = src1 ? (uint32_t)(src1->ne[2]) : 0u; + const uint32_t ne13 = src1 ? (uint32_t)(src1->ne[3]) : 0u; + const uint32_t nb11 = src1 ? (uint32_t)(src1->nb[1] / src1->nb[0]) : 0u; + const uint32_t nb12 = src1 ? (uint32_t)(src1->nb[2] / src1->nb[0]) : 0u; + const uint32_t nb13 = src1 ? (uint32_t)(src1->nb[3] / src1->nb[0]) : 0u; + + const uint32_t n_head_kv = src0->ne[2]; const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv)); const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, { + ggml_vk_op_f32(ctx, subctx, src0, src1, src2, dst, GGML_OP_SOFT_MAX, { ncols, src1 != nullptr ? nrows_y : (uint32_t)0, + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], + ne12, ne13, + nb11, nb12, nb13, scale, max_bias, m0, m1, n_head_log2, nrows_x, + src2 != nullptr }, dryrun); } @@ -7339,6 +8419,37 @@ static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context }, dryrun); } +static void ggml_vk_conv_transpose_1d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + // src0: (K, Cout, Cin, 1) -- kernel + // src1: (L, Cin, 1, 1) -- input + // dst: (*, Cout, 1, 1) + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT(nb00 == sizeof(float)); + GGML_ASSERT(nb10 == sizeof(float)); + + const int32_t s0 = dst->op_params[0]; + + vk_op_conv_transpose_1d_push_constants p{}; + p.Cout = static_cast(ne01); + p.Cin = static_cast(ne02); + p.K = static_cast(ne00); + p.L = static_cast(ne10); + p.KL = static_cast(ne0); + p.nb01 = static_cast(nb01 / nb00); + p.nb02 = static_cast(nb02 / nb00); + p.nb11 = static_cast(nb11 / nb10); + p.nb1 = static_cast(nb1 / nb0); + p.s0 = static_cast(s0); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_TRANSPOSE_1D, std::move(p), dryrun); +} + static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { uint32_t op = static_cast(dst->op_params[0]); const int32_t k1 = dst->op_params[1]; @@ -7367,6 +8478,55 @@ static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, c }, dryrun); } +static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context & subctx, const ggml_tensor * src0, + const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT(nb00 == sizeof(float) || nb00 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb10 == sizeof(float)); + GGML_ASSERT(nb0 == sizeof(float)); + + vk_op_conv2d_push_constants p{}; + p.Cout = static_cast(ne03); + p.Cin = static_cast(ne02); + p.N = static_cast(ne13); + + p.KW = static_cast(ne00); + p.KH = static_cast(ne01); + p.W = static_cast(ne10); + p.H = static_cast(ne11); + p.OW = static_cast(ne0); + p.OH = static_cast(ne1); + + p.s0 = static_cast(dst->op_params[0]); + p.s1 = static_cast(dst->op_params[1]); + p.p0 = static_cast(dst->op_params[2]); + p.p1 = static_cast(dst->op_params[3]); + p.d0 = static_cast(dst->op_params[4]); + p.d1 = static_cast(dst->op_params[5]); + + p.nb01 = static_cast(nb01 / nb00); + p.nb02 = static_cast(nb02 / nb00); + p.nb03 = static_cast(nb03 / nb00); + + p.nb11 = static_cast(nb11 / nb10); + p.nb12 = static_cast(nb12 / nb10); + p.nb13 = static_cast(nb13 / nb10); + + p.nb1 = static_cast(nb1 / nb0); + p.nb2 = static_cast(nb2 / nb0); + p.nb3 = static_cast(nb3 / nb0); + + GGML_ASSERT(ne03 == ne2); + GGML_ASSERT(ne02 == ne12); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_2D, std::move(p), dryrun); +} + static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { vk_op_conv2d_dw_push_constants p{}; p.ne = ggml_nelements(dst); @@ -7539,9 +8699,9 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t } } - ggml_pipeline_request_descriptor_sets(ctx->device, p, num_it); + ggml_pipeline_request_descriptor_sets(ctx, p, num_it); if (split_k > 1) { - ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, num_it); + ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it); if (ctx->prealloc_split_k == nullptr || ctx->prealloc_split_k->size < sizeof(float) * d_ne * split_k) { // Resize buffer @@ -7556,7 +8716,7 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t ggml_vk_load_shaders(ctx->device); } - ggml_pipeline_allocate_descriptor_sets(ctx->device); + ggml_pipeline_allocate_descriptor_sets(ctx); vk_buffer d_X = ggml_vk_create_buffer_check(ctx->device, sizeof(X_TYPE) * x_ne, vk::MemoryPropertyFlagBits::eDeviceLocal); vk_buffer d_Y = ggml_vk_create_buffer_check(ctx->device, sizeof(Y_TYPE) * y_ne, vk::MemoryPropertyFlagBits::eDeviceLocal); @@ -7598,7 +8758,7 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t ggml_vk_buffer_write(d_X, 0, x, sizeof(X_TYPE) * k * m * batch); ggml_vk_buffer_write(d_Y, 0, y, sizeof(Y_TYPE) * k * n * batch); - vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue); + vk_context subctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); ggml_vk_ctx_begin(ctx->device, subctx); for (size_t i = 0; i < num_it; i++) { ggml_vk_matmul( @@ -7614,6 +8774,7 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t ggml_vk_submit(subctx, ctx->fence); VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_matmul waitForFences"); ctx->device->device.resetFences({ ctx->fence }); + ggml_vk_queue_command_pools_cleanup(ctx->device); auto end = std::chrono::high_resolution_clock::now(); double time = std::chrono::duration_cast(end-begin).count() / 1000.0; @@ -7715,16 +8876,13 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t free(d_chk); - ggml_vk_queue_cleanup(ctx->device, ctx->device->transfer_queue); - ggml_vk_queue_cleanup(ctx->device, ctx->device->compute_queue); + ggml_vk_command_pool_cleanup(ctx->device, ctx->compute_cmd_pool); + ggml_vk_command_pool_cleanup(ctx->device, ctx->transfer_cmd_pool); ggml_vk_destroy_buffer(d_X); ggml_vk_destroy_buffer(d_Y); ggml_vk_destroy_buffer(d_D); - ggml_pipeline_cleanup(p); - ggml_pipeline_cleanup(ctx->device->pipeline_matmul_split_k_reduce); - free(x); free(y); free(d); @@ -7802,20 +8960,20 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_ ggml_vk_quantize_data(x, qx, ne, quant); ggml_vk_dequantize_data(qx, x_ref, ne, quant); - ggml_pipeline_request_descriptor_sets(ctx->device, p, 1); + ggml_pipeline_request_descriptor_sets(ctx, p, 1); if (ctx->device->need_compiles) { ggml_vk_load_shaders(ctx->device); } - ggml_pipeline_allocate_descriptor_sets(ctx->device); + ggml_pipeline_allocate_descriptor_sets(ctx); ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz); - vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue); + vk_context subctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); ggml_vk_ctx_begin(ctx->device, subctx); const std::vector pc = { 1, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne }; - ggml_vk_dispatch_pipeline(ctx, subctx, p, { vk_subbuffer{ qx_buf, 0, qx_sz }, vk_subbuffer{ x_buf, 0, x_sz_f16 } }, pc.size() * sizeof(int), pc.data(), { (uint32_t)ne, 1, 1}); + ggml_vk_dispatch_pipeline(ctx, subctx, p, { vk_subbuffer{ qx_buf, 0, qx_sz }, vk_subbuffer{ x_buf, 0, x_sz_f16 } }, pc, { (uint32_t)ne, 1, 1}); ggml_vk_ctx_end(subctx); auto begin = std::chrono::high_resolution_clock::now(); @@ -7823,6 +8981,7 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_ ggml_vk_submit(subctx, ctx->fence); VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_dequant waitForFences"); ctx->device->device.resetFences({ ctx->fence }); + ggml_vk_queue_command_pools_cleanup(ctx->device); auto end = std::chrono::high_resolution_clock::now(); @@ -7902,17 +9061,17 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_ // // vk_pipeline p = ggml_vk_get_quantize_pipeline(ctx, quant); // -// ggml_pipeline_request_descriptor_sets(ctx->device, p, 1); +// ggml_pipeline_request_descriptor_sets(ctx, p, 1); // // if (ctx->device->need_compiles) { // ggml_vk_load_shaders(ctx->device); // } // -// ggml_pipeline_allocate_descriptor_sets(ctx->device); +// ggml_pipeline_allocate_descriptor_sets(ctx); // // ggml_vk_buffer_write(x_buf, 0, x, x_sz); // -// vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue); +// vk_context subctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); // ggml_vk_ctx_begin(ctx->device, subctx); // ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(x_buf), ggml_vk_subbuffer(qx_buf), ne); // ggml_vk_ctx_end(subctx); @@ -7922,6 +9081,7 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_ // ggml_vk_submit(subctx, ctx->fence); // VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_quantize waitForFences"); // ctx->device->device.resetFences({ ctx->fence }); +// ggml_vk_queue_command_pools_cleanup(ctx->device); // // auto end = std::chrono::high_resolution_clock::now(); // @@ -8061,9 +9221,9 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, // y[i] = i % k; } - ggml_pipeline_request_descriptor_sets(ctx->device, p, num_it); + ggml_pipeline_request_descriptor_sets(ctx, p, num_it); if (split_k > 1) { - ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, num_it); + ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it); if (ctx->prealloc_split_k == nullptr || ctx->prealloc_split_k->size < sizeof(float) * d_ne * split_k) { // Resize buffer @@ -8074,19 +9234,19 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, } } if (mmq) { - ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_quantize_q8_1, num_it); + ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_quantize_q8_1, num_it); } if (ctx->device->need_compiles) { ggml_vk_load_shaders(ctx->device); } - ggml_pipeline_allocate_descriptor_sets(ctx->device); + ggml_pipeline_allocate_descriptor_sets(ctx); ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz); ggml_vk_buffer_write(y_buf, 0, y, y_sz); - vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue); + vk_context subctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); ggml_vk_ctx_begin(ctx->device, subctx); if (mmq) { for (size_t i = 0; i < num_it; i++) { @@ -8115,6 +9275,7 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, ggml_vk_submit(subctx, ctx->fence); VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_dequant waitForFences"); ctx->device->device.resetFences({ ctx->fence }); + ggml_vk_queue_command_pools_cleanup(ctx->device); auto end = std::chrono::high_resolution_clock::now(); @@ -8333,11 +9494,12 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) { } } -static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready); +static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_cgraph * cgraph, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready); // Returns true if node has enqueued work into the queue, false otherwise // If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution. -static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * node, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool almost_ready, bool submit){ +static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool almost_ready, bool submit){ + ggml_tensor * node = cgraph->nodes[node_idx]; if (ggml_is_empty(node) || !node->buffer) { return false; } @@ -8362,6 +9524,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod switch (ggml_get_unary_op(node)) { case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_GELU: + case GGML_UNARY_OP_GELU_ERF: case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_TANH: @@ -8371,10 +9534,24 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod return false; } break; + case GGML_OP_GLU: + switch (ggml_get_glu_op(node)) { + case GGML_GLU_OP_GEGLU: + case GGML_GLU_OP_REGLU: + case GGML_GLU_OP_SWIGLU: + case GGML_GLU_OP_SWIGLU_OAI: + case GGML_GLU_OP_GEGLU_ERF: + case GGML_GLU_OP_GEGLU_QUICK: + break; + default: + return false; + } + break; case GGML_OP_REPEAT: case GGML_OP_REPEAT_BACK: case GGML_OP_GET_ROWS: case GGML_OP_ADD: + case GGML_OP_ADD_ID: case GGML_OP_ACC: case GGML_OP_SUB: case GGML_OP_MUL: @@ -8387,7 +9564,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_COS: case GGML_OP_CLAMP: case GGML_OP_PAD: + case GGML_OP_ROLL: case GGML_OP_CPY: + case GGML_OP_SET_ROWS: case GGML_OP_CONT: case GGML_OP_DUP: case GGML_OP_SILU_BACK: @@ -8410,7 +9589,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_COUNT_EQUAL: case GGML_OP_IM2COL: case GGML_OP_TIMESTEP_EMBEDDING: + case GGML_OP_CONV_TRANSPOSE_1D: case GGML_OP_POOL_2D: + case GGML_OP_CONV_2D: case GGML_OP_CONV_2D_DW: case GGML_OP_RWKV_WKV6: case GGML_OP_RWKV_WKV7: @@ -8428,7 +9609,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod if (!dryrun) { if (ctx->compute_ctx.expired()) { - compute_ctx = ggml_vk_create_context(ctx, ctx->device->compute_queue); + compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); ctx->compute_ctx = compute_ctx; ggml_vk_ctx_begin(ctx->device, compute_ctx); } else { @@ -8453,6 +9634,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_CLAMP: case GGML_OP_PAD: case GGML_OP_CPY: + case GGML_OP_SET_ROWS: case GGML_OP_CONT: case GGML_OP_DUP: case GGML_OP_SILU_BACK: @@ -8462,6 +9644,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_RMS_NORM_BACK: case GGML_OP_L2_NORM: case GGML_OP_UNARY: + case GGML_OP_GLU: case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX_BACK: @@ -8474,14 +9657,16 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_COUNT_EQUAL: case GGML_OP_IM2COL: case GGML_OP_TIMESTEP_EMBEDDING: + case GGML_OP_CONV_TRANSPOSE_1D: case GGML_OP_POOL_2D: + case GGML_OP_CONV_2D: case GGML_OP_CONV_2D_DW: case GGML_OP_LEAKY_RELU: { // These operations all go through ggml_vk_op_f32, so short-circuit and // do the only thing needed for the dryrun. vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, node, node->op); - ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); return false; } default: @@ -8521,6 +9706,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_DIV: ggml_vk_div(ctx, compute_ctx, src0, src1, node, dryrun); + break; + case GGML_OP_ADD_ID: + ggml_vk_add_id(ctx, compute_ctx, src0, src1, src2, node, dryrun); + break; case GGML_OP_CONCAT: ggml_vk_concat(ctx, compute_ctx, src0, src1, node, dryrun); @@ -8553,12 +9742,20 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_PAD: ggml_vk_pad(ctx, compute_ctx, src0, node, dryrun); + break; + case GGML_OP_ROLL: + ggml_vk_roll(ctx, compute_ctx, src0, node, dryrun); + break; case GGML_OP_CPY: case GGML_OP_CONT: case GGML_OP_DUP: ggml_vk_cpy(ctx, compute_ctx, src0, node, dryrun); + break; + case GGML_OP_SET_ROWS: + ggml_vk_set_rows(ctx, compute_ctx, src0, src1, node, dryrun); + break; case GGML_OP_SILU_BACK: ggml_vk_silu_back(ctx, compute_ctx, src0, src1, node, dryrun); @@ -8573,8 +9770,14 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod break; case GGML_OP_RMS_NORM: - ggml_vk_rms_norm(ctx, compute_ctx, src0, node, dryrun); - + if (ctx->num_additional_fused_ops > 0) { + // fused rms_norm + mul + ggml_tensor *mul = cgraph->nodes[node_idx + 1]; + ggml_tensor *other_src = mul->src[0] == node ? mul->src[1] : mul->src[0]; + ggml_vk_rms_norm(ctx, compute_ctx, src0, other_src, mul, (float *)node->op_params, dryrun); + } else { + ggml_vk_rms_norm(ctx, compute_ctx, src0, src0, node, (float *)node->op_params, dryrun); + } break; case GGML_OP_RMS_NORM_BACK: ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun); @@ -8588,6 +9791,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod switch (ggml_get_unary_op(node)) { case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_GELU: + case GGML_UNARY_OP_GELU_ERF: case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_TANH: @@ -8598,12 +9802,26 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod return false; } break; + case GGML_OP_GLU: + switch (ggml_get_glu_op(node)) { + case GGML_GLU_OP_GEGLU: + case GGML_GLU_OP_REGLU: + case GGML_GLU_OP_SWIGLU: + case GGML_GLU_OP_SWIGLU_OAI: + case GGML_GLU_OP_GEGLU_ERF: + case GGML_GLU_OP_GEGLU_QUICK: + ggml_vk_glu(ctx, compute_ctx, src0, src1, node, dryrun); + break; + default: + return false; + } + break; case GGML_OP_DIAG_MASK_INF: ggml_vk_diag_mask_inf(ctx, compute_ctx, src0, node, dryrun); break; case GGML_OP_SOFT_MAX: - ggml_vk_soft_max(ctx, compute_ctx, src0, src1, node, dryrun); + ggml_vk_soft_max(ctx, compute_ctx, src0, src1, src2, node, dryrun); break; case GGML_OP_SOFT_MAX_BACK: @@ -8645,10 +9863,18 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_TIMESTEP_EMBEDDING: ggml_vk_timestep_embedding(ctx, compute_ctx, src0, node, dryrun); + break; + case GGML_OP_CONV_TRANSPOSE_1D: + ggml_vk_conv_transpose_1d(ctx, compute_ctx, src0, src1, node, dryrun); + break; case GGML_OP_POOL_2D: ggml_vk_pool_2d(ctx, compute_ctx, src0, node, dryrun); + break; + case GGML_OP_CONV_2D: + ggml_vk_conv_2d(ctx, compute_ctx, src0, src1, node, dryrun); + break; case GGML_OP_CONV_2D_DW: ggml_vk_conv_2d_dw(ctx, compute_ctx, src0, src1, node, dryrun); @@ -8668,7 +9894,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod break; case GGML_OP_FLASH_ATTN_EXT: - ggml_vk_flash_attn(ctx, compute_ctx, src0, src1, src2, src3, node, dryrun); + ggml_vk_flash_attn(ctx, compute_ctx, src0, src1, src2, src3, node->src[4], node, dryrun); break; @@ -8696,7 +9922,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod ctx->tensor_ctxs[node_idx] = compute_ctx; -#if defined(GGML_VULKAN_CHECK_RESULTS) || defined(GGML_VULKAN_PERF) +#if defined(GGML_VULKAN_CHECK_RESULTS) // Force context reset on each node so that each tensor ends up in its own context // and can be run and compared to its CPU equivalent separately last_node = true; @@ -8715,12 +9941,13 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod ctx->compute_ctx.reset(); - bool ok = ggml_vk_compute_forward(ctx, node_begin, node_idx_begin, false, almost_ready); + bool ok = ggml_vk_compute_forward(ctx, cgraph, node_begin, node_idx_begin, false, almost_ready); if (!ok) { if (node->op == GGML_OP_UNARY) { std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast(node->op_params[0])) << ")" << std::endl; - } - else { + } else if (node->op == GGML_OP_GLU) { + std::cerr << __func__ << ": error: op not supported GLU " << node->name << " (" << ggml_glu_op_name(static_cast(node->op_params[0])) << ")" << std::endl; + } else { std::cerr << __func__ << ": error: op not supported " << node->name << " (" << ggml_op_name(node->op) << ")" << std::endl; } } @@ -8729,7 +9956,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod return true; } -static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * tensor, int tensor_idx, bool use_fence = true, bool almost_ready = false) { +static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, ggml_tensor * tensor, int tensor_idx, bool use_fence = true, bool almost_ready = false) { + GGML_UNUSED(cgraph); ggml_backend_buffer * buf = nullptr; switch (tensor->op) { @@ -8739,6 +9967,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: + case GGML_OP_ADD_ID: case GGML_OP_CONCAT: case GGML_OP_UPSCALE: case GGML_OP_SCALE: @@ -8747,7 +9976,9 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * case GGML_OP_COS: case GGML_OP_CLAMP: case GGML_OP_PAD: + case GGML_OP_ROLL: case GGML_OP_CPY: + case GGML_OP_SET_ROWS: case GGML_OP_CONT: case GGML_OP_DUP: case GGML_OP_SILU_BACK: @@ -8773,7 +10004,9 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * case GGML_OP_COUNT_EQUAL: case GGML_OP_IM2COL: case GGML_OP_TIMESTEP_EMBEDDING: + case GGML_OP_CONV_TRANSPOSE_1D: case GGML_OP_POOL_2D: + case GGML_OP_CONV_2D: case GGML_OP_CONV_2D_DW: case GGML_OP_RWKV_WKV6: case GGML_OP_RWKV_WKV7: @@ -8788,6 +10021,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * switch (ggml_get_unary_op(tensor)) { case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_GELU: + case GGML_UNARY_OP_GELU_ERF: case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_TANH: @@ -8798,6 +10032,20 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * return false; } break; + case GGML_OP_GLU: + switch (ggml_get_glu_op(tensor)) { + case GGML_GLU_OP_GEGLU: + case GGML_GLU_OP_REGLU: + case GGML_GLU_OP_SWIGLU: + case GGML_GLU_OP_SWIGLU_OAI: + case GGML_GLU_OP_GEGLU_ERF: + case GGML_GLU_OP_GEGLU_QUICK: + buf = tensor->buffer; + break; + default: + return false; + } + break; case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: case GGML_OP_FLASH_ATTN_EXT: @@ -8824,7 +10072,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * // Only run if ctx hasn't been submitted yet if (!subctx->seqs.empty()) { #ifdef GGML_VULKAN_CHECK_RESULTS - ggml_vk_check_results_0(tensor); + ggml_vk_check_results_0(ctx, cgraph, tensor_idx); use_fence = true; #endif @@ -8844,7 +10092,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * ggml_vk_wait_for_fence(ctx); } #ifdef GGML_VULKAN_CHECK_RESULTS - ggml_vk_check_results_1(tensor); + ggml_vk_check_results_1(ctx, cgraph, tensor_idx); #endif } @@ -8868,19 +10116,8 @@ static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) { } ctx->gc.temp_buffers.clear(); - for (auto& dsr : ctx->device->pipeline_descriptor_set_requirements) { - vk_pipeline_ref plr = ctx->device->pipelines[dsr.first]; - - if (plr.expired()) { - continue; - } - - vk_pipeline pl = plr.lock(); - ggml_pipeline_cleanup(pl); - } - - ggml_vk_queue_cleanup(ctx->device, ctx->device->compute_queue); - ggml_vk_queue_cleanup(ctx->device, ctx->device->transfer_queue); + ggml_vk_command_pool_cleanup(ctx->device, ctx->compute_cmd_pool); + ggml_vk_command_pool_cleanup(ctx->device, ctx->transfer_cmd_pool); for (size_t i = 0; i < ctx->gc.semaphores.size(); i++) { ctx->device->device.destroySemaphore({ ctx->gc.semaphores[i].s }); @@ -8901,7 +10138,8 @@ static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) { ctx->tensor_ctxs.clear(); ctx->gc.contexts.clear(); - ctx->device->pipeline_descriptor_set_requirements.clear(); + ctx->pipeline_descriptor_set_requirements = 0; + ctx->descriptor_set_idx = 0; } // Clean up on backend free @@ -8928,6 +10166,15 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) { ctx->device->device.destroyFence(ctx->fence); ctx->device->device.destroyFence(ctx->almost_ready_fence); + + for (auto& pool : ctx->descriptor_pools) { + ctx->device->device.destroyDescriptorPool(pool); + } + ctx->descriptor_pools.clear(); + ctx->descriptor_sets.clear(); + + ctx->compute_cmd_pool.destroy(ctx->device->device); + ctx->transfer_cmd_pool.destroy(ctx->device->device); } static int ggml_vk_get_device_count() { @@ -9115,8 +10362,7 @@ static ggml_backend_buffer_t ggml_backend_vk_host_buffer_type_alloc_buffer(ggml_ try { ptr = ggml_vk_host_malloc(vk_instance.devices[0], size); } catch (vk::SystemError& e) { - std::cerr << "ggml_vulkan: Failed to allocate pinned memory." << std::endl; - std::cerr << "ggml_vulkan: " << e.what() << std::endl; + GGML_LOG_WARN("ggml_vulkan: Failed to allocate pinned memory (%s)\n", e.what()); // fallback to cpu buffer return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size); } @@ -9136,6 +10382,12 @@ static size_t ggml_backend_vk_host_buffer_type_get_alignment(ggml_backend_buffer UNUSED(buft); } +static size_t ggml_backend_vk_host_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) { + return vk_instance.devices[0]->suballocation_block_size; + + UNUSED(buft); +} + // Should be changed to return device-specific host buffer type // but that probably requires changes in llama.cpp ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type() { @@ -9144,7 +10396,7 @@ ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type() { /* .get_name = */ ggml_backend_vk_host_buffer_type_name, /* .alloc_buffer = */ ggml_backend_vk_host_buffer_type_alloc_buffer, /* .get_alignment = */ ggml_backend_vk_host_buffer_type_get_alignment, - /* .get_max_size = */ NULL, // defaults to SIZE_MAX + /* .get_max_size = */ ggml_backend_vk_host_buffer_type_get_max_size, /* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size, /* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host, }, @@ -9195,7 +10447,7 @@ static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor if (ctx->transfer_ctx.expired()) { // Initialize new transfer context - transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue); + transfer_ctx = ggml_vk_create_context(ctx, ctx->transfer_cmd_pool); ctx->transfer_ctx = transfer_ctx; ggml_vk_ctx_begin(ctx->device, transfer_ctx); } else { @@ -9218,7 +10470,7 @@ static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_ if (ctx->transfer_ctx.expired()) { // Initialize new transfer context - transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue); + transfer_ctx = ggml_vk_create_context(ctx, ctx->transfer_cmd_pool); ctx->transfer_ctx = transfer_ctx; ggml_vk_ctx_begin(ctx->device, transfer_ctx); } else { @@ -9241,7 +10493,7 @@ static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend, const ggml_ if (ctx->transfer_ctx.expired()) { // Initialize new transfer context - transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue); + transfer_ctx = ggml_vk_create_context(ctx, ctx->transfer_cmd_pool); ctx->transfer_ctx = transfer_ctx; ggml_vk_ctx_begin(ctx->device, transfer_ctx); } else { @@ -9287,22 +10539,71 @@ static bool ggml_vk_is_empty(ggml_tensor * node) { return ggml_is_empty(node) || node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE; } +static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list ops) { + if (!ggml_can_fuse(cgraph, node_idx, ops)) { + return false; + } + + if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) { + // additional constraints specific to this fusion + const ggml_tensor *rms_norm = cgraph->nodes[node_idx]; + const ggml_tensor *mul = cgraph->nodes[node_idx + 1]; + + GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(rms_norm->type == GGML_TYPE_F32); + // rms_norm only supports f32 + if (mul->src[0]->type != GGML_TYPE_F32 || + mul->src[1]->type != GGML_TYPE_F32 || + mul->type != GGML_TYPE_F32) { + return false; + } + // if rms_norm is the B operand, then we don't handle broadcast + if (rms_norm == mul->src[1] && + !ggml_are_same_shape(mul->src[0], rms_norm)) { + return false; + } + // rms_norm shader assumes contiguous rows + if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) { + return false; + } + } + return true; +} + static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)"); ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; + if (vk_instance.debug_utils_support) { + vk::DebugUtilsLabelEXT dul = {}; + dul.pLabelName = "ggml_backend_vk_graph_compute"; + dul.color = std::array{1.0f, 1.0f, 1.0f, 1.0f}; + vk_instance.pfn_vkQueueBeginDebugUtilsLabelEXT(ctx->device->compute_queue.queue, reinterpret_cast(&dul)); + } + uint64_t total_mat_mul_bytes = 0; for (int i = 0; i < cgraph->n_nodes; i++) { - ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false, false); + if (!ctx->device->disable_fusion && ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { + ctx->num_additional_fused_ops = 1; + } + ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false); if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) { total_mat_mul_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]); + } else if (cgraph->nodes[i]->op == GGML_OP_CONV_2D) { + // Return CRSxNPQxsizeof(*) to account as many bytes as mul_mat has in im2col->mul_mat mode. + auto CRS_size = + cgraph->nodes[i]->src[0]->ne[0] * cgraph->nodes[i]->src[0]->ne[1] * cgraph->nodes[i]->src[0]->ne[2]; + auto NPQ_size = cgraph->nodes[i]->ne[0] * cgraph->nodes[i]->ne[1] * cgraph->nodes[i]->ne[3]; + total_mat_mul_bytes += NPQ_size * CRS_size * ggml_type_size(cgraph->nodes[i]->type); } + i += ctx->num_additional_fused_ops; + ctx->num_additional_fused_ops = 0; } if (ctx->device->need_compiles) { ggml_vk_load_shaders(ctx->device); } ggml_vk_preallocate_buffers(ctx); - ggml_pipeline_allocate_descriptor_sets(ctx->device); + ggml_pipeline_allocate_descriptor_sets(ctx); int last_node = cgraph->n_nodes - 1; @@ -9317,6 +10618,29 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg bool first_node_in_batch = true; // true if next node will be first node in a batch int submit_node_idx = 0; // index to first node in a batch + vk_context compute_ctx; + if (vk_perf_logger_enabled) { + // allocate/resize the query pool + if (ctx->device->num_queries < cgraph->n_nodes + 1) { + if (ctx->device->query_pool) { + ctx->device->device.destroyQueryPool(ctx->device->query_pool); + } + vk::QueryPoolCreateInfo query_create_info; + query_create_info.queryType = vk::QueryType::eTimestamp; + query_create_info.queryCount = cgraph->n_nodes + 100; + ctx->device->query_pool = ctx->device->device.createQueryPool(query_create_info); + ctx->device->num_queries = query_create_info.queryCount; + } + + ctx->device->device.resetQueryPool(ctx->device->query_pool, 0, cgraph->n_nodes+1); + + GGML_ASSERT(ctx->compute_ctx.expired()); + compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); + ctx->compute_ctx = compute_ctx; + ggml_vk_ctx_begin(ctx->device, compute_ctx); + compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, 0); + } + // Submit after enough work has accumulated, to overlap CPU cmdbuffer generation with GPU execution. // Estimate the amount of matmul work by looking at the weight matrix size, and submit every 100MB // (and scaled down based on model size, so smaller models submit earlier). @@ -9335,14 +10659,32 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]); } + if (!ctx->device->disable_fusion && ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { + ctx->num_additional_fused_ops = 1; + } + // Signal the almost_ready fence when the graph is mostly complete (< 20% remaining) bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5; bool submit = (submitted_nodes >= nodes_per_submit) || (mul_mat_bytes >= mul_mat_bytes_per_submit) || - (i == last_node) || + (i + ctx->num_additional_fused_ops == last_node) || (almost_ready && !ctx->almost_ready_fence_pending); - bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, almost_ready, submit); + bool enqueued = ggml_vk_build_graph(ctx, cgraph, i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i + ctx->num_additional_fused_ops == last_node, almost_ready, submit); + + if (vk_perf_logger_enabled) { + if (ctx->compute_ctx.expired()) { + compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); + ctx->compute_ctx = compute_ctx; + ggml_vk_ctx_begin(ctx->device, compute_ctx); + } else { + compute_ctx = ctx->compute_ctx.lock(); + } + // If there are fused ops, just write out timestamps for all nodes to keep the accounting simple + for (int j = 0; j < ctx->num_additional_fused_ops + 1; ++j) { + compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, i+j+1); + } + } if (enqueued) { ++submitted_nodes; @@ -9363,11 +10705,31 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg } submit_count++; } + i += ctx->num_additional_fused_ops; + ctx->num_additional_fused_ops = 0; } -#ifdef GGML_VULKAN_PERF - ctx->device->perf_logger->print_timings(); -#endif + if (vk_perf_logger_enabled) { + // End the command buffer and submit/wait + GGML_ASSERT(!ctx->compute_ctx.expired()); + compute_ctx = ctx->compute_ctx.lock(); + ggml_vk_ctx_end(compute_ctx); + + ggml_vk_submit(compute_ctx, ctx->device->fence); + VK_CHECK(ctx->device->device.waitForFences({ ctx->device->fence }, true, UINT64_MAX), "GGML_VULKAN_PERF waitForFences"); + ctx->device->device.resetFences({ ctx->device->fence }); + + // Get the results and pass them to the logger + std::vector timestamps(cgraph->n_nodes + 1); + VK_CHECK(ctx->device->device.getQueryPoolResults(ctx->device->query_pool, 0, cgraph->n_nodes + 1, (cgraph->n_nodes + 1)*sizeof(uint64_t), timestamps.data(), sizeof(uint64_t), vk::QueryResultFlagBits::e64 | vk::QueryResultFlagBits::eWait), "get timestamp results"); + for (int i = 0; i < cgraph->n_nodes; i++) { + if (!ggml_vk_is_empty(cgraph->nodes[i])) { + ctx->device->perf_logger->log_timing(cgraph->nodes[i], uint64_t((timestamps[i+1] - timestamps[i]) * ctx->device->properties.limits.timestampPeriod)); + } + } + + ctx->device->perf_logger->print_timings(); + } ggml_vk_graph_cleanup(ctx); @@ -9405,10 +10767,10 @@ ggml_backend_t ggml_backend_vk_init(size_t dev_num) { ggml_vk_init(ctx, dev_num); ggml_backend_t vk_backend = new ggml_backend { - /* .guid = */ ggml_backend_vk_guid(), - /* .interface = */ ggml_backend_vk_interface, - /* .device = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), dev_num), - /* .context = */ ctx, + /* .guid = */ ggml_backend_vk_guid(), + /* .iface = */ ggml_backend_vk_interface, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), dev_num), + /* .context = */ ctx, }; return vk_backend; @@ -9506,6 +10868,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_UNARY: switch (ggml_get_unary_op(op)) { case GGML_UNARY_OP_GELU: + case GGML_UNARY_OP_GELU_ERF: case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_RELU: @@ -9519,15 +10882,33 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm return false; } break; + case GGML_OP_GLU: + switch (ggml_get_glu_op(op)) { + case GGML_GLU_OP_GEGLU: + case GGML_GLU_OP_REGLU: + case GGML_GLU_OP_SWIGLU: + case GGML_GLU_OP_SWIGLU_OAI: + case GGML_GLU_OP_GEGLU_ERF: + case GGML_GLU_OP_GEGLU_QUICK: + return ggml_is_contiguous(op->src[0]) && + (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && + (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && + (op->src[0]->type == op->type); + default: + return false; + } + break; case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: { ggml_type src0_type = op->src[0]->type; ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; const vk_device& device = ggml_vk_get_device(ctx->device); - if (op->op == GGML_OP_MUL_MAT_ID && !device->mul_mat_id_s[src0_type] && !device->mul_mat_id_m[src0_type] && !device->mul_mat_id_l[src0_type]) { - // If there's not enough shared memory for row_ids and the result tile, fallback to CPU - return false; + if (op->op == GGML_OP_MUL_MAT_ID) { + if (!device->mul_mat_id_s[src0_type] && !device->mul_mat_id_m[src0_type] && !device->mul_mat_id_l[src0_type]) { + // If there's not enough shared memory for row_ids and the result tile, fallback to CPU + return false; + } } switch (src0_type) { case GGML_TYPE_F32: @@ -9552,6 +10933,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: + case GGML_TYPE_MXFP4: break; default: return false; @@ -9585,19 +10967,11 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; auto device = ggml_vk_get_device(ctx->device); bool coopmat2 = device->coopmat2; - switch (op->src[0]->ne[0]) { - case 64: - case 80: - case 96: - case 112: - case 128: - case 256: - break; - default: + FaHeadSizes head_sizes = fa_get_head_sizes(op->src[1]->ne[0], op->src[2]->ne[0]); + if (head_sizes == FA_HEAD_SIZE_UNSUPPORTED) { return false; } - if (op->src[1]->ne[0] != op->src[2]->ne[0]) { - // different head sizes of K and V are not supported yet + if (op->src[4] && op->src[4]->type != GGML_TYPE_F32) { return false; } if (op->src[0]->type != GGML_TYPE_F32) { @@ -9671,6 +11045,24 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_MXFP4: + return true; + default: + return false; + } + } break; + case GGML_OP_SET_ROWS: + { + switch (op->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_BF16: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: case GGML_TYPE_IQ4_NL: return true; default: @@ -9718,6 +11110,15 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { return true; } + + // We can handle copying from a type to the same type if it's + // contiguous (memcpy). We use f16 or f32 shaders to do the copy, + // so the type/block size must be a multiple of 4. + if (src0_type == src1_type && + ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op) && + (ggml_type_size(src0_type) % 2) == 0) { + return true; + } return false; } break; case GGML_OP_REPEAT: @@ -9744,6 +11145,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && (op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16) && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16); + case GGML_OP_ADD_ID: + return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->src[2]->type == GGML_TYPE_I32 && + op->type == GGML_TYPE_F32; case GGML_OP_SILU_BACK: case GGML_OP_RMS_NORM_BACK: case GGML_OP_SQR: @@ -9752,11 +11156,11 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_CLAMP: return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_UPSCALE: - return op->op_params[0] == GGML_SCALE_MODE_NEAREST; case GGML_OP_ACC: case GGML_OP_CONCAT: case GGML_OP_SCALE: case GGML_OP_PAD: + case GGML_OP_ROLL: case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX_BACK: @@ -9774,6 +11178,22 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_LEAKY_RELU: case GGML_OP_OPT_STEP_ADAMW: return true; + case GGML_OP_CONV_TRANSPOSE_1D: + return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32; + case GGML_OP_CONV_2D: + { + // Op is disabled for Apple because it segfaults at pipeline create time on MoltenVK + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + const vk_device& device = ggml_vk_get_device(ctx->device); + bool is_Apple = ggml_vk_get_device(ctx->device)->vendor_id == VK_VENDOR_ID_APPLE; + // Channel-contiguous format is not supported yet. + return ((op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && + op->src[1]->type == GGML_TYPE_F32 && + op->type == GGML_TYPE_F32 && + ggml_is_contiguous(op->src[0]) && + ggml_is_contiguous(op->src[1]) && + ggml_is_contiguous(op)) && !is_Apple; + } default: return false; } @@ -9917,11 +11337,28 @@ static bool ggml_vk_instance_portability_enumeration_ext_available(const std::ve UNUSED(instance_extensions); } +// Extension availability +static bool ggml_vk_instance_debug_utils_ext_available( + const std::vector & instance_extensions) { + // Check for portability enumeration extension for MoltenVK support + for (const auto & properties : instance_extensions) { + if (strcmp("VK_EXT_debug_utils", properties.extensionName) == 0) { + return true; + } + } + + std::cerr << "ggml_vulkan: WARNING: Instance extension VK_EXT_debug_utils not found." << std::endl; + return false; + + UNUSED(instance_extensions); +} + static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch) { switch (props.vendorID) { case VK_VENDOR_ID_INTEL: - // Intel drivers don't support coopmat properly yet - return false; + // Only allowing Xe2 GPU at the moment since Xe2 GPU can gain significant performance boost, + // while some older hardware (ex. Arc A770) has performance regressions + return arch == vk_device_architecture::INTEL_XE2; case VK_VENDOR_ID_AMD: if (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource) { // Workaround for AMD proprietary driver reporting support on all GPUs @@ -10028,11 +11465,21 @@ void * comp_result; size_t comp_size; size_t comp_nb[GGML_MAX_DIMS]; size_t check_counter = 0; -static void ggml_vk_check_results_0(ggml_tensor * tensor) { - if (tensor->op == GGML_OP_TRANSPOSE) { +static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) { + ggml_tensor * tensor = cgraph->nodes[tensor_idx]; + if (tensor->op == GGML_OP_TRANSPOSE || tensor->op == GGML_OP_SET_ROWS) { return; } + bool fused_rms_norm_mul = false; + int rms_norm_idx = -1; + if (ctx->num_additional_fused_ops == 1 && + tensor->op == GGML_OP_RMS_NORM && + cgraph->nodes[tensor_idx + 1]->op == GGML_OP_MUL) { + fused_rms_norm_mul = true; + tensor = cgraph->nodes[tensor_idx + 1]; + } + check_counter++; if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) { return; @@ -10060,6 +11507,15 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { for (int i = 0; i < 6; i++) { ggml_tensor * srci = tensor->src[i]; + if (fused_rms_norm_mul) { + rms_norm_idx = tensor->src[0]->op == GGML_OP_RMS_NORM ? 0 : 1; + ggml_tensor *rms_norm = tensor->src[rms_norm_idx]; + switch (i) { + case 0: srci = rms_norm->src[0]; break; + case 1: srci = tensor->src[1 - rms_norm_idx]; break; + default: continue; + } + } if (srci == nullptr) { continue; } @@ -10110,6 +11566,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { if (tensor->op == GGML_OP_FLASH_ATTN_EXT) { const float * params = (const float *)tensor->op_params; tensor_clone = ggml_flash_attn_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], params[0], params[1], params[2]); + if (src_clone[4]) { + ggml_flash_attn_ext_add_sinks(tensor_clone, src_clone[4]); + } } else if (tensor->op == GGML_OP_MUL_MAT) { tensor_clone = ggml_mul_mat(ggml_ctx, src_clone[0], src_clone[1]); } else if (tensor->op == GGML_OP_MUL_MAT_ID) { @@ -10117,16 +11576,21 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { } else if (tensor->op == GGML_OP_SUB) { tensor_clone = ggml_sub(ggml_ctx, src_clone[0], src_clone[1]); } else if (tensor->op == GGML_OP_MUL) { - tensor_clone = ggml_mul(ggml_ctx, src_clone[0], src_clone[1]); + if (fused_rms_norm_mul) { + tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->src[rms_norm_idx]->op_params); + tensor_clone = ggml_mul(ggml_ctx, tensor_clone, src_clone[1 - rms_norm_idx]); + } else { + tensor_clone = ggml_mul(ggml_ctx, src_clone[0], src_clone[1]); + } } else if (tensor->op == GGML_OP_DIV) { tensor_clone = ggml_div(ggml_ctx, src_clone[0], src_clone[1]); } else if (tensor->op == GGML_OP_CONCAT) { tensor_clone = ggml_concat(ggml_ctx, src_clone[0], src_clone[1], *(int *)tensor->op_params); } else if (tensor->op == GGML_OP_UPSCALE) { - tensor_clone = ggml_upscale_ext(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], tensor->op_params[0], tensor->op_params[1], (ggml_scale_mode) tensor->op_params[0]); + tensor_clone = ggml_upscale_ext(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], (ggml_scale_mode) tensor->op_params[0]); } else if (tensor->op == GGML_OP_SCALE) { const float * params = (const float *)tensor->op_params; - tensor_clone = ggml_scale(ggml_ctx, src_clone[0], params[0]); + tensor_clone = ggml_scale_bias(ggml_ctx, src_clone[0], params[0], params[1]); } else if (tensor->op == GGML_OP_SQR) { tensor_clone = ggml_sqr(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_SIN) { @@ -10205,6 +11669,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { case GGML_UNARY_OP_GELU: tensor_clone = ggml_gelu(ggml_ctx, src_clone[0]); break; + case GGML_UNARY_OP_GELU_ERF: + tensor_clone = ggml_gelu_erf(ggml_ctx, src_clone[0]); + break; case GGML_UNARY_OP_GELU_QUICK: tensor_clone = ggml_gelu_quick(ggml_ctx, src_clone[0]); break; @@ -10221,6 +11688,12 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl; GGML_ABORT("fatal error"); } + } else if (tensor->op == GGML_OP_GLU) { + if (src_clone[1] == nullptr) { + tensor_clone = ggml_glu(ggml_ctx, src_clone[0], (ggml_glu_op) tensor->op_params[0], tensor->op_params[1]); + } else { + tensor_clone = ggml_glu_split(ggml_ctx, src_clone[0], src_clone[1], (ggml_glu_op) tensor->op_params[0]); + } } else if (tensor->op == GGML_OP_CPY || tensor->op == GGML_OP_DUP) { if (src1 == nullptr) { tensor_clone = ggml_dup(ggml_ctx, src_clone[0]); @@ -10265,6 +11738,11 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { const int32_t dim = tensor->op_params[0]; const int32_t max_period = tensor->op_params[1]; tensor_clone = ggml_timestep_embedding(ggml_ctx, src_clone[0], dim, max_period); + } else if (tensor->op == GGML_OP_CONV_TRANSPOSE_1D){ + const int32_t s0 = tensor->op_params[0]; + const int32_t p0 = tensor->op_params[1]; + const int32_t d0 = tensor->op_params[2]; + tensor_clone = ggml_conv_transpose_1d(ggml_ctx, src_clone[0], src_clone[1], s0, p0, d0); } else if (tensor->op == GGML_OP_POOL_2D) { enum ggml_op_pool op = static_cast(tensor->op_params[0]); const int32_t k0 = tensor->op_params[1]; @@ -10275,6 +11753,14 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { const int32_t p1 = tensor->op_params[6]; tensor_clone = ggml_pool_2d(ggml_ctx, src_clone[0], op, k0, k1, s0, s1, p0, p1); + } else if (tensor->op == GGML_OP_CONV_2D) { + const int32_t s0 = tensor->op_params[0]; + const int32_t s1 = tensor->op_params[1]; + const int32_t p0 = tensor->op_params[2]; + const int32_t p1 = tensor->op_params[3]; + const int32_t d0 = tensor->op_params[4]; + const int32_t d1 = tensor->op_params[5]; + tensor_clone = ggml_conv_2d(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1); } else if (tensor->op == GGML_OP_LEAKY_RELU) { const float * op_params = (const float *)tensor->op_params; tensor_clone = ggml_leaky_relu(ggml_ctx, src_clone[0], op_params[0], false); @@ -10294,10 +11780,10 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { GGML_ABORT("fatal error"); } - ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx); - ggml_build_forward_expand(cgraph, tensor_clone); + ggml_cgraph * cgraph_cpu = ggml_new_graph(ggml_ctx); + ggml_build_forward_expand(cgraph_cpu, tensor_clone); - ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 8); + ggml_graph_compute_with_ctx(ggml_ctx, cgraph_cpu, 8); if (vk_output_tensor > 0 && vk_output_tensor == check_counter) { ggml_vk_print_tensor(tensor_clone, "tensor_clone"); @@ -10320,10 +11806,19 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { VK_LOG_DEBUG("END ggml_vk_check_results_0(" << tensor->name << ")"); } -static void ggml_vk_check_results_1(ggml_tensor * tensor) { - if (tensor->op == GGML_OP_TRANSPOSE) { +static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) { + ggml_tensor * tensor = cgraph->nodes[tensor_idx]; + if (tensor->op == GGML_OP_TRANSPOSE || tensor->op == GGML_OP_SET_ROWS) { return; } + bool fused_rms_norm_mul = false; + if (ctx->num_additional_fused_ops == 1 && + tensor->op == GGML_OP_RMS_NORM && + cgraph->nodes[tensor_idx + 1]->op == GGML_OP_MUL) { + fused_rms_norm_mul = true; + tensor = cgraph->nodes[tensor_idx + 1]; + } + if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) { return; } @@ -10373,6 +11868,9 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) { } else if (tensor->type == GGML_TYPE_F16) { correct = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0])); result = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0])); + } else if (tensor->type == GGML_TYPE_BF16) { + correct = ggml_bf16_to_fp32(*(ggml_bf16_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0])); + result = ggml_bf16_to_fp32(*(ggml_bf16_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0])); } else if (tensor->type == GGML_TYPE_I32) { correct = *(int32_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]); result = *(int32_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]); @@ -10412,7 +11910,8 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) { ggml_vk_print_graph_origin(tensor, done); GGML_ABORT("fatal error"); } - if (first_error[0] == -1 && std::fabs(correct - result) > 0.1f) { + const double denom = std::fabs(correct) > 1.0f ? (std::fabs(correct) > 1e-8 ? std::fabs(correct) : 1e-8) : 1.0f; + if (first_error[0] == -1 && std::fabs(correct - result) / denom > 0.5) { first_error[0] = i0; first_error[1] = i1; first_error[2] = i2; @@ -10424,7 +11923,7 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) { // Special case, value is infinite, avoid NaN result in avg_err // NaN also appears in results, if both are nan error is 0 if (!std::isinf(correct) && !std::isinf(result) && !std::isnan(correct) && !std::isnan(result)) { - avg_err += std::fabs(correct - result); + avg_err += std::fabs(correct - result) / denom; } counter++; } @@ -10459,7 +11958,7 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) { ggml_vk_print_graph_origin(tensor, done); } - if (avg_err > 0.05 || std::isnan(avg_err)) { + if (avg_err > 0.5 || std::isnan(avg_err)) { std::cerr << "ERROR: avg_err=" << avg_err << " in " << ggml_op_name(tensor->op) << " (check " << check_counter << ")" << std::endl; std::cerr << "tensor=" << tensor << " tensor->name=" << tensor->name << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << " offset=" << tensor->view_offs << std::endl; if (src0 != nullptr) { diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt index ad13f69b3..e1f613fb4 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt @@ -5,16 +5,25 @@ find_package (Threads REQUIRED) if (GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) add_compile_definitions(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + message(STATUS "Enabling coopmat glslc support") endif() if (GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + message(STATUS "Enabling coopmat2 glslc support") endif() if (GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + message(STATUS "Enabling dot glslc support") endif() if (GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) add_compile_definitions(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + message(STATUS "Enabling bfloat16 glslc support") endif() +if (GGML_VULKAN_SHADER_DEBUG_INFO) + add_compile_definitions(GGML_VULKAN_SHADER_DEBUG_INFO) + message(STATUS "Enabling shader debug info") +endif() + set(TARGET vulkan-shaders-gen) add_executable(${TARGET} vulkan-shaders-gen.cpp) install(TARGETS ${TARGET} RUNTIME) diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp new file mode 100644 index 000000000..3ae8f0116 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp @@ -0,0 +1,42 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : require + +#include "types.comp" + +layout (push_constant) uniform parameter +{ + uint ne0; + uint ne1; + uint s01; + uint s02; + uint s11; + uint s21; +} p; + +#define BLOCK_SIZE 512 + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer Y {B_TYPE data_b[];}; +layout (binding = 2) readonly buffer Z {int32_t data_c[];}; +layout (binding = 3) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i1 = gl_WorkGroupID.x; + const uint i2 = gl_WorkGroupID.y; + + const uint i11 = data_c[i1 + i2 * p.s21]; + + const uint s1 = p.ne0; + const uint s2 = p.ne0 * p.ne1; + + const uint d0 = i1 * s1 + i2 * s2; + const uint a0 = i1 * p.s01 + i2 * p.s02; + const uint b0 = i11 * p.s11; + + for (uint i0 = gl_LocalInvocationID.x; i0 < p.ne0; i0 += BLOCK_SIZE) { + data_d[d0 + i0] = data_a[a0 + i0] + data_b[b0 + i0]; + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp new file mode 100644 index 000000000..86bafba4a --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp @@ -0,0 +1,329 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#ifdef COOPMAT2 +#extension GL_NV_cooperative_matrix2 : enable +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_KHR_memory_scope_semantics : enable +#endif + +#ifdef USE_COLLECTIVES +# extension GL_KHR_shader_subgroup_shuffle : enable +#endif + +#include "types.comp" + +// shape notation: [dim(N), ..., dim(0)] -- stride(dim(j)) >= stride(dim(i)) if i > j +layout(binding = 0) readonly buffer A { + A_TYPE knl_data[]; +}; // src0 - kernel: [KW, KH, Cin, Cout] + +layout(binding = 1) readonly buffer B { + B_TYPE src_data[]; +}; // src1 - input: [W, H, Cin, N] -- channel_first format + +layout(binding = 2) writeonly buffer D { + D_TYPE dst_data[]; +}; // dst - result: [OW, OH, Cout, N] + +layout(push_constant) uniform parameter { + // I/O channels, batch size + uint32_t Cout; + uint32_t Cin; + uint32_t N; + + // Tensor spatial sizes: kernel, input, output + uint32_t KW; + uint32_t KH; + uint32_t W; + uint32_t H; + uint32_t OW; + uint32_t OH; + + // Parameters: stride, padding, dilation - 0=y, 1=x + uint32_t s0; + uint32_t s1; + uint32_t p0; + uint32_t p1; + uint32_t d0; + uint32_t d1; + + // Strides in elements + uint32_t nb01; + uint32_t nb02; + uint32_t nb03; + + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + + uint32_t nb1; + uint32_t nb2; + uint32_t nb3; + + // fastdiv helper values + uint32_t KWmp; uint32_t KWL; + uint32_t KWKHmp; uint32_t KWKHL; + uint32_t OWmp; uint32_t OWL; + uint32_t OWOHmp; uint32_t OWOHL; +} + +p; + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; +// Blocktile sizes +layout(constant_id = 1) const uint BS_K = 128; +layout(constant_id = 2) const uint BS_CRS = 16; +layout(constant_id = 3) const uint BS_NPQ = 128; +// Thread-tile sizes +layout(constant_id = 4) const uint TS_K = 8; +layout(constant_id = 5) const uint use_collectives = 1; +layout(constant_id = 6) const uint SHMEM_PAD = 4; + +uint32_t tid = gl_LocalInvocationID.x; +const uint32_t WG_SIZE = gl_WorkGroupSize.x; + +uint splitWork(uint work_size, uint block_size) { + return (block_size + work_size - 1) / block_size; +} + +uint32_t K = p.Cout; +uint32_t CRS = p.Cin * p.KH * p.KW; +uint32_t NPQ = p.N * p.OH * p.OW; + +uint32_t n_elems_out = K * NPQ; + +// Number of blocktiles per input +uint32_t NB_CRS = splitWork(CRS, BS_CRS); + +#ifdef COOPMAT2 +#define SHMEM_TYPE float16_t +#else +#define SHMEM_TYPE float +#endif + +const uint32_t Ash_stride = BS_CRS + SHMEM_PAD; +const uint32_t Bsh_stride = BS_NPQ + SHMEM_PAD; + +const uint32_t Ash_numel = BS_K * BS_CRS; +const uint32_t Bsh_numel = BS_CRS * BS_NPQ; + +const uint32_t Ash_len = BS_K * Ash_stride; +const uint32_t Bsh_len = BS_CRS * Bsh_stride; + +shared SHMEM_TYPE Ash[Ash_len]; // K x CRS +shared SHMEM_TYPE Bsh[Bsh_len]; // CRS x NPQ + +// Threadtile sizes +const uint32_t TS_NPQ = BS_K * BS_NPQ / WG_SIZE / TS_K; + +// Number of threadtiles per blocktile +const uint32_t NT_K = BS_K / TS_K; +const uint32_t NT_NPQ = BS_NPQ / TS_NPQ; + +/* +Compute +KxCRS @ CRSxNPQ = K x NPQ +K=Cout +C=Cin +R,S=KH,KW +P,Q=OH,OW +*/ + +uint32_t B_idx_K = gl_WorkGroupID.x; +uint32_t B_idx_NPQ = gl_WorkGroupID.y; + +uint32_t T_y = tid / NT_NPQ; +uint32_t T_x = tid % NT_NPQ; + +uint32_t Ar = tid / BS_CRS; +uint32_t Ac = tid % BS_CRS; +const uint32_t ArpWg = WG_SIZE / BS_CRS; + +uint32_t Br = tid / BS_NPQ; +uint32_t Bc = tid % BS_NPQ; +const uint32_t BrpWg = WG_SIZE / BS_NPQ; + +// see init_fastdiv_values in ggml-vulkan.cpp +uint fastdiv(uint n, uint mp, uint L) { + uint msbs, lsbs; + // msbs = mulhi(n, mp) + umulExtended(n, mp, msbs, lsbs); + return (msbs + n) >> L; +} + +#ifdef COOPMAT2 +#define ACC_TYPE float16_t + +ACC_TYPE perElemOpStore(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem) +{ + uint32_t K_idx = B_idx_K * BS_K + r; + uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + c; + uint32_t N_idx = fastdiv(NPQ_idx, p.OWOHmp, p.OWOHL); // divide by p.OH * p.OW; + uint32_t OH_idx = fastdiv(NPQ_idx - N_idx * p.OH * p.OW, p.OWmp, p.OWL); // divide by p.OW; + uint32_t OW_idx = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW; + uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3; + if (K_idx < K && NPQ_idx < NPQ) { + dst_data[dst_idx] = D_TYPE(elem); + } + return elem; +} +#endif + +void main() { +#ifdef COOPMAT2 + coopmat matC; + matC = coopmat(0.0); +#else + float regC[TS_K][TS_NPQ]; + for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { + for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) { + regC[T_ly][T_lx] = 0.0; + } + } +#endif + /* Advance block in CRS dim */ + for (uint32_t B_idx_CRS = 0; B_idx_CRS < NB_CRS; B_idx_CRS++) { + uint32_t CRS_idx_a; + uint32_t Cin_idx_a; + uint32_t KH_idx_a; + uint32_t KW_idx_a; + +#ifdef USE_COLLECTIVES + uint32_t cached_CRS_idx; + uint32_t cached_Cin_idx; + uint32_t cached_KH_idx; + uint32_t cached_KW_idx; + if (use_collectives == 1) { + cached_CRS_idx = B_idx_CRS * BS_CRS + gl_SubgroupInvocationID; + cached_Cin_idx = fastdiv(cached_CRS_idx, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH); + uint32_t cached_CRS_remainder = (cached_CRS_idx - cached_Cin_idx * p.KW * p.KH); + cached_KH_idx = fastdiv(cached_CRS_remainder, p.KWmp, p.KWL); // divide by p.KW; + cached_KW_idx = cached_CRS_remainder - cached_KH_idx * p.KW; + + CRS_idx_a = subgroupShuffle(cached_CRS_idx, Ac); + Cin_idx_a = subgroupShuffle(cached_Cin_idx, Ac); + KH_idx_a = subgroupShuffle(cached_KH_idx, Ac); + KW_idx_a = subgroupShuffle(cached_KW_idx, Ac); + } else { + CRS_idx_a = B_idx_CRS * BS_CRS + Ac; // Global CRS_idx_a (column index of A) + Cin_idx_a = fastdiv(CRS_idx_a, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH); + uint32_t CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH; + KH_idx_a = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW; + KW_idx_a = CRS_remainder - KH_idx_a * p.KW; + } +#else + CRS_idx_a = B_idx_CRS * BS_CRS + Ac; // Global CRS_idx_a (column index of A) + Cin_idx_a = fastdiv(CRS_idx_a, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH); / (p.KW * p.KH); + CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH; + KH_idx_a = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW; + KW_idx_a = CRS_remainder - KH_idx_a * p.KW; +#endif + + /* Load kernel to A_block: (BS_K x BS_CRS)*/ + for (uint32_t r_offset = 0; r_offset < BS_K; r_offset += ArpWg) { + uint32_t B_ly = r_offset + Ar; + uint32_t B_lx = Ac; + uint32_t K_idx = B_idx_K * BS_K + B_ly; /* Global K_idx (row index of A)*/ + uint32_t knl_idx = min(KW_idx_a + KH_idx_a * p.nb01 + Cin_idx_a * p.nb02 + K_idx * p.nb03, K * CRS - 1); + float val = knl_data[knl_idx]; + if (K_idx >= K || CRS_idx_a >= CRS) { + val = 0.0; + } + Ash[B_ly * Ash_stride + B_lx] = SHMEM_TYPE(val); + } + /* Load input to B_block: (BS_CRS x BS_NPQ) */ + UNROLL for (uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg) { + uint32_t B_ly = r_offset + Br; /* Row index of B block */ + uint32_t B_lx = Bc; + uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + B_lx; /* Global NPQ index (column index of B) */ + uint32_t N_idx = fastdiv(NPQ_idx, p.OWOHmp, p.OWOHL); // divide by p.OH * p.OW; + uint32_t NPQ_remainder = NPQ_idx - N_idx * p.OH * p.OW; + uint32_t OH_idx = fastdiv(NPQ_remainder, p.OWmp, p.OWL); // divide by p.OW; + uint32_t OW_idx = NPQ_remainder - OH_idx * p.OW; + + uint32_t CRS_idx_b; + uint32_t Cin_idx_b; + uint32_t KH_idx_b; + uint32_t KW_idx_b; +#ifdef USE_COLLECTIVES + if (use_collectives == 1) { + CRS_idx_b = subgroupShuffle(cached_CRS_idx, r_offset + Br); + Cin_idx_b = subgroupShuffle(cached_Cin_idx, r_offset + Br); + KH_idx_b = subgroupShuffle(cached_KH_idx, r_offset + Br); + KW_idx_b = subgroupShuffle(cached_KW_idx, r_offset + Br); + } else { + CRS_idx_b = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */ + Cin_idx_b = fastdiv(CRS_idx_b, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH); + uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH; + KH_idx_b = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW; + KW_idx_b = CRS_remainder - KH_idx_b * p.KW; + } +#else + CRS_idx_b = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */ + Cin_idx_b = fastdiv(CRS_idx_b, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH); + uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH; + KH_idx_b = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW; + KW_idx_b = CRS_remainder - KH_idx_b * p.KW; +#endif + + uint32_t H_idx = OH_idx * p.s1 + KH_idx_b * p.d1 - p.p1; + uint32_t W_idx = OW_idx * p.s0 + KW_idx_b * p.d0 - p.p0; + uint32_t src_idx = + min(max(W_idx + H_idx * p.nb11 + Cin_idx_b * p.nb12 + N_idx * p.nb13, 0), p.Cin * p.N * p.W * p.H - 1); + float val = src_data[src_idx]; + if (CRS_idx_b >= CRS || NPQ_idx >= NPQ || H_idx < 0 || H_idx >= p.H || W_idx < 0 || W_idx >= p.W) { + val = 0.0; + } + Bsh[B_ly * Bsh_stride + B_lx] = SHMEM_TYPE(val); + } + barrier(); +#ifdef COOPMAT2 + coopmat matA; + coopmat matB; + + coopMatLoad(matA, Ash, 0, Ash_stride, gl_CooperativeMatrixLayoutRowMajor); + coopMatLoad(matB, Bsh, 0, Bsh_stride, gl_CooperativeMatrixLayoutRowMajor); + matC = coopMatMulAdd(matA, matB, matC); +#else + if (T_y * TS_K < K) { + UNROLL for (uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++) { + float regA[TS_K]; + float regB[TS_NPQ]; + for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { + regA[T_ly] = Ash[(T_y * TS_K + T_ly) * Ash_stride + CRS_lidx]; + } + for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) { + regB[T_lx] = Bsh[CRS_lidx * Bsh_stride + T_x * TS_NPQ + T_lx]; + } + for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { + for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) { + regC[T_ly][T_lx] = fma(regA[T_ly], regB[T_lx], regC[T_ly][T_lx]); + } + } + } + } +#endif + barrier(); + } + /* Save C* */ +#ifdef COOPMAT2 + coopMatPerElementNV(matC, matC, perElemOpStore); +#else + if (T_y * TS_K < K) { + for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { + for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) { + uint32_t K_idx = B_idx_K * BS_K + T_y * TS_K + T_ly; + uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx; + uint32_t N_idx = fastdiv(NPQ_idx, p.OWOHmp, p.OWOHL); // divide by p.OH * p.OW; + uint32_t OH_idx = fastdiv(NPQ_idx - N_idx * p.OH * p.OW, p.OWmp, p.OWL); // divide by p.OW; + uint32_t OW_idx = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW; + uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3; + if (K_idx < K && NPQ_idx < NPQ) { + dst_data[dst_idx] = regC[T_ly][T_lx]; + } + } + } + } +#endif +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp new file mode 100644 index 000000000..b17b4e83e --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp @@ -0,0 +1,98 @@ +#version 450 + +#include "types.comp" + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; // src0 - kernel: [K, Cout, Cin] +layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; // src1 - input: [L, Cin] +layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; // dst - result [KL, Cout] + +layout(local_size_x = 128 , local_size_y = 1, local_size_z = 1) in; + +layout (push_constant) uniform parameter { + uint32_t Cout; + uint32_t Cin; + uint32_t K; + uint32_t L; + uint32_t KL; + + uint32_t nb01; + uint32_t nb02; + uint32_t nb11; + uint32_t nb1; + + int32_t s0; +} p; + + +uint32_t Cout_idx = gl_WorkGroupID.x; +const uint32_t bs = gl_WorkGroupSize.x; +uint32_t tid = gl_LocalInvocationID.x; +// Code is more straightforward if we assume it is bs*s0+K instead of (bs-1)*s0+K. +uint32_t tmp_len = bs*p.s0+p.K; +shared D_TYPE tmp[4096]; + +uint splitWork(uint workSize){ + return (bs + workSize -1) / bs; +} + +void main(){ + for(uint32_t i = 0; i < splitWork(tmp_len); i++){ + uint32_t idx = i*bs+tid; + if(idx < tmp_len){ + tmp[idx] = 0.0; + } + } + + uint32_t L_blocks = splitWork(p.L); + for(uint32_t L_block_id = 0; L_block_id < L_blocks; L_block_id++){ + if(L_block_id > 0){ + barrier(); + // Shift values in tmp to the current processing window + for(int i = 0; i < splitWork(tmp_len); i++){ + uint32_t idx = i*bs+tid; + if(idx >= bs*p.s0 && idx < tmp_len){ + tmp[idx-bs*p.s0] = tmp[idx]; + tmp[idx] = 0.0; + }else if(idx >= p.K && idx < bs*p.s0){ + tmp[idx] = 0.0; + } + } + } + barrier(); + + // Save contributions of the block to tmp + uint32_t L_idx = L_block_id*bs + tid; + for(uint32_t K_idx = 0; K_idx < p.K; K_idx++){ + D_TYPE dp = 0.0; + for(uint32_t Cin_idx = 0; Cin_idx < p.Cin; Cin_idx++){ + A_TYPE elemKrn = data_a[K_idx + Cout_idx * p.nb01 + Cin_idx * p.nb02]; + if(L_idx < p.L){ + B_TYPE elemInp = data_b[L_idx + Cin_idx*p.nb11]; + dp = fma(elemKrn, elemInp, dp); + } + } + tmp[tid*p.s0 + K_idx] += dp; + barrier(); + } + + // Save the computed values except the last block that can have different size + uint32_t KLb_idx = L_block_id*bs*p.s0; + if(L_block_id < L_blocks-1){ + for(uint32_t s0_idx = 0; s0_idx < p.s0; s0_idx++){ + uint32_t sh_idx = p.s0*tid+s0_idx; + uint32_t KL_idx = KLb_idx+sh_idx; + if(KL_idx < p.KL){ + data_d[KL_idx + Cout_idx*p.nb1] = tmp[sh_idx]; + } + } + } + } + + for(uint32_t i = 0; i < splitWork(tmp_len); i++){ + uint32_t idx = i*bs+tid; + uint32_t KL_idx = (L_blocks-1)*bs*p.s0+idx; + if(KL_idx < p.KL){ + data_d[KL_idx + Cout_idx*p.nb1] = tmp[idx]; + } + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp index dbc7daa33..978d43003 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp @@ -4,8 +4,8 @@ #include "generic_unary_head.comp" #include "dequant_funcs.comp" -#if defined(DATA_A_IQ4_NL) -// 16 invocations needed for init_iq4nl_shmem +#if defined(DATA_A_IQ4_NL) || defined(DATA_A_MXFP4) +// 16 invocations needed for init_iq_shmem layout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in; #else layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp index 9c76437d9..27d6b7464 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp @@ -1,22 +1,26 @@ #version 450 -#if RTE16 -#extension GL_EXT_spirv_intrinsics : enable -spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits -#endif // RTE16 - +#include "rte.comp" #include "types.comp" -#include "generic_unary_head.comp" -#if defined(DATA_A_IQ4_NL) -// 16 invocations needed for init_iq4nl_shmem -layout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in; +#if defined(SET_ROWS) && QUANT_K == 1 +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; +const uint BLOCK_SIZE = 512; #else -layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; +layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; +const uint BLOCK_SIZE = 32; #endif layout (binding = 0) readonly buffer S {float data_s[];}; + +#if defined(SET_ROWS) +#include "generic_binary_head.comp" +layout (binding = 1) readonly buffer C {uvec2 data_i[];}; +layout (binding = 2) writeonly buffer Q {A_TYPE data_q[];}; +#else +#include "generic_unary_head.comp" layout (binding = 1) writeonly buffer Q {A_TYPE data_q[];}; +#endif #if defined(DATA_A_Q4_0) void quantize(uint dst_idx, uint src_idx) @@ -221,15 +225,56 @@ void quantize(uint dst_idx, uint src_idx) } #endif +#if defined(DATA_A_F32) || defined(DATA_A_F16) +void quantize(uint dst_idx, uint src_idx) +{ + data_q[dst_idx] = A_TYPE(data_s[src_idx]); +} +#endif + +#if defined(DATA_A_BF16) +void quantize(uint dst_idx, uint src_idx) +{ + data_q[dst_idx] = A_TYPE(fp32_to_bf16(data_s[src_idx])); +} +#endif + +#if defined(SET_ROWS) + void main() { #ifdef NEEDS_INIT_IQ_SHMEM init_iq_shmem(gl_WorkGroupSize); - if (gl_LocalInvocationIndex.x != 0) { - return; - } #endif - const uint idx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * QUANT_K; + const uint idx = ((gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x) * BLOCK_SIZE + gl_LocalInvocationID.x) * QUANT_K; + + if (idx >= p.ne) { + return; + } + + uint i00, i01, i02, i03; + get_indices(idx, i00, i01, i02, i03); + + uint i12 = fastmod(i03, p.ne12); + uint i11 = fastmod(i02, p.ne11); + uint i10 = i01; + + uint i1 = data_i[src1_idx(i10, i11, i12, 0) + get_boffset()].x; + + uint src0_idx = src0_idx(i00, i01, i02, i03) + get_aoffset(); + uint dst_idx = dst_idx(i00 / QUANT_K, i1, i02, i03) + get_doffset(); + + quantize(dst_idx, src0_idx); +} + +#else + +void main() { +#ifdef NEEDS_INIT_IQ_SHMEM + init_iq_shmem(gl_WorkGroupSize); +#endif + + const uint idx = (gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x) * QUANT_K; if (idx >= p.ne) { return; @@ -240,3 +285,5 @@ void main() { quantize(dst_idx, src_idx); } + +#endif diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp index 0d9739d40..d3127fbd9 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp @@ -434,6 +434,18 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) { } #endif +#if defined(DATA_A_MXFP4) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint vui = uint(data_a[a_offset + ib].qs[iqs]); + return vec2(kvalues_mxfp4[vui & 0xF], kvalues_mxfp4[vui >> 4]); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + vec2 v0 = dequantize(ib, iqs, a_offset); + vec2 v1 = dequantize(ib, iqs + 1, a_offset); + return vec4(v0.x, v0.y, v1.x, v1.y); +} +#endif + #if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16) vec2 get_dm(uint ib, uint a_offset) { return vec2(0, 0); @@ -455,6 +467,12 @@ vec2 get_dm(uint ib, uint a_offset) { } #endif +#if defined(DATA_A_MXFP4) +vec2 get_dm(uint ib, uint a_offset) { + return vec2(e8m0_to_fp32(data_a[a_offset + ib].e), 0); +} +#endif + #if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1) vec2 get_dm(uint ib, uint a_offset) { return vec2(float(data_a[a_offset + ib].d), float(data_a[a_offset + ib].m)); diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp index 9cb7da2da..706540fd8 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp @@ -654,6 +654,25 @@ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoor } #endif +#if defined(DATA_A_MXFP4) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufMXFP4 { + block_mxfp4 block; +}; + +float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float d = e8m0_to_fp32(bl.block.e); + const uint idx = coordInBlock[1]; + const uint iqs = idx & 0xF; + const uint shift = (idx & 0x10) >> 2; + uint32_t qs = bl.block.qs[iqs]; + qs >>= shift; + qs &= 0xF; + float16_t ret = float16_t(kvalues_mxfp4[qs] * d); + return ret; +} +#endif + #if defined(DATA_A_Q4_0) #define dequantFuncA dequantFuncQ4_0 #elif defined(DATA_A_Q4_1) @@ -696,4 +715,6 @@ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoor #define dequantFuncA dequantFuncIQ4_XS #elif defined(DATA_A_IQ4_NL) #define dequantFuncA dequantFuncIQ4_NL +#elif defined(DATA_A_MXFP4) +#define dequantFuncA dequantFuncMXFP4 #endif diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp index 39184ef58..b604c1881 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp @@ -1,6 +1,6 @@ #version 450 -#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require #include "dequant_head.comp" diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp new file mode 100644 index 000000000..ee496e9d5 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp @@ -0,0 +1,32 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_mxfp4 data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; + + init_iq_shmem(gl_WorkGroupSize); + + const uint tid = gl_LocalInvocationID.x % 64; + const uint il = tid/32; + const uint ir = tid%32; + const uint ib = 32*i + ir; + if (ib >= p.nel / 32) { + return; + } + + const uint q_idx = 8*il; + const uint b_idx = 1024*i + 32*ir + q_idx; + + const float d = e8m0_to_fp32(data_a[ib].e); + + [[unroll]] for (uint l = 0; l < 8; ++l) { + data_b[b_idx + l + 0] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] & 0xF]); + data_b[b_idx + l + 16] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] >> 4]); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp index 157154af3..d4e4e6bae 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp @@ -10,7 +10,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; void main() { [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { const uint i = gl_WorkGroupID.x * 256 + wgy; - if (i >= p.M * p.K / QUANT_K) { + if (i >= p.nel / QUANT_K) { return; } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp index c17dd0d99..3661f771c 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp @@ -10,7 +10,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; void main() { [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { const uint i = uint(gl_WorkGroupID.x * 256 + wgy); - if (i >= p.M * p.K / QUANT_K) { + if (i >= p.nel / QUANT_K) { return; } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp index 987f113a3..1370db365 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp @@ -10,7 +10,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; void main() { [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { const uint ib = gl_WorkGroupID.x * 256 + wgy; - if (ib >= p.M * p.K / QUANT_K) { + if (ib >= p.nel / QUANT_K) { return; } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp index 6db5403b6..3f3b839e1 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp @@ -10,7 +10,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; void main() { [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { const uint ib = gl_WorkGroupID.x * 256 + wgy; - if (ib >= p.M * p.K / QUANT_K) { + if (ib >= p.nel / QUANT_K) { return; } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp index 0b9131755..9cf34256e 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp @@ -10,7 +10,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; void main() { [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { const uint i = gl_WorkGroupID.x * 256 + wgy; - if (i >= p.M * p.K / QUANT_K) { + if (i >= p.nel / QUANT_K) { return; } const uint tid = gl_LocalInvocationID.x; diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index e6545160d..d40848e15 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -9,59 +9,14 @@ #extension GL_KHR_shader_subgroup_shuffle : enable #include "types.comp" +#include "flash_attn_base.comp" -layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; +const uint32_t HSK_per_thread = HSK / D_split; +const uint32_t HSV_per_thread = HSV / D_split; -layout (constant_id = 1) const uint32_t Br = 1; -layout (constant_id = 2) const uint32_t Bc = 32; -layout (constant_id = 3) const uint32_t D = 32; - -layout (constant_id = 5) const uint32_t D_split = 16; -const uint32_t D_per_thread = D / D_split; - -const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split; +const uint32_t cols_per_iter = WorkGroupSize / D_split; const uint32_t cols_per_thread = Bc / cols_per_iter; -layout (push_constant) uniform parameter { - uint32_t N; - uint32_t KV; - - uint32_t ne1; - uint32_t ne2; - uint32_t ne3; - - uint32_t neq2; - uint32_t neq3; - uint32_t nek2; - uint32_t nek3; - uint32_t nev2; - uint32_t nev3; - uint32_t nem1; - - uint32_t nb01; - uint32_t nb02; - uint32_t nb03; - uint32_t nb11; - uint32_t nb12; - uint32_t nb13; - uint32_t nb21; - uint32_t nb22; - uint32_t nb23; - uint32_t nb31; - - float scale; - float max_bias; - float logit_softcap; - - uint32_t mask; - uint32_t n_head_log2; - float m0; - float m1; - - uint32_t gqa_ratio; - uint32_t split_kv; - uint32_t k_num; -} p; layout (binding = 0) readonly buffer Q {float data_q[];}; layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];}; @@ -70,147 +25,47 @@ layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];}; layout (binding = 2) readonly buffer V {float16_t data_v[];}; layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];}; layout (binding = 3) readonly buffer M {float16_t data_m[];}; -layout (binding = 4) writeonly buffer O {D_TYPE data_o[];}; - -#if defined(A_TYPE_PACKED16) -#define BINDING_IDX_K 0 -#define BINDING_IDX_V 1 -layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2]; -#endif - -#if defined(DATA_A_Q4_0) -#define BLOCK_BYTE_SIZE 18 - -vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { - uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); - uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); - uint shift = (iqs & 0x10) >> 2; - vui_lo >>= shift; - vui_hi >>= shift; - - return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f); -} -#endif - -#if defined(DATA_A_Q8_0) -#define BLOCK_BYTE_SIZE 34 -vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { - const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147 - const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy; - - return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y); -} -#endif - -#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) // Store the output when doing grouped query attention. // Rows index by Q's dimension 2, and the first N rows are valid. D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) { - uint32_t offset = (iq2 + r) * D + c; + uint32_t offset = (iq2 + r) * HSV + c; data_o[o_offset + offset] = D_TYPE(elem); return elem; } -// Store column zero. This is used to save per-row m and L values for split_k. -ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) -{ - if (r < N && c == 0) { - uint32_t offset = iq2 + r; - data_o[o_offset + offset] = D_TYPE(elem); - } - return elem; -} - -// Load the slope matrix, indexed by Q's dimension 2. -ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2) -{ - const uint32_t h = iq2 + (r % p.gqa_ratio); - - const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1); - const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1); - - return ACC_TYPE(pow(base, ACC_TYPE(exph))); -} - -shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x]; -shared vec4 tmpshv4[gl_WorkGroupSize.x]; +shared FLOAT_TYPE tmpsh[WorkGroupSize]; +shared vec4 tmpshv4[WorkGroupSize]; shared float masksh[Bc][Br]; -shared vec4 Qf[Br][D / 4]; +shared vec4 Qf[Br][HSK / 4]; void main() { #ifdef NEEDS_INIT_IQ_SHMEM init_iq_shmem(gl_WorkGroupSize); #endif - const uint32_t tid = gl_LocalInvocationIndex; - const uint32_t N = p.N; - const uint32_t KV = p.KV; + init_indices(); + const uint32_t tid = gl_LocalInvocationIndex; const uint32_t d_tid = gl_LocalInvocationIndex % D_split; const uint32_t col_tid = gl_LocalInvocationIndex / D_split; - uint32_t i = gl_WorkGroupID.x; - uint32_t split_k_index = 0; - - if (p.k_num > 1) { - i = 0; - split_k_index = gl_WorkGroupID.x; - } - - const uint32_t Tr = CEIL_DIV(N, Br); - - const uint32_t start_j = split_k_index * p.split_kv / Bc; - const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc); - - // When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y. - // When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2. - const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio; - const uint32_t iq3 = gl_WorkGroupID.z; - - // broadcast factors - const uint32_t rk2 = p.neq2/p.nek2; - const uint32_t rk3 = p.neq3/p.nek3; - - const uint32_t rv2 = p.neq2/p.nev2; - const uint32_t rv3 = p.neq3/p.nev3; - - // k indices - const uint32_t ik3 = iq3 / rk3; - const uint32_t ik2 = iq2 / rk2; - - // v indices - const uint32_t iv3 = iq3 / rv3; - const uint32_t iv2 = iq2 / rv2; - - // nb?1 are already divided by the type size and are in units of elements. - // When using grouped query attention, Q is indexed by iq2, so the stride - // should be nb02 (which is in bytes). - uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01; - uint32_t k_stride = p.nb11; - uint32_t v_stride = p.nb21; - // When using grouped query attention, all rows use the same mask (stride 0). - // "p.gqa_ratio >> 16" is just a roundabout way of writing zero - // that prevents the compiler from folding the "&" through the select - // and breaking the alignment detection. - uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV; - uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4; - [[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) { - uint32_t d = (idx + tid) % (D / 4); - uint32_t r = (idx + tid) / (D / 4); - if (r < Br && d < D / 4 && + [[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) { + uint32_t d = (idx + tid) % (HSK / 4); + uint32_t r = (idx + tid) / (HSK / 4); + if (r < Br && d < HSK / 4 && i * Br + r < N) { Qf[r][d] = vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d]) * p.scale; } } barrier(); - vec4 Of[Br][D_per_thread / 4]; - [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + vec4 Of[Br][HSV_per_thread / 4]; + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { [[unroll]] for (uint32_t r = 0; r < Br; ++r) { Of[r][d] = vec4(0.0); } @@ -245,6 +100,10 @@ void main() { uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2; uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2; #endif + uint32_t m_offset = 0; + if (p.nem2 != 1 || p.nem3 != 1) { + m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV; + } [[dont_unroll]] for (uint32_t j = start_j; j < end_j; ++j) { @@ -258,7 +117,7 @@ void main() { [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { - [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) { #if BLOCK_SIZE > 1 uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); uint ib = coord / BLOCK_SIZE; @@ -290,13 +149,13 @@ void main() { } } - if (p.mask != 0) { + if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { uint32_t c = (idx + tid) % Bc; uint32_t r = (idx + tid) / Bc; if (idx + tid < Bc * Br) { - masksh[c][r] = float(data_m[(i * Br + r) * m_stride + (j * Bc + c)]); + masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]); } } barrier(); @@ -337,14 +196,14 @@ void main() { Lf[r] = eMf[r]*Lf[r] + rowsumf[r]; } - [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { [[unroll]] for (uint32_t r = 0; r < Br; ++r) { Of[r][d] = eMf[r] * Of[r][d]; } } [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { - [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { #if BLOCK_SIZE > 1 uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); uint ib = coord / BLOCK_SIZE; @@ -401,7 +260,7 @@ void main() { Lf[r] = tmpsh[d_tid]; barrier(); - [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { Of[r][d] = eMf * Of[r][d]; tmpshv4[tid] = Of[r][d]; @@ -423,11 +282,11 @@ void main() { // If there is split_k, then the split_k resolve shader does the final // division by L. Store the intermediate O value and per-row m and L values. if (p.k_num > 1) { - uint32_t o_offset = D * p.ne1 * split_k_index; + uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num); [[unroll]] for (uint32_t r = 0; r < Br; ++r) { if (r < N) { - [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N); } @@ -435,7 +294,7 @@ void main() { } } - o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2; + o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2; [[unroll]] for (uint32_t r = 0; r < Br; ++r) { if (r < N) { perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N); @@ -446,23 +305,44 @@ void main() { return; } + if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + float sink = perElemOpGetSink(r, 0u, ACC_TYPE(0), iq2); + + float ms = 1.0f; + float vs = 1.0f; + + if (sink > Mf[r]) { + ms = exp(Mf[r] - sink); + + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + Of[r][d] *= ms; + } + } else { + vs = exp(sink - Mf[r]); + } + + Lf[r] = Lf[r]*ms + vs; + } + } + float Lfrcp[Br]; [[unroll]] for (uint32_t r = 0; r < Br; ++r) { Lfrcp[r] = 1.0 / Lf[r]; } - [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { [[unroll]] for (uint32_t r = 0; r < Br; ++r) { Of[r][d] *= Lfrcp[r]; } } - uint32_t o_offset = iq3*p.ne2*p.ne1; + uint32_t o_offset = iq3*p.ne2*p.ne1*HSV; if (p.gqa_ratio > 1) { [[unroll]] for (uint32_t r = 0; r < Br; ++r) { if (r < N) { - [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N); } @@ -472,9 +352,9 @@ void main() { } else { [[unroll]] for (uint32_t r = 0; r < Br; ++r) { if (i * Br + r < N) { - [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { - data_o[o_offset + iq2 * D + (i * Br + r) * p.ne1 * D + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]); + data_o[o_offset + iq2 * HSV + (i * Br + r) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]); } } } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp new file mode 100644 index 000000000..b57c9dcfc --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp @@ -0,0 +1,178 @@ + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (constant_id = 0) const uint32_t WorkGroupSize = 128; +layout (constant_id = 1) const uint32_t Br = 1; +layout (constant_id = 2) const uint32_t Bc = 32; +layout (constant_id = 3) const uint32_t HSK = 32; +layout (constant_id = 4) const uint32_t HSV = 32; +layout (constant_id = 5) const uint32_t Clamp = 0; +layout (constant_id = 6) const uint32_t D_split = 16; + +layout (push_constant) uniform parameter { + uint32_t N; + uint32_t KV; + + uint32_t ne1; + uint32_t ne2; + uint32_t ne3; + + uint32_t neq2; + uint32_t neq3; + uint32_t nek2; + uint32_t nek3; + uint32_t nev2; + uint32_t nev3; + uint32_t nem1; + uint32_t nem2; + uint32_t nem3; + + uint32_t nb01; + uint32_t nb02; + uint32_t nb03; + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + uint32_t nb21; + uint32_t nb22; + uint32_t nb23; + + float scale; + float max_bias; + float logit_softcap; + + uint32_t mask_n_head_log2; + float m0; + float m1; + + uint32_t gqa_ratio; + uint32_t split_kv; + uint32_t k_num; +} p; + +#define SINK_ENABLE_BIT (1<<24) +#define MASK_ENABLE_BIT (1<<16) +#define N_LOG2_MASK 0xFFFF + +layout (binding = 4) readonly buffer S {float data_s[];}; + +layout (binding = 5) writeonly buffer O {D_TYPE data_o[];}; + +#if defined(A_TYPE_PACKED16) +#define BINDING_IDX_K 0 +#define BINDING_IDX_V 1 +layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2]; +#endif + +#if defined(DATA_A_Q4_0) +#define BLOCK_BYTE_SIZE 18 + +vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { + uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); + uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); + uint shift = (iqs & 0x10) >> 2; + vui_lo >>= shift; + vui_hi >>= shift; + + return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f); +} +#endif + +#if defined(DATA_A_Q8_0) +#define BLOCK_BYTE_SIZE 34 +vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { + const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147 + const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy; + + return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y); +} +#endif + +#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) + + +// Store column zero. This is used to save per-row m and L values for split_k. +ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) +{ + if (r < N && c == 0) { + uint32_t offset = iq2 + r; + data_o[o_offset + offset] = D_TYPE(elem); + } + return elem; +} + +// Load the slope matrix, indexed by Q's dimension 2. +ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2) +{ + const uint32_t h = iq2 + (r % p.gqa_ratio); + + uint32_t n_head_log2 = p.mask_n_head_log2 & N_LOG2_MASK; + + const ACC_TYPE base = ACC_TYPE(h < n_head_log2 ? p.m0 : p.m1); + const int exph = int(h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1); + + return ACC_TYPE(pow(base, ACC_TYPE(exph))); +} + +// Load the sink value, indexed by Q's dimension 2. +ACC_TYPE perElemOpGetSink(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2) +{ + const uint32_t h = iq2 + (r % p.gqa_ratio); + + return ACC_TYPE(data_s[h]); +} + +uint32_t i, N, KV, split_k_index, Tr, start_j, end_j, + iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3, + q_stride, k_stride, v_stride, m_stride; + +void init_indices() +{ + N = p.N; + KV = p.KV; + + i = gl_WorkGroupID.x; + split_k_index = 0; + + if (p.k_num > 1) { + i = 0; + split_k_index = gl_WorkGroupID.x; + } + + Tr = CEIL_DIV(N, Br); + + start_j = split_k_index * p.split_kv / Bc; + end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc); + + // When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y. + // When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2. + iq2 = gl_WorkGroupID.y * p.gqa_ratio; + iq3 = gl_WorkGroupID.z; + + // broadcast factors + rk2 = p.neq2/p.nek2; + rk3 = p.neq3/p.nek3; + + rv2 = p.neq2/p.nev2; + rv3 = p.neq3/p.nev3; + + // k indices + ik3 = iq3 / rk3; + ik2 = iq2 / rk2; + + // v indices + iv3 = iq3 / rv3; + iv2 = iq2 / rv2; + + // nb?1 are already divided by the type size and are in units of elements. + // When using grouped query attention, Q is indexed by iq2, so the stride + // should be nb02 (which is in bytes). + q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01; + k_stride = p.nb11; + v_stride = p.nb21; + // When using grouped query attention, all rows use the same mask (stride 0). + // "p.gqa_ratio >> 16" is just a roundabout way of writing zero + // that prevents the compiler from folding the "&" through the select + // and breaking the alignment detection. + m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV; +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp new file mode 100644 index 000000000..230e815f2 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -0,0 +1,387 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : require + +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#extension GL_KHR_shader_subgroup_basic : enable +#extension GL_KHR_memory_scope_semantics : enable +#extension GL_KHR_cooperative_matrix : enable + +#include "types.comp" +#include "flash_attn_base.comp" + +const uint32_t HSK_per_thread = HSK / D_split; +const uint32_t HSV_per_thread = HSV / D_split; + +const uint32_t row_split = 4; +const uint32_t rows_per_thread = Br / row_split; +const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split / row_split; +const uint32_t cols_per_thread = Bc / cols_per_iter; + + +layout (binding = 0) readonly buffer Q {float data_q[];}; +layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];}; +layout (binding = 1) readonly buffer K {float16_t data_k[];}; +layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];}; +layout (binding = 2) readonly buffer V {float16_t data_v[];}; +layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];}; +layout (binding = 3) readonly buffer M {float16_t data_m[];}; + +// Store the output when doing grouped query attention. +// Rows index by Q's dimension 2, and the first N rows are valid. +D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) +{ + uint32_t offset = (iq2 + r) * HSV + c; + data_o[o_offset + offset] = D_TYPE(elem); + return elem; +} + +// These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd +const uint32_t MatBr = 16; +const uint32_t MatBc = 16; + +shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x]; +shared ACC_TYPEV4 tmpshv4[gl_WorkGroupSize.x]; + +const uint32_t qstride = HSK / 4 + 2; // in units of f16vec4 +shared f16vec4 Qf[Br * qstride]; + +// Avoid padding for hsk==256 to make it fit in 48KB shmem. +const uint32_t sfshstride = (HSK <= 128) ? (Br + 8) : Br; +shared ACC_TYPE sfsh[Bc * sfshstride]; + +const uint32_t kshstride = HSK / 4 + 2; // in units of f16vec4 +shared f16vec4 ksh[Bc * kshstride]; + +shared float slope[Br]; + +void main() { +#ifdef NEEDS_INIT_IQ_SHMEM + init_iq_shmem(gl_WorkGroupSize); +#endif + + init_indices(); + + const uint32_t tid = gl_LocalInvocationIndex; + + const uint32_t threads_per_rowgroup = gl_WorkGroupSize.x / row_split; + const uint32_t row_tid = gl_LocalInvocationIndex / threads_per_rowgroup; + const uint32_t d_tid = gl_LocalInvocationIndex % D_split; + const uint32_t col_tid = (gl_LocalInvocationIndex % threads_per_rowgroup) / D_split; + +#define tile_row(r) (row_tid * rows_per_thread + (r)) + + uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4; + + [[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) { + uint32_t d = (idx + tid) % (HSK / 4); + uint32_t r = (idx + tid) / (HSK / 4); + if (r < Br && d < HSK / 4 && + i * Br + r < N) { + Qf[r * qstride + d] = f16vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale); + } + } + barrier(); + + ACC_TYPEV4 Of[rows_per_thread][HSV_per_thread / 4]; + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Of[r][d] = ACC_TYPEV4(0.0); + } + } + + float Lf[rows_per_thread], Mf[rows_per_thread]; + + // Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M. + const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF); + + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Lf[r] = 0; + Mf[r] = NEG_FLT_MAX_OVER_2; + } + + // ALiBi + if (p.max_bias > 0.0f) { + if (tid < Br) { + uint r = tid; + slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2); + } + barrier(); + } else { + if (tid < Br) { + uint r = tid; + slope[r] = 1.0; + } + barrier(); + } + +#if BLOCK_SIZE > 1 + uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE; + uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE; +#else + uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2; + uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2; +#endif + uint32_t m_offset = 0; + if (p.nem2 != 1 || p.nem3 != 1) { + m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV; + } + + [[dont_unroll]] + for (uint32_t j = start_j; j < end_j; ++j) { + + [[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) { + uint32_t d = (idx + tid) % (HSK / 4); + uint32_t c = (idx + tid) / (HSK / 4); + if (c < Bc && d < HSK / 4) { +#if BLOCK_SIZE > 1 + uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d; + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + f16vec4 K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K)); +#else + f16vec4 K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]); +#endif + + ksh[c * kshstride + d] = K_Tf; + } + } + barrier(); + + // K * Q^T -> S^T: Bc x HSK * HSK x Br -> Bc x Br + // Bc split across workgroup (four subgroups), loop over HSK in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16 + // This is written transposed in order to allow for N being 8 if implementations need it + coopmat SfMat = coopmat(0); + coopmat KMat; + coopmat QMat; + + for (uint32_t d = 0; d < HSK / 16; ++d) { + coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor); + + uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4; + coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor); + + SfMat = coopMatMulAdd(KMat, QMat, SfMat); + } + + uint coord = gl_SubgroupID * MatBc * sfshstride; + coopMatStore(SfMat, sfsh, coord, sfshstride, gl_CooperativeMatrixLayoutRowMajor); + barrier(); + + if (p.logit_softcap != 0.0f) { + [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { + uint32_t c = (idx + tid) / Br; + uint32_t r = (idx + tid) % Br; + if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) { + sfsh[c * sfshstride + r] = ACC_TYPE(p.logit_softcap * tanh(sfsh[c * sfshstride + r])); + } + } + barrier(); + } + + if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { + [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { + uint32_t c = (idx + tid) % Bc; + uint32_t r = (idx + tid) / Bc; + if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) { + sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)])); + } + } + barrier(); + } + + float eMf[rows_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + float rowmaxf = sfsh[tile_row(r) + (0 * cols_per_iter + col_tid) * sfshstride]; + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + rowmaxf = max(rowmaxf, float(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride])); + } + float Moldf = Mf[r]; + + // M = max(rowmax, Mold) + // P = e^(S - M) + // eM = e^(Mold - M) + Mf[r] = max(rowmaxf, Moldf); + eMf[r] = exp(Moldf - Mf[r]); + } + + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Of[r][d] = float16_t(eMf[r]) * Of[r][d]; + } + } + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Lf[r] = eMf[r]*Lf[r]; + } + + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + float Pf[rows_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Pf[r] = exp(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride] - Mf[r]); + Lf[r] += Pf[r]; + } + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { +#if BLOCK_SIZE > 1 + uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + vec4 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); +#else + vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]); +#endif + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Of[r][d] += float16_t(Pf[r]) * ACC_TYPEV4(Vf); + } + } + } + + barrier(); + } + + // reduce across threads + + float rowmaxf[rows_per_thread], eMf[rows_per_thread], Moldf[rows_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + FLOAT_TYPE M = Mf[r]; + tmpsh[tid] = M; + // Compute max across the row + barrier(); + [[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) { + M = max(M, tmpsh[tid ^ s]); + barrier(); + tmpsh[tid] = M; + barrier(); + } + rowmaxf[r] = tmpsh[d_tid + row_tid * threads_per_rowgroup]; + barrier(); + } + + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Moldf[r] = Mf[r]; + + // M = max(rowmax, Mold) + // eM = e^(Mold - M) + Mf[r] = max(rowmaxf[r], Moldf[r]); + eMf[r] = exp(Moldf[r] - Mf[r]); + + Lf[r] = eMf[r]*Lf[r]; + } + + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + FLOAT_TYPE L = Lf[r]; + tmpsh[tid] = L; + // Compute sum across the row + barrier(); + [[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) { + L += tmpsh[tid ^ s]; + barrier(); + tmpsh[tid] = L; + barrier(); + } + Lf[r] = tmpsh[d_tid + row_tid * threads_per_rowgroup]; + barrier(); + } + + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + + Of[r][d] = float16_t(eMf[r]) * Of[r][d]; + tmpshv4[tid] = Of[r][d]; + + barrier(); + [[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) { + Of[r][d] += tmpshv4[tid ^ s]; + barrier(); + tmpshv4[tid] = Of[r][d]; + barrier(); + } + Of[r][d] = tmpshv4[d_tid + row_tid * threads_per_rowgroup]; + barrier(); + } + } + + // If there is split_k, then the split_k resolve shader does the final + // division by L. Store the intermediate O value and per-row m and L values. + if (p.k_num > 1) { + uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num); + + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + if (tile_row(r) < N) { + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { + perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N); + } + } + } + } + + o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + if (tile_row(r) < N) { + perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N); + perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N); + } + } + + return; + } + + if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + float sink = perElemOpGetSink(r, 0u, ACC_TYPE(0), iq2); + + float ms = 1.0f; + float vs = 1.0f; + + if (sink > Mf[r]) { + ms = exp(Mf[r] - sink); + + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + Of[r][d] *= ACC_TYPE(ms); + } + } else { + vs = exp(sink - Mf[r]); + } + + Lf[r] = Lf[r]*ms + vs; + } + } + + float Lfrcp[rows_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Lfrcp[r] = 1.0 / Lf[r]; + } + + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Of[r][d] *= float16_t(Lfrcp[r]); + } + } + + uint32_t o_offset = iq3*p.ne2*p.ne1*HSV; + + if (p.gqa_ratio > 1) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + if (tile_row(r) < N) { + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { + perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N); + } + } + } + } + } else { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + if (i * Br + tile_row(r) < N) { + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { + data_o[o_offset + iq2 * HSV + (i * Br + tile_row(r)) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]); + } + } + } + } + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index b926a578a..b0564ca0b 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -18,62 +18,12 @@ #include "types.comp" #include "dequant_funcs_cm2.comp" - -layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; - -layout (constant_id = 1) const uint32_t Br = 32; -layout (constant_id = 2) const uint32_t Bc = 32; -layout (constant_id = 3) const uint32_t D = 32; -layout (constant_id = 4) const uint32_t Clamp = gl_CooperativeMatrixClampModeConstantNV; - -layout (push_constant) uniform parameter { - uint32_t N; - uint32_t KV; - - uint32_t ne1; - uint32_t ne2; - uint32_t ne3; - - uint32_t neq2; - uint32_t neq3; - uint32_t nek2; - uint32_t nek3; - uint32_t nev2; - uint32_t nev3; - uint32_t nem1; - - uint32_t nb01; - uint32_t nb02; - uint32_t nb03; - uint32_t nb11; - uint32_t nb12; - uint32_t nb13; - uint32_t nb21; - uint32_t nb22; - uint32_t nb23; - uint32_t nb31; - - float scale; - float max_bias; - float logit_softcap; - - uint32_t mask; - uint32_t n_head_log2; - float m0; - float m1; - - uint32_t gqa_ratio; - uint32_t split_kv; - uint32_t k_num; -} p; +#include "flash_attn_base.comp" layout (binding = 0) readonly buffer Q {uint8_t data_q[];}; layout (binding = 1) readonly buffer K {uint8_t data_k[];}; layout (binding = 2) readonly buffer V {uint8_t data_v[];}; layout (binding = 3) readonly buffer M {uint8_t data_m[];}; -layout (binding = 4) writeonly buffer O {D_TYPE data_o[];}; - -#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) ACC_TYPE maxReduce(const in ACC_TYPE x, const in ACC_TYPE y) { return max(x, y); @@ -111,74 +61,19 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele // Rows index by Q's dimension 2, and the first N rows are valid. D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) { - if (r < N && c < D) { - uint32_t offset = (iq2 + r) * D + c; + if (r < N && c < HSV) { + uint32_t offset = (iq2 + r) * HSV + c; data_o[o_offset + offset] = D_TYPE(elem); } return elem; } -// Store column zero. This is used to save per-row m and L values for split_k. -ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) -{ - if (r < N && c == 0) { - uint32_t offset = iq2 + r; - data_o[o_offset + offset] = D_TYPE(elem); - } - return elem; -} - -// Load the slope matrix, indexed by Q's dimension 2. -ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2) -{ - const uint32_t h = iq2 + (r % p.gqa_ratio); - - const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1); - const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1); - - return ACC_TYPE(pow(base, ACC_TYPE(exph))); -} - void main() { #ifdef NEEDS_INIT_IQ_SHMEM init_iq_shmem(gl_WorkGroupSize); #endif - const uint32_t N = p.N; - const uint32_t KV = p.KV; - - uint32_t i = gl_WorkGroupID.x; - uint32_t split_k_index = 0; - - if (p.k_num > 1) { - i = 0; - split_k_index = gl_WorkGroupID.x; - } - - const uint32_t Tr = CEIL_DIV(N, Br); - - const uint32_t start_j = split_k_index * p.split_kv / Bc; - const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc); - - // When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y. - // When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2. - const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio; - const uint32_t iq3 = gl_WorkGroupID.z; - - // broadcast factors - const uint32_t rk2 = p.neq2/p.nek2; - const uint32_t rk3 = p.neq3/p.nek3; - - const uint32_t rv2 = p.neq2/p.nev2; - const uint32_t rv3 = p.neq3/p.nev3; - - // k indices - const uint32_t ik3 = iq3 / rk3; - const uint32_t ik2 = iq2 / rk2; - - // v indices - const uint32_t iv3 = iq3 / rv3; - const uint32_t iv2 = iq2 / rv2; + init_indices(); tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutQ = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); tensorLayoutNV<2, Clamp> tensorLayoutK = createTensorLayoutNV(2, Clamp); @@ -191,21 +86,10 @@ void main() { tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, BLOCK_SIZE); #endif - tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N, D); - tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D); - tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D); + tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N, HSK); + tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, HSK); + tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, HSV); - // nb?1 are already divided by the type size and are in units of elements. - // When using grouped query attention, Q is indexed by iq2, so the stride - // should be nb02 (which is in bytes). - uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01; - uint32_t k_stride = p.nb11; - uint32_t v_stride = p.nb21; - // When using grouped query attention, all rows use the same mask (stride 0). - // "p.gqa_ratio >> 16" is just a roundabout way of writing zero - // that prevents the compiler from folding the "&" through the select - // and breaking the alignment detection. - uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV; // hint to the compiler that strides are aligned for the aligned variant of the shader if (Clamp != gl_CooperativeMatrixClampModeConstantNV) { @@ -220,16 +104,16 @@ void main() { tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1); tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1); - coopmat Q; - coopmat Qf16; + coopmat Q; + coopmat Qf16; uint32_t q_offset = iq2*p.nb02+iq3*p.nb03; - coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, D)); + coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK)); - Qf16 = coopmat(Q); + Qf16 = coopmat(Q); Qf16 *= float16_t(p.scale); - coopmat O = coopmat(0); + coopmat O = coopmat(0); coopmat L, M; @@ -246,15 +130,20 @@ void main() { coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2); } + uint32_t m_offset = 0; + if (p.nem2 != 1 || p.nem3 != 1) { + m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV * 2 /*sizeof(float16_t)*/; + } + [[dont_unroll]] for (uint32_t j = start_j; j < end_j; ++j) { coopmat S = coopmat(0); - coopmat K_T; + coopmat K_T; uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13; - coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, D), tensorViewTranspose DECODEFUNC); + coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK), tensorViewTranspose DECODEFUNC); S = coopMatMulAdd(Qf16, K_T, S); if (p.logit_softcap != 0.0f) { @@ -264,14 +153,14 @@ void main() { } } - if (p.mask != 0) { + if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp); tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV); tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1); coopmat mv; - coopMatLoadTensorNV(mv, data_m, 0, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); + coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); S += slopeMat*coopmat(mv); } @@ -319,46 +208,74 @@ void main() { rowsum = coopmat(0.0); rowsum = coopMatMulAdd(P_A, One, rowsum); - coopmat V; + coopmat V; uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23; - coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, D) DECODEFUNC); + coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV) DECODEFUNC); L = eM*L + rowsum; // This is the "diagonal" matrix in the paper, but since we do componentwise // multiply rather than matrix multiply it has the diagonal element smeared // across the row - coopmat eMdiag; + coopmat eMdiag; // resize eM by using smear/reduce coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce); // multiply with fp16 accumulation, then add to O. - coopmat PV = coopmat(0); + coopmat PV = coopmat(0); PV = coopMatMulAdd(P_A, V, PV); - O = eMdiag * O + coopmat(PV); + O = eMdiag * O + coopmat(PV); } // If there is split_k, then the split_k resolve shader does the final // division by L. Store the intermediate O value and per-row m and L values. if (p.k_num > 1) { - coopmat O_D = coopmat(O); + coopmat O_D = coopmat(O); - uint32_t o_offset = D * p.ne1 * split_k_index; + uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num); coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N); - o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2; + o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2; coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N); coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N); return; } - coopmat Ldiag; + coopmat Ldiag; // resize L by using smear/reduce coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce); + if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) { + coopmat S; + coopMatPerElementNV(S, S, perElemOpGetSink, iq2); + + coopmat Mr; + + // resize M by using smear/reduce + coopMatReduceNV(Mr, M, gl_CooperativeMatrixReduceRowNV, smearReduce); + + // O, Ldiag, Mr all have the same type so all element locations match + [[unroll]] for (uint32_t i = 0; i < Ldiag.length(); ++i) { + ACC_TYPE sink = S[i]; + + ACC_TYPE ms = ACC_TYPE(1.0f); + ACC_TYPE vs = ACC_TYPE(1.0f); + + if (sink > Mr[i]) { + ms = exp(Mr[i] - sink); + + O[i] *= ms; + } else { + vs = exp(sink - Mr[i]); + } + + Ldiag[i] = Ldiag[i]*ms + vs; + } + } + [[unroll]] for (int k = 0; k < Ldiag.length(); ++k) { Ldiag[k] = ACC_TYPE(1.0) / Ldiag[k]; @@ -366,18 +283,18 @@ void main() { O = Ldiag*O; - uint32_t o_offset = iq3*p.ne2*p.ne1; + uint32_t o_offset = iq3*p.ne2*p.ne1*HSV; - coopmat O_D = coopmat(O); + coopmat O_D = coopmat(O); if (p.gqa_ratio > 1) { coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N); } else { tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV); - tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, D); + tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, HSV); // permute dimensions tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2); - coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, D), tensorViewPermute); + coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, HSV), tensorViewPermute); } } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp index a7e395685..76ef4b6df 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp @@ -2,58 +2,115 @@ #extension GL_EXT_control_flow_attributes : enable -#define BLOCK_SIZE 32 +layout(constant_id = 0) const uint BLOCK_SIZE = 32; -layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer A {float data_a[];}; -layout (binding = 1) writeonly buffer D {float data_d[];}; +layout (binding = 1) readonly buffer B {float data_s[];}; +layout (binding = 2) writeonly buffer D {float data_d[];}; layout (push_constant) uniform parameter { uint D; uint N; + uint ne3; uint k_num; + uint sinks; } p; +shared float tmpsh[BLOCK_SIZE]; + void main() { // Each workgroup handles a row const uint n = gl_WorkGroupID.x; const uint tid = gl_LocalInvocationID.x; + const uint iq3 = gl_WorkGroupID.z; uint D = p.D; uint N = p.N; uint k_num = p.k_num; - uint l_offset = D * N * k_num + n; - uint m_offset = D * N * k_num + N + n; + uint l_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + n; + uint m_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + N + n; uint lm_stride = N * 2; // Compute the max m value for the row float m_max = -1.0/0.0; - [[unroll]] for (uint k = 0; k < k_num; ++k) { - float m = data_a[m_offset + k * lm_stride]; + for (uint k = 0; k + tid < k_num; k += BLOCK_SIZE) { + float m = data_a[m_offset + (k + tid) * lm_stride]; m_max = max(m_max, m); } + // reduce across the workgroup + tmpsh[tid] = m_max; + barrier(); + [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) { + if (tid < s) { + m_max = max(m_max, tmpsh[tid + s]); + tmpsh[tid] = m_max; + } + barrier(); + } + m_max = tmpsh[0]; + + barrier(); + // Compute L based on m_max float L = 0; - [[unroll]] for (uint k = 0; k < k_num; ++k) { - float l = data_a[l_offset + k * lm_stride]; - float m = data_a[m_offset + k * lm_stride]; + for (uint k = 0; k + tid < k_num; k += BLOCK_SIZE) { + float l = data_a[l_offset + (k + tid) * lm_stride]; + float m = data_a[m_offset + (k + tid) * lm_stride]; L += exp(m - m_max) * l; } + // reduce across the workgroup + tmpsh[tid] = L; + barrier(); + [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) { + if (tid < s) { + L += tmpsh[tid + s]; + tmpsh[tid] = L; + } + barrier(); + } + L = tmpsh[0]; + + float sink; + if (p.sinks != 0) { + sink = data_s[n]; + + float ms = 1.0f; + float vs = 1.0f; + + if (sink > m_max) { + ms = exp(m_max - sink); + } else { + vs = exp(sink - m_max); + } + + L = L*ms + vs; + } + L = 1.0 / L; + // D dimension is split across workgroups in the y dimension + uint d = tid + gl_WorkGroupID.y * BLOCK_SIZE; // Scale and sum the O contributions based on m_max and store the result to memory - for (uint d = tid; d < D; d += BLOCK_SIZE) { + if (d < D) { float O = 0.0; [[unroll]] for (uint k = 0; k < k_num; ++k) { - uint o_offset = D * N * k + D * n + d; + uint o_offset = D * N * (k + iq3 * k_num) + D * n + d; float m = data_a[m_offset + k * lm_stride]; O += exp(m - m_max) * data_a[o_offset]; } + if (p.sinks != 0) { + if (sink > m_max) { + float ms = 1.0f; + ms = exp(m_max - sink); + O *= ms; + } + } O *= L; - data_d[D * n + d] = O; + data_d[iq3 * D * N + D * n + d] = O; } } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp new file mode 100644 index 000000000..f4268ed24 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp @@ -0,0 +1,13 @@ +#version 450 + +#include "glu_head.comp" + +const float GELU_COEF_A = 0.044715f; +const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; + +float op(float a, float b) { + const float val = SQRT_2_OVER_PI*a*(1.0f + GELU_COEF_A*a*a); + return 0.5f*a*(2.0f - 2.0f / (exp(2 * val) + 1)) * b; +} + +#include "glu_main.comp" diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp new file mode 100644 index 000000000..cbd4cb36b --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp @@ -0,0 +1,27 @@ +#version 450 + +#include "glu_head.comp" + +// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation +// ref: https://www.johndcook.com/blog/python_erf/ +const float p_erf = 0.3275911f; +const float a1_erf = 0.254829592f; +const float a2_erf = -0.284496736f; +const float a3_erf = 1.421413741f; +const float a4_erf = -1.453152027f; +const float a5_erf = 1.061405429f; + +const float SQRT_2_INV = 0.70710678118654752440084436210484f; + +float op(float a, float b) { + const float a_div_sqr2 = a * SQRT_2_INV; + const float sign_x = sign(a_div_sqr2); + const float x = abs(a_div_sqr2); + const float t = 1.0f / (1.0f + p_erf * x); + const float y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x); + const float erf_approx = sign_x * y; + + return 0.5f * a * (1.0f + erf_approx) * b; +} + +#include "glu_main.comp" diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp new file mode 100644 index 000000000..3a2a6897b --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp @@ -0,0 +1,11 @@ +#version 450 + +#include "glu_head.comp" + +const float GELU_QUICK_COEF = -1.702f; + +float op(float a, float b) { + return a * (1.0f / (1.0f + exp(GELU_QUICK_COEF * a))) * b; +} + +#include "glu_main.comp" diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp new file mode 100644 index 000000000..5fd5a5e70 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp @@ -0,0 +1,39 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + // based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation + // ref: https://www.johndcook.com/blog/python_erf/ + const float p_erf = 0.3275911f; + const float a1_erf = 0.254829592f; + const float a2_erf = -0.284496736f; + const float a3_erf = 1.421413741f; + const float a4_erf = -1.453152027f; + const float a5_erf = 1.061405429f; + + const float SQRT_2_INV = 0.70710678118654752440084436210484f; + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float a = float(data_a[i]); + const float a_div_sqr2 = a * SQRT_2_INV; + const float sign_x = sign(a_div_sqr2); + const float x = abs(a_div_sqr2); + const float t = 1.0f / (1.0f + p_erf * x); + const float y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x); + const float erf_approx = sign_x * y; + + data_d[i] = D_TYPE(0.5f * a * (1.0f + erf_approx)); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp index 062e2a4cd..4b4316cf3 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp @@ -1,6 +1,8 @@ #extension GL_EXT_shader_16bit_storage : require #extension GL_EXT_control_flow_attributes : require +#include "rte.comp" + layout (push_constant) uniform parameter { uint ne; diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp new file mode 100644 index 000000000..51d70869d --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp @@ -0,0 +1,19 @@ +#extension GL_EXT_shader_16bit_storage : require + +#include "rte.comp" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer B {A_TYPE data_b[];}; +layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; + +layout (push_constant) uniform parameter +{ + uint N; + uint ne00; + uint ne20; + uint mode; + float alpha; + float limit; +} p; diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp new file mode 100644 index 000000000..85cf65a9e --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp @@ -0,0 +1,29 @@ +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.N) { + return; + } + + const uint row = i / p.ne20; + const uint col = i - row * p.ne20; + + if (p.mode == 0) { + // Default + const uint offset = p.ne00 / 2; + const uint idx = row * p.ne00 + col; + + data_d[row * offset + col] = D_TYPE(op(float(data_a[idx]), float(data_a[idx + offset]))); + } else if (p.mode == 1) { + // Swapped + const uint offset = p.ne00 / 2; + const uint idx = row * p.ne00 + col; + + data_d[row * offset + col] = D_TYPE(op(float(data_a[idx + offset]), float(data_a[idx]))); + } else { + // Split + const uint idx = row * p.ne00 + col; + + data_d[idx] = D_TYPE(op(float(data_a[idx]), float(data_b[idx]))); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp index 09aa849e8..fdbcf7eba 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp @@ -1,12 +1,9 @@ #version 450 #extension GL_EXT_shader_16bit_storage : require -#extension GL_EXT_spirv_intrinsics: enable #extension GL_EXT_control_flow_attributes : require -#if RTE16 -spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits -#endif +#include "rte.comp" layout (push_constant) uniform parameter { @@ -43,12 +40,10 @@ void main() { const uint src_base = ic * p.offset_delta + batch * p.batch_offset; const uint dst_base = ((batch * p.OH + oh) * p.OW) * p.CHW + ic * (p.KW * p.KH); const int oh_s1 = int(oh) * p.s1; - const uint ksize = p.OW * (p.KH > 1 ? p.KW : 1); + const uint ksize = p.OW * p.KH; const uint base_linear_idx = gidx * NUM_ITER; - const uint max_ky = ksize / p.OW; - uint current_kx = base_linear_idx / ksize; const uint rem = base_linear_idx - (current_kx * ksize); uint current_ky = rem / p.OW; @@ -79,7 +74,7 @@ void main() { if (++current_ix == p.OW) { current_ix = 0; - if (++current_ky == max_ky) { + if (++current_ky == p.KH) { current_ky = 0; current_kx++; } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp index bc633369f..638878d94 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp @@ -26,6 +26,9 @@ layout (push_constant) uniform parameter uint ne12; uint b_offset; uint d_offset; + uint nb03; + uint nb13; + uint nb23; } p; shared FLOAT_TYPE tmp[BLOCK_SIZE]; @@ -34,6 +37,7 @@ void main() { const uint tid = gl_LocalInvocationID.x; const uint row_x = gl_GlobalInvocationID.y; const uint channel = gl_GlobalInvocationID.z; + const uint i3 = gl_WorkGroupID.x; const uint channel_x = channel / p.channel_x_divisor; const uint channel_y = channel % p.ne12; @@ -41,7 +45,7 @@ void main() { const uint nrows_dst = p.nrows_x; const uint row_dst = row_x; - const uint idst = channel*nrows_dst + row_dst; + const uint idst = i3*p.nb23 + channel*nrows_dst + row_dst; FLOAT_TYPE temp = 0.0f; @@ -58,8 +62,8 @@ void main() { const uint row_y = col_x; - const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x; - const uint iy = channel_y*p.channel_stride_y + row_y; + const uint ix = i3*p.nb03 + channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x; + const uint iy = i3*p.nb13 + channel_y*p.channel_stride_y + row_y; const vec4 av4 = vec4(data_a_v4[ix / 4]); const vec4 bv4 = vec4(data_b_v4[iy / 4]); @@ -74,8 +78,8 @@ void main() { const uint row_y = col_x; - const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x; - const uint iy = channel_y*p.channel_stride_y + row_y; + const uint ix = i3*p.nb03 + channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x; + const uint iy = i3*p.nb13 + channel_y*p.channel_stride_y + row_y; const vec4 av4 = vec4(data_a_v4[ix / 4]); const vec4 bv4 = vec4(data_b_v4[iy / 4]); @@ -91,8 +95,8 @@ void main() { const uint row_y = col_x; - const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x; - const uint iy = channel_y*p.channel_stride_y + row_y; + const uint ix = i3*p.nb03 + channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x; + const uint iy = i3*p.nb13 + channel_y*p.channel_stride_y + row_y; const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]); diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index 7859a1a60..8c5114a79 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -7,7 +7,7 @@ #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require #endif #if defined(DATA_A_IQ1_M) -#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require #endif #if defined(DATA_A_BF16) && defined(COOPMAT) @@ -18,6 +18,7 @@ #extension GL_KHR_cooperative_matrix : enable #extension GL_KHR_memory_scope_semantics : enable #extension GL_KHR_shader_subgroup_basic : enable +#extension GL_KHR_shader_subgroup_ballot : enable #endif #ifdef MUL_MAT_ID @@ -104,6 +105,10 @@ shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE]; #ifdef MUL_MAT_ID shared u16vec2 row_ids[4096]; +uint _ne1; +#ifdef COOPMAT +shared uint _ne1_sh; +#endif #endif // MUL_MAT_ID #define NUM_WARPS (BLOCK_SIZE / WARP) @@ -172,7 +177,47 @@ void main() { const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B / BK; #ifdef MUL_MAT_ID - uint _ne1 = 0; +#ifdef COOPMAT + // Spread the search across all elements in the first subgroup + if (gl_SubgroupID == 0) { + _ne1 = 0; + uint num_elements = p.nei1 * p.nei0; + + uint ids[16]; + uint iter = 0; + + for (uint j = 0; j < num_elements; j += gl_SubgroupSize) { + // prefetch up to 16 elements + if (iter == 0) { + [[unroll]] for (uint k = 0; k < 16; ++k) { + uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize; + bool in_range = i < num_elements; + uint ii1 = i / p.nei0; + uint ii0 = i % p.nei0; + ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0; + } + } + uint i = j + gl_SubgroupInvocationID; + bool in_range = i < num_elements; + uint ii1 = i / p.nei0; + uint ii0 = i % p.nei0; + uint id = ids[iter++]; + uvec4 ballot = subgroupBallot(in_range && id == expert_idx); + uint idx = subgroupBallotExclusiveBitCount(ballot); + if (in_range && id == expert_idx) { + row_ids[_ne1 + idx] = u16vec2(ii0, ii1); + } + _ne1 += subgroupBallotBitCount(ballot); + iter &= 15; + } + _ne1_sh = _ne1; + } + + barrier(); + + _ne1 = _ne1_sh; +#else + _ne1 = 0; for (uint ii1 = 0; ii1 < p.nei1; ii1++) { for (uint ii0 = 0; ii0 < p.nei0; ii0++) { if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) { @@ -183,6 +228,7 @@ void main() { } barrier(); +#endif // Workgroup has no work if (ic * BN >= _ne1) return; @@ -500,10 +546,9 @@ void main() { const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - const uint ib = idx / 128; // 2 values per idx - const uint ib32 = (idx % 128) / 16; // 0..7 - const uint ib8 = (idx % 128) / 4; - const int i8 = 2 * int(idx % 4); + const uint ib = idx / 32; // 8 values per idx + const uint ib32 = (idx % 32) / 4; // 0..7 + const uint ib8 = idx % 32; const float d = float(data_a[ib].d); const uint qh = data_a[ib].qh[ib32]; @@ -512,22 +557,16 @@ void main() { const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA; const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]); - const ivec2 gvec = ivec2( - bitfieldExtract(grid, 2 * (i8), 2), - bitfieldExtract(grid, 2 * (i8 + 1), 2) - ); - const vec2 v = dl * (vec2(gvec) + delta); - - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + [[unroll]] for (int k = 0; k < 8; ++k) { + buf_a[buf_idx + k] = FLOAT_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta)); + } #elif defined(DATA_A_IQ1_M) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - const uint ib = idx / 128; // 2 values per idx - const uint ib8 = (idx % 128) / 4; + const uint ib = idx / 32; // 8 values per idx + const uint ib8 = idx % 32; const uint ib16 = ib8 / 2; - const int i8 = 2 * int(idx % 4); const uint16_t[4] scales = data_a[ib].scales; const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12; @@ -538,21 +577,17 @@ void main() { const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1); const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA; const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]); - const ivec2 gvec = ivec2( - bitfieldExtract(grid, 2 * (i8), 2), - bitfieldExtract(grid, 2 * (i8 + 1), 2) - ); - const vec2 v = dl * (vec2(gvec) + delta); - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + [[unroll]] for (int k = 0; k < 8; ++k) { + buf_a[buf_idx + k] = FLOAT_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta)); + } #elif defined(DATA_A_IQ2_XXS) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - const uint ib = idx / 128; // 2 values per idx - const uint ib32 = (idx % 128) / 16; // 0..7 - const uint ib8 = (idx / 4) % 4; + const uint ib = idx / 32; // 8 values per idx + const uint ib32 = (idx % 32) / 4; // 0..7 + const uint ib8 = idx % 4; const float d = float(data_a[ib].d); const uint qs = data_a[ib].qs[8 * ib32 + ib8]; @@ -562,63 +597,81 @@ void main() { data_a[ib].qs[8*ib32 + 6], data_a[ib].qs[8*ib32 + 7] )); - const float db = d * 0.25 * (0.5 + (signs >> 28)); + const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + (signs >> 28))); const uint32_t sign7 = bitfieldExtract(signs, 7 * int(ib8), 7); - const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4)); - const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); - const uint grid = iq2xxs_grid[qs][(idx % 4) / 2] >> (16 * (idx & 1)); - const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147 + const uint sign = sign7 | (bitCount(sign7) << 7); + const uvec2 grid = iq2xxs_grid[qs]; + const vec4 grid0 = vec4(unpack8(grid.x)); + const vec4 grid1 = vec4(unpack8(grid.y)); - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x); + buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y); + buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z); + buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w); + buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x); + buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y); + buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z); + buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w); #elif defined(DATA_A_IQ2_XS) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - const uint ib = idx / 128; // 2 values per idx - const uint ib32 = (idx % 128) / 16; // 0..7 - const uint ib8 = (idx / 4) % 4; // 0..3 + const uint ib = idx / 32; // 8 values per idx + const uint ib32 = (idx % 32) / 4; // 0..7 + const uint ib8 = idx % 4; // 0..3 const float d = float(data_a[ib].d); const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf; - const float db = d * 0.25 * (0.5 + scale); + const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale)); const uint qs = data_a[ib].qs[4 * ib32 + ib8]; const uint sign7 = qs >> 9; - const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4)); - const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); - const uint grid = iq2xs_grid[qs & 511][(idx % 4) / 2] >> (16 * (idx & 1)); - const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147 + const uint sign = sign7 | (bitCount(sign7) << 7); + const uvec2 grid = iq2xs_grid[qs & 511]; + const vec4 grid0 = vec4(unpack8(grid.x)); + const vec4 grid1 = vec4(unpack8(grid.y)); - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x); + buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y); + buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z); + buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w); + buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x); + buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y); + buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z); + buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w); #elif defined(DATA_A_IQ2_S) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - const uint ib = idx / 128; // 2 values per idx - const uint ib8 = (idx % 128) / 4; // 0..31 - const uint ib32 = ib8 / 4; // 0..7 + const uint ib = idx / 32; // 8 values per idx + const uint ib8 = idx % 32; // 0..31 + const uint ib32 = ib8 / 4; // 0..7 const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf; const uint qs = data_a[ib].qs[ib8]; const uint qh = data_a[ib].qh[ib32]; const uint qhshift = 2 * (ib8 % 4); - const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8] >> (2 * (idx % 4)); + const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8]; const float d = float(data_a[ib].d); - const float db = d * 0.25 * (0.5 + scale); - const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); - const uint16_t grid = unpack16(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 2) >> 1])[idx & 1]; - const vec2 v = db * vec2(sign01) * vec2(unpack8(uint32_t(grid)).xy); // vec4 used due to #12147 + const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale)); + const uvec2 grid = iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)]; + const vec4 grid0 = vec4(unpack8(grid.x)); + const vec4 grid1 = vec4(unpack8(grid.y)); - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x); + buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y); + buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z); + buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w); + buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x); + buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y); + buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z); + buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w); #elif defined(DATA_A_IQ3_XXS) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - const uint ib = idx / 128; // 2 values per idx - const uint iqs = (idx % 128) / 2; // 0..63 + const uint ib = idx / 64; // 4 values per idx + const uint iqs = idx % 64; // 0..63 const uint is = QUANT_K / 4 + 4 * (iqs / 8); // 8 values const float d = float(data_a[ib].d); @@ -631,33 +684,36 @@ void main() { )); const float db = d * 0.5 * (0.5 + (signs >> 28)); const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7); - const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4)); - const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); - const uint grid = iq3xxs_grid[qs] >> (16 * (idx & 1)); - const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147 + const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (4 * (idx % 2)); + const uint grid = iq3xxs_grid[qs]; + const vec4 v = db * vec4(unpack8(grid)); - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + buf_a[buf_idx ] = FLOAT_TYPE((sign & 1) != 0 ? -v.x : v.x); + buf_a[buf_idx + 1] = FLOAT_TYPE((sign & 2) != 0 ? -v.y : v.y); + buf_a[buf_idx + 2] = FLOAT_TYPE((sign & 4) != 0 ? -v.z : v.z); + buf_a[buf_idx + 3] = FLOAT_TYPE((sign & 8) != 0 ? -v.w : v.w); #elif defined(DATA_A_IQ3_S) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - const uint ib = idx / 128; // 2 values per idx - const uint iqs = (idx % 128) / 2; // 0..63 + const uint ib = idx / 64; // 4 values per idx + const uint iqs = idx % 64; // 0..63 const uint iqh = iqs / 8; const float d = float(data_a[ib].d); const uint qs = data_a[ib].qs[iqs]; const uint qh = data_a[ib].qh[iqh]; - const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (2 * (idx % 4))); + const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (4 * (idx % 2))); const uint scale = data_a[ib].scales[iqs / 16]; const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign))); const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf)); - const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> (16 * (idx % 2)); - const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147 + const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)]; + const vec4 v = db * vec4(unpack8(grid)); - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + buf_a[buf_idx ] = FLOAT_TYPE((sign & 1) != 0 ? -v.x : v.x); + buf_a[buf_idx + 1] = FLOAT_TYPE((sign & 2) != 0 ? -v.y : v.y); + buf_a[buf_idx + 2] = FLOAT_TYPE((sign & 4) != 0 ? -v.z : v.z); + buf_a[buf_idx + 3] = FLOAT_TYPE((sign & 8) != 0 ? -v.w : v.w); #elif defined(DATA_A_IQ4_XS) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; @@ -691,6 +747,21 @@ void main() { buf_a[buf_idx + 1 ] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 8, 4)]) * d; buf_a[buf_idx + 16] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 4, 4)]) * d; buf_a[buf_idx + 17] = FLOAT_TYPE(kvalues_iq4nl[vui >> 12]) * d; +#elif defined(DATA_A_MXFP4) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a; + + const uint ib = idx / 8; + const uint iqs = (idx & 0x07) * 2; + + const float d = e8m0_to_fp32(data_a[ib].e); + const uint vui = uint(data_a[ib].qs[iqs]); + const uint vui2 = uint(data_a[ib].qs[iqs+1]); + + buf_a[buf_idx ] = FLOAT_TYPE(kvalues_mxfp4[vui & 0xF] * d); + buf_a[buf_idx + 16] = FLOAT_TYPE(kvalues_mxfp4[vui >> 4] * d); + buf_a[buf_idx + 1] = FLOAT_TYPE(kvalues_mxfp4[vui2 & 0xF] * d); + buf_a[buf_idx + 17] = FLOAT_TYPE(kvalues_mxfp4[vui2 >> 4] * d); #endif } [[unroll]] for (uint l = 0; l < BN; l += loadstride_b) { diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp index 918465757..29e4b5c9c 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp @@ -162,17 +162,32 @@ void main() { _ne1 = 0; uint num_elements = p.nei1 * p.nei0; - for (uint i = gl_SubgroupInvocationID; subgroupAny(i < num_elements); i += gl_SubgroupSize) { + uint ids[16]; + uint iter = 0; + + for (uint j = 0; j < num_elements; j += gl_SubgroupSize) { + // prefetch up to 16 elements + if (iter == 0) { + [[unroll]] for (uint k = 0; k < 16; ++k) { + uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize; + bool in_range = i < num_elements; + uint ii1 = i / p.nei0; + uint ii0 = i % p.nei0; + ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0; + } + } + uint i = j + gl_SubgroupInvocationID; bool in_range = i < num_elements; - uint ii0 = i % p.nei0; uint ii1 = i / p.nei0; - uint id = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0; + uint ii0 = i % p.nei0; + uint id = ids[iter++]; uvec4 ballot = subgroupBallot(in_range && id == expert_idx); uint idx = subgroupBallotExclusiveBitCount(ballot); if (in_range && id == expert_idx) { row_ids[_ne1 + idx] = u16vec4(ii0 % p.ne11, ii1, ii0, 0); } _ne1 += subgroupBallotBitCount(ballot); + iter &= 15; } _ne1_sh = _ne1; } @@ -414,17 +429,31 @@ void main() { fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false); } - coopmat mat_a; - coopmat mat_b; + if ((ir + 1) * BM <= p.M && block_k + BK <= end_k) { + coopmat mat_a; + coopmat mat_b; - coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); #ifdef MUL_MAT_ID - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB); #else - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose); #endif - sum = coopMatMulAdd(mat_a, mat_b, sum); + sum = coopMatMulAdd(mat_a, mat_b, sum); + } else { + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); +#ifdef MUL_MAT_ID + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB); +#else + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose); +#endif + + sum = coopMatMulAdd(mat_a, mat_b, sum); + } } // Convert from ACC_TYPE to D_TYPE diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp index 63b15471b..34e8db977 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp @@ -92,6 +92,12 @@ FLOAT_TYPE get_d(uint ib) { } #endif +#if defined(DATA_A_MXFP4) +FLOAT_TYPE get_d(uint ib) { + return FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e)); +} +#endif + #if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1) FLOAT_TYPE_VEC2 get_dm(uint ib) { return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm); diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp new file mode 100644 index 000000000..0073d8f76 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp @@ -0,0 +1,9 @@ +#version 450 + +#include "glu_head.comp" + +float op(float a, float b) { + return max(a, 0.0f) * b; +} + +#include "glu_main.comp" diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp index deb8ee996..bdd7db2d6 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp @@ -1,11 +1,13 @@ #version 450 -#include "generic_unary_head.comp" +#include "generic_binary_head.comp" #include "types.comp" #extension GL_EXT_control_flow_attributes : enable #define BLOCK_SIZE 512 +layout (constant_id = 1) const bool do_multiply = false; + layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; shared FLOAT_TYPE sum[BLOCK_SIZE]; @@ -25,6 +27,7 @@ void main() { const uint stride_sample = p.nb03; uint32_t a_offset = samp*stride_sample + channel*stride_channel + row*stride_row + get_aoffset(); + uint32_t b_offset = src1_idx(0, row, channel, samp) + get_boffset(); uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset(); sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp @@ -46,7 +49,19 @@ void main() { const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(ncols); const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1)); - [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) { - data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col])); + if (do_multiply) { + if (ncols > p.ne10) { + [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) { + data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + fastmod(col, p.ne10)])); + } + } else { + [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) { + data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col])); + } + } + } else { + [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) { + data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col])); + } } } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp new file mode 100644 index 000000000..b9abe8ded --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp @@ -0,0 +1,46 @@ +#version 450 + +#include "types.comp" +#include "generic_unary_head.comp" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +uint wrap_idx(int i, uint ne) { + if (i < 0) { + return i + ne; + } else if (i >= ne) { + return i - ne; + } + return i; +} + +void main() { + const uint idx = get_idx(); + if (idx >= p.ne) { + return; + } + + const uint i3 = fastdiv(idx, p.ne1_012mp, p.ne1_012L); + const uint i3_offset = i3 * p.ne12*p.ne11*p.ne10; + const uint i2 = fastdiv(idx - i3_offset, p.ne1_01mp, p.ne1_01L); + const uint i2_offset = i2*p.ne11*p.ne10; + const uint i1 = fastdiv(idx - i3_offset - i2_offset, p.ne1_0mp, p.ne1_0L); + const uint i0 = idx - i3_offset - i2_offset - i1*p.ne10; + + const uint p1 = floatBitsToUint(p.param1); + const uint p2 = floatBitsToUint(p.param2); + const int s0 = int(p1 >> 16) - 0x8000; + const int s1 = int(p1 & 0xFFFF) - 0x8000; + const int s2 = int(p2 >> 16) - 0x8000; + const int s3 = int(p2 & 0xFFFF) - 0x8000; + + const uint i00 = wrap_idx(int(i0) - s0, p.ne10); + const uint i01 = wrap_idx(int(i1) - s1, p.ne11); + const uint i02 = wrap_idx(int(i2) - s2, p.ne12); + const uint i03 = wrap_idx(int(i3) - s3, p.ne13); + + const uint a_idx = i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00; + const uint d_idx = i3 *p.nb13 + i2 *p.nb12 + i1 *p.nb11 + i0 *p.nb10; + + data_d[get_doffset() + d_idx] = D_TYPE(data_a[get_aoffset() + a_idx]); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp index 96c9c4cbd..00e203e73 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp @@ -1,11 +1,8 @@ #include "types.comp" #extension GL_EXT_shader_16bit_storage : require -#extension GL_EXT_spirv_intrinsics: enable -#if RTE16 -spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits -#endif +#include "rte.comp" layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in; diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp index 4f5b1a0ec..5808710cc 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp @@ -14,21 +14,19 @@ void main() { const uint row_dst = gl_GlobalInvocationID.x; - if (i0 >= p.n_dims) { - const uint i = row_dst*ne0 + i0; - - data_d[i + 0] = data_a[i + 0]; - data_d[i + 1] = data_a[i + 1]; - - return; - } - const uint row_x = row_dst % ne1; const uint channel_x = row_dst / ne1; const uint idst = row_dst*ne0 + i0/2; const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2; + if (i0 >= p.n_dims) { + data_d[idst + i0/2 + 0] = data_a[ix + i0/2 + 0]; + data_d[idst + i0/2 + 1] = data_a[ix + i0/2 + 1]; + + return; + } + const int sect_dims = p.sections[0] + p.sections[1] + p.sections[2] + p.sections[3]; const int sec_w = p.sections[1] + p.sections[0]; const uint sector = (i0 / 2) % sect_dims; diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp index db775c456..366a7b1c4 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp @@ -13,21 +13,19 @@ void main() { const uint row_dst = gl_GlobalInvocationID.x; - if (i0 >= p.n_dims) { - const uint i = row_dst*ne0 + i0; - - data_d[i + 0] = data_a[i + 0]; - data_d[i + 1] = data_a[i + 1]; - - return; - } - const uint row_x = row_dst % ne1; const uint channel_x = row_dst / ne1; const uint idst = row_dst*ne0 + i0/2; const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2; + if (i0 >= p.n_dims) { + data_d[idst + i0/2 + 0] = data_a[ix + i0/2 + 0]; + data_d[idst + i0/2 + 1] = data_a[ix + i0/2 + 1]; + + return; + } + const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f); const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f; diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp index 4ad35e549..9643bca96 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp @@ -13,21 +13,19 @@ void main() { const uint row_dst = gl_GlobalInvocationID.x; - if (i0 >= p.n_dims) { - const uint i = row_dst*ne0 + i0; - - data_d[i + 0] = data_a[i + 0]; - data_d[i + 1] = data_a[i + 1]; - - return; - } - const uint row_x = row_dst % ne1; const uint channel_x = row_dst / ne1; const uint idst = row_dst*ne0 + i0; const uint ix = channel_x*p.s2 + row_x*p.s1 + i0; + if (i0 >= p.n_dims) { + data_d[idst + 0] = data_a[ix + 0]; + data_d[idst + 1] = data_a[ix + 1]; + + return; + } + const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f); const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f; diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp new file mode 100644 index 000000000..ad51c1e80 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp @@ -0,0 +1,5 @@ + +#if RTE16 +#extension GL_EXT_spirv_intrinsics : enable +spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits +#endif // RTE16 diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp index 4663428de..f10b0a02b 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp @@ -18,7 +18,7 @@ void main() { continue; } - data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) * FLOAT_TYPE(p.param1)); + data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) * FLOAT_TYPE(p.param1) + FLOAT_TYPE(p.param2)); idx += num_threads; } } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp index 51fc2dc7e..5f20a1ee7 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp @@ -6,12 +6,21 @@ layout (push_constant) uniform parameter { uint KX; uint KY; + uint ne00; + uint ne01; + uint ne02; + uint ne12; + uint ne13; + uint nb11; + uint nb12; + uint nb13; float scale; float max_bias; float m0; float m1; uint n_head_log2; uint nrows_x; + uint has_sinks; } p; #include "types.comp" @@ -21,7 +30,8 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; layout (binding = 1) readonly buffer Y {B_TYPE data_b[];}; -layout (binding = 2) buffer D {D_TYPE data_d[];}; +layout (binding = 2) readonly buffer Z {float data_c[];}; +layout (binding = 3) buffer D {D_TYPE data_d[];}; shared FLOAT_TYPE vals[BLOCK_SIZE]; @@ -31,7 +41,15 @@ shared FLOAT_TYPE vals[BLOCK_SIZE]; void soft_max(uint num_iters) { const uint tid = gl_LocalInvocationID.x; const uint rowx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; - const uint rowy = (p.KY > 0) ? (rowx % p.KY) : 0; + + const uint32_t i03 = rowx / (p.ne01 * p.ne02); + const uint32_t i02 = (rowx - i03 * p.ne01 * p.ne02) / p.ne01; + const uint32_t i01 = rowx % p.ne01; + + uint rowy_start = 0; + if (p.KY > 0) { + rowy_start = i01 * p.nb11 + (i02 % p.ne12) * p.nb12 + (i03 % p.ne13) * p.nb13; + } if (rowx >= p.nrows_x) { return; @@ -41,16 +59,16 @@ void soft_max(uint num_iters) { // ALiBi if (p.max_bias > 0.0f) { - const uint h = rowx/p.KY; // head index + const uint h = (rowx / p.ne01) % p.ne02; // head index const float base = h < p.n_head_log2 ? p.m0 : p.m1; - const uint exp = h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1; + const uint exp = h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1; slope = pow(base, exp); } // Find max - FLOAT_TYPE max_val = uintBitsToFloat(0xFF800000); + FLOAT_TYPE max_val = p.has_sinks == 0 ? uintBitsToFloat(0xFF800000) : data_c[i02]; // Cache values while we compute the max, so we don't need to read them // again when we're ready to compute exp(x-max). @@ -67,7 +85,7 @@ void soft_max(uint num_iters) { FLOAT_TYPE b = FLOAT_TYPE(0); if (p.KY > 0 && col < p.KX) { - b = data_b[rowy * p.KX + col]; + b = data_b[rowy_start + col]; } FLOAT_TYPE v = a * p.scale + slope * b; @@ -111,7 +129,7 @@ void soft_max(uint num_iters) { if (idx < DATA_CACHE_SIZE) { val = exp(data_cache[idx] - max_val); } else { - val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)) - max_val); + val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy_start + col]) : FLOAT_TYPE(0.0f)) - max_val); } sum += val; if (idx < DATA_CACHE_SIZE) { @@ -132,6 +150,10 @@ void soft_max(uint num_iters) { } sum = vals[0]; + if (p.has_sinks != 0) { + sum += FLOAT_TYPE(exp(FLOAT_TYPE(data_c[i02]) - max_val)); + } + FLOAT_TYPE rcpdivisor = 1.0/sum; [[unroll]] for (uint col0 = 0, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) { diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp new file mode 100644 index 000000000..a28e7c6cc --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp @@ -0,0 +1,9 @@ +#version 450 + +#include "glu_head.comp" + +float op(float a, float b) { + return a / (1.0f + exp(-a)) * b; +} + +#include "glu_main.comp" diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp new file mode 100644 index 000000000..970750eec --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp @@ -0,0 +1,14 @@ +#version 450 + +#include "glu_head.comp" + +float op(float a, float b) { + float xi = min(a, p.limit); + float gi = max(min(b, p.limit), -p.limit); + + float out_glu = xi / (1.0f + exp(-xi * p.alpha)); + out_glu = out_glu * (1.0f + gi); + return out_glu; +} + +#include "glu_main.comp" diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/types.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/types.comp index 3bde71783..a36c33e26 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/types.comp @@ -1337,6 +1337,29 @@ struct block_iq4_nl_packed16 #define A_TYPE_PACKED16 block_iq4_nl_packed16 #endif +#define QUANT_K_MXFP4 32 +#define QUANT_R_MXFP4 2 + +struct block_mxfp4 +{ + uint8_t e; + uint8_t qs[QUANT_K_MXFP4/2]; +}; + +//struct block_mxfp4_packed16 +//{ +// uint8_t e; +// uint16_t qs[QUANT_K_MXFP4/2/2]; +//}; + +#if defined(DATA_A_MXFP4) +#define QUANT_K QUANT_K_MXFP4 +#define QUANT_R QUANT_R_MXFP4 +#define QUANT_AUXF 1 +#define A_TYPE block_mxfp4 +//#define A_TYPE_PACKED16 block_mxfp4_packed16 +#endif + #if defined(DATA_A_IQ4_NL) || defined(DATA_A_IQ4_XS) const int8_t kvalues_iq4nl_const[16] = { int8_t(-127), int8_t(-104), int8_t(-83), int8_t(-65), int8_t(-49), int8_t(-35), int8_t(-22), int8_t(-10), @@ -1356,6 +1379,25 @@ void init_iq_shmem(uvec3 wgsize) } #endif +#if defined(DATA_A_MXFP4) +const FLOAT_TYPE kvalues_mxfp4_const[16] = { + FLOAT_TYPE(0.0f), FLOAT_TYPE(0.5f), FLOAT_TYPE(1.0f), FLOAT_TYPE(1.5f), FLOAT_TYPE(2.0f), FLOAT_TYPE(3.0f), FLOAT_TYPE(4.0f), FLOAT_TYPE(6.0f), + FLOAT_TYPE(-0.0f), FLOAT_TYPE(-0.5f), FLOAT_TYPE(-1.0f), FLOAT_TYPE(-1.5f), FLOAT_TYPE(-2.0f), FLOAT_TYPE(-3.0f), FLOAT_TYPE(-4.0f), FLOAT_TYPE(-6.0f) +}; + +shared FLOAT_TYPE kvalues_mxfp4[16]; + +#define NEEDS_INIT_IQ_SHMEM +void init_iq_shmem(uvec3 wgsize) +{ + // copy the table into shared memory and sync + for (uint i = gl_LocalInvocationIndex.x; i < kvalues_mxfp4.length(); i += wgsize.x) { + kvalues_mxfp4[i] = kvalues_mxfp4_const[i]; + } + barrier(); +} +#endif + // returns the bfloat value in the low 16b. // See ggml_compute_fp32_to_bf16 uint32_t fp32_to_bf16(float f) @@ -1370,4 +1412,17 @@ float bf16_to_fp32(uint32_t u) return uintBitsToFloat(u << 16); } +float e8m0_to_fp32(uint8_t x) { + uint32_t bits; + + if (x == 0) { + bits = 0x00400000; + } else { + bits = x; + bits = bits << 23; + } + + return uintBitsToFloat(bits); +} + #endif // !defined(GGML_TYPES_COMP) diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp index 6f607380d..74771def0 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp @@ -3,6 +3,7 @@ layout (push_constant) uniform parameter { uint ne; uint a_offset; uint d_offset; + uint ne00; uint ne01; uint nb00; uint nb01; uint nb02; uint nb03; uint ne10; uint ne11; uint ne12; uint ne13; float sf0; float sf1; float sf2; float sf3; @@ -15,6 +16,61 @@ layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; +// from ggml.h: enum ggml_scale_mode, enum ggml_scale_flag +#define NEAREST 0 +#define BILINEAR 1 +#define ALIGN_CORNERS (1 << 8) + +layout (constant_id = 0) const uint scale_mode = 0; + +float fetch_nearest(uint i10, uint i11, uint i12, uint i13) { + const uint i00 = uint(i10 / p.sf0); + const uint i01 = uint(i11 / p.sf1); + const uint i02 = uint(i12 / p.sf2); + const uint i03 = uint(i13 / p.sf3); + + return data_a[p.a_offset + i03 * p.nb03 + i02 * p.nb02 + i01 * p.nb01 + i00 * p.nb00]; +} + +float fetch_bilinear(ivec2 c0, ivec2 c1, vec2 d, uint i12, uint i13) { + const uint i02 = uint(i12 / p.sf2); + const uint i03 = uint(i13 / p.sf3); + const uint base = p.a_offset + i03 * p.nb03 + i02 * p.nb02; + + const float v00 = data_a[base + c0.y * p.nb01 + c0.x * p.nb00]; + const float v01 = data_a[base + c0.y * p.nb01 + c1.x * p.nb00]; + const float v10 = data_a[base + c1.y * p.nb01 + c0.x * p.nb00]; + const float v11 = data_a[base + c1.y * p.nb01 + c1.x * p.nb00]; + + return + v00 * (1.0-d.x) * (1.0-d.y) + + v01 * d.x * (1.0-d.y) + + v10 * (1.0-d.x) * d.y + + v11 * d.x * d.y; +} + +float interpolate_bilinear(uint i10, uint i11, uint i12, uint i13) { + const ivec2 ne0 = ivec2(p.ne00, p.ne01); + + const vec2 c = (vec2(i10, i11) + 0.5) / vec2(p.sf0, p.sf1) - 0.5; + const vec2 c0f = floor(c); + const vec2 d = c - c0f; + const ivec2 c0 = max(ivec2(c0f), 0); + const ivec2 c1 = min(ivec2(c0f + 1), ne0 - 1); + + return fetch_bilinear(c0, c1, d, i12, i13); +} + +float interpolate_bilinear_align_corners(uint i10, uint i11, uint i12, uint i13) { + const vec2 c = vec2(i10, i11) / vec2(p.sf0, p.sf1); + const vec2 c0f = floor(c); + const vec2 d = c - c0f; + const ivec2 c0 = ivec2(c0f); + const ivec2 c1 = c0 + 1; + + return fetch_bilinear(c0, c1, d, i12, i13); +} + void main() { const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; @@ -27,10 +83,18 @@ void main() { const uint i12 = (idx / (p.ne10 * p.ne11)) % p.ne12; const uint i13 = (idx / (p.ne10 * p.ne11 * p.ne12)) % p.ne13; - const uint i00 = uint(i10 / p.sf0); - const uint i01 = uint(i11 / p.sf1); - const uint i02 = uint(i12 / p.sf2); - const uint i03 = uint(i13 / p.sf3); + float result; + switch (scale_mode) { + case NEAREST: + result = fetch_nearest(i10, i11, i12, i13); + break; + case BILINEAR: + result = interpolate_bilinear(i10, i11, i12, i13); + break; + case BILINEAR | ALIGN_CORNERS: + result = interpolate_bilinear_align_corners(i10, i11, i12, i13); + break; + } - data_d[p.d_offset + idx] = D_TYPE(data_a[p.a_offset + i03 * p.nb03 + i02 * p.nb02 + i01 * p.nb01 + i00 * p.nb00]); + data_d[p.d_offset + idx] = D_TYPE(result); } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index d196137eb..4cd94c51e 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -64,6 +64,7 @@ const std::vector type_names = { "iq3_s", "iq4_xs", "iq4_nl", + "mxfp4", "bf16", }; @@ -118,7 +119,7 @@ void execute_command(const std::string& command, std::string& stdout_str, std::s CloseHandle(pi.hProcess); CloseHandle(pi.hThread); #else -int stdout_pipe[2]; + int stdout_pipe[2]; int stderr_pipe[2]; if (pipe(stdout_pipe) != 0 || pipe(stderr_pipe) != 0) { @@ -215,7 +216,7 @@ static std::mutex compile_count_mutex; static std::condition_variable compile_count_cond; void string_to_spv_func(const std::string& _name, const std::string& in_fname, const std::map& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) { - std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat ? "_coopmat" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32")); + std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat ? "_cm1" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32")); std::string out_fname = join_paths(output_dir, name + ".spv"); std::string in_path = join_paths(input_dir, in_fname); @@ -360,9 +361,9 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool for (const auto& tname : type_names) { std::string load_vec_quant = "2"; - if ((tname == "q4_0") || (tname == "q4_1")) + if ((tname == "q4_0") || (tname == "q4_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s")) load_vec_quant = "8"; - else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq4_nl")) + else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl") || (tname == "mxfp4")) load_vec_quant = "4"; if (tname == "bf16") { @@ -424,6 +425,7 @@ void process_shaders() { // flash attention for (const auto& f16acc : {false, true}) { std::string acctype = f16acc ? "float16_t" : "float"; + std::string acctypev4 = f16acc ? "f16vec4" : "vec4"; for (const auto& tname : type_names) { if (tname == "f32") { @@ -440,6 +442,16 @@ void process_shaders() { string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc); } +#endif +#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + if (tname == "f16") { + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", + merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"ACC_TYPEV4", acctypev4}, {"COOPMAT", "1"}}), true, true, false, f16acc); + } else if (tname == "q4_0" || tname == "q8_0") { + std::string data_a_key = "DATA_A_" + to_uppercase(tname); + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", + merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"ACC_TYPEV4", acctypev4}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc); + } #endif if (tname == "f16") { string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", @@ -486,7 +498,7 @@ void process_shaders() { // Norms string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); - string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); @@ -507,6 +519,11 @@ void process_shaders() { string_to_spv("cpy_" + t + "_f32", "copy_from_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); } + for (std::string t : {"f32", "f16", "bf16", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) { + string_to_spv("set_rows_" + t, "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + string_to_spv("set_rows_" + t + "_rte", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}}); + } + auto get_type_str = [](bool f16) { return f16 ? "float16_t" : "float"; }; @@ -521,8 +538,10 @@ void process_shaders() { for (auto src0_f16 : {false, true}) { for (auto src1_f16 : {false, true}) { for (auto dst_f16 : {false, true}) { - auto name = op + get_suffix(src0_f16, src1_f16, dst_f16); - string_to_spv(name.c_str(), op + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}}); + for (auto rte : {false, true}) { + auto name = op + get_suffix(src0_f16, src1_f16, dst_f16) + (rte ? "_rte" : ""); + string_to_spv(name.c_str(), op + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); + } } } } @@ -563,6 +582,8 @@ void process_shaders() { string_to_spv("gelu_f16", "gelu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("gelu_erf_f16", "gelu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("gelu_erf_f32", "gelu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("gelu_quick_f16", "gelu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("gelu_quick_f32", "gelu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("silu_f16", "silu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); @@ -574,6 +595,22 @@ void process_shaders() { string_to_spv("sigmoid_f16", "sigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("sigmoid_f32", "sigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + for (auto rte : {false, true}) { + std::string suffix = rte ? "_rte" : ""; + string_to_spv("geglu_f16" + suffix, "geglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); + string_to_spv("geglu_f32" + suffix, "geglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); + string_to_spv("reglu_f16" + suffix, "reglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); + string_to_spv("reglu_f32" + suffix, "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); + string_to_spv("swiglu_f16" + suffix, "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); + string_to_spv("swiglu_f32" + suffix, "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); + string_to_spv("swiglu_oai_f16" + suffix, "swiglu_oai.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); + string_to_spv("swiglu_oai_f32" + suffix, "swiglu_oai.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); + string_to_spv("geglu_erf_f16" + suffix, "geglu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); + string_to_spv("geglu_erf_f32" + suffix, "geglu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); + string_to_spv("geglu_quick_f16" + suffix,"geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); + string_to_spv("geglu_quick_f32" + suffix,"geglu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); + } + string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); @@ -611,6 +648,8 @@ void process_shaders() { string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("conv_transpose_1d_f32", "conv_transpose_1d.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("pool2d_f32", "pool2d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); @@ -619,9 +658,24 @@ void process_shaders() { string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); + string_to_spv("conv2d_f32_unroll", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}}); + string_to_spv("conv2d_f16_f32_unroll", "conv2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}}); + + string_to_spv("conv2d_f32", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", ""}}); + string_to_spv("conv2d_f16_f32", "conv2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", ""}}); + +#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + string_to_spv("conv2d_f32", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}, {"COOPMAT2", "1"}}, true, false, true); + string_to_spv("conv2d_f16_f32", "conv2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}, {"COOPMAT2", "1"}}, true, false, true); +#endif + string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}})); string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}})); + string_to_spv("roll_f32", "roll.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + + string_to_spv("add_id_f32", "add_id.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + for (auto &c : compiles) { c.wait(); } @@ -676,11 +730,59 @@ void write_output_files() { std::remove(path.c_str()); } } + + std::string suffixes[2] = {"_f32", "_f16"}; for (const char *op : {"add", "sub", "mul", "div"}) { - fprintf(hdr, "extern unsigned char *%s_data[2][2][2];\n", op); - fprintf(hdr, "extern uint64_t %s_len[2][2][2];\n", op); - fprintf(src, "unsigned char *%s_data[2][2][2] = {{{%s_f32_f32_f32_data, %s_f32_f32_f16_data}, {%s_f32_f16_f32_data, %s_f32_f16_f16_data}}, {{%s_f16_f32_f32_data, %s_f16_f32_f16_data}, {%s_f16_f16_f32_data, %s_f16_f16_f16_data}}};\n", op, op, op, op, op, op, op, op, op); - fprintf(src, "uint64_t %s_len[2][2][2] = {{{%s_f32_f32_f32_len, %s_f32_f32_f16_len}, {%s_f32_f16_f32_len, %s_f32_f16_f16_len}}, {{%s_f16_f32_f32_len, %s_f16_f32_f16_len}, {%s_f16_f16_f32_len, %s_f16_f16_f16_len}}};\n", op, op, op, op, op, op, op, op, op); + fprintf(hdr, "extern unsigned char *%s_data[2][2][2][2];\n", op); + fprintf(hdr, "extern uint64_t %s_len[2][2][2][2];\n", op); + std::string data = "unsigned char *" + std::string(op) + "_data[2][2][2][2] = "; + std::string len = "uint64_t " + std::string(op) + "_len[2][2][2][2] = "; + for (uint32_t t0 = 0; t0 < 2; ++t0) { + if (t0 == 0) { + data += "{"; + len += "{"; + } + for (uint32_t t1 = 0; t1 < 2; ++t1) { + if (t1 == 0) { + data += "{"; + len += "{"; + } + for (uint32_t t2 = 0; t2 < 2; ++t2) { + if (t2 == 0) { + data += "{"; + len += "{"; + } + for (uint32_t rte = 0; rte < 2; ++rte) { + if (rte == 0) { + data += "{"; + len += "{"; + } + data += op + suffixes[t0] + suffixes[t1] + suffixes[t2] + ((rte != 0) ? "_rte" : ""); + len += op + suffixes[t0] + suffixes[t1] + suffixes[t2] + ((rte != 0) ? "_rte" : ""); + data += "_data,"; + len += "_len,"; + if (rte == 1) { + data += "}, "; + len += "}, "; + } + } + if (t2 == 1) { + data += "}, "; + len += "}, "; + } + } + if (t1 == 1) { + data += "}, "; + len += "}, "; + } + } + if (t0 == 1) { + data += "};\n"; + len += "};\n"; + } + } + fputs(data.c_str(), src); + fputs(len.c_str(), src); } fclose(hdr); fclose(src); From af5f5bdf602f87997dc9395b894dca31864b5411 Mon Sep 17 00:00:00 2001 From: Masato Nakasaka Date: Wed, 27 Aug 2025 11:51:53 +0900 Subject: [PATCH 059/172] Removed libcap related code libcap is not directly related to Vulkan and should be added by its own PR. It adds additional library dependencies for building and also requires users to run setcap or run ollama as root, which is not ideal for easy use --- Dockerfile | 4 +- discover/gpu.go | 47 +++++++++----------- discover/gpu_info_vulkan.c | 90 +++++--------------------------------- discover/gpu_info_vulkan.h | 16 +------ discover/gpu_linux.go | 11 ----- 5 files changed, 34 insertions(+), 134 deletions(-) diff --git a/Dockerfile b/Dockerfile index 416e1bb0a..e0d568066 100644 --- a/Dockerfile +++ b/Dockerfile @@ -20,7 +20,7 @@ ENV PATH=/opt/rh/gcc-toolset-10/root/usr/bin:$PATH ARG VULKANVERSION RUN wget https://sdk.lunarg.com/sdk/download/${VULKANVERSION}/linux/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz -O /tmp/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz \ && tar xvf /tmp/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz \ - && dnf -y install ninja-build libcap-devel \ + && dnf -y install ninja-build \ && ln -s /usr/bin/python3 /usr/bin/python \ && /${VULKANVERSION}/vulkansdk -j 8 vulkan-headers \ && /${VULKANVERSION}/vulkansdk -j 8 shaderc @@ -126,7 +126,7 @@ COPY --from=build /bin/ollama /bin/ollama FROM ubuntu:24.04 RUN apt-get update \ - && apt-get install -y ca-certificates libcap2 libvulkan1 \ + && apt-get install -y ca-certificates libvulkan1 \ && apt-get clean \ && rm -rf /var/lib/apt/lists/* COPY --from=archive /bin /usr/bin diff --git a/discover/gpu.go b/discover/gpu.go index 72011e699..056191f6a 100644 --- a/discover/gpu.go +++ b/discover/gpu.go @@ -57,7 +57,6 @@ var ( cudartLibPath string oneapiLibPath string vulkanLibPath string - libcapLibPath string nvmlLibPath string rocmGPUs []RocmGPUInfo oneapiGPUs []OneapiGPUInfo @@ -187,19 +186,18 @@ func initVulkanHandles() *vulkanHandles { vHandles := &vulkanHandles{} // Short Circuit if we already know which library to use - if vulkanLibPath != "" && libcapLibPath != "" { - vHandles.deviceCount, vHandles.vulkan, _, _ = LoadVulkanMgmt([]string{vulkanLibPath}, []string{libcapLibPath}) + if vulkanLibPath != "" { + vHandles.deviceCount, vHandles.vulkan, _ = LoadVulkanMgmt([]string{vulkanLibPath}) return vHandles } vulkanPaths := FindGPULibs(VulkanMgmtName, VulkanGlobs) - libcapPaths := FindLibCapLibs() - if len(vulkanPaths) > 0 && len(libcapPaths) > 0 { - slog.Info("vulkan: load libvulkan and libcap ok") - vHandles.deviceCount, vHandles.vulkan, vulkanLibPath, libcapLibPath = LoadVulkanMgmt(vulkanPaths, libcapPaths) + if len(vulkanPaths) > 0 { + slog.Info("vulkan: load libvulkan ok") + vHandles.deviceCount, vHandles.vulkan, vulkanLibPath = LoadVulkanMgmt(vulkanPaths) } else { - slog.Info("vulkan: failed to load libvulkan or libcap") + slog.Info("vulkan: failed to load libvulkan") } return vHandles @@ -760,32 +758,27 @@ func loadOneapiMgmt(oneapiLibPaths []string) (int, *C.oneapi_handle_t, string, e return 0, nil, "", err } -func LoadVulkanMgmt(vulkanLibPaths []string, capLibPaths []string) (int, *C.vk_handle_t, string, string) { +func LoadVulkanMgmt(vulkanLibPaths []string) (int, *C.vk_handle_t, string) { var resp C.vk_init_resp_t resp.ch.verbose = getVerboseState() for _, vkLibPath := range vulkanLibPaths { - for _, capLibPath := range capLibPaths { - vkLib := C.CString(vkLibPath) - capLib := C.CString(capLibPath) - defer C.free(unsafe.Pointer(vkLib)) - defer C.free(unsafe.Pointer(capLib)) + vkLib := C.CString(vkLibPath) + defer C.free(unsafe.Pointer(vkLib)) - C.vk_init(vkLib, capLib, &resp) - if resp.err != nil { - slog.Error( - "Unable to load vulkan", - "vulkan_library", vkLibPath, - "cap_library", capLibPath, - "error", C.GoString(resp.err), - ) - C.free(unsafe.Pointer(resp.err)) - } else { - return int(resp.num_devices), &resp.ch, vkLibPath, capLibPath - } + C.vk_init(vkLib, &resp) + if resp.err != nil { + slog.Error( + "Unable to load vulkan", + "vulkan_library", vkLibPath, + "error", C.GoString(resp.err), + ) + C.free(unsafe.Pointer(resp.err)) + } else { + return int(resp.num_devices), &resp.ch, vkLibPath } } - return 0, nil, "", "" + return 0, nil, "" } func getVerboseState() C.uint16_t { diff --git a/discover/gpu_info_vulkan.c b/discover/gpu_info_vulkan.c index 29eaaeb7f..b520c8436 100644 --- a/discover/gpu_info_vulkan.c +++ b/discover/gpu_info_vulkan.c @@ -2,28 +2,6 @@ #include -int check_perfmon(vk_handle_t* rh) { -#ifdef __linux__ - cap_t caps; - const cap_value_t cap_list[1] = {CAP_PERFMON}; - - caps = (*rh->cap_get_proc)(); - if (caps == NULL) - return -1; - - if ((*rh->cap_set_flag)(caps, CAP_EFFECTIVE, 1, cap_list, CAP_SET) == -1) - return -1; - - if ((*rh->cap_set_proc)(caps) == -1) - return -1; - - if ((*rh->cap_free)(caps) == -1) - return -1; -#endif - - return 0; -} - int is_extension_supported(vk_handle_t* rh, VkPhysicalDevice device, char* extension) { VkPhysicalDeviceProperties properties; (*rh->vkGetPhysicalDeviceProperties)(device, &properties); @@ -53,30 +31,22 @@ int is_extension_supported(vk_handle_t* rh, VkPhysicalDevice device, char* exten return 0; } -void vk_init(char* vk_lib_path, char* cap_lib_path, vk_init_resp_t *resp) { +void vk_init(char* vk_lib_path, vk_init_resp_t *resp) { const int buflen = 256; char buf[buflen + 1]; int i; struct lookup { - int is_cap; char *s; void **p; } l[] = { -#ifdef __linux__ - {1, "cap_get_proc", (void *)&resp->ch.cap_get_proc}, - {1, "cap_get_bound", (void *)&resp->ch.cap_get_bound}, - {1, "cap_set_flag", (void *)&resp->ch.cap_set_flag}, - {1, "cap_set_proc", (void *)&resp->ch.cap_set_proc}, - {1, "cap_free", (void *)&resp->ch.cap_free}, -#endif - {0, "vkGetPhysicalDeviceProperties", (void *)&resp->ch.vkGetPhysicalDeviceProperties}, - {0, "vkEnumerateDeviceExtensionProperties", (void *)&resp->ch.vkEnumerateDeviceExtensionProperties}, - {0, "vkCreateInstance", (void *)&resp->ch.vkCreateInstance}, - {0, "vkEnumeratePhysicalDevices", (void *)&resp->ch.vkEnumeratePhysicalDevices}, - {0, "vkGetPhysicalDeviceMemoryProperties2", (void *)&resp->ch.vkGetPhysicalDeviceMemoryProperties2}, - {0, "vkDestroyInstance", (void *)&resp->ch.vkDestroyInstance}, - {0, NULL, NULL}, + {"vkGetPhysicalDeviceProperties", (void *)&resp->ch.vkGetPhysicalDeviceProperties}, + {"vkEnumerateDeviceExtensionProperties", (void *)&resp->ch.vkEnumerateDeviceExtensionProperties}, + {"vkCreateInstance", (void *)&resp->ch.vkCreateInstance}, + {"vkEnumeratePhysicalDevices", (void *)&resp->ch.vkEnumeratePhysicalDevices}, + {"vkGetPhysicalDeviceMemoryProperties2", (void *)&resp->ch.vkGetPhysicalDeviceMemoryProperties2}, + {"vkDestroyInstance", (void *)&resp->ch.vkDestroyInstance}, + {NULL, NULL}, }; resp->ch.vk_handle = LOAD_LIBRARY(vk_lib_path, RTLD_LAZY); @@ -91,39 +61,13 @@ void vk_init(char* vk_lib_path, char* cap_lib_path, vk_init_resp_t *resp) { return; } -#ifdef __linux__ - resp->ch.cap_handle = LOAD_LIBRARY(cap_lib_path, RTLD_LAZY); - if (!resp->ch.cap_handle) { - char *msg = LOAD_ERR(); - LOG(resp->ch.verbose, "library %s load err: %s\n", cap_lib_path, msg); - snprintf(buf, buflen, - "Unable to load %s library to query for Vulkan GPUs: %s", - cap_lib_path, msg); - free(msg); - resp->err = strdup(buf); - return; - } -#endif - for (i = 0; l[i].s != NULL; i++) { - if (l[i].is_cap) -#ifdef __linux__ - *l[i].p = LOAD_SYMBOL(resp->ch.cap_handle, l[i].s); -#else - continue; -#endif - else - *l[i].p = LOAD_SYMBOL(resp->ch.vk_handle, l[i].s); + *l[i].p = LOAD_SYMBOL(resp->ch.vk_handle, l[i].s); if (!*l[i].p) { char *msg = LOAD_ERR(); LOG(resp->ch.verbose, "dlerr: %s\n", msg); - if (l[i].is_cap) { - UNLOAD_LIBRARY(resp->ch.cap_handle); - resp->ch.cap_handle = NULL; - } else { - UNLOAD_LIBRARY(resp->ch.vk_handle); - resp->ch.vk_handle = NULL; - } + UNLOAD_LIBRARY(resp->ch.vk_handle); + resp->ch.vk_handle = NULL; snprintf(buf, buflen, "symbol lookup for %s failed: %s", l[i].s, msg); free(msg); @@ -132,12 +76,6 @@ void vk_init(char* vk_lib_path, char* cap_lib_path, vk_init_resp_t *resp) { } } - if (check_perfmon(&resp->ch) != 0) { - resp->err = strdup("performance monitoring is not allowed. Please enable CAP_PERFMON or run as root to use Vulkan."); - LOG(resp->ch.verbose, "vulkan: %s", resp->err); - return; - } - VkInstance instance; VkApplicationInfo appInfo = {}; @@ -277,10 +215,4 @@ void vk_release(vk_handle_t rh) { (*rh.vkDestroyInstance)(rh.vk, NULL); UNLOAD_LIBRARY(rh.vk_handle); rh.vk_handle = NULL; - -#ifdef __linux__ - LOG(rh.verbose, "releasing libcap library\n"); - UNLOAD_LIBRARY(rh.cap_handle); - rh.cap_handle = NULL; -#endif } diff --git a/discover/gpu_info_vulkan.h b/discover/gpu_info_vulkan.h index 1f19be58e..245d58f1f 100644 --- a/discover/gpu_info_vulkan.h +++ b/discover/gpu_info_vulkan.h @@ -4,29 +4,15 @@ #include "gpu_info.h" -#ifdef __linux__ -#include -#endif - #include typedef struct { void* vk_handle; - void* cap_handle; uint16_t verbose; VkInstance vk; int num_devices; -#ifdef __linux__ - cap_t (*cap_get_proc)(void); - - int (*cap_get_bound)(cap_value_t); - int (*cap_set_flag)(cap_t, cap_flag_t, int, const cap_value_t *, cap_flag_value_t); - int (*cap_set_proc)(cap_t); - int (*cap_free)(cap_t); -#endif - void (*vkGetPhysicalDeviceProperties)( VkPhysicalDevice physicalDevice, VkPhysicalDeviceProperties* pProperties); @@ -58,7 +44,7 @@ typedef struct vk_init_resp vk_handle_t ch; } vk_init_resp_t; -void vk_init(char* vk_lib_path, char* cap_lib_path, vk_init_resp_t *resp); +void vk_init(char* vk_lib_path, vk_init_resp_t *resp); void vk_check_vram(vk_handle_t rh, int i, mem_info_t *resp); int vk_check_flash_attention(vk_handle_t rh, int i); void vk_release(vk_handle_t rh); diff --git a/discover/gpu_linux.go b/discover/gpu_linux.go index c603ecea9..2631dd2cf 100644 --- a/discover/gpu_linux.go +++ b/discover/gpu_linux.go @@ -53,7 +53,6 @@ var ( NvmlMgmtName = "" // not currently wired on linux OneapiMgmtName = "libze_intel_gpu.so*" VulkanMgmtName = "libvulkan.so*" - libcapMgmtName = "libcap.so*" ) var VulkanGlobs = []string{ @@ -62,16 +61,6 @@ var VulkanGlobs = []string{ "/usr/lib*/libvulkan.so*", } -var capLinuxGlobs = []string{ - "/usr/lib/x86_64-linux-gnu/libcap.so*", - "/usr/lib/aarch64-linux-gnu/libvulkan.so*", - "/usr/lib*/libcap.so*", -} - -func FindLibCapLibs() []string { - return FindGPULibs(libcapMgmtName, capLinuxGlobs) -} - func GetCPUMem() (memInfo, error) { var mem memInfo var total, available, free, buffers, cached, freeSwap uint64 From 8300a55e1d18f12597fd37dbe6bc57f679d0ceda Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sat, 30 Aug 2025 20:26:53 +0200 Subject: [PATCH 060/172] Fix Unit Test (Add Vulkan Library) --- discover/gpu_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/discover/gpu_test.go b/discover/gpu_test.go index 0c6ef7bad..76ab7dfba 100644 --- a/discover/gpu_test.go +++ b/discover/gpu_test.go @@ -11,7 +11,7 @@ import ( func TestBasicGetGPUInfo(t *testing.T) { info := GetGPUInfo() assert.NotEmpty(t, len(info)) - assert.Contains(t, "cuda rocm cpu metal", info[0].Library) + assert.Contains(t, "cuda rocm cpu metal vulkan", info[0].Library) if info[0].Library != "cpu" { assert.Greater(t, info[0].TotalMemory, uint64(0)) assert.Greater(t, info[0].FreeMemory, uint64(0)) From 1fc323958206fafe9faeb8b9d9cb2c3bb0679a7c Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sat, 30 Aug 2025 20:33:06 +0200 Subject: [PATCH 061/172] Add vulkan to TestHomogeneousGPUs Test --- server/sched_test.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/server/sched_test.go b/server/sched_test.go index 3892fbbab..9bca2d0cf 100644 --- a/server/sched_test.go +++ b/server/sched_test.go @@ -717,11 +717,14 @@ func TestHomogeneousGPUs(t *testing.T) { gpus := []discover.GpuInfo{ {Library: "cuda"}, {Library: "rocm"}, + {Library: "vulkan"}, } gpus[0].TotalMemory = 1 * format.GibiByte gpus[0].FreeMemory = 256 * format.MebiByte gpus[1].TotalMemory = 1 * format.GibiByte gpus[1].FreeMemory = 256 * format.MebiByte + gpus[2].TotalMemory = 1 * format.GibiByte + gpus[2].FreeMemory = 256 * format.MebiByte return gpus } s.getCpuFn = getCpuFn From 603d3ab0ca21efc06c8f5e8c644af247d1ba1364 Mon Sep 17 00:00:00 2001 From: Xiaodong Ye Date: Mon, 18 Aug 2025 17:49:32 +0800 Subject: [PATCH 062/172] vulkan: get GPU ID (ollama v0.11.5) Signed-off-by: Xiaodong Ye --- discover/gpu_info_vulkan.c | 10 ++ ...023-vulkan-get-GPU-ID-ollama-v0.11.5.patch | 115 ++++++++++++++++++ .../ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp | 35 ++++++ 3 files changed, 160 insertions(+) create mode 100644 llama/patches/0023-vulkan-get-GPU-ID-ollama-v0.11.5.patch diff --git a/discover/gpu_info_vulkan.c b/discover/gpu_info_vulkan.c index b520c8436..a85754a29 100644 --- a/discover/gpu_info_vulkan.c +++ b/discover/gpu_info_vulkan.c @@ -203,6 +203,16 @@ void vk_check_vram(vk_handle_t rh, int i, mem_info_t *resp) { snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", i); resp->gpu_name[GPU_NAME_LEN - 1] = '\0'; strncpy(&resp->gpu_name[0], properties.deviceName, GPU_NAME_LEN - 1); + resp->gpu_name[GPU_NAME_LEN - 1] = '\0'; + const uint8_t *uuid = properties.pipelineCacheUUID; + snprintf(&resp->gpu_id[0], GPU_ID_LEN, + "GPU-%02X%02X%02X%02X-%02X%02X-%02X%02X-%02X%02X-%02X%02X%02X%02X%02X%02X", + uuid[0], uuid[1], uuid[2], uuid[3], + uuid[4], uuid[5], + uuid[6], uuid[7], + uuid[8], uuid[9], + uuid[10], uuid[11], uuid[12], uuid[13], uuid[14], uuid[15] + ); resp->total = (uint64_t) device_memory_total_size; resp->free = (uint64_t) device_memory_heap_budget; resp->major = VK_API_VERSION_MAJOR(properties.apiVersion); diff --git a/llama/patches/0023-vulkan-get-GPU-ID-ollama-v0.11.5.patch b/llama/patches/0023-vulkan-get-GPU-ID-ollama-v0.11.5.patch new file mode 100644 index 000000000..f5b8f428d --- /dev/null +++ b/llama/patches/0023-vulkan-get-GPU-ID-ollama-v0.11.5.patch @@ -0,0 +1,115 @@ +From e0ba120c913a2931010a31e0fdf160697a15b9f1 Mon Sep 17 00:00:00 2001 +From: Xiaodong Ye +Date: Mon, 18 Aug 2025 12:48:07 +0800 +Subject: [PATCH] vulkan: get GPU ID (ollama v0.11.5) + +Signed-off-by: Xiaodong Ye +--- + discover/gpu_info_vulkan.c | 9 +++++ + .../ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp | 35 +++++++++++++++++++ + 2 files changed, 44 insertions(+) + +diff --git a/discover/gpu_info_vulkan.c b/discover/gpu_info_vulkan.c +index 6d67353d..afac97dd 100644 +--- a/discover/gpu_info_vulkan.c ++++ b/discover/gpu_info_vulkan.c +@@ -171,6 +171,15 @@ void vk_check_vram(vk_handle_t rh, int i, mem_info_t *resp) { + snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", i); + strncpy(&resp->gpu_name[0], properties.deviceName, GPU_NAME_LEN - 1); + resp->gpu_name[GPU_NAME_LEN - 1] = '\0'; ++ const uint8_t *uuid = properties.pipelineCacheUUID; ++ snprintf(&resp->gpu_id[0], GPU_ID_LEN, ++ "GPU-%02X%02X%02X%02X-%02X%02X-%02X%02X-%02X%02X-%02X%02X%02X%02X%02X%02X", ++ uuid[0], uuid[1], uuid[2], uuid[3], ++ uuid[4], uuid[5], ++ uuid[6], uuid[7], ++ uuid[8], uuid[9], ++ uuid[10], uuid[11], uuid[12], uuid[13], uuid[14], uuid[15] ++ ); + resp->total = (uint64_t) device_memory_total_size; + resp->free = (uint64_t) device_memory_heap_budget; + resp->major = VK_API_VERSION_MAJOR(properties.apiVersion); +diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp +index 4070e248..1c8c15d5 100644 +--- a/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp ++++ b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp +@@ -10194,6 +10194,27 @@ static void ggml_vk_get_device_description(int device, char * description, size_ + snprintf(description, description_size, "%s", props.deviceName.data()); + } + ++static std::string ggml_vk_get_device_id(int device) { ++ ggml_vk_instance_init(); ++ ++ std::vector devices = vk_instance.instance.enumeratePhysicalDevices(); ++ ++ vk::PhysicalDeviceProperties props; ++ devices[device].getProperties(&props); ++ ++ const auto& uuid = props.pipelineCacheUUID; ++ char id[64]; ++ snprintf(id, sizeof(id), ++ "GPU-%02X%02X%02X%02X-%02X%02X-%02X%02X-%02X%02X-%02X%02X%02X%02X%02X%02X", ++ uuid[0], uuid[1], uuid[2], uuid[3], ++ uuid[4], uuid[5], ++ uuid[6], uuid[7], ++ uuid[8], uuid[9], ++ uuid[10], uuid[11], uuid[12], uuid[13], uuid[14], uuid[15] ++ ); ++ return std::string(id); ++} ++ + // backend interface + + #define UNUSED GGML_UNUSED +@@ -10790,6 +10811,12 @@ void ggml_backend_vk_get_device_description(int device, char * description, size + ggml_vk_get_device_description(dev_idx, description, description_size); + } + ++std::string ggml_backend_vk_get_device_id(int device) { ++ GGML_ASSERT(device < (int) vk_instance.device_indices.size()); ++ int dev_idx = vk_instance.device_indices[device]; ++ return ggml_vk_get_device_id(dev_idx); ++} ++ + void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total) { + GGML_ASSERT(device < (int) vk_instance.device_indices.size()); + +@@ -10812,6 +10839,7 @@ struct ggml_backend_vk_device_context { + size_t device; + std::string name; + std::string description; ++ std::string id; + }; + + static const char * ggml_backend_vk_device_get_name(ggml_backend_dev_t dev) { +@@ -10824,6 +10852,11 @@ static const char * ggml_backend_vk_device_get_description(ggml_backend_dev_t de + return ctx->description.c_str(); + } + ++static const char * ggml_backend_vk_device_get_id(ggml_backend_dev_t dev) { ++ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; ++ return ctx->id.c_str(); ++} ++ + static void ggml_backend_vk_device_get_memory(ggml_backend_dev_t device, size_t * free, size_t * total) { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)device->context; + ggml_backend_vk_get_device_memory(ctx->device, free, total); +@@ -10847,6 +10880,7 @@ static enum ggml_backend_dev_type ggml_backend_vk_device_get_type(ggml_backend_d + static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) { + props->name = ggml_backend_vk_device_get_name(dev); + props->description = ggml_backend_vk_device_get_description(dev); ++ props->id = ggml_backend_vk_device_get_id(dev); + props->type = ggml_backend_vk_device_get_type(dev); + ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total); + props->caps = { +@@ -11265,6 +11299,7 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, + ctx->device = i; + ctx->name = GGML_VK_NAME + std::to_string(i); + ctx->description = desc; ++ ctx->id = ggml_backend_vk_get_device_id(i); + devices.push_back(new ggml_backend_device { + /* .iface = */ ggml_backend_vk_device_i, + /* .reg = */ reg, +-- +2.25.1 + diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 4070e248b..1c8c15d52 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -10194,6 +10194,27 @@ static void ggml_vk_get_device_description(int device, char * description, size_ snprintf(description, description_size, "%s", props.deviceName.data()); } +static std::string ggml_vk_get_device_id(int device) { + ggml_vk_instance_init(); + + std::vector devices = vk_instance.instance.enumeratePhysicalDevices(); + + vk::PhysicalDeviceProperties props; + devices[device].getProperties(&props); + + const auto& uuid = props.pipelineCacheUUID; + char id[64]; + snprintf(id, sizeof(id), + "GPU-%02X%02X%02X%02X-%02X%02X-%02X%02X-%02X%02X-%02X%02X%02X%02X%02X%02X", + uuid[0], uuid[1], uuid[2], uuid[3], + uuid[4], uuid[5], + uuid[6], uuid[7], + uuid[8], uuid[9], + uuid[10], uuid[11], uuid[12], uuid[13], uuid[14], uuid[15] + ); + return std::string(id); +} + // backend interface #define UNUSED GGML_UNUSED @@ -10790,6 +10811,12 @@ void ggml_backend_vk_get_device_description(int device, char * description, size ggml_vk_get_device_description(dev_idx, description, description_size); } +std::string ggml_backend_vk_get_device_id(int device) { + GGML_ASSERT(device < (int) vk_instance.device_indices.size()); + int dev_idx = vk_instance.device_indices[device]; + return ggml_vk_get_device_id(dev_idx); +} + void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total) { GGML_ASSERT(device < (int) vk_instance.device_indices.size()); @@ -10812,6 +10839,7 @@ struct ggml_backend_vk_device_context { size_t device; std::string name; std::string description; + std::string id; }; static const char * ggml_backend_vk_device_get_name(ggml_backend_dev_t dev) { @@ -10824,6 +10852,11 @@ static const char * ggml_backend_vk_device_get_description(ggml_backend_dev_t de return ctx->description.c_str(); } +static const char * ggml_backend_vk_device_get_id(ggml_backend_dev_t dev) { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + return ctx->id.c_str(); +} + static void ggml_backend_vk_device_get_memory(ggml_backend_dev_t device, size_t * free, size_t * total) { ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)device->context; ggml_backend_vk_get_device_memory(ctx->device, free, total); @@ -10847,6 +10880,7 @@ static enum ggml_backend_dev_type ggml_backend_vk_device_get_type(ggml_backend_d static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) { props->name = ggml_backend_vk_device_get_name(dev); props->description = ggml_backend_vk_device_get_description(dev); + props->id = ggml_backend_vk_device_get_id(dev); props->type = ggml_backend_vk_device_get_type(dev); ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total); props->caps = { @@ -11265,6 +11299,7 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, ctx->device = i; ctx->name = GGML_VK_NAME + std::to_string(i); ctx->description = desc; + ctx->id = ggml_backend_vk_get_device_id(i); devices.push_back(new ggml_backend_device { /* .iface = */ ggml_backend_vk_device_i, /* .reg = */ reg, From 8880174a8e5483379a6a2406ae0c5792503ef589 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sat, 6 Sep 2025 20:33:56 +0200 Subject: [PATCH 063/172] disable mmap for vulkan --- llm/server.go | 1 + 1 file changed, 1 insertion(+) diff --git a/llm/server.go b/llm/server.go index 664a69fb3..e7e8b4da8 100644 --- a/llm/server.go +++ b/llm/server.go @@ -564,6 +564,7 @@ func (s *llamaServer) Load(ctx context.Context, gpus discover.GpuInfoList, requi // For CPU loads we want the memory to be allocated, not FS cache if (runtime.GOOS == "windows" && gpus[0].Library == "cuda" && s.options.UseMMap == nil) || (runtime.GOOS == "linux" && systemInfo.System.FreeMemory < s.estimate.TotalSize && s.options.UseMMap == nil) || + (gpus[0].Library == "vulkan" && s.options.UseMMap == nil) || (gpus[0].Library == "cpu" && s.options.UseMMap == nil) || (s.options.UseMMap != nil && !*s.options.UseMMap) { s.loadRequest.UseMmap = false From 80873ca49e54b008616e9cd7929146cb6279b1ce Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sat, 6 Sep 2025 20:58:35 +0200 Subject: [PATCH 064/172] Reduce Changes remove TestHomogeneousGPUs (doesn't exist on master) --- server/sched_test.go | 41 ----------------------------------------- 1 file changed, 41 deletions(-) diff --git a/server/sched_test.go b/server/sched_test.go index 2f3a9caf8..0acd59118 100644 --- a/server/sched_test.go +++ b/server/sched_test.go @@ -676,47 +676,6 @@ func TestAlreadyCanceled(t *testing.T) { require.Empty(t, scenario1a.req.successCh) } -func TestHomogeneousGPUs(t *testing.T) { - ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond) - defer done() - s := InitScheduler(ctx) - - s.getGpuFn = func() discover.GpuInfoList { - // Set memory values to require the model to be spread - gpus := []discover.GpuInfo{ - {Library: "cuda"}, - {Library: "rocm"}, - {Library: "vulkan"}, - } - gpus[0].TotalMemory = 1 * format.GibiByte - gpus[0].FreeMemory = 256 * format.MebiByte - gpus[1].TotalMemory = 1 * format.GibiByte - gpus[1].FreeMemory = 256 * format.MebiByte - gpus[2].TotalMemory = 1 * format.GibiByte - gpus[2].FreeMemory = 256 * format.MebiByte - return gpus - } - s.getCpuFn = getCpuFn - a := newScenarioRequest(t, ctx, "ollama-model-1", 10, &api.Duration{Duration: 5 * time.Millisecond}) - s.newServerFn = func(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) { - require.Len(t, gpus, 1) - return a.newServer(gpus, model, f, adapters, projectors, opts, numParallel) - } - slog.Info("a") - s.pendingReqCh <- a.req - require.Len(t, s.pendingReqCh, 1) - s.Run(ctx) - select { - case resp := <-a.req.successCh: - require.Equal(t, resp.llama, a.srv) - require.Empty(t, s.pendingReqCh) - require.Empty(t, a.req.errCh) - case err := <-a.req.errCh: - t.Fatal(err.Error()) - case <-ctx.Done(): - t.Fatal("timeout") - } -} type mockLlm struct { modelPath string pingResp error From 8687e30bb54b0ded232a14f4e1321aa0b09662fa Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sat, 6 Sep 2025 21:03:31 +0200 Subject: [PATCH 065/172] Update vulkan version to the version used in llama.cpp --- Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index 89f1d1cc9..83e2f89e4 100644 --- a/Dockerfile +++ b/Dockerfile @@ -6,7 +6,7 @@ ARG ROCMVERSION=6.3.3 ARG JETPACK5VERSION=r35.4.1 ARG JETPACK6VERSION=r36.4.0 ARG CMAKEVERSION=3.31.2 -ARG VULKANVERSION=1.4.304.1 +ARG VULKANVERSION=1.4.313.2 # We require gcc v10 minimum. v10.3 has regressions, so the rockylinux 8.5 AppStream has the latest compatible version FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base-amd64 From ab7f456cf6e689c52f3388fdc01a413a49cf52dd Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sun, 7 Sep 2025 01:05:00 +0200 Subject: [PATCH 066/172] rename gpu patch to correct number --- ...-v0.11.5.patch => 0026-vulkan-get-GPU-ID-ollama-v0.11.5.patch} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename llama/patches/{0023-vulkan-get-GPU-ID-ollama-v0.11.5.patch => 0026-vulkan-get-GPU-ID-ollama-v0.11.5.patch} (100%) diff --git a/llama/patches/0023-vulkan-get-GPU-ID-ollama-v0.11.5.patch b/llama/patches/0026-vulkan-get-GPU-ID-ollama-v0.11.5.patch similarity index 100% rename from llama/patches/0023-vulkan-get-GPU-ID-ollama-v0.11.5.patch rename to llama/patches/0026-vulkan-get-GPU-ID-ollama-v0.11.5.patch From ec7628f8535a2ef678e20add358de08c0289000a Mon Sep 17 00:00:00 2001 From: Masato Nakasaka Date: Tue, 9 Sep 2025 17:11:50 +0900 Subject: [PATCH 067/172] added Vulkan API to get correct Device UUID current UUID from pipelineCacheUUID does not match CUDA --- discover/gpu_info_vulkan.c | 14 ++++++++++++-- discover/gpu_info_vulkan.h | 3 +++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/discover/gpu_info_vulkan.c b/discover/gpu_info_vulkan.c index a85754a29..9ebe79e82 100644 --- a/discover/gpu_info_vulkan.c +++ b/discover/gpu_info_vulkan.c @@ -41,6 +41,7 @@ void vk_init(char* vk_lib_path, vk_init_resp_t *resp) { void **p; } l[] = { {"vkGetPhysicalDeviceProperties", (void *)&resp->ch.vkGetPhysicalDeviceProperties}, + {"vkGetPhysicalDeviceProperties2", (void *)&resp->ch.vkGetPhysicalDeviceProperties2}, {"vkEnumerateDeviceExtensionProperties", (void *)&resp->ch.vkEnumerateDeviceExtensionProperties}, {"vkCreateInstance", (void *)&resp->ch.vkCreateInstance}, {"vkEnumeratePhysicalDevices", (void *)&resp->ch.vkEnumeratePhysicalDevices}, @@ -176,6 +177,15 @@ void vk_check_vram(vk_handle_t rh, int i, mem_info_t *resp) { return; } + VkPhysicalDeviceProperties2 device_props2 = {}; + device_props2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2; + + VkPhysicalDeviceIDProperties id_props = {}; + id_props.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_ID_PROPERTIES; + + device_props2.pNext = &id_props; + (*rh.vkGetPhysicalDeviceProperties2)(devices[i], &device_props2); + VkPhysicalDeviceMemoryBudgetPropertiesEXT physical_device_memory_budget_properties; physical_device_memory_budget_properties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MEMORY_BUDGET_PROPERTIES_EXT; physical_device_memory_budget_properties.pNext = NULL; @@ -204,9 +214,9 @@ void vk_check_vram(vk_handle_t rh, int i, mem_info_t *resp) { resp->gpu_name[GPU_NAME_LEN - 1] = '\0'; strncpy(&resp->gpu_name[0], properties.deviceName, GPU_NAME_LEN - 1); resp->gpu_name[GPU_NAME_LEN - 1] = '\0'; - const uint8_t *uuid = properties.pipelineCacheUUID; + const uint8_t *uuid = id_props.deviceUUID; snprintf(&resp->gpu_id[0], GPU_ID_LEN, - "GPU-%02X%02X%02X%02X-%02X%02X-%02X%02X-%02X%02X-%02X%02X%02X%02X%02X%02X", + "GPU-%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x", uuid[0], uuid[1], uuid[2], uuid[3], uuid[4], uuid[5], uuid[6], uuid[7], diff --git a/discover/gpu_info_vulkan.h b/discover/gpu_info_vulkan.h index 245d58f1f..665f07fb4 100644 --- a/discover/gpu_info_vulkan.h +++ b/discover/gpu_info_vulkan.h @@ -16,6 +16,9 @@ typedef struct { void (*vkGetPhysicalDeviceProperties)( VkPhysicalDevice physicalDevice, VkPhysicalDeviceProperties* pProperties); + void (*vkGetPhysicalDeviceProperties2)( + VkPhysicalDevice physicalDevice, + VkPhysicalDeviceProperties2* pProperties); VkResult (*vkEnumerateDeviceExtensionProperties)( VkPhysicalDevice physicalDevice, const char* pLayerName, From d5cecee907e1951b3d310f73558066815c2540f5 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Tue, 9 Sep 2025 23:47:08 +0200 Subject: [PATCH 068/172] Fix GPU ID Patch --- ...026-vulkan-get-GPU-ID-ollama-v0.11.5.patch | 33 ++++--------------- 1 file changed, 6 insertions(+), 27 deletions(-) diff --git a/llama/patches/0026-vulkan-get-GPU-ID-ollama-v0.11.5.patch b/llama/patches/0026-vulkan-get-GPU-ID-ollama-v0.11.5.patch index f5b8f428d..c8d09ad01 100644 --- a/llama/patches/0026-vulkan-get-GPU-ID-ollama-v0.11.5.patch +++ b/llama/patches/0026-vulkan-get-GPU-ID-ollama-v0.11.5.patch @@ -1,38 +1,17 @@ -From e0ba120c913a2931010a31e0fdf160697a15b9f1 Mon Sep 17 00:00:00 2001 +From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 From: Xiaodong Ye Date: Mon, 18 Aug 2025 12:48:07 +0800 Subject: [PATCH] vulkan: get GPU ID (ollama v0.11.5) Signed-off-by: Xiaodong Ye --- - discover/gpu_info_vulkan.c | 9 +++++ - .../ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp | 35 +++++++++++++++++++ - 2 files changed, 44 insertions(+) + ggml/src/ggml-vulkan/ggml-vulkan.cpp | 35 ++++++++++++++++++++++++++++ + 1 file changed, 35 insertions(+) -diff --git a/discover/gpu_info_vulkan.c b/discover/gpu_info_vulkan.c -index 6d67353d..afac97dd 100644 ---- a/discover/gpu_info_vulkan.c -+++ b/discover/gpu_info_vulkan.c -@@ -171,6 +171,15 @@ void vk_check_vram(vk_handle_t rh, int i, mem_info_t *resp) { - snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", i); - strncpy(&resp->gpu_name[0], properties.deviceName, GPU_NAME_LEN - 1); - resp->gpu_name[GPU_NAME_LEN - 1] = '\0'; -+ const uint8_t *uuid = properties.pipelineCacheUUID; -+ snprintf(&resp->gpu_id[0], GPU_ID_LEN, -+ "GPU-%02X%02X%02X%02X-%02X%02X-%02X%02X-%02X%02X-%02X%02X%02X%02X%02X%02X", -+ uuid[0], uuid[1], uuid[2], uuid[3], -+ uuid[4], uuid[5], -+ uuid[6], uuid[7], -+ uuid[8], uuid[9], -+ uuid[10], uuid[11], uuid[12], uuid[13], uuid[14], uuid[15] -+ ); - resp->total = (uint64_t) device_memory_total_size; - resp->free = (uint64_t) device_memory_heap_budget; - resp->major = VK_API_VERSION_MAJOR(properties.apiVersion); -diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp +diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 4070e248..1c8c15d5 100644 ---- a/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp -+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp +--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp ++++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -10194,6 +10194,27 @@ static void ggml_vk_get_device_description(int device, char * description, size_ snprintf(description, description_size, "%s", props.deviceName.data()); } From 08bec121eb0d4ffa400a67978b69db753b68c7f4 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Wed, 10 Sep 2025 00:09:17 +0200 Subject: [PATCH 069/172] Remove Code not in llama.cpp --- .../src/ggml-vulkan/cmake/host-toolchain.cmake.in | 15 --------------- 1 file changed, 15 deletions(-) delete mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/cmake/host-toolchain.cmake.in diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/cmake/host-toolchain.cmake.in b/ml/backend/ggml/ggml/src/ggml-vulkan/cmake/host-toolchain.cmake.in deleted file mode 100644 index 2d8a85696..000000000 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/cmake/host-toolchain.cmake.in +++ /dev/null @@ -1,15 +0,0 @@ -set(CMAKE_BUILD_TYPE Release) -set(CMAKE_C_FLAGS -O2) -set(CMAKE_CXX_FLAGS -O2) -set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER) -set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY NEVER) -set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE NEVER) -set(CMAKE_C_COMPILER "@HOST_C_COMPILER@") -set(CMAKE_CXX_COMPILER "@HOST_CXX_COMPILER@") -set(CMAKE_RUNTIME_OUTPUT_DIRECTORY @CMAKE_RUNTIME_OUTPUT_DIRECTORY@) - -if("@CMAKE_C_COMPILER_ID@" STREQUAL "MSVC") - foreach(CONFIG IN ITEMS DEBUG RELEASE MINSIZEREL RELWITHDEBINFO) - set(CMAKE_RUNTIME_OUTPUT_DIRECTORY_${CONFIG} ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}) - endforeach() -endif() From dd853c4040131f0223c87a43992b0f31de75425c Mon Sep 17 00:00:00 2001 From: Masato Nakasaka Date: Wed, 10 Sep 2025 14:45:12 +0900 Subject: [PATCH 070/172] modified UUID code inside ggml --- .../0026-vulkan-get-GPU-ID-ollama-v0.11.5.patch | 10 ++++++---- ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp | 10 ++++++---- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/llama/patches/0026-vulkan-get-GPU-ID-ollama-v0.11.5.patch b/llama/patches/0026-vulkan-get-GPU-ID-ollama-v0.11.5.patch index c8d09ad01..4af0e1b7d 100644 --- a/llama/patches/0026-vulkan-get-GPU-ID-ollama-v0.11.5.patch +++ b/llama/patches/0026-vulkan-get-GPU-ID-ollama-v0.11.5.patch @@ -21,13 +21,15 @@ index 4070e248..1c8c15d5 100644 + + std::vector devices = vk_instance.instance.enumeratePhysicalDevices(); + -+ vk::PhysicalDeviceProperties props; -+ devices[device].getProperties(&props); ++ vk::PhysicalDeviceProperties2 props; ++ vk::PhysicalDeviceIDProperties deviceIDProps; ++ props.pNext = &deviceIDProps; ++ devices[device].getProperties2(&props); + -+ const auto& uuid = props.pipelineCacheUUID; ++ const auto& uuid = deviceIDProps.deviceUUID; + char id[64]; + snprintf(id, sizeof(id), -+ "GPU-%02X%02X%02X%02X-%02X%02X-%02X%02X-%02X%02X-%02X%02X%02X%02X%02X%02X", ++ "GPU-%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x", + uuid[0], uuid[1], uuid[2], uuid[3], + uuid[4], uuid[5], + uuid[6], uuid[7], diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 1c8c15d52..671323ad0 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -10199,13 +10199,15 @@ static std::string ggml_vk_get_device_id(int device) { std::vector devices = vk_instance.instance.enumeratePhysicalDevices(); - vk::PhysicalDeviceProperties props; - devices[device].getProperties(&props); + vk::PhysicalDeviceProperties2 props; + vk::PhysicalDeviceIDProperties deviceIDProps; + props.pNext = &deviceIDProps; + devices[device].getProperties2(&props); - const auto& uuid = props.pipelineCacheUUID; + const auto& uuid = deviceIDProps.deviceUUID; char id[64]; snprintf(id, sizeof(id), - "GPU-%02X%02X%02X%02X-%02X%02X-%02X%02X-%02X%02X-%02X%02X%02X%02X%02X%02X", + "GPU-%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x", uuid[0], uuid[1], uuid[2], uuid[3], uuid[4], uuid[5], uuid[6], uuid[7], From 5053b2e351516baadc4860694130eed8896a914e Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Fri, 12 Sep 2025 08:13:17 +0200 Subject: [PATCH 071/172] Fix Patch --- ...026-vulkan-get-GPU-ID-ollama-v0.11.5.patch | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/llama/patches/0026-vulkan-get-GPU-ID-ollama-v0.11.5.patch b/llama/patches/0026-vulkan-get-GPU-ID-ollama-v0.11.5.patch index 4af0e1b7d..928e85d5e 100644 --- a/llama/patches/0026-vulkan-get-GPU-ID-ollama-v0.11.5.patch +++ b/llama/patches/0026-vulkan-get-GPU-ID-ollama-v0.11.5.patch @@ -5,14 +5,14 @@ Subject: [PATCH] vulkan: get GPU ID (ollama v0.11.5) Signed-off-by: Xiaodong Ye --- - ggml/src/ggml-vulkan/ggml-vulkan.cpp | 35 ++++++++++++++++++++++++++++ - 1 file changed, 35 insertions(+) + ggml/src/ggml-vulkan/ggml-vulkan.cpp | 37 ++++++++++++++++++++++++++++ + 1 file changed, 37 insertions(+) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp -index 4070e248..1c8c15d5 100644 +index 4070e248..671323ad 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp -@@ -10194,6 +10194,27 @@ static void ggml_vk_get_device_description(int device, char * description, size_ +@@ -10194,6 +10194,29 @@ static void ggml_vk_get_device_description(int device, char * description, size_ snprintf(description, description_size, "%s", props.deviceName.data()); } @@ -42,7 +42,7 @@ index 4070e248..1c8c15d5 100644 // backend interface #define UNUSED GGML_UNUSED -@@ -10790,6 +10811,12 @@ void ggml_backend_vk_get_device_description(int device, char * description, size +@@ -10790,6 +10813,12 @@ void ggml_backend_vk_get_device_description(int device, char * description, size ggml_vk_get_device_description(dev_idx, description, description_size); } @@ -55,7 +55,7 @@ index 4070e248..1c8c15d5 100644 void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total) { GGML_ASSERT(device < (int) vk_instance.device_indices.size()); -@@ -10812,6 +10839,7 @@ struct ggml_backend_vk_device_context { +@@ -10812,6 +10841,7 @@ struct ggml_backend_vk_device_context { size_t device; std::string name; std::string description; @@ -63,7 +63,7 @@ index 4070e248..1c8c15d5 100644 }; static const char * ggml_backend_vk_device_get_name(ggml_backend_dev_t dev) { -@@ -10824,6 +10852,11 @@ static const char * ggml_backend_vk_device_get_description(ggml_backend_dev_t de +@@ -10824,6 +10854,11 @@ static const char * ggml_backend_vk_device_get_description(ggml_backend_dev_t de return ctx->description.c_str(); } @@ -75,7 +75,7 @@ index 4070e248..1c8c15d5 100644 static void ggml_backend_vk_device_get_memory(ggml_backend_dev_t device, size_t * free, size_t * total) { ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)device->context; ggml_backend_vk_get_device_memory(ctx->device, free, total); -@@ -10847,6 +10880,7 @@ static enum ggml_backend_dev_type ggml_backend_vk_device_get_type(ggml_backend_d +@@ -10847,6 +10882,7 @@ static enum ggml_backend_dev_type ggml_backend_vk_device_get_type(ggml_backend_d static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) { props->name = ggml_backend_vk_device_get_name(dev); props->description = ggml_backend_vk_device_get_description(dev); @@ -83,7 +83,7 @@ index 4070e248..1c8c15d5 100644 props->type = ggml_backend_vk_device_get_type(dev); ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total); props->caps = { -@@ -11265,6 +11299,7 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, +@@ -11265,6 +11301,7 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, ctx->device = i; ctx->name = GGML_VK_NAME + std::to_string(i); ctx->description = desc; @@ -92,5 +92,4 @@ index 4070e248..1c8c15d5 100644 /* .iface = */ ggml_backend_vk_device_i, /* .reg = */ reg, -- -2.25.1 - +2.51.0 \ No newline at end of file From da466f4f868aee8b9c17835c40e6e9a38edb3e67 Mon Sep 17 00:00:00 2001 From: "Nakasaka, Masato" Date: Tue, 16 Sep 2025 15:05:54 +0900 Subject: [PATCH 072/172] Copied minimal definition from vulkan header --- discover/gpu_info_vulkan.h | 205 ++++++++++++++++++++++++++++++++++++- 1 file changed, 204 insertions(+), 1 deletion(-) diff --git a/discover/gpu_info_vulkan.h b/discover/gpu_info_vulkan.h index 665f07fb4..470b230c5 100644 --- a/discover/gpu_info_vulkan.h +++ b/discover/gpu_info_vulkan.h @@ -4,7 +4,210 @@ #include "gpu_info.h" -#include +#define VK_DEFINE_HANDLE(object) typedef struct object##_T* object; +VK_DEFINE_HANDLE(VkInstance) +VK_DEFINE_HANDLE(VkPhysicalDevice) + +typedef uint32_t VkFlags; +typedef uint32_t VkBool32; +typedef uint64_t VkDeviceSize; +typedef uint32_t VkSampleMask; + +#define VK_MAX_EXTENSION_NAME_SIZE 256 +#define VK_MAX_DESCRIPTION_SIZE 256 +#define VK_UUID_SIZE 16 +#define VK_MAX_MEMORY_TYPES 32 +#define VK_MAX_MEMORY_HEAPS 16 + +#define VK_MAKE_VERSION(major, minor, patch) (((major) << 22) | ((minor) << 12) | (patch)) +#define VK_API_VERSION_1_0 VK_MAKE_VERSION(1, 0, 0) +#define VK_API_VERSION_1_1 VK_MAKE_VERSION(1, 1, 0) +#define VK_API_VERSION_1_2 VK_MAKE_VERSION(1, 2, 0) +#define VK_API_VERSION_1_3 VK_MAKE_VERSION(1, 3, 0) +#define VK_API_VERSION_MAJOR(version) ((uint32_t)(version) >> 22) +#define VK_API_VERSION_MINOR(version) (((uint32_t)(version) >> 12) & 0x3FF) +#define VK_API_VERSION_PATCH(version) ((uint32_t)(version) & 0xFFF) + +#define VK_KHR_GET_PHYSICAL_DEVICE_PROPERTIES_2_EXTENSION_NAME "VK_KHR_get_physical_device_properties2" +#define VK_NV_COOPERATIVE_MATRIX_2_EXTENSION_NAME "VK_NV_cooperative_matrix2" +#define VK_EXT_MEMORY_BUDGET_EXTENSION_NAME "VK_EXT_memory_budget" + + +typedef enum VkResult { + VK_SUCCESS = 0, + VK_NOT_READY = 1, + VK_TIMEOUT = 2, + VK_EVENT_SET = 3, + VK_EVENT_RESET = 4, + VK_INCOMPLETE = 5, + VK_ERROR_OUT_OF_HOST_MEMORY = -1, + VK_ERROR_OUT_OF_DEVICE_MEMORY = -2, + VK_ERROR_INITIALIZATION_FAILED = -3, + VK_ERROR_DEVICE_LOST = -4, + VK_ERROR_MEMORY_MAP_FAILED = -5, + VK_ERROR_LAYER_NOT_PRESENT = -6, + VK_ERROR_EXTENSION_NOT_PRESENT = -7, + VK_ERROR_FEATURE_NOT_PRESENT = -8, + VK_ERROR_INCOMPATIBLE_DRIVER = -9, + VK_ERROR_TOO_MANY_OBJECTS = -10, + VK_ERROR_FORMAT_NOT_SUPPORTED = -11, + VK_ERROR_FRAGMENTED_POOL = -12, + VK_ERROR_UNKNOWN = -13, + VK_ERROR_OUT_OF_POOL_MEMORY = -1000069000, + VK_ERROR_INVALID_EXTERNAL_HANDLE = -1000072003, + VK_ERROR_FRAGMENTATION = -1000168000, + VK_ERROR_INVALID_OPAQUE_CAPTURE_ADDRESS = -1000257000, + VK_PIPELINE_COMPILE_REQUIRED = 1000297000, + VK_ERROR_SURFACE_LOST_KHR = -1000000000, + VK_ERROR_NATIVE_WINDOW_IN_USE_KHR = -1000000001, + VK_SUBOPTIMAL_KHR = 1000001003, + VK_ERROR_OUT_OF_DATE_KHR = -1000001004, + VK_ERROR_INCOMPATIBLE_DISPLAY_KHR = -1000003001, + VK_ERROR_VALIDATION_FAILED_EXT = -1000011001, + VK_ERROR_INVALID_SHADER_NV = -1000012000, + VK_ERROR_IMAGE_USAGE_NOT_SUPPORTED_KHR = -1000158000, + VK_ERROR_VIDEO_PICTURE_LAYOUT_NOT_SUPPORTED_KHR = -1000158001, + VK_ERROR_VIDEO_PROFILE_OPERATION_NOT_SUPPORTED_KHR = -1000158002, + VK_ERROR_VIDEO_PROFILE_FORMAT_NOT_SUPPORTED_KHR = -1000158003, + VK_ERROR_VIDEO_PROFILE_CODEC_NOT_SUPPORTED_KHR = -1000158004, + VK_ERROR_VIDEO_STD_VERSION_NOT_SUPPORTED_KHR = -1000158005, + VK_ERROR_INVALID_DRM_FORMAT_MODIFIER_PLANE_LAYOUT_EXT = -1000158006, + VK_ERROR_NOT_PERMITTED_KHR = -1000174001, + VK_ERROR_FULL_SCREEN_EXCLUSIVE_MODE_LOST_EXT = -1000255000, + VK_THREAD_IDLE_KHR = 1000268000, + VK_THREAD_DONE_KHR = 1000268001, + VK_OPERATION_DEFERRED_KHR = 1000268002, + VK_OPERATION_NOT_DEFERRED_KHR = 1000268003, + VK_ERROR_COMPRESSION_EXHAUSTED_EXT = -1000338000, + VK_RESULT_MAX_ENUM = 0x7FFFFFFF +} VkResult; + +typedef enum VkStructureType { + VK_STRUCTURE_TYPE_APPLICATION_INFO = 0, + VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO = 1, + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2 = 1000059000, + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MEMORY_PROPERTIES_2 = 1000059006, + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_ID_PROPERTIES = 1000071004, + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MEMORY_BUDGET_PROPERTIES_EXT = 1000237002, + VK_STRUCTURE_TYPE_MAX_ENUM = 0x7FFFFFFF +} VkStructureType; + +typedef enum VkPhysicalDeviceType { + VK_PHYSICAL_DEVICE_TYPE_OTHER = 0, + VK_PHYSICAL_DEVICE_TYPE_INTEGRATED_GPU = 1, + VK_PHYSICAL_DEVICE_TYPE_DISCRETE_GPU = 2, + VK_PHYSICAL_DEVICE_TYPE_VIRTUAL_GPU = 3, + VK_PHYSICAL_DEVICE_TYPE_CPU = 4, + VK_PHYSICAL_DEVICE_TYPE_MAX_ENUM = 0x7FFFFFFF +} VkPhysicalDeviceType; + +typedef enum VkSystemAllocationScope { + VK_SYSTEM_ALLOCATION_SCOPE_COMMAND = 0, + VK_SYSTEM_ALLOCATION_SCOPE_OBJECT = 1, + VK_SYSTEM_ALLOCATION_SCOPE_CACHE = 2, + VK_SYSTEM_ALLOCATION_SCOPE_DEVICE = 3, + VK_SYSTEM_ALLOCATION_SCOPE_INSTANCE = 4, + VK_SYSTEM_ALLOCATION_SCOPE_MAX_ENUM = 0x7FFFFFFF +} VkSystemAllocationScope; + +typedef enum VkInternalAllocationType { + VK_INTERNAL_ALLOCATION_TYPE_EXECUTABLE = 0, + VK_INTERNAL_ALLOCATION_TYPE_NON_EXECUTABLE = 1, + VK_INTERNAL_ALLOCATION_TYPE_MAX_ENUM = 0x7FFFFFFF +} VkInternalAllocationType; + +#define VK_MEMORY_HEAP_DEVICE_LOCAL_BIT 0x00000001 + +typedef struct VkExtensionProperties { + char extensionName[VK_MAX_EXTENSION_NAME_SIZE]; + uint32_t specVersion; +} VkExtensionProperties; + +typedef struct VkPhysicalDeviceProperties { + uint32_t apiVersion; + uint32_t driverVersion; + uint32_t vendorID; + uint32_t deviceID; + uint32_t deviceType; + char deviceName[VK_MAX_DESCRIPTION_SIZE]; + uint8_t pipelineCacheUUID[VK_UUID_SIZE]; +} VkPhysicalDeviceProperties; + +typedef struct VkPhysicalDeviceProperties2 { + VkStructureType sType; + void* pNext; + VkPhysicalDeviceProperties properties; +} VkPhysicalDeviceProperties2; + +typedef struct VkPhysicalDeviceIDProperties { + VkStructureType sType; + void* pNext; + uint8_t deviceUUID[VK_UUID_SIZE]; + uint8_t driverUUID[VK_UUID_SIZE]; + uint8_t deviceLUID[8]; + uint32_t deviceNodeMask; + VkBool32 deviceLUIDValid; +} VkPhysicalDeviceIDProperties; + +typedef struct VkMemoryType { + uint32_t propertyFlags; + uint32_t heapIndex; +} VkMemoryType; + +typedef struct VkMemoryHeap { + VkDeviceSize size; + uint32_t flags; +} VkMemoryHeap; + +typedef struct VkPhysicalDeviceMemoryProperties { + uint32_t memoryTypeCount; + VkMemoryType memoryTypes[VK_MAX_MEMORY_TYPES]; + uint32_t memoryHeapCount; + VkMemoryHeap memoryHeaps[VK_MAX_MEMORY_HEAPS]; +} VkPhysicalDeviceMemoryProperties; + +typedef struct VkPhysicalDeviceMemoryProperties2 { + VkStructureType sType; + void* pNext; + VkPhysicalDeviceMemoryProperties memoryProperties; +} VkPhysicalDeviceMemoryProperties2; + +typedef struct VkPhysicalDeviceMemoryBudgetPropertiesEXT { + VkStructureType sType; + void* pNext; + VkDeviceSize heapBudget[VK_MAX_MEMORY_HEAPS]; + VkDeviceSize heapUsage[VK_MAX_MEMORY_HEAPS]; +} VkPhysicalDeviceMemoryBudgetPropertiesEXT; + +typedef struct VkApplicationInfo { + VkStructureType sType; + const void* pNext; + const char* pApplicationName; + uint32_t applicationVersion; + const char* pEngineName; + uint32_t engineVersion; + uint32_t apiVersion; +} VkApplicationInfo; + +typedef struct VkInstanceCreateInfo { + VkStructureType sType; + const void* pNext; + uint32_t flags; + const VkApplicationInfo* pApplicationInfo; + uint32_t enabledLayerCount; + const char* const* ppEnabledLayerNames; + uint32_t enabledExtensionCount; + const char* const* ppEnabledExtensionNames; +} VkInstanceCreateInfo; + +typedef struct VkAllocationCallbacks { + void* pUserData; + void* (*pfnAllocation)(void* pUserData, size_t size, size_t alignment, VkSystemAllocationScope allocationScope); + void* (*pfnReallocation)(void* pUserData, void* pOriginal, size_t size, size_t alignment, VkSystemAllocationScope allocationScope); + void (*pfnFree)(void* pUserData, void* pMemory); + void (*pfnInternalAllocation)(void* pUserData, size_t size, VkInternalAllocationType allocationType, VkSystemAllocationScope allocationScope); + void (*pfnInternalFree)(void* pUserData, size_t size, VkInternalAllocationType allocationType, VkSystemAllocationScope allocationScope); +} VkAllocationCallbacks; typedef struct { void* vk_handle; From ede4081253337eb420aad0e0ef77b27091e861d3 Mon Sep 17 00:00:00 2001 From: Masato Nakasaka Date: Tue, 16 Sep 2025 17:00:17 +0900 Subject: [PATCH 073/172] Fix compile error in Mac Metal is preferred so we're disabling Vulkan for now --- discover/gpu_info_vulkan.c | 3 +++ 1 file changed, 3 insertions(+) diff --git a/discover/gpu_info_vulkan.c b/discover/gpu_info_vulkan.c index 9ebe79e82..0ff0ae581 100644 --- a/discover/gpu_info_vulkan.c +++ b/discover/gpu_info_vulkan.c @@ -1,3 +1,4 @@ +#ifndef __APPLE__ #include "gpu_info_vulkan.h" #include @@ -236,3 +237,5 @@ void vk_release(vk_handle_t rh) { UNLOAD_LIBRARY(rh.vk_handle); rh.vk_handle = NULL; } + +#endif // __APPLE__ \ No newline at end of file From 7a6b09ebae5bfaa687ffe006399ddbefc9eaa44c Mon Sep 17 00:00:00 2001 From: "Nakasaka, Masato" Date: Tue, 16 Sep 2025 17:18:49 +0900 Subject: [PATCH 074/172] Removed unused code Fix linter error in CI --- discover/vulkan_common.go | 19 ------------------- 1 file changed, 19 deletions(-) delete mode 100644 discover/vulkan_common.go diff --git a/discover/vulkan_common.go b/discover/vulkan_common.go deleted file mode 100644 index 4dccbade9..000000000 --- a/discover/vulkan_common.go +++ /dev/null @@ -1,19 +0,0 @@ -package discover - -import ( - "log/slog" - "strings" -) - -func vkGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) { - ids := []string{} - for _, info := range gpuInfo { - if info.Library != "vulkan" { - // TODO shouldn't happen if things are wired correctly... - slog.Debug("vkGetVisibleDevicesEnv skipping over non-vulkan device", "library", info.Library) - continue - } - ids = append(ids, info.ID) - } - return "GGML_VK_VISIBLE_DEVICES", strings.Join(ids, ",") -} From eb7b5ce9f44ca8e62ef44c182e6221d21a4a0fb7 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Tue, 16 Sep 2025 22:14:05 +0200 Subject: [PATCH 075/172] Fix patches apply --- ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 671323ad0..d73cdf176 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -10232,6 +10232,7 @@ static void ggml_backend_vk_buffer_free_buffer(ggml_backend_buffer_t buffer) { ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context; ggml_vk_destroy_buffer(ctx->dev_buffer); delete ctx; + delete buffer; } static void * ggml_backend_vk_buffer_get_base(ggml_backend_buffer_t buffer) { @@ -10375,6 +10376,7 @@ static const char * ggml_backend_vk_host_buffer_name(ggml_backend_buffer_t buffe static void ggml_backend_vk_host_buffer_free_buffer(ggml_backend_buffer_t buffer) { VK_LOG_MEMORY("ggml_backend_vk_host_buffer_free_buffer()"); ggml_vk_host_free(vk_instance.devices[0], buffer->context); + delete buffer; } static ggml_backend_buffer_t ggml_backend_vk_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { From 176d30744e38a6af43dc05c8249b305abb8fb4b8 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Tue, 16 Sep 2025 22:48:24 +0200 Subject: [PATCH 076/172] fixing lint error --- discover/types.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/discover/types.go b/discover/types.go index cde23d01b..b0c607a87 100644 --- a/discover/types.go +++ b/discover/types.go @@ -95,7 +95,7 @@ type OneapiGPUInfoList []OneapiGPUInfo type VulkanGPUInfo struct { GpuInfo - index int + index int //nolint:unused,nolintlint } type VulkanGPUInfoList []VulkanGPUInfo From 73441c97806b70609b597ae7966b43a4544fe507 Mon Sep 17 00:00:00 2001 From: "Nakasaka, Masato" Date: Wed, 17 Sep 2025 15:11:13 +0900 Subject: [PATCH 077/172] Removed unneeded function call Somehow removing this call fixed the crashing when Vulkan header was removed --- discover/gpu_info_vulkan.c | 3 --- 1 file changed, 3 deletions(-) diff --git a/discover/gpu_info_vulkan.c b/discover/gpu_info_vulkan.c index 0ff0ae581..b9599978d 100644 --- a/discover/gpu_info_vulkan.c +++ b/discover/gpu_info_vulkan.c @@ -4,9 +4,6 @@ #include int is_extension_supported(vk_handle_t* rh, VkPhysicalDevice device, char* extension) { - VkPhysicalDeviceProperties properties; - (*rh->vkGetPhysicalDeviceProperties)(device, &properties); - uint32_t extensionCount; (*rh->vkEnumerateDeviceExtensionProperties)(device, NULL, &extensionCount, NULL); From 6cf4e0a7c8c4b2b6b7f0b0b6e86fd59e9b67c8bf Mon Sep 17 00:00:00 2001 From: "Nakasaka, Masato" Date: Wed, 17 Sep 2025 15:21:24 +0900 Subject: [PATCH 078/172] added missing NL --- discover/gpu_info_vulkan.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/discover/gpu_info_vulkan.c b/discover/gpu_info_vulkan.c index b9599978d..800da6ee0 100644 --- a/discover/gpu_info_vulkan.c +++ b/discover/gpu_info_vulkan.c @@ -235,4 +235,4 @@ void vk_release(vk_handle_t rh) { rh.vk_handle = NULL; } -#endif // __APPLE__ \ No newline at end of file +#endif // __APPLE__ From 45430ded4be90ce7347d7e6147279dc73da470d9 Mon Sep 17 00:00:00 2001 From: "Nakasaka, Masato" Date: Wed, 17 Sep 2025 16:04:43 +0900 Subject: [PATCH 079/172] Fixed missing members in Vulkan header also added zero clear for some structs --- discover/gpu_info_vulkan.c | 11 ++- discover/gpu_info_vulkan.h | 145 ++++++++++++++++++++++++++++++++++--- 2 files changed, 142 insertions(+), 14 deletions(-) diff --git a/discover/gpu_info_vulkan.c b/discover/gpu_info_vulkan.c index 800da6ee0..65033ad8a 100644 --- a/discover/gpu_info_vulkan.c +++ b/discover/gpu_info_vulkan.c @@ -4,6 +4,9 @@ #include int is_extension_supported(vk_handle_t* rh, VkPhysicalDevice device, char* extension) { + VkPhysicalDeviceProperties properties = {}; + (*rh->vkGetPhysicalDeviceProperties)(device, &properties); + uint32_t extensionCount; (*rh->vkEnumerateDeviceExtensionProperties)(device, NULL, &extensionCount, NULL); @@ -129,7 +132,7 @@ int vk_check_flash_attention(vk_handle_t rh, int i) { return 0; } - VkPhysicalDeviceProperties properties; + VkPhysicalDeviceProperties properties = {}; (*rh.vkGetPhysicalDeviceProperties)(devices[i], &properties); int supports_nv_coopmat2 = is_extension_supported(&rh, devices[i], VK_NV_COOPERATIVE_MATRIX_2_EXTENSION_NAME); @@ -159,7 +162,7 @@ void vk_check_vram(vk_handle_t rh, int i, mem_info_t *resp) { return; } - VkPhysicalDeviceProperties properties; + VkPhysicalDeviceProperties properties = {}; (*rh.vkGetPhysicalDeviceProperties)(devices[i], &properties); int supports_budget = is_extension_supported(&rh, devices[i], VK_EXT_MEMORY_BUDGET_EXTENSION_NAME); @@ -184,11 +187,11 @@ void vk_check_vram(vk_handle_t rh, int i, mem_info_t *resp) { device_props2.pNext = &id_props; (*rh.vkGetPhysicalDeviceProperties2)(devices[i], &device_props2); - VkPhysicalDeviceMemoryBudgetPropertiesEXT physical_device_memory_budget_properties; + VkPhysicalDeviceMemoryBudgetPropertiesEXT physical_device_memory_budget_properties = {}; physical_device_memory_budget_properties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MEMORY_BUDGET_PROPERTIES_EXT; physical_device_memory_budget_properties.pNext = NULL; - VkPhysicalDeviceMemoryProperties2 device_memory_properties; + VkPhysicalDeviceMemoryProperties2 device_memory_properties = {}; device_memory_properties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MEMORY_PROPERTIES_2; device_memory_properties.pNext = &physical_device_memory_budget_properties; diff --git a/discover/gpu_info_vulkan.h b/discover/gpu_info_vulkan.h index 470b230c5..5c6ab85e6 100644 --- a/discover/gpu_info_vulkan.h +++ b/discover/gpu_info_vulkan.h @@ -12,12 +12,18 @@ typedef uint32_t VkFlags; typedef uint32_t VkBool32; typedef uint64_t VkDeviceSize; typedef uint32_t VkSampleMask; +typedef VkFlags VkSampleCountFlags; +typedef VkFlags VkMemoryPropertyFlags; +typedef VkFlags VkMemoryHeapFlags; +typedef VkFlags VkInstanceCreateFlags; -#define VK_MAX_EXTENSION_NAME_SIZE 256 -#define VK_MAX_DESCRIPTION_SIZE 256 -#define VK_UUID_SIZE 16 -#define VK_MAX_MEMORY_TYPES 32 -#define VK_MAX_MEMORY_HEAPS 16 +#define VK_MAX_EXTENSION_NAME_SIZE 256U +#define VK_MAX_DESCRIPTION_SIZE 256U +#define VK_LUID_SIZE 8U +#define VK_UUID_SIZE 16U +#define VK_MAX_MEMORY_TYPES 32U +#define VK_MAX_MEMORY_HEAPS 16U +#define VK_MAX_PHYSICAL_DEVICE_NAME_SIZE 256U #define VK_MAKE_VERSION(major, minor, patch) (((major) << 22) | ((minor) << 12) | (patch)) #define VK_API_VERSION_1_0 VK_MAKE_VERSION(1, 0, 0) @@ -123,14 +129,133 @@ typedef struct VkExtensionProperties { uint32_t specVersion; } VkExtensionProperties; +typedef struct VkPhysicalDeviceLimits { + uint32_t maxImageDimension1D; + uint32_t maxImageDimension2D; + uint32_t maxImageDimension3D; + uint32_t maxImageDimensionCube; + uint32_t maxImageArrayLayers; + uint32_t maxTexelBufferElements; + uint32_t maxUniformBufferRange; + uint32_t maxStorageBufferRange; + uint32_t maxPushConstantsSize; + uint32_t maxMemoryAllocationCount; + uint32_t maxSamplerAllocationCount; + VkDeviceSize bufferImageGranularity; + VkDeviceSize sparseAddressSpaceSize; + uint32_t maxBoundDescriptorSets; + uint32_t maxPerStageDescriptorSamplers; + uint32_t maxPerStageDescriptorUniformBuffers; + uint32_t maxPerStageDescriptorStorageBuffers; + uint32_t maxPerStageDescriptorSampledImages; + uint32_t maxPerStageDescriptorStorageImages; + uint32_t maxPerStageDescriptorInputAttachments; + uint32_t maxPerStageResources; + uint32_t maxDescriptorSetSamplers; + uint32_t maxDescriptorSetUniformBuffers; + uint32_t maxDescriptorSetUniformBuffersDynamic; + uint32_t maxDescriptorSetStorageBuffers; + uint32_t maxDescriptorSetStorageBuffersDynamic; + uint32_t maxDescriptorSetSampledImages; + uint32_t maxDescriptorSetStorageImages; + uint32_t maxDescriptorSetInputAttachments; + uint32_t maxVertexInputAttributes; + uint32_t maxVertexInputBindings; + uint32_t maxVertexInputAttributeOffset; + uint32_t maxVertexInputBindingStride; + uint32_t maxVertexOutputComponents; + uint32_t maxTessellationGenerationLevel; + uint32_t maxTessellationPatchSize; + uint32_t maxTessellationControlPerVertexInputComponents; + uint32_t maxTessellationControlPerVertexOutputComponents; + uint32_t maxTessellationControlPerPatchOutputComponents; + uint32_t maxTessellationControlTotalOutputComponents; + uint32_t maxTessellationEvaluationInputComponents; + uint32_t maxTessellationEvaluationOutputComponents; + uint32_t maxGeometryShaderInvocations; + uint32_t maxGeometryInputComponents; + uint32_t maxGeometryOutputComponents; + uint32_t maxGeometryOutputVertices; + uint32_t maxGeometryTotalOutputComponents; + uint32_t maxFragmentInputComponents; + uint32_t maxFragmentOutputAttachments; + uint32_t maxFragmentDualSrcAttachments; + uint32_t maxFragmentCombinedOutputResources; + uint32_t maxComputeSharedMemorySize; + uint32_t maxComputeWorkGroupCount[3]; + uint32_t maxComputeWorkGroupInvocations; + uint32_t maxComputeWorkGroupSize[3]; + uint32_t subPixelPrecisionBits; + uint32_t subTexelPrecisionBits; + uint32_t mipmapPrecisionBits; + uint32_t maxDrawIndexedIndexValue; + uint32_t maxDrawIndirectCount; + float maxSamplerLodBias; + float maxSamplerAnisotropy; + uint32_t maxViewports; + uint32_t maxViewportDimensions[2]; + float viewportBoundsRange[2]; + uint32_t viewportSubPixelBits; + size_t minMemoryMapAlignment; + VkDeviceSize minTexelBufferOffsetAlignment; + VkDeviceSize minUniformBufferOffsetAlignment; + VkDeviceSize minStorageBufferOffsetAlignment; + int32_t minTexelOffset; + uint32_t maxTexelOffset; + int32_t minTexelGatherOffset; + uint32_t maxTexelGatherOffset; + float minInterpolationOffset; + float maxInterpolationOffset; + uint32_t subPixelInterpolationOffsetBits; + uint32_t maxFramebufferWidth; + uint32_t maxFramebufferHeight; + uint32_t maxFramebufferLayers; + VkSampleCountFlags framebufferColorSampleCounts; + VkSampleCountFlags framebufferDepthSampleCounts; + VkSampleCountFlags framebufferStencilSampleCounts; + VkSampleCountFlags framebufferNoAttachmentsSampleCounts; + uint32_t maxColorAttachments; + VkSampleCountFlags sampledImageColorSampleCounts; + VkSampleCountFlags sampledImageIntegerSampleCounts; + VkSampleCountFlags sampledImageDepthSampleCounts; + VkSampleCountFlags sampledImageStencilSampleCounts; + VkSampleCountFlags storageImageSampleCounts; + uint32_t maxSampleMaskWords; + VkBool32 timestampComputeAndGraphics; + float timestampPeriod; + uint32_t maxClipDistances; + uint32_t maxCullDistances; + uint32_t maxCombinedClipAndCullDistances; + uint32_t discreteQueuePriorities; + float pointSizeRange[2]; + float lineWidthRange[2]; + float pointSizeGranularity; + float lineWidthGranularity; + VkBool32 strictLines; + VkBool32 standardSampleLocations; + VkDeviceSize optimalBufferCopyOffsetAlignment; + VkDeviceSize optimalBufferCopyRowPitchAlignment; + VkDeviceSize nonCoherentAtomSize; +} VkPhysicalDeviceLimits; + +typedef struct VkPhysicalDeviceSparseProperties { + VkBool32 residencyStandard2DBlockShape; + VkBool32 residencyStandard2DMultisampleBlockShape; + VkBool32 residencyStandard3DBlockShape; + VkBool32 residencyAlignedMipSize; + VkBool32 residencyNonResidentStrict; +} VkPhysicalDeviceSparseProperties; + typedef struct VkPhysicalDeviceProperties { uint32_t apiVersion; uint32_t driverVersion; uint32_t vendorID; uint32_t deviceID; uint32_t deviceType; - char deviceName[VK_MAX_DESCRIPTION_SIZE]; + char deviceName[VK_MAX_PHYSICAL_DEVICE_NAME_SIZE]; uint8_t pipelineCacheUUID[VK_UUID_SIZE]; + VkPhysicalDeviceLimits limits; + VkPhysicalDeviceSparseProperties sparseProperties; } VkPhysicalDeviceProperties; typedef struct VkPhysicalDeviceProperties2 { @@ -144,19 +269,19 @@ typedef struct VkPhysicalDeviceIDProperties { void* pNext; uint8_t deviceUUID[VK_UUID_SIZE]; uint8_t driverUUID[VK_UUID_SIZE]; - uint8_t deviceLUID[8]; + uint8_t deviceLUID[VK_LUID_SIZE]; uint32_t deviceNodeMask; VkBool32 deviceLUIDValid; } VkPhysicalDeviceIDProperties; typedef struct VkMemoryType { - uint32_t propertyFlags; + VkMemoryPropertyFlags propertyFlags; uint32_t heapIndex; } VkMemoryType; typedef struct VkMemoryHeap { VkDeviceSize size; - uint32_t flags; + VkMemoryHeapFlags flags; } VkMemoryHeap; typedef struct VkPhysicalDeviceMemoryProperties { @@ -192,7 +317,7 @@ typedef struct VkApplicationInfo { typedef struct VkInstanceCreateInfo { VkStructureType sType; const void* pNext; - uint32_t flags; + VkInstanceCreateFlags flags; const VkApplicationInfo* pApplicationInfo; uint32_t enabledLayerCount; const char* const* ppEnabledLayerNames; From ac9d59cf6900057a18e0a4b64925d605f4250546 Mon Sep 17 00:00:00 2001 From: "Nakasaka, Masato" Date: Wed, 17 Sep 2025 16:59:23 +0900 Subject: [PATCH 080/172] Fixed wrong structure ID --- discover/gpu_info_vulkan.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/discover/gpu_info_vulkan.h b/discover/gpu_info_vulkan.h index 5c6ab85e6..b85f18e6b 100644 --- a/discover/gpu_info_vulkan.h +++ b/discover/gpu_info_vulkan.h @@ -91,10 +91,10 @@ typedef enum VkResult { typedef enum VkStructureType { VK_STRUCTURE_TYPE_APPLICATION_INFO = 0, VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO = 1, - VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2 = 1000059000, + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2 = 1000059001, VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MEMORY_PROPERTIES_2 = 1000059006, VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_ID_PROPERTIES = 1000071004, - VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MEMORY_BUDGET_PROPERTIES_EXT = 1000237002, + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MEMORY_BUDGET_PROPERTIES_EXT = 1000237000, VK_STRUCTURE_TYPE_MAX_ENUM = 0x7FFFFFFF } VkStructureType; From d0b5247084fc1f8ad2cc7c16b9c2d08196371498 Mon Sep 17 00:00:00 2001 From: "Nakasaka, Masato" Date: Thu, 18 Sep 2025 08:40:52 +0900 Subject: [PATCH 081/172] Fixed Vulkan header More aligned with official header definition now --- discover/gpu_info_vulkan.h | 46 +++++++++++++++++++++++--------------- 1 file changed, 28 insertions(+), 18 deletions(-) diff --git a/discover/gpu_info_vulkan.h b/discover/gpu_info_vulkan.h index b85f18e6b..42e4b1610 100644 --- a/discover/gpu_info_vulkan.h +++ b/discover/gpu_info_vulkan.h @@ -8,15 +8,6 @@ VK_DEFINE_HANDLE(VkInstance) VK_DEFINE_HANDLE(VkPhysicalDevice) -typedef uint32_t VkFlags; -typedef uint32_t VkBool32; -typedef uint64_t VkDeviceSize; -typedef uint32_t VkSampleMask; -typedef VkFlags VkSampleCountFlags; -typedef VkFlags VkMemoryPropertyFlags; -typedef VkFlags VkMemoryHeapFlags; -typedef VkFlags VkInstanceCreateFlags; - #define VK_MAX_EXTENSION_NAME_SIZE 256U #define VK_MAX_DESCRIPTION_SIZE 256U #define VK_LUID_SIZE 8U @@ -25,19 +16,32 @@ typedef VkFlags VkInstanceCreateFlags; #define VK_MAX_MEMORY_HEAPS 16U #define VK_MAX_PHYSICAL_DEVICE_NAME_SIZE 256U -#define VK_MAKE_VERSION(major, minor, patch) (((major) << 22) | ((minor) << 12) | (patch)) -#define VK_API_VERSION_1_0 VK_MAKE_VERSION(1, 0, 0) -#define VK_API_VERSION_1_1 VK_MAKE_VERSION(1, 1, 0) -#define VK_API_VERSION_1_2 VK_MAKE_VERSION(1, 2, 0) -#define VK_API_VERSION_1_3 VK_MAKE_VERSION(1, 3, 0) -#define VK_API_VERSION_MAJOR(version) ((uint32_t)(version) >> 22) -#define VK_API_VERSION_MINOR(version) (((uint32_t)(version) >> 12) & 0x3FF) -#define VK_API_VERSION_PATCH(version) ((uint32_t)(version) & 0xFFF) +#define VK_MAKE_VERSION(major, minor, patch) \ + ((((uint32_t)(major)) << 22U) | (((uint32_t)(minor)) << 12U) | ((uint32_t)(patch))) + +#define VK_MAKE_API_VERSION(variant, major, minor, patch) \ + ((((uint32_t)(variant)) << 29U) | (((uint32_t)(major)) << 22U) | (((uint32_t)(minor)) << 12U) | ((uint32_t)(patch))) + +#define VK_API_VERSION_1_0 VK_MAKE_API_VERSION(0, 1, 0, 0)// Patch version should always be set to 0 +#define VK_API_VERSION_1_1 VK_MAKE_API_VERSION(0, 1, 1, 0)// Patch version should always be set to 0 +#define VK_API_VERSION_1_2 VK_MAKE_API_VERSION(0, 1, 2, 0)// Patch version should always be set to 0 +#define VK_API_VERSION_1_3 VK_MAKE_API_VERSION(0, 1, 3, 0)// Patch version should always be set to 0 +#define VK_API_VERSION_MAJOR(version) (((uint32_t)(version) >> 22U) & 0x7FU) +#define VK_API_VERSION_MINOR(version) (((uint32_t)(version) >> 12U) & 0x3FFU) +#define VK_API_VERSION_PATCH(version) ((uint32_t)(version) & 0xFFFU) #define VK_KHR_GET_PHYSICAL_DEVICE_PROPERTIES_2_EXTENSION_NAME "VK_KHR_get_physical_device_properties2" #define VK_NV_COOPERATIVE_MATRIX_2_EXTENSION_NAME "VK_NV_cooperative_matrix2" #define VK_EXT_MEMORY_BUDGET_EXTENSION_NAME "VK_EXT_memory_budget" +typedef uint32_t VkFlags; +typedef uint32_t VkBool32; +typedef uint64_t VkDeviceSize; +typedef uint32_t VkSampleMask; +typedef VkFlags VkSampleCountFlags; +typedef VkFlags VkMemoryPropertyFlags; +typedef VkFlags VkMemoryHeapFlags; +typedef VkFlags VkInstanceCreateFlags; typedef enum VkResult { VK_SUCCESS = 0, @@ -122,7 +126,13 @@ typedef enum VkInternalAllocationType { VK_INTERNAL_ALLOCATION_TYPE_MAX_ENUM = 0x7FFFFFFF } VkInternalAllocationType; -#define VK_MEMORY_HEAP_DEVICE_LOCAL_BIT 0x00000001 +typedef enum VkMemoryHeapFlagBits { + VK_MEMORY_HEAP_DEVICE_LOCAL_BIT = 0x00000001, + VK_MEMORY_HEAP_MULTI_INSTANCE_BIT = 0x00000002, + VK_MEMORY_HEAP_TILE_MEMORY_BIT_QCOM = 0x00000008, + VK_MEMORY_HEAP_MULTI_INSTANCE_BIT_KHR = VK_MEMORY_HEAP_MULTI_INSTANCE_BIT, + VK_MEMORY_HEAP_FLAG_BITS_MAX_ENUM = 0x7FFFFFFF +} VkMemoryHeapFlagBits; typedef struct VkExtensionProperties { char extensionName[VK_MAX_EXTENSION_NAME_SIZE]; From 62b2265f9db818de0590a7ff98da259fbed07879 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Fri, 19 Sep 2025 06:52:05 +0200 Subject: [PATCH 082/172] buildvulkanAsSeperateFunction --- scripts/build_windows.ps1 | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/scripts/build_windows.ps1 b/scripts/build_windows.ps1 index ff853b5a3..5cb8faf2f 100644 --- a/scripts/build_windows.ps1 +++ b/scripts/build_windows.ps1 @@ -180,15 +180,18 @@ function buildROCm() { & cmake --install build --component "HIP" --strip if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} } - if ($env:VULKAN_SDK) { - write-host "Building Vulkan backend libraries" - & cmake --fresh --preset Vulkan --install-prefix $script:DIST_DIR - if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} - & cmake --build --preset Vulkan --config Release --parallel $script:JOBS - if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} - & cmake --install build --component Vulkan --strip - if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} - } + } +} + +function buildVulkan(){ + if ($env:VULKAN_SDK) { + write-host "Building Vulkan backend libraries" + & cmake --fresh --preset Vulkan --install-prefix $script:DIST_DIR + if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} + & cmake --build --preset Vulkan --config Release --parallel $script:JOBS + if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} + & cmake --install build --component Vulkan --strip + if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} } } @@ -305,6 +308,7 @@ try { buildCUDA12 buildCUDA13 buildROCm + buildVulkan buildOllama buildApp gatherDependencies From 0f543fdb1eaf954cc5ab6be0d434b4fcfa7ed977 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sat, 20 Sep 2025 08:04:11 +0200 Subject: [PATCH 083/172] Vulkan on Windows Test --- .github/workflows/test.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index e470540a2..9504eaee1 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -92,6 +92,8 @@ jobs: - preset: ROCm install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"' + - preset: Vulkan + install: https://sdk.lunarg.com/sdk/download/1.4.313.2/windows/vulkansdk-windows-X64-1.4.313.2.exe runs-on: windows steps: - run: | From 6bbc054705cd0380dc917fc1a2ee09c6be79f6ec Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sat, 20 Sep 2025 08:35:58 +0200 Subject: [PATCH 084/172] temporarly comment out gate to run windows task --- .github/workflows/test.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 9504eaee1..0f9e99189 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -71,8 +71,8 @@ jobs: cmake --build --preset ${{ matrix.preset }} --parallel windows: - needs: [changes] - if: needs.changes.outputs.changed == 'True' + # needs: [changes] + # if: needs.changes.outputs.changed == 'True' strategy: matrix: include: From a4461bc0d4e99686b30f25f8dd4cadba5229a048 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sat, 20 Sep 2025 08:46:59 +0200 Subject: [PATCH 085/172] use temporarly windows-latest for build --- .github/workflows/test.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 0f9e99189..feb81aafa 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -94,7 +94,7 @@ jobs: flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"' - preset: Vulkan install: https://sdk.lunarg.com/sdk/download/1.4.313.2/windows/vulkansdk-windows-X64-1.4.313.2.exe - runs-on: windows + runs-on: windows-latest steps: - run: | choco install -y --no-progress ccache ninja From c84ac535790ab49703bd3c9df1562a0e855b48e8 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sat, 20 Sep 2025 09:00:26 +0200 Subject: [PATCH 086/172] Commenting out other presets to build vulkan --- .github/workflows/test.yaml | 46 +++++++++++++++++++++++-------------- 1 file changed, 29 insertions(+), 17 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index feb81aafa..95784a0a4 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -76,22 +76,22 @@ jobs: strategy: matrix: include: - - preset: CPU - - preset: CUDA - install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe - flags: '-DCMAKE_CUDA_ARCHITECTURES=80' - cuda-components: - - '"cudart"' - - '"nvcc"' - - '"cublas"' - - '"cublas_dev"' - - '"crt"' - - '"nvvm"' - - '"nvptxcompiler"' - cuda-version: '13.0' - - preset: ROCm - install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe - flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"' + #- preset: CPU + #- preset: CUDA + # install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe + # flags: '-DCMAKE_CUDA_ARCHITECTURES=80' + # cuda-components: + # - '"cudart"' + # - '"nvcc"' + # - '"cublas"' + # - '"cublas_dev"' + # - '"crt"' + # - '"nvvm"' + # - '"nvptxcompiler"' + # cuda-version: '13.0' + #- preset: ROCm + # install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe + # flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"' - preset: Vulkan install: https://sdk.lunarg.com/sdk/download/1.4.313.2/windows/vulkansdk-windows-X64-1.4.313.2.exe runs-on: windows-latest @@ -99,13 +99,14 @@ jobs: - run: | choco install -y --no-progress ccache ninja ccache -o cache_dir=${{ github.workspace }}\.ccache - - if: matrix.preset == 'CUDA' || matrix.preset == 'ROCm' + - if: matrix.preset == 'CUDA' || matrix.preset == 'ROCm' || atrix.preset == 'Vulkan' id: cache-install uses: actions/cache/restore@v4 with: path: | C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA C:\Program Files\AMD\ROCm + C:\VulkanSDK key: ${{ matrix.install }} - if: matrix.preset == 'CUDA' name: Install CUDA ${{ matrix.cuda-version }} @@ -135,6 +136,17 @@ jobs: echo "HIPCXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append echo "HIP_PLATFORM=amd" | Out-File -FilePath $env:GITHUB_ENV -Append echo "CMAKE_PREFIX_PATH=$hipPath" | Out-File -FilePath $env:GITHUB_ENV -Append + - if: matrix.preset == 'Vulkan' + name: Install Vulkan ${{ matrix.rocm-version }} + run: | + $ErrorActionPreference = "Stop" + if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') { + Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe" + Start-Process -FilePath .\install.exe -ArgumentList '-install' -NoNewWindow -Wait + } + + $vulkanPath = (Resolve-Path "C:\VulkanSDK\*").path + echo "$vulkanPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append - if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }} uses: actions/cache/save@v4 with: From ed03bb7928dd3e0056dc3c4fa7ee04e27bc45b65 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sat, 20 Sep 2025 09:01:25 +0200 Subject: [PATCH 087/172] reenable cpu --- .github/workflows/test.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 95784a0a4..25994c118 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -76,7 +76,7 @@ jobs: strategy: matrix: include: - #- preset: CPU + - preset: CPU #- preset: CUDA # install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe # flags: '-DCMAKE_CUDA_ARCHITECTURES=80' From e2b38c391b1beb7a9ef8ae52b97ea5a998460920 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sat, 20 Sep 2025 09:02:55 +0200 Subject: [PATCH 088/172] commenting out error action stop --- .github/workflows/test.yaml | 33 ++++++++++++++++----------------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 25994c118..cf3c040fe 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -77,21 +77,21 @@ jobs: matrix: include: - preset: CPU - #- preset: CUDA - # install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe - # flags: '-DCMAKE_CUDA_ARCHITECTURES=80' - # cuda-components: - # - '"cudart"' - # - '"nvcc"' - # - '"cublas"' - # - '"cublas_dev"' - # - '"crt"' - # - '"nvvm"' - # - '"nvptxcompiler"' - # cuda-version: '13.0' - #- preset: ROCm - # install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe - # flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"' + - preset: CUDA + install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe + flags: '-DCMAKE_CUDA_ARCHITECTURES=80' + cuda-components: + - '"cudart"' + - '"nvcc"' + - '"cublas"' + - '"cublas_dev"' + - '"crt"' + - '"nvvm"' + - '"nvptxcompiler"' + cuda-version: '13.0' + - preset: ROCm + install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe + flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"' - preset: Vulkan install: https://sdk.lunarg.com/sdk/download/1.4.313.2/windows/vulkansdk-windows-X64-1.4.313.2.exe runs-on: windows-latest @@ -122,8 +122,7 @@ jobs: echo "$cudaPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append - if: matrix.preset == 'ROCm' name: Install ROCm ${{ matrix.rocm-version }} - run: | - $ErrorActionPreference = "Stop" + run: | if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') { Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe" Start-Process -FilePath .\install.exe -ArgumentList '-install' -NoNewWindow -Wait From 45f7850e75f19f6691cee22f0b880fa13fcd114e Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sat, 20 Sep 2025 09:04:30 +0200 Subject: [PATCH 089/172] temporarly commenting out rocm --- .github/workflows/test.yaml | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index cf3c040fe..a81abd5d2 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -89,9 +89,9 @@ jobs: - '"nvvm"' - '"nvptxcompiler"' cuda-version: '13.0' - - preset: ROCm - install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe - flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"' + #- preset: ROCm + # install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe + # flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"' - preset: Vulkan install: https://sdk.lunarg.com/sdk/download/1.4.313.2/windows/vulkansdk-windows-X64-1.4.313.2.exe runs-on: windows-latest @@ -99,7 +99,7 @@ jobs: - run: | choco install -y --no-progress ccache ninja ccache -o cache_dir=${{ github.workspace }}\.ccache - - if: matrix.preset == 'CUDA' || matrix.preset == 'ROCm' || atrix.preset == 'Vulkan' + - if: matrix.preset == 'CUDA' || matrix.preset == 'ROCm' || matrix.preset == 'Vulkan' id: cache-install uses: actions/cache/restore@v4 with: @@ -122,7 +122,8 @@ jobs: echo "$cudaPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append - if: matrix.preset == 'ROCm' name: Install ROCm ${{ matrix.rocm-version }} - run: | + run: | + $ErrorActionPreference = "Stop" if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') { Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe" Start-Process -FilePath .\install.exe -ArgumentList '-install' -NoNewWindow -Wait From c972cf6d46817a01a24fefec81b2f47c29da86c8 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sat, 20 Sep 2025 09:12:14 +0200 Subject: [PATCH 090/172] set vulkan path --- .github/workflows/test.yaml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index a81abd5d2..1ba1afbf6 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -142,11 +142,12 @@ jobs: $ErrorActionPreference = "Stop" if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') { Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe" - Start-Process -FilePath .\install.exe -ArgumentList '-install' -NoNewWindow -Wait + Start-Process -FilePath .\install.exe /S } $vulkanPath = (Resolve-Path "C:\VulkanSDK\*").path echo "$vulkanPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append + echo "VULKAN_SDK=$vulkanPath" >> $env:GITHUB_ENV - if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }} uses: actions/cache/save@v4 with: From d1125ea3498c5d1339275a3aaeacfa26970caf80 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sat, 20 Sep 2025 09:14:02 +0200 Subject: [PATCH 091/172] comment out cude for faster turnaround --- .github/workflows/test.yaml | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 1ba1afbf6..b5c5b0f51 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -77,18 +77,18 @@ jobs: matrix: include: - preset: CPU - - preset: CUDA - install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe - flags: '-DCMAKE_CUDA_ARCHITECTURES=80' - cuda-components: - - '"cudart"' - - '"nvcc"' - - '"cublas"' - - '"cublas_dev"' - - '"crt"' - - '"nvvm"' - - '"nvptxcompiler"' - cuda-version: '13.0' + #- preset: CUDA + # install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe + # flags: '-DCMAKE_CUDA_ARCHITECTURES=80' + # cuda-components: + # - '"cudart"' + # - '"nvcc"' + # - '"cublas"' + # - '"cublas_dev"' + # - '"crt"' + # - '"nvvm"' + # - '"nvptxcompiler"' + # cuda-version: '13.0' #- preset: ROCm # install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe # flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"' From 7e161f1dbfc199856fb521672cca53f72191e511 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sat, 20 Sep 2025 09:16:54 +0200 Subject: [PATCH 092/172] correct vulkan install --- .github/workflows/test.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index b5c5b0f51..7dca3da76 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -142,7 +142,7 @@ jobs: $ErrorActionPreference = "Stop" if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') { Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe" - Start-Process -FilePath .\install.exe /S + Start-Process -FilePath .\install.exe -ArgumentList '/S' } $vulkanPath = (Resolve-Path "C:\VulkanSDK\*").path From b4595f002271f83536271f645244d725d46163ad Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sat, 20 Sep 2025 09:31:58 +0200 Subject: [PATCH 093/172] correct vulkan silent install --- .github/workflows/test.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 7dca3da76..7d36c9ea8 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -142,7 +142,7 @@ jobs: $ErrorActionPreference = "Stop" if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') { Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe" - Start-Process -FilePath .\install.exe -ArgumentList '/S' + Start-Process -FilePath .\install.exe -ArgumentList '-c --am --al in' } $vulkanPath = (Resolve-Path "C:\VulkanSDK\*").path From 6e310d1cb6f716b24c990166f4f97e0b69fcfd94 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sat, 20 Sep 2025 09:37:25 +0200 Subject: [PATCH 094/172] fixed install command --- .github/workflows/test.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 7d36c9ea8..eeeb74748 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -142,7 +142,7 @@ jobs: $ErrorActionPreference = "Stop" if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') { Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe" - Start-Process -FilePath .\install.exe -ArgumentList '-c --am --al in' + Start-Process -FilePath .\install.exe -ArgumentList "-c","--am","--al","in" -NoNewWindow -Wait } $vulkanPath = (Resolve-Path "C:\VulkanSDK\*").path From b244c9f9f3db6d52c808695d72652f5cd9d98b1c Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sat, 20 Sep 2025 09:44:09 +0200 Subject: [PATCH 095/172] revert debugging changes (vulkan builds on windows) --- .github/workflows/test.yaml | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index eeeb74748..6a77f909b 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -71,27 +71,27 @@ jobs: cmake --build --preset ${{ matrix.preset }} --parallel windows: - # needs: [changes] - # if: needs.changes.outputs.changed == 'True' + needs: [changes] + if: needs.changes.outputs.changed == 'True' strategy: matrix: include: - preset: CPU - #- preset: CUDA - # install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe - # flags: '-DCMAKE_CUDA_ARCHITECTURES=80' - # cuda-components: - # - '"cudart"' - # - '"nvcc"' - # - '"cublas"' - # - '"cublas_dev"' - # - '"crt"' - # - '"nvvm"' - # - '"nvptxcompiler"' - # cuda-version: '13.0' - #- preset: ROCm - # install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe - # flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"' + - preset: CUDA + install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe + flags: '-DCMAKE_CUDA_ARCHITECTURES=80' + cuda-components: + - '"cudart"' + - '"nvcc"' + - '"cublas"' + - '"cublas_dev"' + - '"crt"' + - '"nvvm"' + - '"nvptxcompiler"' + cuda-version: '13.0' + - preset: ROCm + install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe + flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"' - preset: Vulkan install: https://sdk.lunarg.com/sdk/download/1.4.313.2/windows/vulkansdk-windows-X64-1.4.313.2.exe runs-on: windows-latest From a0389785c760cea52a8887538317ea1cccc12fe1 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sat, 20 Sep 2025 09:45:36 +0200 Subject: [PATCH 096/172] revert windows-latest --- .github/workflows/test.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 6a77f909b..f9f812520 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -94,7 +94,7 @@ jobs: flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"' - preset: Vulkan install: https://sdk.lunarg.com/sdk/download/1.4.313.2/windows/vulkansdk-windows-X64-1.4.313.2.exe - runs-on: windows-latest + runs-on: windows steps: - run: | choco install -y --no-progress ccache ninja From e29bb17613c76fa9afc308e68e7126208f256323 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sat, 20 Sep 2025 09:58:31 +0200 Subject: [PATCH 097/172] trying to build vulkan for linux --- .github/workflows/test.yaml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index f9f812520..66a1b9fb6 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -39,8 +39,8 @@ jobs: echo changed=$(changed 'llama/llama.cpp/**/*' 'ml/backend/ggml/ggml/**/*') | tee -a $GITHUB_OUTPUT linux: - needs: [changes] - if: needs.changes.outputs.changed == 'True' + #needs: [changes] + #if: needs.changes.outputs.changed == 'True' strategy: matrix: include: @@ -52,6 +52,8 @@ jobs: container: rocm/dev-ubuntu-22.04:6.1.2 extra-packages: rocm-libs flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_PREFIX_PATH=/opt/rocm' + - preset: Vulkan + container: nvidia/vulkan:1.3-470 runs-on: linux container: ${{ matrix.container }} steps: From 236c2740171bee72d24d342264209a7681d4760a Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sat, 20 Sep 2025 10:00:14 +0200 Subject: [PATCH 098/172] temporarly disable cuda and rocm --- .github/workflows/test.yaml | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 66a1b9fb6..d53659a3c 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -45,16 +45,16 @@ jobs: matrix: include: - preset: CPU - - preset: CUDA - container: nvidia/cuda:13.0.0-devel-ubuntu22.04 - flags: '-DCMAKE_CUDA_ARCHITECTURES=87' - - preset: ROCm - container: rocm/dev-ubuntu-22.04:6.1.2 - extra-packages: rocm-libs - flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_PREFIX_PATH=/opt/rocm' + #- preset: CUDA + # container: nvidia/cuda:13.0.0-devel-ubuntu22.04 + # flags: '-DCMAKE_CUDA_ARCHITECTURES=87' + #- preset: ROCm + # container: rocm/dev-ubuntu-22.04:6.1.2 + # extra-packages: rocm-libs + # flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_PREFIX_PATH=/opt/rocm' - preset: Vulkan container: nvidia/vulkan:1.3-470 - runs-on: linux + runs-on: ubuntu-latest container: ${{ matrix.container }} steps: - uses: actions/checkout@v4 From af50fd5af7d4e10e7ebb2a76630513a4c603f49a Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sat, 20 Sep 2025 10:08:24 +0200 Subject: [PATCH 099/172] try again linux build --- .github/workflows/test.yaml | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index d53659a3c..adfac78f6 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -53,7 +53,11 @@ jobs: # extra-packages: rocm-libs # flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_PREFIX_PATH=/opt/rocm' - preset: Vulkan - container: nvidia/vulkan:1.3-470 + container: ubuntu:22.04 + extra-packages: > + mesa-vulkan-drivers vulkan-tools + libvulkan1 libvulkan-dev + lunarg-vulkan-sdk runs-on: ubuntu-latest container: ${{ matrix.container }} steps: @@ -61,7 +65,18 @@ jobs: - run: | [ -n "${{ matrix.container }}" ] || sudo=sudo $sudo apt-get update + # Add LunarG Vulkan SDK apt repo for Ubuntu 22.04 + if [ "${{ matrix.preset }}" = "Vulkan" ]; then + $sudo apt-get install -y --no-install-recommends wget gnupg ca-certificates software-properties-common + wget -qO - https://packages.lunarg.com/lunarg-signing-key-pub.asc | $sudo gpg --dearmor -o /usr/share/keyrings/lunarg-archive-keyring.gpg + echo "deb [signed-by=/usr/share/keyrings/lunarg-archive-keyring.gpg] https://packages.lunarg.com/vulkan/1.4.313.2/ubuntu jammy main" | $sudo tee /etc/apt/sources.list.d/lunarg-vulkan-1.4.313.2-jammy.list + $sudo apt-get update + fi $sudo apt-get install -y cmake ccache ${{ matrix.extra-packages }} + # Export VULKAN_SDK if provided by LunarG package (defensive) + if [ -d "/usr/lib/x86_64-linux-gnu/vulkan" ] && [ "${{ matrix.preset }}" = "Vulkan" ]; then + echo "VULKAN_SDK=/usr" >> $GITHUB_ENV + fi env: DEBIAN_FRONTEND: noninteractive - uses: actions/cache@v4 From c91b494a8b842c44b2f85618628140e30377ffb8 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sat, 20 Sep 2025 10:10:10 +0200 Subject: [PATCH 100/172] fix version --- .github/workflows/test.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index adfac78f6..fb74891c8 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -69,7 +69,7 @@ jobs: if [ "${{ matrix.preset }}" = "Vulkan" ]; then $sudo apt-get install -y --no-install-recommends wget gnupg ca-certificates software-properties-common wget -qO - https://packages.lunarg.com/lunarg-signing-key-pub.asc | $sudo gpg --dearmor -o /usr/share/keyrings/lunarg-archive-keyring.gpg - echo "deb [signed-by=/usr/share/keyrings/lunarg-archive-keyring.gpg] https://packages.lunarg.com/vulkan/1.4.313.2/ubuntu jammy main" | $sudo tee /etc/apt/sources.list.d/lunarg-vulkan-1.4.313.2-jammy.list + echo "deb [signed-by=/usr/share/keyrings/lunarg-archive-keyring.gpg] https://packages.lunarg.com/vulkan/1.4.313/ubuntu jammy main" | $sudo tee /etc/apt/sources.list.d/lunarg-vulkan-1.4.313-jammy.list $sudo apt-get update fi $sudo apt-get install -y cmake ccache ${{ matrix.extra-packages }} From 475d2c2583544f8c2fa223a693ca8cce5217b1d2 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sat, 20 Sep 2025 10:15:29 +0200 Subject: [PATCH 101/172] trying to fix --- .github/workflows/test.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index fb74891c8..5d03b9579 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -69,7 +69,7 @@ jobs: if [ "${{ matrix.preset }}" = "Vulkan" ]; then $sudo apt-get install -y --no-install-recommends wget gnupg ca-certificates software-properties-common wget -qO - https://packages.lunarg.com/lunarg-signing-key-pub.asc | $sudo gpg --dearmor -o /usr/share/keyrings/lunarg-archive-keyring.gpg - echo "deb [signed-by=/usr/share/keyrings/lunarg-archive-keyring.gpg] https://packages.lunarg.com/vulkan/1.4.313/ubuntu jammy main" | $sudo tee /etc/apt/sources.list.d/lunarg-vulkan-1.4.313-jammy.list + echo "deb [signed-by=/usr/share/keyrings/lunarg-archive-keyring.gpg] https://packages.lunarg.com/vulkan/1.4.313/dists jammy main" | $sudo tee /etc/apt/sources.list.d/lunarg-vulkan-1.4.313-jammy.list $sudo apt-get update fi $sudo apt-get install -y cmake ccache ${{ matrix.extra-packages }} From 26df69a025dbc502378df715aa55f830380f900b Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sat, 20 Sep 2025 10:24:31 +0200 Subject: [PATCH 102/172] trying again --- .github/workflows/test.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 5d03b9579..a05e7f7bf 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -57,7 +57,7 @@ jobs: extra-packages: > mesa-vulkan-drivers vulkan-tools libvulkan1 libvulkan-dev - lunarg-vulkan-sdk + lunarg-vulkan-sdk=1.4.313.2 runs-on: ubuntu-latest container: ${{ matrix.container }} steps: @@ -67,9 +67,9 @@ jobs: $sudo apt-get update # Add LunarG Vulkan SDK apt repo for Ubuntu 22.04 if [ "${{ matrix.preset }}" = "Vulkan" ]; then - $sudo apt-get install -y --no-install-recommends wget gnupg ca-certificates software-properties-common + $sudo apt-get install -y --no-install-recommends wget gnupg ca-certificates software-properties-common wget -qO - https://packages.lunarg.com/lunarg-signing-key-pub.asc | $sudo gpg --dearmor -o /usr/share/keyrings/lunarg-archive-keyring.gpg - echo "deb [signed-by=/usr/share/keyrings/lunarg-archive-keyring.gpg] https://packages.lunarg.com/vulkan/1.4.313/dists jammy main" | $sudo tee /etc/apt/sources.list.d/lunarg-vulkan-1.4.313-jammy.list + wget -qO /etc/apt/sources.list.d/lunarg-vulkan-jammy.list https://packages.lunarg.com/vulkan/lunarg-vulkan-jammy.list $sudo apt-get update fi $sudo apt-get install -y cmake ccache ${{ matrix.extra-packages }} From 62a8d66002955474dfac5c9a9512065ada7339e3 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sat, 20 Sep 2025 10:30:31 +0200 Subject: [PATCH 103/172] trying again --- .github/workflows/test.yaml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index a05e7f7bf..589940a05 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -67,9 +67,10 @@ jobs: $sudo apt-get update # Add LunarG Vulkan SDK apt repo for Ubuntu 22.04 if [ "${{ matrix.preset }}" = "Vulkan" ]; then - $sudo apt-get install -y --no-install-recommends wget gnupg ca-certificates software-properties-common + $sudo apt-get install -y --no-install-recommends wget gnupg ca-certificates software-properties-common wget -qO - https://packages.lunarg.com/lunarg-signing-key-pub.asc | $sudo gpg --dearmor -o /usr/share/keyrings/lunarg-archive-keyring.gpg - wget -qO /etc/apt/sources.list.d/lunarg-vulkan-jammy.list https://packages.lunarg.com/vulkan/lunarg-vulkan-jammy.list + # Use signed-by to bind the repo to the installed keyring to avoid NO_PUBKEY + echo "deb [signed-by=/usr/share/keyrings/lunarg-archive-keyring.gpg] https://packages.lunarg.com/vulkan jammy main" | $sudo tee /etc/apt/sources.list.d/lunarg-vulkan-jammy.list > /dev/null $sudo apt-get update fi $sudo apt-get install -y cmake ccache ${{ matrix.extra-packages }} From 0f86789808106235bdedf2d3a55f18fe6050a25b Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sat, 20 Sep 2025 10:31:44 +0200 Subject: [PATCH 104/172] fix version --- .github/workflows/test.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 589940a05..7b08e4c8d 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -57,7 +57,7 @@ jobs: extra-packages: > mesa-vulkan-drivers vulkan-tools libvulkan1 libvulkan-dev - lunarg-vulkan-sdk=1.4.313.2 + lunarg-vulkan-sdk=1.4.313 runs-on: ubuntu-latest container: ${{ matrix.container }} steps: From 79a0f526b1faef6037281ed0101a4466a30b6dbf Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sat, 20 Sep 2025 10:33:23 +0200 Subject: [PATCH 105/172] fixed vulkan-sdk name --- .github/workflows/test.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 7b08e4c8d..bd0cd5b75 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -57,7 +57,7 @@ jobs: extra-packages: > mesa-vulkan-drivers vulkan-tools libvulkan1 libvulkan-dev - lunarg-vulkan-sdk=1.4.313 + vulkan-sdk=1.4.313 runs-on: ubuntu-latest container: ${{ matrix.container }} steps: From 3ccc18f1e16282eb0e231307953b00ab121f920b Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sat, 20 Sep 2025 10:36:48 +0200 Subject: [PATCH 106/172] try again --- .github/workflows/test.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index bd0cd5b75..7b08e4c8d 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -57,7 +57,7 @@ jobs: extra-packages: > mesa-vulkan-drivers vulkan-tools libvulkan1 libvulkan-dev - vulkan-sdk=1.4.313 + lunarg-vulkan-sdk=1.4.313 runs-on: ubuntu-latest container: ${{ matrix.container }} steps: From a7557cf1a84beea97e9106490c812af24778d45e Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sat, 20 Sep 2025 10:39:05 +0200 Subject: [PATCH 107/172] trying again --- .github/workflows/test.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 7b08e4c8d..9d82ae396 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -57,7 +57,7 @@ jobs: extra-packages: > mesa-vulkan-drivers vulkan-tools libvulkan1 libvulkan-dev - lunarg-vulkan-sdk=1.4.313 + vulkan-sdk=1.4.313.0 runs-on: ubuntu-latest container: ${{ matrix.container }} steps: From 19bc49de5f2cde07316986afaaa4d59356e6fa06 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sat, 20 Sep 2025 10:48:18 +0200 Subject: [PATCH 108/172] try without version number --- .github/workflows/test.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 9d82ae396..369468d0a 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -57,7 +57,7 @@ jobs: extra-packages: > mesa-vulkan-drivers vulkan-tools libvulkan1 libvulkan-dev - vulkan-sdk=1.4.313.0 + lunarg-vulkan-sdk runs-on: ubuntu-latest container: ${{ matrix.container }} steps: From 6f546457de5093028069c2c97cd7743d6c950d1f Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sat, 20 Sep 2025 10:49:24 +0200 Subject: [PATCH 109/172] try again --- .github/workflows/test.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 369468d0a..f11262431 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -57,7 +57,7 @@ jobs: extra-packages: > mesa-vulkan-drivers vulkan-tools libvulkan1 libvulkan-dev - lunarg-vulkan-sdk + vulkan-sdk runs-on: ubuntu-latest container: ${{ matrix.container }} steps: From fe471917201e8ff298a80833b7b362711e151bb8 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sat, 20 Sep 2025 10:53:43 +0200 Subject: [PATCH 110/172] add some more extra --- .github/workflows/test.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index f11262431..7c00f869c 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -57,7 +57,7 @@ jobs: extra-packages: > mesa-vulkan-drivers vulkan-tools libvulkan1 libvulkan-dev - vulkan-sdk + vulkan-sdk cmake ccache g++ make runs-on: ubuntu-latest container: ${{ matrix.container }} steps: From 2098e6a8e3ab739999c93289e8291df0cd1a2783 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sat, 20 Sep 2025 11:00:37 +0200 Subject: [PATCH 111/172] trying to use version 1.4.313 --- .github/workflows/test.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 7c00f869c..02178d32d 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -70,7 +70,7 @@ jobs: $sudo apt-get install -y --no-install-recommends wget gnupg ca-certificates software-properties-common wget -qO - https://packages.lunarg.com/lunarg-signing-key-pub.asc | $sudo gpg --dearmor -o /usr/share/keyrings/lunarg-archive-keyring.gpg # Use signed-by to bind the repo to the installed keyring to avoid NO_PUBKEY - echo "deb [signed-by=/usr/share/keyrings/lunarg-archive-keyring.gpg] https://packages.lunarg.com/vulkan jammy main" | $sudo tee /etc/apt/sources.list.d/lunarg-vulkan-jammy.list > /dev/null + echo "deb [signed-by=/usr/share/keyrings/lunarg-archive-keyring.gpg] https://packages.lunarg.com/vulkan/1.4.313 jammy main" | $sudo tee /etc/apt/sources.list.d/lunarg-vulkan-1.4.313-jammy.list > /dev/null $sudo apt-get update fi $sudo apt-get install -y cmake ccache ${{ matrix.extra-packages }} From 04fba9ba09b8001726ac77505badec36a1b05451 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sat, 20 Sep 2025 11:03:09 +0200 Subject: [PATCH 112/172] revert debugging changes --- .github/workflows/test.yaml | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 02178d32d..c9e0b917c 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -39,26 +39,26 @@ jobs: echo changed=$(changed 'llama/llama.cpp/**/*' 'ml/backend/ggml/ggml/**/*') | tee -a $GITHUB_OUTPUT linux: - #needs: [changes] - #if: needs.changes.outputs.changed == 'True' + needs: [changes] + if: needs.changes.outputs.changed == 'True' strategy: matrix: include: - preset: CPU - #- preset: CUDA - # container: nvidia/cuda:13.0.0-devel-ubuntu22.04 - # flags: '-DCMAKE_CUDA_ARCHITECTURES=87' - #- preset: ROCm - # container: rocm/dev-ubuntu-22.04:6.1.2 - # extra-packages: rocm-libs - # flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_PREFIX_PATH=/opt/rocm' + - preset: CUDA + container: nvidia/cuda:13.0.0-devel-ubuntu22.04 + flags: '-DCMAKE_CUDA_ARCHITECTURES=87' + - preset: ROCm + container: rocm/dev-ubuntu-22.04:6.1.2 + extra-packages: rocm-libs + flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_PREFIX_PATH=/opt/rocm' - preset: Vulkan container: ubuntu:22.04 extra-packages: > mesa-vulkan-drivers vulkan-tools libvulkan1 libvulkan-dev vulkan-sdk cmake ccache g++ make - runs-on: ubuntu-latest + runs-on: linux container: ${{ matrix.container }} steps: - uses: actions/checkout@v4 From d26d920fb26575448ebc4d77ab42e6b15a2b7e54 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sat, 20 Sep 2025 15:18:39 +0200 Subject: [PATCH 113/172] Filter out already supported gpus --- discover/gpu.go | 56 +++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 45 insertions(+), 11 deletions(-) diff --git a/discover/gpu.go b/discover/gpu.go index bce12d2d1..a96e5de5b 100644 --- a/discover/gpu.go +++ b/discover/gpu.go @@ -214,6 +214,30 @@ func GetCPUInfo() GpuInfoList { return GpuInfoList{cpus[0].GpuInfo} } +// gpuIDExistsInOtherBackends returns true if the given ID exists in CUDA, ROCm, or OneAPI lists. +// Note: It intentionally does not check Vulkan to avoid self-comparison during Vulkan discovery. +func gpuInfoExistsInOtherBackends(gpu VulkanGPUInfo) string { + for _, g := range cudaGPUs { + if g.ID == gpu.ID { + return "cuda" + } + } + + // ID is not always filled, so use the gpu Name for duplicate detection + for _, g := range rocmGPUs { + if g.ID == gpu.ID || g.Name == gpu.Name { + return "rocm" + } + } + + for _, g := range oneapiGPUs { + if g.ID == gpu.ID { + return "oneapi" + } + } + return "" +} + func GetGPUInfo() GpuInfoList { // TODO - consider exploring lspci (and equivalent on windows) to check for // GPUs so we can report warnings if we see Nvidia/AMD but fail to load the libraries @@ -413,6 +437,17 @@ func GetGPUInfo() GpuInfoList { } } + //rocmGPUs, err = AMDGetGPUInfo() + + // The ID field is used in context of the filtered set of GPUS + // so we have to replace any of these numeric IDs with their + // placement in this set of GPUs + //for i := range rocmGPUs { + // if _, err := strconv.Atoi(rocmGPUs[i].ID); err == nil { + // rocmGPUs[i].ID = strconv.Itoa(i) + // } + //} + // Vulkan vHandles = initVulkanHandles() for i := range vHandles.deviceCount { @@ -442,20 +477,19 @@ func GetGPUInfo() GpuInfoList { gpuInfo.DriverMinor = int(memInfo.minor) // TODO potentially sort on our own algorithm instead of what the underlying GPU library does... - vulkanGPUs = append(vulkanGPUs, gpuInfo) + var backend = gpuInfoExistsInOtherBackends(gpuInfo) + if backend != "" { + unsupportedGPUs = append(unsupportedGPUs, + UnsupportedGPUInfo{ + GpuInfo: gpuInfo.GpuInfo, + }) + slog.Info(fmt.Sprintf("[%-s] Vulkan GPU is supported by [%-s]", gpuInfo.ID, backend)) + } else { + vulkanGPUs = append(vulkanGPUs, gpuInfo) + } } } - rocmGPUs, err = AMDGetGPUInfo() - - // The ID field is used in context of the filtered set of GPUS - // so we have to replace any of these numeric IDs with their - // placement in this set of GPUs - for i := range rocmGPUs { - if _, err := strconv.Atoi(rocmGPUs[i].ID); err == nil { - rocmGPUs[i].ID = strconv.Itoa(i) - } - } if err != nil { bootstrapErrors = append(bootstrapErrors, err) } From 1cb70716bf0d4304deb31f91d817c5e41095f56d Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sat, 20 Sep 2025 15:26:24 +0200 Subject: [PATCH 114/172] revert debug code --- discover/gpu.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/discover/gpu.go b/discover/gpu.go index a96e5de5b..f6152bf0d 100644 --- a/discover/gpu.go +++ b/discover/gpu.go @@ -437,16 +437,16 @@ func GetGPUInfo() GpuInfoList { } } - //rocmGPUs, err = AMDGetGPUInfo() + rocmGPUs, err = AMDGetGPUInfo() // The ID field is used in context of the filtered set of GPUS // so we have to replace any of these numeric IDs with their // placement in this set of GPUs - //for i := range rocmGPUs { - // if _, err := strconv.Atoi(rocmGPUs[i].ID); err == nil { - // rocmGPUs[i].ID = strconv.Itoa(i) - // } - //} + for i := range rocmGPUs { + if _, err := strconv.Atoi(rocmGPUs[i].ID); err == nil { + rocmGPUs[i].ID = strconv.Itoa(i) + } + } // Vulkan vHandles = initVulkanHandles() From f761292516151c7cb3921cedbb40a82377bc6cab Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Fri, 22 Aug 2025 15:08:59 -0700 Subject: [PATCH 115/172] Use runners for GPU discovery This revamps how we discover GPUs in the system by leveraging the Ollama runner. This should eliminate inconsistency between our GPU discovery and the runners capabilities at runtime, particularly for cases where we try to filter out unsupported GPUs. Now the runner does that implicitly based on the actual device list. In some cases free VRAM reporting can be unreliable which can leaad to scheduling mistakes, so this also includes a patch to leverage more reliable VRAM reporting libraries if available. Automatic workarounds have been removed as only one GPU leveraged this, which is now documented. This GPU will soon fall off the support matrix with the next ROCm bump. Additional cleanup of the scheduler and discovery packages can be done in the future once we have switched on the new memory management code, and removed support for the llama runner. --- discover/amd_common.go | 83 -- discover/amd_hip_windows.go | 147 --- discover/amd_linux.go | 549 ----------- discover/amd_windows.go | 226 ----- discover/cpu_common.go | 24 - discover/{gpu_linux.go => cpu_linux.go} | 64 +- .../{gpu_linux_test.go => cpu_linux_test.go} | 5 +- discover/{gpu_windows.go => cpu_windows.go} | 35 +- ...pu_windows_test.go => cpu_windows_test.go} | 0 discover/cuda_common.go | 64 -- discover/gpu.go | 798 +++------------- discover/gpu_darwin.go | 67 +- discover/gpu_info.h | 72 -- discover/gpu_info_cudart.c | 181 ---- discover/gpu_info_cudart.h | 145 --- discover/gpu_info_nvcuda.c | 251 ----- discover/gpu_info_nvcuda.h | 79 -- discover/gpu_info_nvml.c | 104 --- discover/gpu_info_nvml.h | 48 - discover/gpu_info_oneapi.c | 259 ------ discover/gpu_info_oneapi.h | 203 ----- discover/gpu_test.go | 60 -- discover/runner.go | 543 +++++++++++ discover/types.go | 131 +-- docs/gpu.md | 3 + .../0026-GPU-discovery-enhancements.patch | 860 ++++++++++++++++++ llm/memory.go | 2 +- llm/server.go | 34 +- ml/backend.go | 114 +++ ml/backend/ggml/ggml.go | 74 ++ ml/backend/ggml/ggml/include/ggml-backend.h | 9 + ml/backend/ggml/ggml/src/CMakeLists.txt | 2 + .../ggml/ggml/src/ggml-cuda/ggml-cuda.cu | 79 +- .../ggml/ggml/src/ggml-cuda/vendors/hip.h | 1 + ml/backend/ggml/ggml/src/ggml-impl.h | 8 + ml/backend/ggml/ggml/src/mem_hip.cpp | 449 +++++++++ ml/backend/ggml/ggml/src/mem_nvml.cpp | 172 ++++ ml/nn/pooling/pooling_test.go | 17 +- runner/ollamarunner/runner.go | 42 + scripts/build_windows.ps1 | 2 +- server/routes.go | 4 +- server/routes_debug_test.go | 8 +- server/routes_generate_test.go | 12 +- server/routes_harmony_streaming_test.go | 12 +- server/sched.go | 58 +- server/sched_test.go | 21 +- 46 files changed, 2694 insertions(+), 3427 deletions(-) delete mode 100644 discover/amd_common.go delete mode 100644 discover/amd_hip_windows.go delete mode 100644 discover/amd_linux.go delete mode 100644 discover/amd_windows.go delete mode 100644 discover/cpu_common.go rename discover/{gpu_linux.go => cpu_linux.go} (75%) rename discover/{gpu_linux_test.go => cpu_linux_test.go} (99%) rename discover/{gpu_windows.go => cpu_windows.go} (91%) rename discover/{gpu_windows_test.go => cpu_windows_test.go} (100%) delete mode 100644 discover/cuda_common.go delete mode 100644 discover/gpu_info.h delete mode 100644 discover/gpu_info_cudart.c delete mode 100644 discover/gpu_info_cudart.h delete mode 100644 discover/gpu_info_nvcuda.c delete mode 100644 discover/gpu_info_nvcuda.h delete mode 100644 discover/gpu_info_nvml.c delete mode 100644 discover/gpu_info_nvml.h delete mode 100644 discover/gpu_info_oneapi.c delete mode 100644 discover/gpu_info_oneapi.h delete mode 100644 discover/gpu_test.go create mode 100644 discover/runner.go create mode 100644 llama/patches/0026-GPU-discovery-enhancements.patch create mode 100644 ml/backend/ggml/ggml/src/mem_hip.cpp create mode 100644 ml/backend/ggml/ggml/src/mem_nvml.cpp diff --git a/discover/amd_common.go b/discover/amd_common.go deleted file mode 100644 index 08834b22d..000000000 --- a/discover/amd_common.go +++ /dev/null @@ -1,83 +0,0 @@ -//go:build linux || windows - -package discover - -import ( - "errors" - "log/slog" - "os" - "path/filepath" - "runtime" - "strings" -) - -// Determine if the given ROCm lib directory is usable by checking for existence of some glob patterns -func rocmLibUsable(libDir string) bool { - slog.Debug("evaluating potential rocm lib dir " + libDir) - for _, g := range ROCmLibGlobs { - res, _ := filepath.Glob(filepath.Join(libDir, g)) - if len(res) == 0 { - return false - } - } - return true -} - -func GetSupportedGFX(libDir string) ([]string, error) { - var ret []string - files, err := filepath.Glob(filepath.Join(libDir, "rocblas", "library", "TensileLibrary_lazy_gfx*.dat")) - if err != nil { - return nil, err - } - for _, file := range files { - ret = append(ret, strings.TrimSuffix(strings.TrimPrefix(filepath.Base(file), "TensileLibrary_lazy_"), ".dat")) - } - return ret, nil -} - -func commonAMDValidateLibDir() (string, error) { - // Favor our bundled version - - // Installer payload location if we're running the installed binary - rocmTargetDir := filepath.Join(LibOllamaPath, "rocm") - if rocmLibUsable(rocmTargetDir) { - slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir) - return rocmTargetDir, nil - } - - // Prefer explicit HIP env var - hipPath := os.Getenv("HIP_PATH") - if hipPath != "" { - hipLibDir := filepath.Join(hipPath, "bin") - if rocmLibUsable(hipLibDir) { - slog.Debug("detected ROCM via HIP_PATH=" + hipPath) - return hipLibDir, nil - } - } - - // Scan the LD_LIBRARY_PATH or PATH - pathEnv := "LD_LIBRARY_PATH" - if runtime.GOOS == "windows" { - pathEnv = "PATH" - } - - paths := os.Getenv(pathEnv) - for _, path := range filepath.SplitList(paths) { - d, err := filepath.Abs(path) - if err != nil { - continue - } - if rocmLibUsable(d) { - return d, nil - } - } - - // Well known location(s) - for _, path := range RocmStandardLocations { - if rocmLibUsable(path) { - return path, nil - } - } - - return "", errors.New("no suitable rocm found, falling back to CPU") -} diff --git a/discover/amd_hip_windows.go b/discover/amd_hip_windows.go deleted file mode 100644 index bf19ef064..000000000 --- a/discover/amd_hip_windows.go +++ /dev/null @@ -1,147 +0,0 @@ -package discover - -import ( - "errors" - "fmt" - "log/slog" - "syscall" - "unsafe" - - "golang.org/x/sys/windows" -) - -const ( - hipSuccess = 0 - hipErrorNoDevice = 100 -) - -type hipDevicePropMinimal struct { - Name [256]byte - unused1 [140]byte - GcnArchName [256]byte // gfx#### - iGPU int // Doesn't seem to actually report correctly - unused2 [128]byte -} - -// Wrap the amdhip64.dll library for GPU discovery -type HipLib struct { - dll windows.Handle - hipGetDeviceCount uintptr - hipGetDeviceProperties uintptr - hipMemGetInfo uintptr - hipSetDevice uintptr - hipDriverGetVersion uintptr -} - -func NewHipLib() (*HipLib, error) { - // At runtime we depend on v6, so discover GPUs with the same library for a consistent set of GPUs - h, err := windows.LoadLibrary("amdhip64_6.dll") - if err != nil { - return nil, fmt.Errorf("unable to load amdhip64_6.dll, please make sure to upgrade to the latest amd driver: %w", err) - } - hl := &HipLib{} - hl.dll = h - hl.hipGetDeviceCount, err = windows.GetProcAddress(hl.dll, "hipGetDeviceCount") - if err != nil { - return nil, err - } - hl.hipGetDeviceProperties, err = windows.GetProcAddress(hl.dll, "hipGetDeviceProperties") - if err != nil { - return nil, err - } - hl.hipMemGetInfo, err = windows.GetProcAddress(hl.dll, "hipMemGetInfo") - if err != nil { - return nil, err - } - hl.hipSetDevice, err = windows.GetProcAddress(hl.dll, "hipSetDevice") - if err != nil { - return nil, err - } - hl.hipDriverGetVersion, err = windows.GetProcAddress(hl.dll, "hipDriverGetVersion") - if err != nil { - return nil, err - } - return hl, nil -} - -// The hip library only evaluates the ROCR_VISIBLE_DEVICES variable at startup -// so we have to unload/reset the library after we do our initial discovery -// to make sure our updates to that variable are processed by llama.cpp -func (hl *HipLib) Release() { - err := windows.FreeLibrary(hl.dll) - if err != nil { - slog.Warn("failed to unload amdhip64.dll", "error", err) - } - hl.dll = 0 -} - -func (hl *HipLib) AMDDriverVersion() (driverMajor, driverMinor int, err error) { - if hl.dll == 0 { - return 0, 0, errors.New("dll has been unloaded") - } - var version int - status, _, err := syscall.SyscallN(hl.hipDriverGetVersion, uintptr(unsafe.Pointer(&version))) - if status != hipSuccess { - return 0, 0, fmt.Errorf("failed call to hipDriverGetVersion: %d %s", status, err) - } - - slog.Debug("hipDriverGetVersion", "version", version) - driverMajor = version / 10000000 - driverMinor = (version - (driverMajor * 10000000)) / 100000 - - return driverMajor, driverMinor, nil -} - -func (hl *HipLib) HipGetDeviceCount() int { - if hl.dll == 0 { - slog.Error("dll has been unloaded") - return 0 - } - var count int - status, _, err := syscall.SyscallN(hl.hipGetDeviceCount, uintptr(unsafe.Pointer(&count))) - if status == hipErrorNoDevice { - slog.Info("AMD ROCm reports no devices found") - return 0 - } - if status != hipSuccess { - slog.Warn("failed call to hipGetDeviceCount", "status", status, "error", err) - } - return count -} - -func (hl *HipLib) HipSetDevice(device int) error { - if hl.dll == 0 { - return errors.New("dll has been unloaded") - } - status, _, err := syscall.SyscallN(hl.hipSetDevice, uintptr(device)) - if status != hipSuccess { - return fmt.Errorf("failed call to hipSetDevice: %d %s", status, err) - } - return nil -} - -func (hl *HipLib) HipGetDeviceProperties(device int) (*hipDevicePropMinimal, error) { - if hl.dll == 0 { - return nil, errors.New("dll has been unloaded") - } - var props hipDevicePropMinimal - status, _, err := syscall.SyscallN(hl.hipGetDeviceProperties, uintptr(unsafe.Pointer(&props)), uintptr(device)) - if status != hipSuccess { - return nil, fmt.Errorf("failed call to hipGetDeviceProperties: %d %s", status, err) - } - return &props, nil -} - -// free, total, err -func (hl *HipLib) HipMemGetInfo() (uint64, uint64, error) { - if hl.dll == 0 { - return 0, 0, errors.New("dll has been unloaded") - } - var totalMemory uint64 - var freeMemory uint64 - status, _, err := syscall.SyscallN(hl.hipMemGetInfo, uintptr(unsafe.Pointer(&freeMemory)), uintptr(unsafe.Pointer(&totalMemory))) - if status != hipSuccess { - return 0, 0, fmt.Errorf("failed call to hipMemGetInfo: %d %s", status, err) - } - return freeMemory, totalMemory, nil -} diff --git a/discover/amd_linux.go b/discover/amd_linux.go deleted file mode 100644 index 0f2aa0673..000000000 --- a/discover/amd_linux.go +++ /dev/null @@ -1,549 +0,0 @@ -package discover - -import ( - "bufio" - "errors" - "fmt" - "io" - "io/fs" - "log/slog" - "os" - "path/filepath" - "regexp" - "slices" - "sort" - "strconv" - "strings" - - "github.com/ollama/ollama/envconfig" - "github.com/ollama/ollama/format" -) - -// Discovery logic for AMD/ROCm GPUs - -const ( - DriverVersionFile = "/sys/module/amdgpu/version" - AMDNodesSysfsDir = "/sys/class/kfd/kfd/topology/nodes/" - GPUPropertiesFileGlob = AMDNodesSysfsDir + "*/properties" - - // Prefix with the node dir - GPUTotalMemoryFileGlob = "mem_banks/*/properties" // size_in_bytes line - - // Direct Rendering Manager sysfs location - DRMDeviceDirGlob = "/sys/class/drm/card*/device" - DRMTotalMemoryFile = "mem_info_vram_total" - DRMUsedMemoryFile = "mem_info_vram_used" - - // In hex; properties file is in decimal - DRMUniqueIDFile = "unique_id" - DRMVendorFile = "vendor" - DRMDeviceFile = "device" -) - -var ( - // Used to validate if the given ROCm lib is usable - ROCmLibGlobs = []string{"libhipblas.so.2*", "rocblas"} // TODO - probably include more coverage of files here... - RocmStandardLocations = []string{"/opt/rocm/lib", "/usr/lib64"} -) - -// Gather GPU information from the amdgpu driver if any supported GPUs are detected -// Only called once during bootstrap -func AMDGetGPUInfo() ([]RocmGPUInfo, error) { - resp := []RocmGPUInfo{} - if !AMDDetected() { - return resp, fmt.Errorf("AMD GPUs not detected") - } - - // Opportunistic logging of driver version to aid in troubleshooting - driverMajor, driverMinor, err := AMDDriverVersion() - if err != nil { - // TODO - if we see users crash and burn with the upstreamed kernel this can be adjusted to hard-fail rocm support and fallback to CPU - slog.Warn("ollama recommends running the https://www.amd.com/en/support/download/linux-drivers.html", "error", err) - } - - // Determine if the user has already pre-selected which GPUs to look at, then ignore the others - var visibleDevices []string - hipVD := envconfig.HipVisibleDevices() // zero based index only - rocrVD := envconfig.RocrVisibleDevices() // zero based index or UUID - gpuDO := envconfig.GpuDeviceOrdinal() // zero based index - switch { - case rocrVD != "": - visibleDevices = strings.Split(rocrVD, ",") - case hipVD != "": - visibleDevices = strings.Split(hipVD, ",") - case gpuDO != "": - visibleDevices = strings.Split(gpuDO, ",") - } - - gfxOverride := envconfig.HsaOverrideGfxVersion() - var supported []string - var libDir string - - // The amdgpu driver always exposes the host CPU(s) first, but we have to skip them and subtract - // from the other IDs to get alignment with the HIP libraries expectations (zero is the first GPU, not the CPU) - matches, _ := filepath.Glob(GPUPropertiesFileGlob) - sort.Slice(matches, func(i, j int) bool { - // /sys/class/kfd/kfd/topology/nodes//properties - a, err := strconv.ParseInt(filepath.Base(filepath.Dir(matches[i])), 10, 64) - if err != nil { - slog.Debug("parse err", "error", err, "match", matches[i]) - return false - } - b, err := strconv.ParseInt(filepath.Base(filepath.Dir(matches[j])), 10, 64) - if err != nil { - slog.Debug("parse err", "error", err, "match", matches[i]) - return false - } - return a < b - }) - gpuCount := 0 - gpuOrdinalID := 0 - for _, match := range matches { - slog.Debug("evaluating amdgpu node " + match) - fp, err := os.Open(match) - if err != nil { - slog.Debug("failed to open sysfs node", "file", match, "error", err) - continue - } - defer fp.Close() - - scanner := bufio.NewScanner(fp) - isCPU := false - var major, minor, patch uint64 - var vendor, device, uniqueID uint64 - for scanner.Scan() { - line := strings.TrimSpace(scanner.Text()) - // Note: we could also use "cpu_cores_count X" where X is greater than zero to detect CPUs - if strings.HasPrefix(line, "gfx_target_version") { - ver := strings.Fields(line) - - // Detect CPUs - if len(ver) == 2 && ver[1] == "0" { - slog.Debug("detected CPU " + match) - isCPU = true - break - } - - if len(ver) != 2 || len(ver[1]) < 5 { - slog.Warn("malformed "+match, "gfx_target_version", line) - // If this winds up being a CPU, our offsets may be wrong - continue - } - l := len(ver[1]) - var err1, err2, err3 error - patch, err1 = strconv.ParseUint(ver[1][l-2:l], 10, 32) - minor, err2 = strconv.ParseUint(ver[1][l-4:l-2], 10, 32) - major, err3 = strconv.ParseUint(ver[1][:l-4], 10, 32) - if err1 != nil || err2 != nil || err3 != nil { - slog.Debug("malformed int " + line) - continue - } - } else if strings.HasPrefix(line, "vendor_id") { - ver := strings.Fields(line) - if len(ver) != 2 { - slog.Debug("malformed", "vendor_id", line) - continue - } - vendor, err = strconv.ParseUint(ver[1], 10, 64) - if err != nil { - slog.Debug("malformed", "vendor_id", line, "error", err) - } - } else if strings.HasPrefix(line, "device_id") { - ver := strings.Fields(line) - if len(ver) != 2 { - slog.Debug("malformed", "device_id", line) - continue - } - device, err = strconv.ParseUint(ver[1], 10, 64) - if err != nil { - slog.Debug("malformed", "device_id", line, "error", err) - } - } else if strings.HasPrefix(line, "unique_id") { - ver := strings.Fields(line) - if len(ver) != 2 { - slog.Debug("malformed", "unique_id", line) - continue - } - uniqueID, err = strconv.ParseUint(ver[1], 10, 64) - if err != nil { - slog.Debug("malformed", "unique_id", line, "error", err) - } - } - // TODO - any other properties we want to extract and record? - // vendor_id + device_id -> pci lookup for "Name" - // Other metrics that may help us understand relative performance between multiple GPUs - } - - // Note: while ./mem_banks/*/used_memory exists, it doesn't appear to take other VRAM consumers - // into consideration, so we instead map the device over to the DRM driver sysfs nodes which - // do reliably report VRAM usage. - - if isCPU { - continue - } - - // Skip over any GPUs that are masked - if major == 0 && minor == 0 && patch == 0 { - slog.Debug("skipping gpu with gfx000") - continue - } - - // Look up the memory for the current node - totalMemory := uint64(0) - usedMemory := uint64(0) - var usedFile string - mapping := []struct { - id uint64 - filename string - }{ - {vendor, DRMVendorFile}, - {device, DRMDeviceFile}, - {uniqueID, DRMUniqueIDFile}, // Not all devices will report this - } - slog.Debug("mapping amdgpu to drm sysfs nodes", "amdgpu", match, "vendor", vendor, "device", device, "unique_id", uniqueID) - // Map over to DRM location to find the total/free memory - drmMatches, _ := filepath.Glob(DRMDeviceDirGlob) - for _, devDir := range drmMatches { - matched := true - for _, m := range mapping { - if m.id == 0 { - // Null ID means it didn't populate, so we can't use it to match - continue - } - filename := filepath.Join(devDir, m.filename) - buf, err := os.ReadFile(filename) - if err != nil { - slog.Debug("failed to read sysfs node", "file", filename, "error", err) - matched = false - break - } - // values here are in hex, strip off the lead 0x and parse so we can compare the numeric (decimal) values in amdgpu - cmp, err := strconv.ParseUint(strings.TrimPrefix(strings.TrimSpace(string(buf)), "0x"), 16, 64) - if err != nil { - slog.Debug("failed to parse sysfs node", "file", filename, "error", err) - matched = false - break - } - if cmp != m.id { - matched = false - break - } - } - if !matched { - continue - } - - // Found the matching DRM directory - slog.Debug("matched", "amdgpu", match, "drm", devDir) - totalFile := filepath.Join(devDir, DRMTotalMemoryFile) - buf, err := os.ReadFile(totalFile) - if err != nil { - slog.Debug("failed to read sysfs node", "file", totalFile, "error", err) - break - } - totalMemory, err = strconv.ParseUint(strings.TrimSpace(string(buf)), 10, 64) - if err != nil { - slog.Debug("failed to parse sysfs node", "file", totalFile, "error", err) - break - } - - usedFile = filepath.Join(devDir, DRMUsedMemoryFile) - usedMemory, err = getFreeMemory(usedFile) - if err != nil { - slog.Debug("failed to update used memory", "error", err) - } - break - } - - var name string - // TODO - PCI ID lookup - if vendor > 0 && device > 0 { - name = fmt.Sprintf("%04x:%04x", vendor, device) - } - - // Favor UUIDs if available to reduce possibility of getting the numeric IDs wrong - var ID string - if uniqueID != 0 { - ID = fmt.Sprintf("GPU-%016x", uniqueID) - } else { - ID = strconv.Itoa(gpuOrdinalID) - } - - gpuInfo := RocmGPUInfo{ - GpuInfo: GpuInfo{ - Library: "rocm", - memInfo: memInfo{ - TotalMemory: totalMemory, - FreeMemory: (totalMemory - usedMemory), - }, - ID: ID, - filterID: gpuOrdinalID, - Name: name, - Compute: fmt.Sprintf("gfx%d%x%x", major, minor, patch), - MinimumMemory: rocmMinimumMemory, - DriverMajor: driverMajor, - DriverMinor: driverMinor, - }, - usedFilepath: usedFile, - index: gpuCount, - } - - // Keep track of numeric IDs based on valid GPUs - gpuCount += 1 - - // If the user wants to filter to a subset of devices, filter out if we aren't a match - if len(visibleDevices) > 0 { - include := false - for _, visible := range visibleDevices { - if (uniqueID != 0 && visible == gpuInfo.ID) || visible == strconv.Itoa(gpuInfo.index) { - include = true - break - } - } - if !include { - reason := "filtering out device per user request" - slog.Info(reason, "id", gpuInfo.ID, "index", gpuInfo.index, "visible_devices", visibleDevices) - unsupportedGPUs = append(unsupportedGPUs, UnsupportedGPUInfo{ - GpuInfo: gpuInfo.GpuInfo, - Reason: reason, - }) - - continue - } - } - - // Ordinal IDs are based on the visible GPUs - gpuOrdinalID += 1 - - // iGPU detection, remove this check once we can support an iGPU variant of the rocm library - if totalMemory < IGPUMemLimit { - reason := "unsupported Radeon iGPU detected skipping" - slog.Info(reason, "id", gpuInfo.ID, "total", format.HumanBytes2(totalMemory)) - unsupportedGPUs = append(unsupportedGPUs, UnsupportedGPUInfo{ - GpuInfo: gpuInfo.GpuInfo, - Reason: reason, - }) - continue - } - minVer, err := strconv.Atoi(RocmComputeMajorMin) - if err != nil { - slog.Error("invalid RocmComputeMajorMin setting", "value", RocmComputeMajorMin, "error", err) - } - if int(major) < minVer { - reason := fmt.Sprintf("amdgpu too old gfx%d%x%x", major, minor, patch) - slog.Warn(reason, "gpu", gpuInfo.ID) - unsupportedGPUs = append(unsupportedGPUs, UnsupportedGPUInfo{ - GpuInfo: gpuInfo.GpuInfo, - Reason: reason, - }) - - continue - } - - slog.Debug("amdgpu memory", "gpu", gpuInfo.ID, "total", format.HumanBytes2(totalMemory)) - slog.Debug("amdgpu memory", "gpu", gpuInfo.ID, "available", format.HumanBytes2(totalMemory-usedMemory)) - - // Final validation is gfx compatibility - load the library if we haven't already loaded it - // even if the user overrides, we still need to validate the library - if libDir == "" { - libDir, err = AMDValidateLibDir() - if err != nil { - err = fmt.Errorf("unable to verify rocm library: %w", err) - slog.Warn(err.Error()) - unsupportedGPUs = append(unsupportedGPUs, UnsupportedGPUInfo{ - GpuInfo: gpuInfo.GpuInfo, - Reason: err.Error(), - }) - return nil, err - } - } - gpuInfo.DependencyPath = []string{libDir} - - if gfxOverride == "" { - // Only load supported list once - if len(supported) == 0 { - supported, err = GetSupportedGFX(libDir) - if err != nil { - err = fmt.Errorf("failed to lookup supported GFX types: %w", err) - slog.Warn(err.Error()) - unsupportedGPUs = append(unsupportedGPUs, UnsupportedGPUInfo{ - GpuInfo: gpuInfo.GpuInfo, - Reason: err.Error(), - }) - return nil, err - } - slog.Debug("rocm supported GPUs", "types", supported) - } - gfx := gpuInfo.Compute - if !slices.Contains[[]string, string](supported, gfx) { - reason := fmt.Sprintf("amdgpu is not supported (supported types:%s)", supported) - slog.Warn(reason, "gpu_type", gfx, "gpu", gpuInfo.ID, "library", libDir) - unsupportedGPUs = append(unsupportedGPUs, UnsupportedGPUInfo{ - GpuInfo: gpuInfo.GpuInfo, - Reason: reason, - }) - - // TODO - consider discrete markdown just for ROCM troubleshooting? - slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/gpu.md#overrides for HSA_OVERRIDE_GFX_VERSION usage") - continue - } else { - slog.Info("amdgpu is supported", "gpu", gpuInfo.ID, "gpu_type", gfx) - } - } else { - slog.Info("skipping rocm gfx compatibility check", "HSA_OVERRIDE_GFX_VERSION", gfxOverride) - } - - // Check for env var workarounds - if name == "1002:687f" { // Vega RX 56 - gpuInfo.EnvWorkarounds = append(gpuInfo.EnvWorkarounds, "HSA_ENABLE_SDMA=0") - } - - // The GPU has passed all the verification steps and is supported - resp = append(resp, gpuInfo) - } - if len(resp) == 0 { - err := fmt.Errorf("no compatible amdgpu devices detected") - slog.Info(err.Error()) - return nil, err - } - if err := verifyKFDDriverAccess(); err != nil { - err = fmt.Errorf("amdgpu devices detected but permission problems block access: %w", err) - slog.Error(err.Error()) - return nil, err - } - return resp, nil -} - -// Quick check for AMD driver so we can skip amdgpu discovery if not present -func AMDDetected() bool { - // Some driver versions (older?) don't have a version file, so just lookup the parent dir - sysfsDir := filepath.Dir(DriverVersionFile) - _, err := os.Stat(sysfsDir) - if errors.Is(err, os.ErrNotExist) { - slog.Debug("amdgpu driver not detected " + sysfsDir) - return false - } else if err != nil { - slog.Debug("error looking up amd driver", "path", sysfsDir, "error", err) - return false - } - return true -} - -// Prefer to use host installed ROCm, as long as it meets our minimum requirements -// failing that, tell the user how to download it on their own -func AMDValidateLibDir() (string, error) { - libDir, err := commonAMDValidateLibDir() - if err == nil { - return libDir, nil - } - - // Well known ollama installer path - installedRocmDir := "/usr/share/ollama/lib/rocm" - if rocmLibUsable(installedRocmDir) { - return installedRocmDir, nil - } - - // If we still haven't found a usable rocm, the user will have to install it on their own - slog.Warn("amdgpu detected, but no compatible rocm library found. Either install rocm v6, or follow manual install instructions at https://github.com/ollama/ollama/blob/main/docs/linux.md#manual-install") - return "", errors.New("no suitable rocm found, falling back to CPU") -} - -func AMDDriverVersion() (driverMajor, driverMinor int, err error) { - _, err = os.Stat(DriverVersionFile) - if err != nil { - return 0, 0, fmt.Errorf("amdgpu version file missing: %s %w", DriverVersionFile, err) - } - fp, err := os.Open(DriverVersionFile) - if err != nil { - return 0, 0, err - } - defer fp.Close() - verString, err := io.ReadAll(fp) - if err != nil { - return 0, 0, err - } - - pattern := `\A(\d+)\.(\d+).*` - regex := regexp.MustCompile(pattern) - match := regex.FindStringSubmatch(string(verString)) - if len(match) < 2 { - return 0, 0, fmt.Errorf("malformed version string %s", string(verString)) - } - driverMajor, err = strconv.Atoi(match[1]) - if err != nil { - return 0, 0, err - } - driverMinor, err = strconv.Atoi(match[2]) - if err != nil { - return 0, 0, err - } - return driverMajor, driverMinor, nil -} - -func (gpus RocmGPUInfoList) RefreshFreeMemory() error { - if len(gpus) == 0 { - return nil - } - for i := range gpus { - usedMemory, err := getFreeMemory(gpus[i].usedFilepath) - if err != nil { - return err - } - slog.Debug("updating rocm free memory", "gpu", gpus[i].ID, "name", gpus[i].Name, "before", format.HumanBytes2(gpus[i].FreeMemory), "now", format.HumanBytes2(gpus[i].TotalMemory-usedMemory)) - gpus[i].FreeMemory = gpus[i].TotalMemory - usedMemory - } - return nil -} - -func getFreeMemory(usedFile string) (uint64, error) { - buf, err := os.ReadFile(usedFile) - if err != nil { - return 0, fmt.Errorf("failed to read sysfs node %s %w", usedFile, err) - } - usedMemory, err := strconv.ParseUint(strings.TrimSpace(string(buf)), 10, 64) - if err != nil { - slog.Debug("failed to parse sysfs node", "file", usedFile, "error", err) - return 0, fmt.Errorf("failed to parse sysfs node %s %w", usedFile, err) - } - return usedMemory, nil -} - -func verifyKFDDriverAccess() error { - // Verify we have permissions - either running as root, or we have group access to the driver - fd, err := os.OpenFile("/dev/kfd", os.O_RDWR, 0o666) - if err != nil { - if errors.Is(err, fs.ErrPermission) { - return fmt.Errorf("permissions not set up properly. Either run ollama as root, or add you user account to the render group. %w", err) - } else if errors.Is(err, fs.ErrNotExist) { - // Container runtime failure? - return fmt.Errorf("kfd driver not loaded. If running in a container, remember to include '--device /dev/kfd --device /dev/dri'") - } - return fmt.Errorf("failed to check permission on /dev/kfd: %w", err) - } - fd.Close() - return nil -} - -func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) string { - ids := []string{} - for _, info := range gpuInfo { - if info.Library != "rocm" { - continue - } - // If the devices requires a numeric ID, for filtering purposes, we use the unfiltered ID number - if _, err := strconv.Atoi(info.ID); err == nil { - ids = append(ids, fmt.Sprintf("%d", info.filterID)) - } else { - ids = append(ids, info.ID) - } - } - if len(ids) == 0 { - return "" - } - - // There are 3 potential env vars to use to select GPUs. - // ROCR_VISIBLE_DEVICES supports UUID or numeric so is our preferred on linux - // GPU_DEVICE_ORDINAL supports numeric IDs only - // HIP_VISIBLE_DEVICES supports numeric IDs only - return "ROCR_VISIBLE_DEVICES=" + strings.Join(ids, ",") -} diff --git a/discover/amd_windows.go b/discover/amd_windows.go deleted file mode 100644 index 08608ad1c..000000000 --- a/discover/amd_windows.go +++ /dev/null @@ -1,226 +0,0 @@ -package discover - -import ( - "bytes" - "errors" - "fmt" - "log/slog" - "path/filepath" - "slices" - "strconv" - "strings" - - "github.com/ollama/ollama/envconfig" - "github.com/ollama/ollama/format" -) - -const ( - - // TODO We're lookinng for this exact name to detect iGPUs since hipGetDeviceProperties never reports integrated==true - iGPUName = "AMD Radeon(TM) Graphics" -) - -var ( - // Used to validate if the given ROCm lib is usable - ROCmLibGlobs = []string{"hipblas.dll", "rocblas"} // This is not sufficient to discern v5 vs v6 - RocmStandardLocations = []string{"C:\\Program Files\\AMD\\ROCm\\6.1\\bin"} // TODO glob? -) - -// Only called once during bootstrap -func AMDGetGPUInfo() ([]RocmGPUInfo, error) { - resp := []RocmGPUInfo{} - hl, err := NewHipLib() - if err != nil { - slog.Debug(err.Error()) - return nil, err - } - defer hl.Release() - - driverMajor, driverMinor, err := hl.AMDDriverVersion() - if err != nil { - // For now this is benign, but we may eventually need to fail compatibility checks - slog.Debug("error looking up amd driver version", "error", err) - } - - // Note: the HIP library automatically handles subsetting to any *_VISIBLE_DEVICES the user specified - count := hl.HipGetDeviceCount() - if count == 0 { - err := fmt.Errorf("no compatible amdgpu devices detected") - slog.Info(err.Error()) - return nil, err - } - - libDir, err := AMDValidateLibDir() - if err != nil { - err = fmt.Errorf("unable to verify rocm library: %w", err) - slog.Warn(err.Error()) - return nil, err - } - - var supported []string - gfxOverride := envconfig.HsaOverrideGfxVersion() - if gfxOverride == "" { - supported, err = GetSupportedGFX(libDir) - if err != nil { - err = fmt.Errorf("failed to lookup supported GFX types: %w", err) - slog.Warn(err.Error()) - return nil, err - } - } else { - slog.Info("skipping rocm gfx compatibility check", "HSA_OVERRIDE_GFX_VERSION", gfxOverride) - } - - slog.Debug("detected hip devices", "count", count) - // TODO how to determine the underlying device ID when visible devices is causing this to subset? - for i := range count { - err = hl.HipSetDevice(i) - if err != nil { - slog.Warn("set device", "id", i, "error", err) - continue - } - - props, err := hl.HipGetDeviceProperties(i) - if err != nil { - slog.Warn("get properties", "id", i, "error", err) - continue - } - n := bytes.IndexByte(props.Name[:], 0) - name := string(props.Name[:n]) - // TODO is UUID actually populated on windows? - // Can luid be used on windows for setting visible devices (and is it actually set?) - n = bytes.IndexByte(props.GcnArchName[:], 0) - gfx := string(props.GcnArchName[:n]) - slog.Debug("hip device", "id", i, "name", name, "gfx", gfx) - // slog.Info(fmt.Sprintf("[%d] Integrated: %d", i, props.iGPU)) // DOESN'T REPORT CORRECTLY! Always 0 - // TODO Why isn't props.iGPU accurate!? - - freeMemory, totalMemory, err := hl.HipMemGetInfo() - if err != nil { - slog.Warn("get mem info", "id", i, "error", err) - continue - } - - gpuInfo := RocmGPUInfo{ - GpuInfo: GpuInfo{ - Library: "rocm", - memInfo: memInfo{ - TotalMemory: totalMemory, - FreeMemory: freeMemory, - }, - // Free memory reporting on Windows is not reliable until we bump to ROCm v6.2 - UnreliableFreeMemory: true, - - ID: strconv.Itoa(i), // TODO this is probably wrong if we specify visible devices - filterID: i, - DependencyPath: []string{libDir}, - MinimumMemory: rocmMinimumMemory, - Name: name, - Compute: gfx, - DriverMajor: driverMajor, - DriverMinor: driverMinor, - }, - index: i, - } - - // iGPU detection, remove this check once we can support an iGPU variant of the rocm library - if strings.EqualFold(name, iGPUName) || totalMemory < IGPUMemLimit { - reason := "unsupported Radeon iGPU detected skipping" - slog.Info(reason, "id", gpuInfo.ID, "total", format.HumanBytes2(totalMemory)) - unsupportedGPUs = append(unsupportedGPUs, UnsupportedGPUInfo{ - GpuInfo: gpuInfo.GpuInfo, - Reason: reason, - }) - continue - } - - // Strip off Target Features when comparing - if !slices.Contains[[]string, string](supported, strings.Split(gfx, ":")[0]) { - reason := fmt.Sprintf("amdgpu is not supported (supported types:%s)", supported) - slog.Warn(reason, "gpu_type", gfx, "gpu", gpuInfo.ID, "library", libDir) - unsupportedGPUs = append(unsupportedGPUs, UnsupportedGPUInfo{ - GpuInfo: gpuInfo.GpuInfo, - Reason: reason, - }) - // HSA_OVERRIDE_GFX_VERSION not supported on windows - continue - } else { - slog.Debug("amdgpu is supported", "gpu", i, "gpu_type", gfx) - } - - slog.Debug("amdgpu memory", "gpu", i, "total", format.HumanBytes2(totalMemory)) - slog.Debug("amdgpu memory", "gpu", i, "available", format.HumanBytes2(freeMemory)) - - resp = append(resp, gpuInfo) - } - - return resp, nil -} - -func AMDValidateLibDir() (string, error) { - libDir, err := commonAMDValidateLibDir() - if err == nil { - return libDir, nil - } - - // Installer payload (if we're running from some other location) - rocmTargetDir := filepath.Join(LibOllamaPath, "rocm") - if rocmLibUsable(rocmTargetDir) { - slog.Debug("detected ollama installed ROCm at " + rocmTargetDir) - return rocmTargetDir, nil - } - - // Should not happen on windows since we include it in the installer, but stand-alone binary might hit this - slog.Warn("amdgpu detected, but no compatible rocm library found. Please install ROCm") - return "", errors.New("no suitable rocm found, falling back to CPU") -} - -func (gpus RocmGPUInfoList) RefreshFreeMemory() error { - if len(gpus) == 0 { - return nil - } - hl, err := NewHipLib() - if err != nil { - slog.Debug(err.Error()) - return err - } - defer hl.Release() - - for i := range gpus { - err := hl.HipSetDevice(gpus[i].index) - if err != nil { - return err - } - freeMemory, _, err := hl.HipMemGetInfo() - if err != nil { - slog.Warn("get mem info", "id", i, "error", err) - continue - } - slog.Debug("updating rocm free memory", "gpu", gpus[i].ID, "name", gpus[i].Name, "before", format.HumanBytes2(gpus[i].FreeMemory), "now", format.HumanBytes2(freeMemory)) - gpus[i].FreeMemory = freeMemory - } - return nil -} - -func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) string { - ids := []string{} - for _, info := range gpuInfo { - if info.Library != "rocm" { - continue - } - // If the devices requires a numeric ID, for filtering purposes, we use the unfiltered ID number - if _, err := strconv.Atoi(info.ID); err == nil { - ids = append(ids, fmt.Sprintf("%d", info.filterID)) - } else { - ids = append(ids, info.ID) - } - } - if len(ids) == 0 { - return "" - } - - // There are 3 potential env vars to use to select GPUs. - // ROCR_VISIBLE_DEVICES supports UUID or numeric but does not work on Windows - // HIP_VISIBLE_DEVICES supports numeric IDs only - // GPU_DEVICE_ORDINAL supports numeric IDs only - return "HIP_VISIBLE_DEVICES=" + strings.Join(ids, ",") -} diff --git a/discover/cpu_common.go b/discover/cpu_common.go deleted file mode 100644 index 2b9f72927..000000000 --- a/discover/cpu_common.go +++ /dev/null @@ -1,24 +0,0 @@ -package discover - -import ( - "os" - "path/filepath" - "runtime" - "strings" -) - -func IsNUMA() bool { - if runtime.GOOS != "linux" { - // numa support in llama.cpp is linux only - return false - } - ids := map[string]any{} - packageIds, _ := filepath.Glob("/sys/devices/system/cpu/cpu*/topology/physical_package_id") - for _, packageId := range packageIds { - id, err := os.ReadFile(packageId) - if err == nil { - ids[strings.TrimSpace(string(id))] = struct{}{} - } - } - return len(ids) > 1 -} diff --git a/discover/gpu_linux.go b/discover/cpu_linux.go similarity index 75% rename from discover/gpu_linux.go rename to discover/cpu_linux.go index 44c53b440..c3a0ef7fa 100644 --- a/discover/gpu_linux.go +++ b/discover/cpu_linux.go @@ -4,7 +4,9 @@ import ( "bufio" "fmt" "io" + "log/slog" "os" + "path/filepath" "reflect" "regexp" "sort" @@ -13,47 +15,6 @@ import ( "github.com/ollama/ollama/format" ) -var CudartGlobs = []string{ - "/usr/local/cuda/lib64/libcudart.so*", - "/usr/lib/x86_64-linux-gnu/nvidia/current/libcudart.so*", - "/usr/lib/x86_64-linux-gnu/libcudart.so*", - "/usr/lib/wsl/lib/libcudart.so*", - "/usr/lib/wsl/drivers/*/libcudart.so*", - "/opt/cuda/lib64/libcudart.so*", - "/usr/local/cuda*/targets/aarch64-linux/lib/libcudart.so*", - "/usr/lib/aarch64-linux-gnu/nvidia/current/libcudart.so*", - "/usr/lib/aarch64-linux-gnu/libcudart.so*", - "/usr/local/cuda/lib*/libcudart.so*", - "/usr/lib*/libcudart.so*", - "/usr/local/lib*/libcudart.so*", -} - -var NvmlGlobs = []string{} - -var NvcudaGlobs = []string{ - "/usr/local/cuda*/targets/*/lib/libcuda.so*", - "/usr/lib/*-linux-gnu/nvidia/current/libcuda.so*", - "/usr/lib/*-linux-gnu/libcuda.so*", - "/usr/lib/wsl/lib/libcuda.so*", - "/usr/lib/wsl/drivers/*/libcuda.so*", - "/opt/cuda/lib*/libcuda.so*", - "/usr/local/cuda/lib*/libcuda.so*", - "/usr/lib*/libcuda.so*", - "/usr/local/lib*/libcuda.so*", -} - -var OneapiGlobs = []string{ - "/usr/lib/x86_64-linux-gnu/libze_intel_gpu.so*", - "/usr/lib*/libze_intel_gpu.so*", -} - -var ( - CudartMgmtName = "libcudart.so*" - NvcudaMgmtName = "libcuda.so*" - NvmlMgmtName = "" // not currently wired on linux - OneapiMgmtName = "libze_intel_gpu.so*" -) - func GetCPUMem() (memInfo, error) { var mem memInfo var total, available, free, buffers, cached, freeSwap uint64 @@ -106,16 +67,17 @@ type linuxCpuInfo struct { CoreID string `cpuinfo:"core id"` } -func GetCPUDetails() ([]CPU, error) { +func GetCPUDetails() []CPU { file, err := os.Open(CpuInfoFilename) if err != nil { - return nil, err + slog.Warn("failed to get CPU details", "error", err) + return nil } defer file.Close() return linuxCPUDetails(file) } -func linuxCPUDetails(file io.Reader) ([]CPU, error) { +func linuxCPUDetails(file io.Reader) []CPU { reColumns := regexp.MustCompile("\t+: ") scanner := bufio.NewScanner(file) cpuInfos := []linuxCpuInfo{} @@ -194,5 +156,17 @@ func linuxCPUDetails(file io.Reader) ([]CPU, error) { for _, k := range keys { result = append(result, *socketByID[k]) } - return result, nil + return result +} + +func IsNUMA() bool { + ids := map[string]any{} + packageIds, _ := filepath.Glob("/sys/devices/system/cpu/cpu*/topology/physical_package_id") + for _, packageId := range packageIds { + id, err := os.ReadFile(packageId) + if err == nil { + ids[strings.TrimSpace(string(id))] = struct{}{} + } + } + return len(ids) > 1 } diff --git a/discover/gpu_linux_test.go b/discover/cpu_linux_test.go similarity index 99% rename from discover/gpu_linux_test.go rename to discover/cpu_linux_test.go index c4d64e389..3a5144780 100644 --- a/discover/gpu_linux_test.go +++ b/discover/cpu_linux_test.go @@ -2062,10 +2062,7 @@ power management: for k, v := range testCases { t.Run(k, func(t *testing.T) { buf := bytes.NewBufferString(v.input) - cpus, err := linuxCPUDetails(buf) - if err != nil { - t.Fatal(err) - } + cpus := linuxCPUDetails(buf) slog.Info("example", "scenario", k, "cpus", cpus) si := SystemInfo{ diff --git a/discover/gpu_windows.go b/discover/cpu_windows.go similarity index 91% rename from discover/gpu_windows.go rename to discover/cpu_windows.go index 2dc2f0746..ee308805e 100644 --- a/discover/gpu_windows.go +++ b/discover/cpu_windows.go @@ -26,29 +26,6 @@ var ( GetLogicalProcessorInformationEx = k32.NewProc("GetLogicalProcessorInformationEx") ) -var CudartGlobs = []string{ - "c:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v*\\bin\\cudart64_*.dll", -} - -var NvmlGlobs = []string{ - "c:\\Windows\\System32\\nvml.dll", -} - -var NvcudaGlobs = []string{ - "c:\\windows\\system*\\nvcuda.dll", -} - -var OneapiGlobs = []string{ - "c:\\Windows\\System32\\DriverStore\\FileRepository\\*\\ze_intel_gpu64.dll", -} - -var ( - CudartMgmtName = "cudart64_*.dll" - NvcudaMgmtName = "nvcuda.dll" - NvmlMgmtName = "nvml.dll" - OneapiMgmtName = "ze_intel_gpu64.dll" -) - func GetCPUMem() (memInfo, error) { memStatus := MEMORYSTATUSEX{length: sizeofMemoryStatusEx} r1, _, err := globalMemoryStatusExProc.Call(uintptr(unsafe.Pointer(&memStatus))) @@ -217,10 +194,11 @@ func processSystemLogicalProcessorInforationList(buf []byte) []*winPackage { return packages } -func GetCPUDetails() ([]CPU, error) { +func GetCPUDetails() []CPU { buf, err := getLogicalProcessorInformationEx() if err != nil { - return nil, err + slog.Warn("failed to get CPU details", "error", err) + return nil } packages := processSystemLogicalProcessorInforationList(buf) cpus := make([]CPU, len(packages)) @@ -230,5 +208,10 @@ func GetCPUDetails() ([]CPU, error) { cpus[i].EfficiencyCoreCount = pkg.efficiencyCoreCount cpus[i].ThreadCount = pkg.threadCount } - return cpus, nil + return cpus +} + +func IsNUMA() bool { + // numa support in ggml is linux only + return false } diff --git a/discover/gpu_windows_test.go b/discover/cpu_windows_test.go similarity index 100% rename from discover/gpu_windows_test.go rename to discover/cpu_windows_test.go diff --git a/discover/cuda_common.go b/discover/cuda_common.go deleted file mode 100644 index a2c43420e..000000000 --- a/discover/cuda_common.go +++ /dev/null @@ -1,64 +0,0 @@ -//go:build linux || windows - -package discover - -import ( - "fmt" - "log/slog" - "os" - "regexp" - "runtime" - "strconv" - "strings" -) - -// Jetson devices have JETSON_JETPACK="x.y.z" factory set to the Jetpack version installed. -// Included to drive logic for reducing Ollama-allocated overhead on L4T/Jetson devices. -var CudaTegra string = os.Getenv("JETSON_JETPACK") - -func cudaVariant(gpuInfos []CudaGPUInfo) string { - if runtime.GOARCH == "arm64" && runtime.GOOS == "linux" { - if CudaTegra != "" { - ver := strings.Split(CudaTegra, ".") - if len(ver) > 0 { - return "jetpack" + ver[0] - } - } else if data, err := os.ReadFile("/etc/nv_tegra_release"); err == nil { - r := regexp.MustCompile(` R(\d+) `) - m := r.FindSubmatch(data) - if len(m) != 2 { - slog.Info("Unexpected format for /etc/nv_tegra_release. Set JETSON_JETPACK to select version") - } else { - if l4t, err := strconv.Atoi(string(m[1])); err == nil { - // Note: mapping from L4t -> JP is inconsistent (can't just subtract 30) - // https://developer.nvidia.com/embedded/jetpack-archive - switch l4t { - case 35: - return "jetpack5" - case 36: - return "jetpack6" - default: - slog.Info("unsupported L4T version", "nv_tegra_release", string(data)) - } - } - } - } - } - - // Check GPU compute capability FIRST, lowest common denominator if multi-gpu - for _, gpuInfo := range gpuInfos { - if gpuInfo.computeMajor < 7 || (gpuInfo.computeMajor == 7 && gpuInfo.computeMinor < 5) { - // GPU is Pascal or older (CC <= 7.4) - use CUDA v12 (supports CC 6.1) - return "v12" - } - } - - // GPU is Turing or newer (CC >= 7.5) - can use newer CUDA - if len(gpuInfos) > 0 && gpuInfos[0].DriverMajor < 13 { - // The detected driver is older than 580 (Aug 2025) - // Warn if their CC is compatible with v13 and they should upgrade their driver to get better performance - slog.Warn("old CUDA driver detected - please upgrade to a newer driver for best performance", "version", fmt.Sprintf("%d.%d", gpuInfos[0].DriverMajor, gpuInfos[0].DriverMinor)) - return "v12" - } - return "v13" -} diff --git a/discover/gpu.go b/discover/gpu.go index a39bc7c3d..a61cfe513 100644 --- a/discover/gpu.go +++ b/discover/gpu.go @@ -1,730 +1,150 @@ -//go:build linux || windows - package discover -/* -#cgo linux LDFLAGS: -lrt -lpthread -ldl -lstdc++ -lm -#cgo windows LDFLAGS: -lpthread - -#include "gpu_info.h" -*/ -import "C" - import ( + "context" "fmt" "log/slog" "os" "path/filepath" "runtime" - "strconv" "strings" - "sync" - "unsafe" - "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/format" + "github.com/ollama/ollama/ml" ) -type cudaHandles struct { - deviceCount int - cudart *C.cudart_handle_t - nvcuda *C.nvcuda_handle_t - nvml *C.nvml_handle_t +// Jetson devices have JETSON_JETPACK="x.y.z" factory set to the Jetpack version installed. +// Included to drive logic for reducing Ollama-allocated overhead on L4T/Jetson devices. +var CudaTegra string = os.Getenv("JETSON_JETPACK") + +func GetCPUInfo() GpuInfo { + mem, err := GetCPUMem() + if err != nil { + slog.Warn("error looking up system memory", "error", err) + } + + return GpuInfo{ + memInfo: mem, + Library: "cpu", + ID: "0", + } } -type oneapiHandles struct { - oneapi *C.oneapi_handle_t - deviceCount int +func GetGPUInfo(ctx context.Context, runners []FilteredRunnerDiscovery) GpuInfoList { + devs := GPUDevices(ctx, runners) + return devInfoToInfoList(devs) } -const ( - cudaMinimumMemory = 457 * format.MebiByte - rocmMinimumMemory = 457 * format.MebiByte - // TODO OneAPI minimum memory -) - -var ( - gpuMutex sync.Mutex - bootstrapped bool - cpus []CPUInfo - cudaGPUs []CudaGPUInfo - nvcudaLibPath string - cudartLibPath string - oneapiLibPath string - nvmlLibPath string - rocmGPUs []RocmGPUInfo - oneapiGPUs []OneapiGPUInfo - - // If any discovered GPUs are incompatible, report why - unsupportedGPUs []UnsupportedGPUInfo - - // Keep track of errors during bootstrapping so that if GPUs are missing - // they expected to be present this may explain why - bootstrapErrors []error -) - -// With our current CUDA compile flags, older than 5.0 will not work properly -// (string values used to allow ldflags overrides at build time) -var ( - CudaComputeMajorMin = "5" - CudaComputeMinorMin = "0" -) - -var RocmComputeMajorMin = "9" - -// TODO find a better way to detect iGPU instead of minimum memory -const IGPUMemLimit = 1 * format.GibiByte // 512G is what they typically report, so anything less than 1G must be iGPU - -// Note: gpuMutex must already be held -func initCudaHandles() *cudaHandles { - // TODO - if the ollama build is CPU only, don't do these checks as they're irrelevant and confusing - - cHandles := &cudaHandles{} - // Short Circuit if we already know which library to use - // ignore bootstrap errors in this case since we already recorded them - if nvmlLibPath != "" { - cHandles.nvml, _, _ = loadNVMLMgmt([]string{nvmlLibPath}) - return cHandles - } - if nvcudaLibPath != "" { - cHandles.deviceCount, cHandles.nvcuda, _, _ = loadNVCUDAMgmt([]string{nvcudaLibPath}) - return cHandles - } - if cudartLibPath != "" { - cHandles.deviceCount, cHandles.cudart, _, _ = loadCUDARTMgmt([]string{cudartLibPath}) - return cHandles - } - - slog.Debug("searching for GPU discovery libraries for NVIDIA") - var cudartMgmtPatterns []string - - // Aligned with driver, we can't carry as payloads - nvcudaMgmtPatterns := NvcudaGlobs - cudartMgmtPatterns = append(cudartMgmtPatterns, filepath.Join(LibOllamaPath, "cuda_v*", CudartMgmtName)) - cudartMgmtPatterns = append(cudartMgmtPatterns, CudartGlobs...) - - if len(NvmlGlobs) > 0 { - nvmlLibPaths := FindGPULibs(NvmlMgmtName, NvmlGlobs) - if len(nvmlLibPaths) > 0 { - nvml, libPath, err := loadNVMLMgmt(nvmlLibPaths) - if nvml != nil { - slog.Debug("nvidia-ml loaded", "library", libPath) - cHandles.nvml = nvml - nvmlLibPath = libPath - } - if err != nil { - bootstrapErrors = append(bootstrapErrors, err) - } - } - } - - nvcudaLibPaths := FindGPULibs(NvcudaMgmtName, nvcudaMgmtPatterns) - if len(nvcudaLibPaths) > 0 { - deviceCount, nvcuda, libPath, err := loadNVCUDAMgmt(nvcudaLibPaths) - if nvcuda != nil { - slog.Debug("detected GPUs", "count", deviceCount, "library", libPath) - cHandles.nvcuda = nvcuda - cHandles.deviceCount = deviceCount - nvcudaLibPath = libPath - return cHandles - } - if err != nil { - bootstrapErrors = append(bootstrapErrors, err) - } - } - - cudartLibPaths := FindGPULibs(CudartMgmtName, cudartMgmtPatterns) - if len(cudartLibPaths) > 0 { - deviceCount, cudart, libPath, err := loadCUDARTMgmt(cudartLibPaths) - if cudart != nil { - slog.Debug("detected GPUs", "library", libPath, "count", deviceCount) - cHandles.cudart = cudart - cHandles.deviceCount = deviceCount - cudartLibPath = libPath - return cHandles - } - if err != nil { - bootstrapErrors = append(bootstrapErrors, err) - } - } - - return cHandles -} - -// Note: gpuMutex must already be held -func initOneAPIHandles() *oneapiHandles { - oHandles := &oneapiHandles{} - - // Short Circuit if we already know which library to use - // ignore bootstrap errors in this case since we already recorded them - if oneapiLibPath != "" { - oHandles.deviceCount, oHandles.oneapi, _, _ = loadOneapiMgmt([]string{oneapiLibPath}) - return oHandles - } - - oneapiLibPaths := FindGPULibs(OneapiMgmtName, OneapiGlobs) - if len(oneapiLibPaths) > 0 { - var err error - oHandles.deviceCount, oHandles.oneapi, oneapiLibPath, err = loadOneapiMgmt(oneapiLibPaths) - if err != nil { - bootstrapErrors = append(bootstrapErrors, err) - } - } - - return oHandles -} - -func GetCPUInfo() GpuInfoList { - gpuMutex.Lock() - if !bootstrapped { - gpuMutex.Unlock() - GetGPUInfo() - } else { - gpuMutex.Unlock() - } - return GpuInfoList{cpus[0].GpuInfo} -} - -func GetGPUInfo() GpuInfoList { - // TODO - consider exploring lspci (and equivalent on windows) to check for - // GPUs so we can report warnings if we see Nvidia/AMD but fail to load the libraries - gpuMutex.Lock() - defer gpuMutex.Unlock() - needRefresh := true - var cHandles *cudaHandles - var oHandles *oneapiHandles - defer func() { - if cHandles != nil { - if cHandles.cudart != nil { - C.cudart_release(*cHandles.cudart) - } - if cHandles.nvcuda != nil { - C.nvcuda_release(*cHandles.nvcuda) - } - if cHandles.nvml != nil { - C.nvml_release(*cHandles.nvml) - } - } - if oHandles != nil { - if oHandles.oneapi != nil { - // TODO - is this needed? - C.oneapi_release(*oHandles.oneapi) - } - } - }() - - if !bootstrapped { - slog.Info("looking for compatible GPUs") - cudaComputeMajorMin, err := strconv.Atoi(CudaComputeMajorMin) - if err != nil { - slog.Error("invalid CudaComputeMajorMin setting", "value", CudaComputeMajorMin, "error", err) - } - cudaComputeMinorMin, err := strconv.Atoi(CudaComputeMinorMin) - if err != nil { - slog.Error("invalid CudaComputeMinorMin setting", "value", CudaComputeMinorMin, "error", err) - } - bootstrapErrors = []error{} - needRefresh = false - var memInfo C.mem_info_t - - mem, err := GetCPUMem() - if err != nil { - slog.Warn("error looking up system memory", "error", err) - } - - details, err := GetCPUDetails() - if err != nil { - slog.Warn("failed to lookup CPU details", "error", err) - } - cpus = []CPUInfo{ - { - GpuInfo: GpuInfo{ - memInfo: mem, - Library: "cpu", - ID: "0", - }, - CPUs: details, - }, - } - - // Load ALL libraries - cHandles = initCudaHandles() - - // NVIDIA - for i := range cHandles.deviceCount { - if cHandles.cudart != nil || cHandles.nvcuda != nil { - gpuInfo := CudaGPUInfo{ - GpuInfo: GpuInfo{ - Library: "cuda", - }, - index: i, - } - var driverMajor int - var driverMinor int - if cHandles.cudart != nil { - C.cudart_bootstrap(*cHandles.cudart, C.int(i), &memInfo) - driverMajor = int(cHandles.cudart.driver_major) - driverMinor = int(cHandles.cudart.driver_minor) - } else { - C.nvcuda_bootstrap(*cHandles.nvcuda, C.int(i), &memInfo) - driverMajor = int(cHandles.nvcuda.driver_major) - driverMinor = int(cHandles.nvcuda.driver_minor) - } - if memInfo.err != nil { - slog.Info("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err)) - C.free(unsafe.Pointer(memInfo.err)) - continue - } - gpuInfo.TotalMemory = uint64(memInfo.total) - gpuInfo.FreeMemory = uint64(memInfo.free) - gpuInfo.ID = C.GoString(&memInfo.gpu_id[0]) - gpuInfo.Compute = fmt.Sprintf("%d.%d", memInfo.major, memInfo.minor) - gpuInfo.computeMajor = int(memInfo.major) - gpuInfo.computeMinor = int(memInfo.minor) - gpuInfo.MinimumMemory = cudaMinimumMemory - gpuInfo.DriverMajor = driverMajor - gpuInfo.DriverMinor = driverMinor - - gpuInfo.Name = C.GoString(&memInfo.gpu_name[0]) - - if int(memInfo.major) < cudaComputeMajorMin || (int(memInfo.major) == cudaComputeMajorMin && int(memInfo.minor) < cudaComputeMinorMin) { - unsupportedGPUs = append(unsupportedGPUs, - UnsupportedGPUInfo{ - GpuInfo: gpuInfo.GpuInfo, - }) - slog.Info(fmt.Sprintf("[%d] CUDA GPU is too old. Compute Capability detected: %d.%d", i, memInfo.major, memInfo.minor)) - continue - } - - // query the management library as well so we can record any skew between the two - // which represents overhead on the GPU we must set aside on subsequent updates - if cHandles.nvml != nil { - uuid := C.CString(gpuInfo.ID) - defer C.free(unsafe.Pointer(uuid)) - C.nvml_get_free(*cHandles.nvml, uuid, &memInfo.free, &memInfo.total, &memInfo.used) - if memInfo.err != nil { - slog.Warn("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err)) - C.free(unsafe.Pointer(memInfo.err)) - } else { - if memInfo.free != 0 && uint64(memInfo.free) > gpuInfo.FreeMemory { - gpuInfo.OSOverhead = uint64(memInfo.free) - gpuInfo.FreeMemory - slog.Info("detected OS VRAM overhead", - "id", gpuInfo.ID, - "library", gpuInfo.Library, - "compute", gpuInfo.Compute, - "driver", fmt.Sprintf("%d.%d", gpuInfo.DriverMajor, gpuInfo.DriverMinor), - "name", gpuInfo.Name, - "overhead", format.HumanBytes2(gpuInfo.OSOverhead), - ) - } - } - } - - // TODO potentially sort on our own algorithm instead of what the underlying GPU library does... - cudaGPUs = append(cudaGPUs, gpuInfo) - } - // Second pass on NVIDIA GPUs to set lowest common denominator variant and DependencyPaths - variant := cudaVariant(cudaGPUs) - var variantPath string - // Start with our bundled libraries - if variant != "" { - variantPath = filepath.Join(LibOllamaPath, "cuda_"+variant) - if _, err := os.Stat(variantPath); err != nil { - variantPath = "" - } - } - - for i := range cudaGPUs { - cudaGPUs[i].Variant = variant - if variantPath != "" { - // Put the variant directory first in the search path to avoid runtime linking to the wrong library - cudaGPUs[i].DependencyPath = append([]string{variantPath}, cudaGPUs[i].DependencyPath...) - } - } - } - - // Intel - if envconfig.IntelGPU() { - oHandles = initOneAPIHandles() - if oHandles != nil && oHandles.oneapi != nil { - for d := range oHandles.oneapi.num_drivers { - if oHandles.oneapi == nil { - // shouldn't happen - slog.Warn("nil oneapi handle with driver count", "count", int(oHandles.oneapi.num_drivers)) - continue - } - devCount := C.oneapi_get_device_count(*oHandles.oneapi, C.int(d)) - for i := range devCount { - gpuInfo := OneapiGPUInfo{ - GpuInfo: GpuInfo{ - Library: "oneapi", - }, - driverIndex: int(d), - gpuIndex: int(i), - } - // TODO - split bootstrapping from updating free memory - C.oneapi_check_vram(*oHandles.oneapi, C.int(d), i, &memInfo) - // TODO - convert this to MinimumMemory based on testing... - var totalFreeMem float64 = float64(memInfo.free) * 0.95 // work-around: leave some reserve vram for mkl lib used in ggml-sycl backend. - memInfo.free = C.uint64_t(totalFreeMem) - gpuInfo.TotalMemory = uint64(memInfo.total) - gpuInfo.FreeMemory = uint64(memInfo.free) - gpuInfo.ID = C.GoString(&memInfo.gpu_id[0]) - gpuInfo.Name = C.GoString(&memInfo.gpu_name[0]) - gpuInfo.DependencyPath = []string{LibOllamaPath} - oneapiGPUs = append(oneapiGPUs, gpuInfo) - } - } - } - } - - rocmGPUs, err = AMDGetGPUInfo() - - // The ID field is used in context of the filtered set of GPUS - // so we have to replace any of these numeric IDs with their - // placement in this set of GPUs - for i := range rocmGPUs { - if _, err := strconv.Atoi(rocmGPUs[i].ID); err == nil { - rocmGPUs[i].ID = strconv.Itoa(i) - } - } - if err != nil { - bootstrapErrors = append(bootstrapErrors, err) - } - bootstrapped = true - if len(cudaGPUs) == 0 && len(rocmGPUs) == 0 && len(oneapiGPUs) == 0 { - slog.Info("no compatible GPUs were discovered") - } - - // TODO verify we have runners for the discovered GPUs, filter out any that aren't supported with good error messages - } - - // For detected GPUs, load library if not loaded - - // Refresh free memory usage - if needRefresh { - mem, err := GetCPUMem() - if err != nil { - slog.Warn("error looking up system memory", "error", err) - } else { - slog.Debug("updating system memory data", - slog.Group( - "before", - "total", format.HumanBytes2(cpus[0].TotalMemory), - "free", format.HumanBytes2(cpus[0].FreeMemory), - "free_swap", format.HumanBytes2(cpus[0].FreeSwap), - ), - slog.Group( - "now", - "total", format.HumanBytes2(mem.TotalMemory), - "free", format.HumanBytes2(mem.FreeMemory), - "free_swap", format.HumanBytes2(mem.FreeSwap), - ), - ) - cpus[0].FreeMemory = mem.FreeMemory - cpus[0].FreeSwap = mem.FreeSwap - } - - var memInfo C.mem_info_t - if cHandles == nil && len(cudaGPUs) > 0 { - cHandles = initCudaHandles() - } - for i, gpu := range cudaGPUs { - if cHandles.nvml != nil { - uuid := C.CString(gpu.ID) - defer C.free(unsafe.Pointer(uuid)) - C.nvml_get_free(*cHandles.nvml, uuid, &memInfo.free, &memInfo.total, &memInfo.used) - } else if cHandles.cudart != nil { - C.cudart_bootstrap(*cHandles.cudart, C.int(gpu.index), &memInfo) - } else if cHandles.nvcuda != nil { - C.nvcuda_get_free(*cHandles.nvcuda, C.int(gpu.index), &memInfo.free, &memInfo.total) - memInfo.used = memInfo.total - memInfo.free - } else { - // shouldn't happen - slog.Warn("no valid cuda library loaded to refresh vram usage") - break - } - if memInfo.err != nil { - slog.Warn("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err)) - C.free(unsafe.Pointer(memInfo.err)) - continue - } - if memInfo.free == 0 { - slog.Warn("error looking up nvidia GPU memory") - continue - } - if cHandles.nvml != nil && gpu.OSOverhead > 0 { - // When using the management library update based on recorded overhead - memInfo.free -= C.uint64_t(gpu.OSOverhead) - } - slog.Debug("updating cuda memory data", - "gpu", gpu.ID, - "name", gpu.Name, - "overhead", format.HumanBytes2(gpu.OSOverhead), - slog.Group( - "before", - "total", format.HumanBytes2(gpu.TotalMemory), - "free", format.HumanBytes2(gpu.FreeMemory), - ), - slog.Group( - "now", - "total", format.HumanBytes2(uint64(memInfo.total)), - "free", format.HumanBytes2(uint64(memInfo.free)), - "used", format.HumanBytes2(uint64(memInfo.used)), - ), - ) - cudaGPUs[i].FreeMemory = uint64(memInfo.free) - } - - if oHandles == nil && len(oneapiGPUs) > 0 { - oHandles = initOneAPIHandles() - } - for i, gpu := range oneapiGPUs { - if oHandles.oneapi == nil { - // shouldn't happen - slog.Warn("nil oneapi handle with device count", "count", oHandles.deviceCount) - continue - } - C.oneapi_check_vram(*oHandles.oneapi, C.int(gpu.driverIndex), C.int(gpu.gpuIndex), &memInfo) - // TODO - convert this to MinimumMemory based on testing... - var totalFreeMem float64 = float64(memInfo.free) * 0.95 // work-around: leave some reserve vram for mkl lib used in ggml-sycl backend. - memInfo.free = C.uint64_t(totalFreeMem) - oneapiGPUs[i].FreeMemory = uint64(memInfo.free) - } - - err = RocmGPUInfoList(rocmGPUs).RefreshFreeMemory() - if err != nil { - slog.Debug("problem refreshing ROCm free memory", "error", err) - } - } - +func devInfoToInfoList(devs []ml.DeviceInfo) GpuInfoList { resp := []GpuInfo{} - for _, gpu := range cudaGPUs { - resp = append(resp, gpu.GpuInfo) + // Our current packaging model places ggml-hip in the main directory + // but keeps rocm in an isolated directory. We have to add it to + // the [LD_LIBRARY_]PATH so ggml-hip will load properly + rocmDir := filepath.Join(LibOllamaPath, "rocm") + if _, err := os.Stat(rocmDir); err != nil { + rocmDir = "" } - for _, gpu := range rocmGPUs { - resp = append(resp, gpu.GpuInfo) - } - for _, gpu := range oneapiGPUs { - resp = append(resp, gpu.GpuInfo) + + for _, dev := range devs { + info := GpuInfo{ + ID: dev.ID, + filterID: dev.FilteredID, + Name: dev.Description, + memInfo: memInfo{ + TotalMemory: dev.TotalMemory, + FreeMemory: dev.FreeMemory, + }, + Library: dev.Library, + // TODO can we avoid variant + DependencyPath: dev.LibraryPath, + DriverMajor: dev.DriverMajor, + DriverMinor: dev.DriverMinor, + } + if dev.Library == "CUDA" || dev.Library == "HIP" { + info.MinimumMemory = 457 * format.MebiByte + } + if dev.Library == "HIP" { + info.Compute = fmt.Sprintf("gfx%x%02x", dev.ComputeMajor, dev.ComputeMinor) + if rocmDir != "" { + info.DependencyPath = append(info.DependencyPath, rocmDir) + } + } else { + info.Compute = fmt.Sprintf("%d.%d", dev.ComputeMajor, dev.ComputeMinor) + } + resp = append(resp, info) } if len(resp) == 0 { - resp = append(resp, cpus[0].GpuInfo) + mem, err := GetCPUMem() + if err != nil { + slog.Warn("error looking up system memory", "error", err) + } + + resp = append(resp, GpuInfo{ + memInfo: mem, + Library: "cpu", + ID: "0", + }) } return resp } -func FindGPULibs(baseLibName string, defaultPatterns []string) []string { - // Multiple GPU libraries may exist, and some may not work, so keep trying until we exhaust them - gpuLibPaths := []string{} - slog.Debug("Searching for GPU library", "name", baseLibName) - - // search our bundled libraries first - patterns := []string{filepath.Join(LibOllamaPath, baseLibName)} - - var ldPaths []string - switch runtime.GOOS { - case "windows": - ldPaths = strings.Split(os.Getenv("PATH"), string(os.PathListSeparator)) - case "linux": - ldPaths = strings.Split(os.Getenv("LD_LIBRARY_PATH"), string(os.PathListSeparator)) - } - - // then search the system's LD_LIBRARY_PATH - for _, p := range ldPaths { - p, err := filepath.Abs(p) - if err != nil { - continue - } - patterns = append(patterns, filepath.Join(p, baseLibName)) - } - - // finally, search the default patterns provided by the caller - patterns = append(patterns, defaultPatterns...) - slog.Debug("gpu library search", "globs", patterns) - for _, pattern := range patterns { - // Nvidia PhysX known to return bogus results - if strings.Contains(pattern, "PhysX") { - slog.Debug("skipping PhysX cuda library path", "path", pattern) - continue - } - // Ignore glob discovery errors - matches, _ := filepath.Glob(pattern) - for _, match := range matches { - // Resolve any links so we don't try the same lib multiple times - // and weed out any dups across globs - libPath := match - tmp := match - var err error - for ; err == nil; tmp, err = os.Readlink(libPath) { - if !filepath.IsAbs(tmp) { - tmp = filepath.Join(filepath.Dir(libPath), tmp) - } - libPath = tmp - } - new := true - for _, cmp := range gpuLibPaths { - if cmp == libPath { - new = false - break - } - } - if new { - gpuLibPaths = append(gpuLibPaths, libPath) - } - } - } - slog.Debug("discovered GPU libraries", "paths", gpuLibPaths) - return gpuLibPaths -} - -// Bootstrap the runtime library -// Returns: num devices, handle, libPath, error -func loadCUDARTMgmt(cudartLibPaths []string) (int, *C.cudart_handle_t, string, error) { - var resp C.cudart_init_resp_t - resp.ch.verbose = getVerboseState() - var err error - for _, libPath := range cudartLibPaths { - lib := C.CString(libPath) - defer C.free(unsafe.Pointer(lib)) - C.cudart_init(lib, &resp) - if resp.err != nil { - err = fmt.Errorf("Unable to load cudart library %s: %s", libPath, C.GoString(resp.err)) - slog.Debug(err.Error()) - C.free(unsafe.Pointer(resp.err)) - } else { - err = nil - return int(resp.num_devices), &resp.ch, libPath, err - } - } - return 0, nil, "", err -} - -// Bootstrap the driver library -// Returns: num devices, handle, libPath, error -func loadNVCUDAMgmt(nvcudaLibPaths []string) (int, *C.nvcuda_handle_t, string, error) { - var resp C.nvcuda_init_resp_t - resp.ch.verbose = getVerboseState() - var err error - for _, libPath := range nvcudaLibPaths { - lib := C.CString(libPath) - defer C.free(unsafe.Pointer(lib)) - C.nvcuda_init(lib, &resp) - if resp.err != nil { - // Decide what log level based on the type of error message to help users understand why - switch resp.cudaErr { - case C.CUDA_ERROR_INSUFFICIENT_DRIVER, C.CUDA_ERROR_SYSTEM_DRIVER_MISMATCH: - err = fmt.Errorf("version mismatch between driver and cuda driver library - reboot or upgrade may be required: library %s", libPath) - slog.Warn(err.Error()) - case C.CUDA_ERROR_NO_DEVICE: - err = fmt.Errorf("no nvidia devices detected by library %s", libPath) - slog.Info(err.Error()) - case C.CUDA_ERROR_UNKNOWN: - err = fmt.Errorf("unknown error initializing cuda driver library %s: %s. see https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for more information", libPath, C.GoString(resp.err)) - slog.Warn(err.Error()) - default: - msg := C.GoString(resp.err) - if strings.Contains(msg, "wrong ELF class") { - slog.Debug("skipping 32bit library", "library", libPath) - } else { - err = fmt.Errorf("Unable to load cudart library %s: %s", libPath, C.GoString(resp.err)) - slog.Info(err.Error()) - } - } - C.free(unsafe.Pointer(resp.err)) - } else { - err = nil - return int(resp.num_devices), &resp.ch, libPath, err - } - } - return 0, nil, "", err -} - -// Bootstrap the management library -// Returns: handle, libPath, error -func loadNVMLMgmt(nvmlLibPaths []string) (*C.nvml_handle_t, string, error) { - var resp C.nvml_init_resp_t - resp.ch.verbose = getVerboseState() - var err error - for _, libPath := range nvmlLibPaths { - lib := C.CString(libPath) - defer C.free(unsafe.Pointer(lib)) - C.nvml_init(lib, &resp) - if resp.err != nil { - err = fmt.Errorf("Unable to load NVML management library %s: %s", libPath, C.GoString(resp.err)) - slog.Info(err.Error()) - C.free(unsafe.Pointer(resp.err)) - } else { - err = nil - return &resp.ch, libPath, err - } - } - return nil, "", err -} - -// bootstrap the Intel GPU library -// Returns: num devices, handle, libPath, error -func loadOneapiMgmt(oneapiLibPaths []string) (int, *C.oneapi_handle_t, string, error) { - var resp C.oneapi_init_resp_t - num_devices := 0 - resp.oh.verbose = getVerboseState() - var err error - for _, libPath := range oneapiLibPaths { - lib := C.CString(libPath) - defer C.free(unsafe.Pointer(lib)) - C.oneapi_init(lib, &resp) - if resp.err != nil { - err = fmt.Errorf("Unable to load oneAPI management library %s: %s", libPath, C.GoString(resp.err)) - slog.Debug(err.Error()) - C.free(unsafe.Pointer(resp.err)) - } else { - err = nil - for i := range resp.oh.num_drivers { - num_devices += int(C.oneapi_get_device_count(resp.oh, C.int(i))) - } - return num_devices, &resp.oh, libPath, err - } - } - return 0, nil, "", err -} - -func getVerboseState() C.uint16_t { - if envconfig.LogLevel() < slog.LevelInfo { - return C.uint16_t(1) - } - return C.uint16_t(0) -} - // Given the list of GPUs this instantiation is targeted for, // figure out the visible devices environment variable +// +// # If different libraries are detected, the first one is what we use +// +// TODO once we're purely running on the new runner, this level of device +// filtering will no longer be necessary. Instead the runner can be told which +// of the set of GPUs to utilize and handle filtering itself, instead of relying +// on the env var to hide devices from the underlying GPU libraries func (l GpuInfoList) GetVisibleDevicesEnv() []string { if len(l) == 0 { return nil } - vd := []string{} - // Only filter the AMD GPUs at this level, let all NVIDIA devices through - if tmp := rocmGetVisibleDevicesEnv(l); tmp != "" { - vd = append(vd, tmp) - } - return vd + return []string{rocmGetVisibleDevicesEnv(l)} } -func GetSystemInfo() SystemInfo { - gpus := GetGPUInfo() - gpuMutex.Lock() - defer gpuMutex.Unlock() - discoveryErrors := []string{} - for _, err := range bootstrapErrors { - discoveryErrors = append(discoveryErrors, err.Error()) +func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) string { + ids := []string{} + for _, info := range gpuInfo { + if info.Library != "HIP" { + continue + } + // If the devices requires a numeric ID, for filtering purposes, we use the unfiltered ID number + if info.filterID != "" { + ids = append(ids, info.filterID) + } else { + ids = append(ids, info.ID) + } } + if len(ids) == 0 { + return "" + } + envVar := "ROCR_VISIBLE_DEVICES=" + if runtime.GOOS != "linux" { + envVar = "HIP_VISIBLE_DEVICES=" + } + // There are 3 potential env vars to use to select GPUs. + // ROCR_VISIBLE_DEVICES supports UUID or numeric but does not work on Windows + // HIP_VISIBLE_DEVICES supports numeric IDs only + // GPU_DEVICE_ORDINAL supports numeric IDs only + return envVar + strings.Join(ids, ",") +} + +// GetSystemInfo returns the last cached state of the GPUs on the system +func GetSystemInfo() SystemInfo { + deviceMu.Lock() + defer deviceMu.Unlock() + gpus := devInfoToInfoList(devices) if len(gpus) == 1 && gpus[0].Library == "cpu" { gpus = []GpuInfo{} } return SystemInfo{ - System: cpus[0], - GPUs: gpus, - UnsupportedGPUs: unsupportedGPUs, - DiscoveryErrors: discoveryErrors, + System: CPUInfo{ + CPUs: GetCPUDetails(), + GpuInfo: GetCPUInfo(), + }, + GPUs: gpus, } } diff --git a/discover/gpu_darwin.go b/discover/gpu_darwin.go index 29b44ff50..6f55b4c57 100644 --- a/discover/gpu_darwin.go +++ b/discover/gpu_darwin.go @@ -1,5 +1,3 @@ -//go:build darwin - package discover /* @@ -11,7 +9,6 @@ import "C" import ( "log/slog" - "runtime" "syscall" "github.com/ollama/ollama/format" @@ -21,39 +18,6 @@ const ( metalMinimumMemory = 512 * format.MebiByte ) -func GetGPUInfo() GpuInfoList { - mem, _ := GetCPUMem() - if runtime.GOARCH == "amd64" { - return []GpuInfo{ - { - Library: "cpu", - memInfo: mem, - }, - } - } - info := GpuInfo{ - Library: "metal", - ID: "0", - } - info.TotalMemory = uint64(C.getRecommendedMaxVRAM()) - - // TODO is there a way to gather actual allocated video memory? (currentAllocatedSize doesn't work) - info.FreeMemory = info.TotalMemory - - info.MinimumMemory = metalMinimumMemory - return []GpuInfo{info} -} - -func GetCPUInfo() GpuInfoList { - mem, _ := GetCPUMem() - return []GpuInfo{ - { - Library: "cpu", - memInfo: mem, - }, - } -} - func GetCPUMem() (memInfo, error) { return memInfo{ TotalMemory: uint64(C.getPhysicalMemory()), @@ -62,13 +26,7 @@ func GetCPUMem() (memInfo, error) { }, nil } -func (l GpuInfoList) GetVisibleDevicesEnv() []string { - // No-op on darwin - return nil -} - -func GetSystemInfo() SystemInfo { - mem, _ := GetCPUMem() +func GetCPUDetails() []CPU { query := "hw.perflevel0.physicalcpu" perfCores, err := syscall.SysctlUint32(query) if err != nil { @@ -81,19 +39,16 @@ func GetSystemInfo() SystemInfo { query = "hw.logicalcpu" logicalCores, _ := syscall.SysctlUint32(query) - return SystemInfo{ - System: CPUInfo{ - GpuInfo: GpuInfo{ - memInfo: mem, - }, - CPUs: []CPU{ - { - CoreCount: int(perfCores + efficiencyCores), - EfficiencyCoreCount: int(efficiencyCores), - ThreadCount: int(logicalCores), - }, - }, + return []CPU{ + { + CoreCount: int(perfCores + efficiencyCores), + EfficiencyCoreCount: int(efficiencyCores), + ThreadCount: int(logicalCores), }, - GPUs: GetGPUInfo(), } } + +func IsNUMA() bool { + // numa support in ggml is linux only + return false +} diff --git a/discover/gpu_info.h b/discover/gpu_info.h deleted file mode 100644 index ee7ff4c33..000000000 --- a/discover/gpu_info.h +++ /dev/null @@ -1,72 +0,0 @@ -#ifndef __APPLE__ -#ifndef __GPU_INFO_H__ -#define __GPU_INFO_H__ -#include -#include -#include - -#ifndef _WIN32 -#include -#define LOAD_LIBRARY(lib, flags) dlopen(lib, flags) -#define LOAD_SYMBOL(handle, sym) dlsym(handle, sym) -#define LOAD_ERR() strdup(dlerror()) -#define UNLOAD_LIBRARY(handle) dlclose(handle) -#else -#include -#define LOAD_LIBRARY(lib, flags) LoadLibrary(lib) -#define LOAD_SYMBOL(handle, sym) GetProcAddress(handle, sym) -#define UNLOAD_LIBRARY(handle) FreeLibrary(handle) -#define LOAD_ERR() ({\ - LPSTR messageBuffer = NULL; \ - size_t size = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, \ - NULL, GetLastError(), MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&messageBuffer, 0, NULL); \ - char *resp = strdup(messageBuffer); \ - LocalFree(messageBuffer); \ - resp; \ -}) - -#endif - -#ifndef LOG -#define LOG(verbose, ...) \ - do { \ - if (verbose) { \ - fprintf(stderr, __VA_ARGS__); \ - } \ - } while (0) -#endif - -#ifdef __cplusplus -extern "C" { -#endif - -#define GPU_ID_LEN 64 -#define GPU_NAME_LEN 96 - -typedef struct mem_info { - char *err; // If non-nill, caller responsible for freeing - char gpu_id[GPU_ID_LEN]; - char gpu_name[GPU_NAME_LEN]; - uint64_t total; - uint64_t free; - uint64_t used; - - // Compute Capability - int major; - int minor; - int patch; -} mem_info_t; - -void cpu_check_ram(mem_info_t *resp); - -#ifdef __cplusplus -} -#endif - -#include "gpu_info_cudart.h" -#include "gpu_info_nvcuda.h" -#include "gpu_info_nvml.h" -#include "gpu_info_oneapi.h" - -#endif // __GPU_INFO_H__ -#endif // __APPLE__ diff --git a/discover/gpu_info_cudart.c b/discover/gpu_info_cudart.c deleted file mode 100644 index 76c17b9d8..000000000 --- a/discover/gpu_info_cudart.c +++ /dev/null @@ -1,181 +0,0 @@ -#ifndef __APPLE__ // TODO - maybe consider nvidia support on intel macs? - -#include -#include -#include "gpu_info_cudart.h" - -void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) { - cudartReturn_t ret; - resp->err = NULL; - resp->num_devices = 0; - const int buflen = 256; - char buf[buflen + 1]; - int i; - - struct lookup { - char *s; - void **p; - } l[] = { - {"cudaSetDevice", (void *)&resp->ch.cudaSetDevice}, - {"cudaDeviceSynchronize", (void *)&resp->ch.cudaDeviceSynchronize}, - {"cudaDeviceReset", (void *)&resp->ch.cudaDeviceReset}, - {"cudaMemGetInfo", (void *)&resp->ch.cudaMemGetInfo}, - {"cudaGetDeviceCount", (void *)&resp->ch.cudaGetDeviceCount}, - {"cudaDeviceGetAttribute", (void *)&resp->ch.cudaDeviceGetAttribute}, - {"cudaDriverGetVersion", (void *)&resp->ch.cudaDriverGetVersion}, - {"cudaGetDeviceProperties", (void *)&resp->ch.cudaGetDeviceProperties}, - {NULL, NULL}, - }; - - resp->ch.handle = LOAD_LIBRARY(cudart_lib_path, RTLD_LAZY); - if (!resp->ch.handle) { - char *msg = LOAD_ERR(); - LOG(resp->ch.verbose, "library %s load err: %s\n", cudart_lib_path, msg); - snprintf(buf, buflen, - "Unable to load %s library to query for Nvidia GPUs: %s", - cudart_lib_path, msg); - free(msg); - resp->err = strdup(buf); - return; - } - - for (i = 0; l[i].s != NULL; i++) { - *l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s); - if (!*(l[i].p)) { - char *msg = LOAD_ERR(); - LOG(resp->ch.verbose, "dlerr: %s\n", msg); - UNLOAD_LIBRARY(resp->ch.handle); - resp->ch.handle = NULL; - snprintf(buf, buflen, "symbol lookup for %s failed: %s", l[i].s, - msg); - free(msg); - resp->err = strdup(buf); - return; - } - } - - ret = (*resp->ch.cudaSetDevice)(0); - if (ret != CUDART_SUCCESS) { - LOG(resp->ch.verbose, "cudaSetDevice err: %d\n", ret); - UNLOAD_LIBRARY(resp->ch.handle); - resp->ch.handle = NULL; - if (ret == CUDART_ERROR_INSUFFICIENT_DRIVER) { - resp->err = strdup("your nvidia driver is too old or missing. If you have a CUDA GPU please upgrade to run ollama"); - return; - } - snprintf(buf, buflen, "cudart init failure: %d", ret); - resp->err = strdup(buf); - return; - } - - int version = 0; - - // Report driver version if we're in verbose mode, ignore errors - ret = (*resp->ch.cudaDriverGetVersion)(&version); - if (ret != CUDART_SUCCESS) { - LOG(resp->ch.verbose, "cudaDriverGetVersion failed: %d\n", ret); - } else { - resp->ch.driver_major = version / 1000; - resp->ch.driver_minor = (version - (resp->ch.driver_major * 1000)) / 10; - LOG(resp->ch.verbose, "CUDA driver version: %d-%d\n", resp->ch.driver_major, resp->ch.driver_minor); - } - - ret = (*resp->ch.cudaGetDeviceCount)(&resp->num_devices); - if (ret != CUDART_SUCCESS) { - LOG(resp->ch.verbose, "cudaGetDeviceCount err: %d\n", ret); - UNLOAD_LIBRARY(resp->ch.handle); - resp->ch.handle = NULL; - snprintf(buf, buflen, "unable to get device count: %d", ret); - resp->err = strdup(buf); - return; - } -} - - -void cudart_bootstrap(cudart_handle_t h, int i, mem_info_t *resp) { - resp->err = NULL; - cudartMemory_t memInfo = {0,0,0}; - cudartReturn_t ret; - const int buflen = 256; - char buf[buflen + 1]; - - if (h.handle == NULL) { - resp->err = strdup("cudart handle isn't initialized"); - return; - } - - ret = (*h.cudaSetDevice)(i); - if (ret != CUDART_SUCCESS) { - snprintf(buf, buflen, "cudart device failed to initialize"); - resp->err = strdup(buf); - return; - } - - cudaDeviceProp_t props; - ret = (*h.cudaGetDeviceProperties)(&props, i); - if (ret != CUDART_SUCCESS) { - LOG(h.verbose, "[%d] device properties lookup failure: %d\n", i, ret); - snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", i); - resp->major = 0; - resp->minor = 0; - } else { - int allNull = 1; - for (int j = 0; j < 16; j++) { - if (props.uuid.bytes[j] != 0) { - allNull = 0; - break; - } - } - if (allNull != 0) { - snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", i); - } else { - // GPU-d110a105-ac29-1d54-7b49-9c90440f215b - snprintf(&resp->gpu_id[0], GPU_ID_LEN, - "GPU-%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x", - props.uuid.bytes[0], - props.uuid.bytes[1], - props.uuid.bytes[2], - props.uuid.bytes[3], - props.uuid.bytes[4], - props.uuid.bytes[5], - props.uuid.bytes[6], - props.uuid.bytes[7], - props.uuid.bytes[8], - props.uuid.bytes[9], - props.uuid.bytes[10], - props.uuid.bytes[11], - props.uuid.bytes[12], - props.uuid.bytes[13], - props.uuid.bytes[14], - props.uuid.bytes[15] - ); - } - resp->major = props.major; - resp->minor = props.minor; - - // TODO add other useful properties from props - } - ret = (*h.cudaMemGetInfo)(&memInfo.free, &memInfo.total); - if (ret != CUDART_SUCCESS) { - snprintf(buf, buflen, "cudart device memory info lookup failure %d", ret); - resp->err = strdup(buf); - return; - } - - resp->total = memInfo.total; - resp->free = memInfo.free; - resp->used = memInfo.used; - - LOG(h.verbose, "[%s] CUDA totalMem %" PRId64 "\n", resp->gpu_id, resp->total); - LOG(h.verbose, "[%s] CUDA freeMem %" PRId64 "\n", resp->gpu_id, resp->free); - LOG(h.verbose, "[%s] CUDA usedMem %" PRId64 "\n", resp->gpu_id, resp->used); - LOG(h.verbose, "[%s] Compute Capability %d.%d\n", resp->gpu_id, resp->major, resp->minor); -} - -void cudart_release(cudart_handle_t h) { - LOG(h.verbose, "releasing cudart library\n"); - UNLOAD_LIBRARY(h.handle); - h.handle = NULL; -} - -#endif // __APPLE__ diff --git a/discover/gpu_info_cudart.h b/discover/gpu_info_cudart.h deleted file mode 100644 index 893f3f7bd..000000000 --- a/discover/gpu_info_cudart.h +++ /dev/null @@ -1,145 +0,0 @@ -#ifndef __APPLE__ -#ifndef __GPU_INFO_CUDART_H__ -#define __GPU_INFO_CUDART_H__ -#include "gpu_info.h" - -// Just enough typedef's to dlopen/dlsym for memory information -typedef enum cudartReturn_enum { - CUDART_SUCCESS = 0, - CUDART_ERROR_INVALID_VALUE = 1, - CUDART_ERROR_MEMORY_ALLOCATION = 2, - CUDART_ERROR_INSUFFICIENT_DRIVER = 35, - // Other values omitted for now... -} cudartReturn_t; - -typedef enum cudartDeviceAttr_enum { - cudartDevAttrComputeCapabilityMajor = 75, - cudartDevAttrComputeCapabilityMinor = 76, - - // TODO - not yet wired up but may be useful for Jetson or other - // integrated GPU scenarios with shared memory - cudaDevAttrIntegrated = 18 - -} cudartDeviceAttr_t; - -typedef void *cudartDevice_t; // Opaque is sufficient -typedef struct cudartMemory_st { - size_t total; - size_t free; - size_t used; -} cudartMemory_t; - -typedef struct cudaUUID { - unsigned char bytes[16]; -} cudaUUID_t; -typedef struct cudaDeviceProp { - char name[256]; /**< ASCII string identifying device */ - cudaUUID_t uuid; /**< 16-byte unique identifier */ - char luid[8]; /**< 8-byte locally unique identifier. Value is undefined on TCC and non-Windows platforms */ - unsigned int luidDeviceNodeMask; /**< LUID device node mask. Value is undefined on TCC and non-Windows platforms */ - size_t totalGlobalMem; /**< Global memory available on device in bytes */ - size_t sharedMemPerBlock; /**< Shared memory available per block in bytes */ - int regsPerBlock; /**< 32-bit registers available per block */ - int warpSize; /**< Warp size in threads */ - size_t memPitch; /**< Maximum pitch in bytes allowed by memory copies */ - int maxThreadsPerBlock; /**< Maximum number of threads per block */ - int maxThreadsDim[3]; /**< Maximum size of each dimension of a block */ - int maxGridSize[3]; /**< Maximum size of each dimension of a grid */ - int clockRate; /**< Clock frequency in kilohertz */ - size_t totalConstMem; /**< Constant memory available on device in bytes */ - int major; /**< Major compute capability */ - int minor; /**< Minor compute capability */ - size_t textureAlignment; /**< Alignment requirement for textures */ - size_t texturePitchAlignment; /**< Pitch alignment requirement for texture references bound to pitched memory */ - int deviceOverlap; /**< Device can concurrently copy memory and execute a kernel. Deprecated. Use instead asyncEngineCount. */ - int multiProcessorCount; /**< Number of multiprocessors on device */ - int kernelExecTimeoutEnabled; /**< Specified whether there is a run time limit on kernels */ - int integrated; /**< Device is integrated as opposed to discrete */ - int canMapHostMemory; /**< Device can map host memory with cudaHostAlloc/cudaHostGetDevicePointer */ - int computeMode; /**< Compute mode (See ::cudaComputeMode) */ - int maxTexture1D; /**< Maximum 1D texture size */ - int maxTexture1DMipmap; /**< Maximum 1D mipmapped texture size */ - int maxTexture1DLinear; /**< Deprecated, do not use. Use cudaDeviceGetTexture1DLinearMaxWidth() or cuDeviceGetTexture1DLinearMaxWidth() instead. */ - int maxTexture2D[2]; /**< Maximum 2D texture dimensions */ - int maxTexture2DMipmap[2]; /**< Maximum 2D mipmapped texture dimensions */ - int maxTexture2DLinear[3]; /**< Maximum dimensions (width, height, pitch) for 2D textures bound to pitched memory */ - int maxTexture2DGather[2]; /**< Maximum 2D texture dimensions if texture gather operations have to be performed */ - int maxTexture3D[3]; /**< Maximum 3D texture dimensions */ - int maxTexture3DAlt[3]; /**< Maximum alternate 3D texture dimensions */ - int maxTextureCubemap; /**< Maximum Cubemap texture dimensions */ - int maxTexture1DLayered[2]; /**< Maximum 1D layered texture dimensions */ - int maxTexture2DLayered[3]; /**< Maximum 2D layered texture dimensions */ - int maxTextureCubemapLayered[2];/**< Maximum Cubemap layered texture dimensions */ - int maxSurface1D; /**< Maximum 1D surface size */ - int maxSurface2D[2]; /**< Maximum 2D surface dimensions */ - int maxSurface3D[3]; /**< Maximum 3D surface dimensions */ - int maxSurface1DLayered[2]; /**< Maximum 1D layered surface dimensions */ - int maxSurface2DLayered[3]; /**< Maximum 2D layered surface dimensions */ - int maxSurfaceCubemap; /**< Maximum Cubemap surface dimensions */ - int maxSurfaceCubemapLayered[2];/**< Maximum Cubemap layered surface dimensions */ - size_t surfaceAlignment; /**< Alignment requirements for surfaces */ - int concurrentKernels; /**< Device can possibly execute multiple kernels concurrently */ - int ECCEnabled; /**< Device has ECC support enabled */ - int pciBusID; /**< PCI bus ID of the device */ - int pciDeviceID; /**< PCI device ID of the device */ - int pciDomainID; /**< PCI domain ID of the device */ - int tccDriver; /**< 1 if device is a Tesla device using TCC driver, 0 otherwise */ - int asyncEngineCount; /**< Number of asynchronous engines */ - int unifiedAddressing; /**< Device shares a unified address space with the host */ - int memoryClockRate; /**< Peak memory clock frequency in kilohertz */ - int memoryBusWidth; /**< Global memory bus width in bits */ - int l2CacheSize; /**< Size of L2 cache in bytes */ - int persistingL2CacheMaxSize; /**< Device's maximum l2 persisting lines capacity setting in bytes */ - int maxThreadsPerMultiProcessor;/**< Maximum resident threads per multiprocessor */ - int streamPrioritiesSupported; /**< Device supports stream priorities */ - int globalL1CacheSupported; /**< Device supports caching globals in L1 */ - int localL1CacheSupported; /**< Device supports caching locals in L1 */ - size_t sharedMemPerMultiprocessor; /**< Shared memory available per multiprocessor in bytes */ - int regsPerMultiprocessor; /**< 32-bit registers available per multiprocessor */ - int managedMemory; /**< Device supports allocating managed memory on this system */ - int isMultiGpuBoard; /**< Device is on a multi-GPU board */ - int multiGpuBoardGroupID; /**< Unique identifier for a group of devices on the same multi-GPU board */ - int hostNativeAtomicSupported; /**< Link between the device and the host supports native atomic operations */ - int singleToDoublePrecisionPerfRatio; /**< Ratio of single precision performance (in floating-point operations per second) to double precision performance */ - int pageableMemoryAccess; /**< Device supports coherently accessing pageable memory without calling cudaHostRegister on it */ - int concurrentManagedAccess; /**< Device can coherently access managed memory concurrently with the CPU */ - int computePreemptionSupported; /**< Device supports Compute Preemption */ - int canUseHostPointerForRegisteredMem; /**< Device can access host registered memory at the same virtual address as the CPU */ - int cooperativeLaunch; /**< Device supports launching cooperative kernels via ::cudaLaunchCooperativeKernel */ - int cooperativeMultiDeviceLaunch; /**< Deprecated, cudaLaunchCooperativeKernelMultiDevice is deprecated. */ - size_t sharedMemPerBlockOptin; /**< Per device maximum shared memory per block usable by special opt in */ - int pageableMemoryAccessUsesHostPageTables; /**< Device accesses pageable memory via the host's page tables */ - int directManagedMemAccessFromHost; /**< Host can directly access managed memory on the device without migration. */ - int maxBlocksPerMultiProcessor; /**< Maximum number of resident blocks per multiprocessor */ - int accessPolicyMaxWindowSize; /**< The maximum value of ::cudaAccessPolicyWindow::num_bytes. */ - size_t reservedSharedMemPerBlock; /**< Shared memory reserved by CUDA driver per block in bytes */ - } cudaDeviceProp_t; - -typedef struct cudart_handle { - void *handle; - uint16_t verbose; - int driver_major; - int driver_minor; - cudartReturn_t (*cudaSetDevice)(int device); - cudartReturn_t (*cudaDeviceSynchronize)(void); - cudartReturn_t (*cudaDeviceReset)(void); - cudartReturn_t (*cudaMemGetInfo)(size_t *, size_t *); - cudartReturn_t (*cudaGetDeviceCount)(int *); - cudartReturn_t (*cudaDeviceGetAttribute)(int* value, cudartDeviceAttr_t attr, int device); - cudartReturn_t (*cudaDriverGetVersion) (int *driverVersion); - cudartReturn_t (*cudaGetDeviceProperties) (cudaDeviceProp_t* prop, int device); -} cudart_handle_t; - -typedef struct cudart_init_resp { - char *err; // If err is non-null handle is invalid - cudart_handle_t ch; - int num_devices; -} cudart_init_resp_t; - -void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp); -void cudart_bootstrap(cudart_handle_t ch, int device_id, mem_info_t *resp); -// TODO - if we keep this library longer term, add cudart_get_free -void cudart_release(cudart_handle_t ch); - -#endif // __GPU_INFO_CUDART_H__ -#endif // __APPLE__ diff --git a/discover/gpu_info_nvcuda.c b/discover/gpu_info_nvcuda.c deleted file mode 100644 index d2d0b683b..000000000 --- a/discover/gpu_info_nvcuda.c +++ /dev/null @@ -1,251 +0,0 @@ -#ifndef __APPLE__ // TODO - maybe consider nvidia support on intel macs? - -#include -#include -#include "gpu_info_nvcuda.h" - -void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp) { - LOG(resp->ch.verbose, "initializing %s\n", nvcuda_lib_path); - CUresult ret; - resp->err = NULL; - resp->num_devices = 0; - resp->cudaErr = CUDA_SUCCESS; - const int buflen = 256; - char buf[buflen + 1]; - int i; - - struct lookup { - char *s; - void **p; - } l[] = { - - {"cuInit", (void *)&resp->ch.cuInit}, - {"cuDriverGetVersion", (void *)&resp->ch.cuDriverGetVersion}, - {"cuDeviceGetCount", (void *)&resp->ch.cuDeviceGetCount}, - {"cuDeviceGet", (void *)&resp->ch.cuDeviceGet}, - {"cuDeviceGetAttribute", (void *)&resp->ch.cuDeviceGetAttribute}, - {"cuDeviceGetUuid", (void *)&resp->ch.cuDeviceGetUuid}, - {"cuDeviceGetName", (void *)&resp->ch.cuDeviceGetName}, - {"cuCtxCreate_v3", (void *)&resp->ch.cuCtxCreate_v3}, - {"cuMemGetInfo_v2", (void *)&resp->ch.cuMemGetInfo_v2}, - {"cuCtxDestroy", (void *)&resp->ch.cuCtxDestroy}, - {NULL, NULL}, - }; - - resp->ch.handle = LOAD_LIBRARY(nvcuda_lib_path, RTLD_LAZY); - if (!resp->ch.handle) { - char *msg = LOAD_ERR(); - LOG(resp->ch.verbose, "library %s load err: %s\n", nvcuda_lib_path, msg); - snprintf(buf, buflen, - "Unable to load %s library to query for Nvidia GPUs: %s", - nvcuda_lib_path, msg); - free(msg); - resp->err = strdup(buf); - resp->cudaErr = -1; - return; - } - - for (i = 0; l[i].s != NULL; i++) { - *l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s); - if (!*(l[i].p)) { - char *msg = LOAD_ERR(); - LOG(resp->ch.verbose, "dlerr: %s\n", msg); - UNLOAD_LIBRARY(resp->ch.handle); - resp->ch.handle = NULL; - snprintf(buf, buflen, "symbol lookup for %s failed: %s", l[i].s, - msg); - free(msg); - resp->err = strdup(buf); - resp->cudaErr = -1; - return; - } - LOG(resp->ch.verbose, "dlsym: %s - %p\n", l[i].s, *l[i].p); - } - - LOG(resp->ch.verbose, "calling cuInit\n"); - ret = (*resp->ch.cuInit)(0); - if (ret != CUDA_SUCCESS) { - LOG(resp->ch.verbose, "cuInit err: %d\n", ret); - UNLOAD_LIBRARY(resp->ch.handle); - resp->ch.handle = NULL; - snprintf(buf, buflen, "cuda driver library init failure: %d", ret); - resp->err = strdup(buf); - resp->cudaErr = ret; - return; - } - - int version = 0; - resp->ch.driver_major = 0; - resp->ch.driver_minor = 0; - - // Report driver version if we're in verbose mode, ignore errors - LOG(resp->ch.verbose, "calling cuDriverGetVersion\n"); - ret = (*resp->ch.cuDriverGetVersion)(&version); - if (ret != CUDA_SUCCESS) { - LOG(resp->ch.verbose, "cuDriverGetVersion failed: %d\n", ret); - } else { - LOG(resp->ch.verbose, "raw version 0x%x\n", version); - resp->ch.driver_major = version / 1000; - resp->ch.driver_minor = (version - (resp->ch.driver_major * 1000)) / 10; - LOG(resp->ch.verbose, "CUDA driver version: %d.%d\n", resp->ch.driver_major, resp->ch.driver_minor); - } - - LOG(resp->ch.verbose, "calling cuDeviceGetCount\n"); - ret = (*resp->ch.cuDeviceGetCount)(&resp->num_devices); - if (ret != CUDA_SUCCESS) { - LOG(resp->ch.verbose, "cuDeviceGetCount err: %d\n", ret); - UNLOAD_LIBRARY(resp->ch.handle); - resp->ch.handle = NULL; - snprintf(buf, buflen, "unable to get device count: %d", ret); - resp->err = strdup(buf); - resp->cudaErr = ret; - return; - } - LOG(resp->ch.verbose, "device count %d\n", resp->num_devices); -} - -const int buflen = 256; -void nvcuda_bootstrap(nvcuda_handle_t h, int i, mem_info_t *resp) { - resp->err = NULL; - nvcudaMemory_t memInfo = {0,0}; - CUresult ret; - CUdevice device = -1; - CUcontext ctx = NULL; - char buf[buflen + 1]; - CUuuid uuid = {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0}; - - if (h.handle == NULL) { - resp->err = strdup("cuda driver library handle isn't initialized"); - return; - } - - ret = (*h.cuDeviceGet)(&device, i); - if (ret != CUDA_SUCCESS) { - snprintf(buf, buflen, "cuda driver library device failed to initialize"); - resp->err = strdup(buf); - return; - } - - int major = 0; - int minor = 0; - ret = (*h.cuDeviceGetAttribute)(&major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device); - if (ret != CUDA_SUCCESS) { - LOG(h.verbose, "[%d] device major lookup failure: %d\n", i, ret); - } else { - ret = (*h.cuDeviceGetAttribute)(&minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device); - if (ret != CUDA_SUCCESS) { - LOG(h.verbose, "[%d] device minor lookup failure: %d\n", i, ret); - } else { - resp->minor = minor; - resp->major = major; - } - } - - ret = (*h.cuDeviceGetUuid)(&uuid, device); - if (ret != CUDA_SUCCESS) { - LOG(h.verbose, "[%d] device uuid lookup failure: %d\n", i, ret); - snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", i); - } else { - // GPU-d110a105-ac29-1d54-7b49-9c90440f215b - snprintf(&resp->gpu_id[0], GPU_ID_LEN, - "GPU-%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x", - uuid.bytes[0], - uuid.bytes[1], - uuid.bytes[2], - uuid.bytes[3], - uuid.bytes[4], - uuid.bytes[5], - uuid.bytes[6], - uuid.bytes[7], - uuid.bytes[8], - uuid.bytes[9], - uuid.bytes[10], - uuid.bytes[11], - uuid.bytes[12], - uuid.bytes[13], - uuid.bytes[14], - uuid.bytes[15] - ); - } - - ret = (*h.cuDeviceGetName)(&resp->gpu_name[0], GPU_NAME_LEN, device); - if (ret != CUDA_SUCCESS) { - LOG(h.verbose, "[%d] device name lookup failure: %d\n", i, ret); - resp->gpu_name[0] = '\0'; - } - - // To get memory we have to set (and release) a context - ret = (*h.cuCtxCreate_v3)(&ctx, NULL, 0, 0, device); - if (ret != CUDA_SUCCESS) { - snprintf(buf, buflen, "cuda driver library failed to get device context %d", ret); - resp->err = strdup(buf); - return; - } - - ret = (*h.cuMemGetInfo_v2)(&memInfo.free, &memInfo.total); - if (ret != CUDA_SUCCESS) { - snprintf(buf, buflen, "cuda driver library device memory info lookup failure %d", ret); - resp->err = strdup(buf); - // Best effort on failure... - (*h.cuCtxDestroy)(ctx); - return; - } - - resp->total = memInfo.total; - resp->free = memInfo.free; - - LOG(h.verbose, "[%s] CUDA totalMem %" PRId64 "mb\n", resp->gpu_id, resp->total / 1024 / 1024); - LOG(h.verbose, "[%s] CUDA freeMem %" PRId64 "mb\n", resp->gpu_id, resp->free / 1024 / 1024); - LOG(h.verbose, "[%s] Compute Capability %d.%d\n", resp->gpu_id, resp->major, resp->minor); - - - - ret = (*h.cuCtxDestroy)(ctx); - if (ret != CUDA_SUCCESS) { - LOG(1, "cuda driver library failed to release device context %d", ret); - } -} - -void nvcuda_get_free(nvcuda_handle_t h, int i, uint64_t *free, uint64_t *total) { - CUresult ret; - CUcontext ctx = NULL; - CUdevice device = -1; - *free = 0; - *total = 0; - - ret = (*h.cuDeviceGet)(&device, i); - if (ret != CUDA_SUCCESS) { - LOG(1, "cuda driver library device failed to initialize"); - return; - } - - - // To get memory we have to set (and release) a context - ret = (*h.cuCtxCreate_v3)(&ctx, NULL, 0, 0, device); - if (ret != CUDA_SUCCESS) { - LOG(1, "cuda driver library failed to get device context %d", ret); - return; - } - - ret = (*h.cuMemGetInfo_v2)(free, total); - if (ret != CUDA_SUCCESS) { - LOG(1, "cuda driver library device memory info lookup failure %d", ret); - // Best effort on failure... - (*h.cuCtxDestroy)(ctx); - return; - } - - ret = (*h.cuCtxDestroy)(ctx); - if (ret != CUDA_SUCCESS) { - LOG(1, "cuda driver library failed to release device context %d", ret); - } -} - -void nvcuda_release(nvcuda_handle_t h) { - LOG(h.verbose, "releasing cuda driver library\n"); - UNLOAD_LIBRARY(h.handle); - // TODO and other context release logic? - h.handle = NULL; -} - -#endif // __APPLE__ diff --git a/discover/gpu_info_nvcuda.h b/discover/gpu_info_nvcuda.h deleted file mode 100644 index ef2fe8a30..000000000 --- a/discover/gpu_info_nvcuda.h +++ /dev/null @@ -1,79 +0,0 @@ -#ifndef __APPLE__ -#ifndef __GPU_INFO_NVCUDA_H__ -#define __GPU_INFO_NVCUDA_H__ -#include "gpu_info.h" - -// Just enough typedef's to dlopen/dlsym for memory information -typedef enum cudaError_enum { - CUDA_SUCCESS = 0, - CUDA_ERROR_INVALID_VALUE = 1, - CUDA_ERROR_OUT_OF_MEMORY = 2, - CUDA_ERROR_NOT_INITIALIZED = 3, - CUDA_ERROR_INSUFFICIENT_DRIVER = 35, - CUDA_ERROR_NO_DEVICE = 100, - CUDA_ERROR_SYSTEM_DRIVER_MISMATCH = 803, - CUDA_ERROR_UNKNOWN = 999, - // Other values omitted for now... -} CUresult; - -typedef enum CUdevice_attribute_enum { - CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR = 75, - CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR = 76, - - // TODO - not yet wired up but may be useful for Jetson or other - // integrated GPU scenarios with shared memory - CU_DEVICE_ATTRIBUTE_INTEGRATED = 18 - -} CUdevice_attribute; - -typedef void *nvcudaDevice_t; // Opaque is sufficient -typedef struct nvcudaMemory_st { - uint64_t total; - uint64_t free; -} nvcudaMemory_t; - -typedef struct nvcudaDriverVersion { - int major; - int minor; -} nvcudaDriverVersion_t; - -typedef struct CUuuid_st { - unsigned char bytes[16]; -} CUuuid; - -typedef int CUdevice; -typedef void* CUcontext; - -typedef struct nvcuda_handle { - void *handle; - uint16_t verbose; - int driver_major; - int driver_minor; - CUresult (*cuInit)(unsigned int Flags); - CUresult (*cuDriverGetVersion)(int *driverVersion); - CUresult (*cuDeviceGetCount)(int *); - CUresult (*cuDeviceGet)(CUdevice* device, int ordinal); - CUresult (*cuDeviceGetAttribute)(int* pi, CUdevice_attribute attrib, CUdevice dev); - CUresult (*cuDeviceGetUuid)(CUuuid* uuid, CUdevice dev); // signature compatible with cuDeviceGetUuid_v2 - CUresult (*cuDeviceGetName)(char *name, int len, CUdevice dev); - - // Context specific aspects - CUresult (*cuCtxCreate_v3)(CUcontext* pctx, void *params, int len, unsigned int flags, CUdevice dev); - CUresult (*cuMemGetInfo_v2)(uint64_t* free, uint64_t* total); - CUresult (*cuCtxDestroy)(CUcontext ctx); -} nvcuda_handle_t; - -typedef struct nvcuda_init_resp { - char *err; // If err is non-null handle is invalid - nvcuda_handle_t ch; - int num_devices; - CUresult cudaErr; -} nvcuda_init_resp_t; - -void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp); -void nvcuda_bootstrap(nvcuda_handle_t ch, int device_id, mem_info_t *resp); -void nvcuda_get_free(nvcuda_handle_t ch, int device_id, uint64_t *free, uint64_t *total); -void nvcuda_release(nvcuda_handle_t ch); - -#endif // __GPU_INFO_NVCUDA_H__ -#endif // __APPLE__ diff --git a/discover/gpu_info_nvml.c b/discover/gpu_info_nvml.c deleted file mode 100644 index 342a3aa4b..000000000 --- a/discover/gpu_info_nvml.c +++ /dev/null @@ -1,104 +0,0 @@ -#ifndef __APPLE__ // TODO - maybe consider nvidia support on intel macs? - -#include - -#include "gpu_info_nvml.h" - -void nvml_init(char *nvml_lib_path, nvml_init_resp_t *resp) { - nvmlReturn_t ret; - resp->err = NULL; - const int buflen = 256; - char buf[buflen + 1]; - int i; - - struct lookup { - char *s; - void **p; - } l[] = { - {"nvmlInit_v2", (void *)&resp->ch.nvmlInit_v2}, - {"nvmlShutdown", (void *)&resp->ch.nvmlShutdown}, - {"nvmlDeviceGetHandleByUUID", (void *)&resp->ch.nvmlDeviceGetHandleByUUID}, - {"nvmlDeviceGetMemoryInfo", (void *)&resp->ch.nvmlDeviceGetMemoryInfo}, - {NULL, NULL}, - }; - - resp->ch.handle = LOAD_LIBRARY(nvml_lib_path, RTLD_LAZY); - if (!resp->ch.handle) { - char *msg = LOAD_ERR(); - LOG(resp->ch.verbose, "library %s load err: %s\n", nvml_lib_path, msg); - snprintf(buf, buflen, - "Unable to load %s library to query for Nvidia GPUs: %s", - nvml_lib_path, msg); - free(msg); - resp->err = strdup(buf); - return; - } - - // TODO once we've squashed the remaining corner cases remove this log - // LOG(resp->ch.verbose, "wiring nvidia management library functions in %s\n", nvml_lib_path); - - for (i = 0; l[i].s != NULL; i++) { - // TODO once we've squashed the remaining corner cases remove this log - // LOG(resp->ch.verbose, "dlsym: %s\n", l[i].s); - - *l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s); - if (!*(l[i].p)) { - resp->ch.handle = NULL; - char *msg = LOAD_ERR(); - LOG(resp->ch.verbose, "dlerr: %s\n", msg); - UNLOAD_LIBRARY(resp->ch.handle); - snprintf(buf, buflen, "symbol lookup for %s failed: %s", l[i].s, - msg); - free(msg); - resp->err = strdup(buf); - return; - } - } - - ret = (*resp->ch.nvmlInit_v2)(); - if (ret != NVML_SUCCESS) { - LOG(resp->ch.verbose, "nvmlInit_v2 err: %d\n", ret); - UNLOAD_LIBRARY(resp->ch.handle); - resp->ch.handle = NULL; - snprintf(buf, buflen, "nvml vram init failure: %d", ret); - resp->err = strdup(buf); - return; - } -} - - -void nvml_get_free(nvml_handle_t h, char *uuid, uint64_t *free, uint64_t *total, uint64_t *used) { - nvmlDevice_t device; - nvmlMemory_t memInfo = {0}; - nvmlReturn_t ret; - ret = (*h.nvmlDeviceGetHandleByUUID)((const char *)(uuid), &device); - if (ret != NVML_SUCCESS) { - LOG(1, "unable to get device handle %s: %d", uuid, ret); - *free = 0; - return; - } - - ret = (*h.nvmlDeviceGetMemoryInfo)(device, &memInfo); - if (ret != NVML_SUCCESS) { - LOG(1, "device memory info lookup failure %s: %d", uuid, ret); - *free = 0; - return; - } - *free = memInfo.free; - *total = memInfo.total; - *used = memInfo.used; -} - - -void nvml_release(nvml_handle_t h) { - LOG(h.verbose, "releasing nvml library\n"); - nvmlReturn_t ret; - ret = (*h.nvmlShutdown)(); - if (ret != NVML_SUCCESS) { - LOG(1, "error during nvmlShutdown %d", ret); - } - UNLOAD_LIBRARY(h.handle); - h.handle = NULL; -} - -#endif // __APPLE__ \ No newline at end of file diff --git a/discover/gpu_info_nvml.h b/discover/gpu_info_nvml.h deleted file mode 100644 index 908802337..000000000 --- a/discover/gpu_info_nvml.h +++ /dev/null @@ -1,48 +0,0 @@ -#ifndef __APPLE__ -#ifndef __GPU_INFO_NVML_H__ -#define __GPU_INFO_NVML_H__ -#include "gpu_info.h" - -// Just enough typedef's to dlopen/dlsym for memory information -typedef enum nvmlReturn_enum { - NVML_SUCCESS = 0, - // Other values omitted for now... -} nvmlReturn_t; -typedef void *nvmlDevice_t; // Opaque is sufficient -typedef struct nvmlMemory_st { - unsigned long long total; - unsigned long long free; - unsigned long long used; -} nvmlMemory_t; - -typedef enum nvmlBrandType_enum -{ - NVML_BRAND_UNKNOWN = 0, -} nvmlBrandType_t; - -typedef struct nvml_handle { - void *handle; - uint16_t verbose; - nvmlReturn_t (*nvmlInit_v2)(void); - nvmlReturn_t (*nvmlShutdown)(void); - nvmlReturn_t (*nvmlDeviceGetHandleByUUID)(const char *, nvmlDevice_t *); - nvmlReturn_t (*nvmlDeviceGetMemoryInfo)(nvmlDevice_t, nvmlMemory_t *); -} nvml_handle_t; - -typedef struct nvml_init_resp { - char *err; // If err is non-null handle is invalid - nvml_handle_t ch; -} nvml_init_resp_t; - -typedef struct nvml_compute_capability { - char *err; - int major; - int minor; -} nvml_compute_capability_t; - -void nvml_init(char *nvml_lib_path, nvml_init_resp_t *resp); -void nvml_get_free(nvml_handle_t ch, char *uuid, uint64_t *free, uint64_t *total, uint64_t *used); -void nvml_release(nvml_handle_t ch); - -#endif // __GPU_INFO_NVML_H__ -#endif // __APPLE__ \ No newline at end of file diff --git a/discover/gpu_info_oneapi.c b/discover/gpu_info_oneapi.c deleted file mode 100644 index 3ff708ea2..000000000 --- a/discover/gpu_info_oneapi.c +++ /dev/null @@ -1,259 +0,0 @@ -#ifndef __APPLE__ - -#include "gpu_info_oneapi.h" - -#include - -void oneapi_init(char *oneapi_lib_path, oneapi_init_resp_t *resp) { - ze_result_t ret; - resp->err = NULL; - resp->oh.devices = NULL; - resp->oh.num_devices = NULL; - resp->oh.drivers = NULL; - resp->oh.num_drivers = 0; - const int buflen = 256; - char buf[buflen + 1]; - int i, d; - struct lookup { - char *s; - void **p; - } l[] = { - {"zesInit", (void *)&resp->oh.zesInit}, - {"zesDriverGet", (void *)&resp->oh.zesDriverGet}, - {"zesDeviceGet", (void *)&resp->oh.zesDeviceGet}, - {"zesDeviceGetProperties", (void *)&resp->oh.zesDeviceGetProperties}, - {"zesDeviceEnumMemoryModules", - (void *)&resp->oh.zesDeviceEnumMemoryModules}, - {"zesMemoryGetProperties", (void *)&resp->oh.zesMemoryGetProperties}, - {"zesMemoryGetState", (void *)&resp->oh.zesMemoryGetState}, - {NULL, NULL}, - }; - - resp->oh.handle = LOAD_LIBRARY(oneapi_lib_path, RTLD_LAZY); - if (!resp->oh.handle) { - char *msg = LOAD_ERR(); - snprintf(buf, buflen, - "Unable to load %s library to query for Intel GPUs: %s\n", - oneapi_lib_path, msg); - free(msg); - resp->err = strdup(buf); - return; - } - - // TODO once we've squashed the remaining corner cases remove this log - LOG(resp->oh.verbose, - "wiring Level-Zero management library functions in %s\n", - oneapi_lib_path); - - for (i = 0; l[i].s != NULL; i++) { - // TODO once we've squashed the remaining corner cases remove this log - LOG(resp->oh.verbose, "dlsym: %s\n", l[i].s); - - *l[i].p = LOAD_SYMBOL(resp->oh.handle, l[i].s); - if (!*(l[i].p)) { - resp->oh.handle = NULL; - char *msg = LOAD_ERR(); - LOG(resp->oh.verbose, "dlerr: %s\n", msg); - UNLOAD_LIBRARY(resp->oh.handle); - snprintf(buf, buflen, "symbol lookup for %s failed: %s", l[i].s, msg); - free(msg); - resp->err = strdup(buf); - return; - } - } - - LOG(resp->oh.verbose, "calling zesInit\n"); - - ret = (*resp->oh.zesInit)(0); - if (ret != ZE_RESULT_SUCCESS) { - LOG(resp->oh.verbose, "zesInit err: %x\n", ret); - snprintf(buf, buflen, "oneapi vram init failure: %x", ret); - resp->err = strdup(buf); - oneapi_release(resp->oh); - return; - } - - LOG(resp->oh.verbose, "calling zesDriverGet\n"); - ret = (*resp->oh.zesDriverGet)(&resp->oh.num_drivers, NULL); - if (ret != ZE_RESULT_SUCCESS) { - LOG(resp->oh.verbose, "zesDriverGet err: %x\n", ret); - snprintf(buf, buflen, "unable to get driver count: %x", ret); - resp->err = strdup(buf); - oneapi_release(resp->oh); - return; - } - LOG(resp->oh.verbose, "oneapi driver count: %d\n", resp->oh.num_drivers); - resp->oh.drivers = malloc(resp->oh.num_drivers * sizeof(zes_driver_handle_t)); - resp->oh.num_devices = malloc(resp->oh.num_drivers * sizeof(uint32_t)); - memset(&resp->oh.num_devices[0], 0, resp->oh.num_drivers * sizeof(uint32_t)); - resp->oh.devices = - malloc(resp->oh.num_drivers * sizeof(zes_device_handle_t *)); - ret = (*resp->oh.zesDriverGet)(&resp->oh.num_drivers, &resp->oh.drivers[0]); - if (ret != ZE_RESULT_SUCCESS) { - LOG(resp->oh.verbose, "zesDriverGet err: %x\n", ret); - snprintf(buf, buflen, "unable to get driver count: %x", ret); - resp->err = strdup(buf); - oneapi_release(resp->oh); - return; - } - - for (d = 0; d < resp->oh.num_drivers; d++) { - LOG(resp->oh.verbose, "calling zesDeviceGet count %d: %p\n", d, resp->oh.drivers[d]); - ret = (*resp->oh.zesDeviceGet)(resp->oh.drivers[d], - &resp->oh.num_devices[d], NULL); - if (ret != ZE_RESULT_SUCCESS) { - LOG(resp->oh.verbose, "zesDeviceGet err: %x\n", ret); - snprintf(buf, buflen, "unable to get device count: %x", ret); - resp->err = strdup(buf); - oneapi_release(resp->oh); - return; - } - resp->oh.devices[d] = - malloc(resp->oh.num_devices[d] * sizeof(zes_device_handle_t)); - ret = (*resp->oh.zesDeviceGet)( - resp->oh.drivers[d], &resp->oh.num_devices[d], resp->oh.devices[d]); - if (ret != ZE_RESULT_SUCCESS) { - LOG(resp->oh.verbose, "zesDeviceGet err: %x\n", ret); - snprintf(buf, buflen, "unable to get device count: %x", ret); - resp->err = strdup(buf); - oneapi_release(resp->oh); - return; - } - } - - return; -} - -void oneapi_check_vram(oneapi_handle_t h, int driver, int device, - mem_info_t *resp) { - ze_result_t ret; - resp->err = NULL; - uint64_t totalMem = 0; - uint64_t usedMem = 0; - const int buflen = 256; - char buf[buflen + 1]; - int i, d, m; - - if (h.handle == NULL) { - resp->err = strdup("Level-Zero handle not initialized"); - return; - } - - if (driver > h.num_drivers || device > h.num_devices[driver]) { - resp->err = strdup("driver of device index out of bounds"); - return; - } - - resp->total = 0; - resp->free = 0; - - zes_device_ext_properties_t ext_props; - ext_props.stype = ZES_STRUCTURE_TYPE_DEVICE_EXT_PROPERTIES; - ext_props.pNext = NULL; - - zes_device_properties_t props; - props.stype = ZES_STRUCTURE_TYPE_DEVICE_PROPERTIES; - props.pNext = &ext_props; - - ret = (*h.zesDeviceGetProperties)(h.devices[driver][device], &props); - if (ret != ZE_RESULT_SUCCESS) { - snprintf(buf, buflen, "unable to get device properties: %d", ret); - resp->err = strdup(buf); - return; - } - - snprintf(&resp->gpu_name[0], GPU_NAME_LEN, "%s", props.modelName); - - // TODO this needs to map to ONEAPI_DEVICE_SELECTOR syntax - // (this is probably wrong...) - // TODO - the driver isn't included - what if there are multiple drivers? - snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", device); - - if (h.verbose) { - // When in verbose mode, report more information about - // the card we discover. - LOG(h.verbose, "[%d:%d] oneAPI device name: %s\n", driver, device, - props.modelName); - LOG(h.verbose, "[%d:%d] oneAPI brand: %s\n", driver, device, - props.brandName); - LOG(h.verbose, "[%d:%d] oneAPI vendor: %s\n", driver, device, - props.vendorName); - LOG(h.verbose, "[%d:%d] oneAPI S/N: %s\n", driver, device, - props.serialNumber); - LOG(h.verbose, "[%d:%d] oneAPI board number: %s\n", driver, device, - props.boardNumber); - } - - // TODO - // Compute Capability equivalent in resp->major, resp->minor, resp->patch - - uint32_t memCount = 0; - ret = (*h.zesDeviceEnumMemoryModules)(h.devices[driver][device], &memCount, - NULL); - if (ret != ZE_RESULT_SUCCESS) { - snprintf(buf, buflen, "unable to enumerate Level-Zero memory modules: %x", - ret); - resp->err = strdup(buf); - return; - } - - LOG(h.verbose, "discovered %d Level-Zero memory modules\n", memCount); - - zes_mem_handle_t *mems = malloc(memCount * sizeof(zes_mem_handle_t)); - (*h.zesDeviceEnumMemoryModules)(h.devices[driver][device], &memCount, mems); - - for (m = 0; m < memCount; m++) { - zes_mem_state_t state; - state.stype = ZES_STRUCTURE_TYPE_MEM_STATE; - state.pNext = NULL; - ret = (*h.zesMemoryGetState)(mems[m], &state); - if (ret != ZE_RESULT_SUCCESS) { - snprintf(buf, buflen, "unable to get memory state: %x", ret); - resp->err = strdup(buf); - free(mems); - return; - } - - resp->total += state.size; - resp->free += state.free; - } - - free(mems); -} - -void oneapi_release(oneapi_handle_t h) { - int d; - LOG(h.verbose, "releasing oneapi library\n"); - for (d = 0; d < h.num_drivers; d++) { - if (h.devices != NULL && h.devices[d] != NULL) { - free(h.devices[d]); - } - } - if (h.devices != NULL) { - free(h.devices); - h.devices = NULL; - } - if (h.num_devices != NULL) { - free(h.num_devices); - h.num_devices = NULL; - } - if (h.drivers != NULL) { - free(h.drivers); - h.drivers = NULL; - } - h.num_drivers = 0; - UNLOAD_LIBRARY(h.handle); - h.handle = NULL; -} - -int oneapi_get_device_count(oneapi_handle_t h, int driver) { - if (h.handle == NULL || h.num_devices == NULL) { - return 0; - } - if (driver > h.num_drivers) { - return 0; - } - return (int)h.num_devices[driver]; -} - -#endif // __APPLE__ diff --git a/discover/gpu_info_oneapi.h b/discover/gpu_info_oneapi.h deleted file mode 100644 index 97fcecd9c..000000000 --- a/discover/gpu_info_oneapi.h +++ /dev/null @@ -1,203 +0,0 @@ -#ifndef __APPLE__ -#ifndef __GPU_INFO_ONEAPI_H__ -#define __GPU_INFO_ONEAPI_H__ -#include "gpu_info.h" - -#define ZE_MAX_DEVICE_NAME 256 -#define ZE_MAX_DEVICE_UUID_SIZE 16 -#define ZES_STRING_PROPERTY_SIZE 64 -#define ZE_BIT(_i) (1 << _i) - -// Just enough typedef's to dlopen/dlsym for memory information -typedef enum ze_result_t { - ZE_RESULT_SUCCESS = 0, - // Other values omitted for now... -} ze_result_t; - -typedef uint8_t ze_bool_t; -typedef struct _zes_driver_handle_t *zes_driver_handle_t; -typedef struct _zes_device_handle_t *zes_device_handle_t; -typedef struct _zes_mem_handle_t *zes_mem_handle_t; - -typedef enum _ze_structure_type_t { - ZE_STRUCTURE_TYPE_FORCE_UINT32 = 0x7fffffff -} ze_structure_type_t; - -typedef enum _zes_structure_type_t { - ZES_STRUCTURE_TYPE_DEVICE_PROPERTIES = 0x1, - ZES_STRUCTURE_TYPE_MEM_PROPERTIES = 0xb, - ZES_STRUCTURE_TYPE_MEM_STATE = 0x1e, - ZES_STRUCTURE_TYPE_DEVICE_EXT_PROPERTIES = 0x2d, - ZES_STRUCTURE_TYPE_FORCE_UINT32 = 0x7fffffff -} zes_structure_type_t; - -typedef enum _zes_mem_type_t { - ZES_MEM_TYPE_FORCE_UINT32 = 0x7fffffff -} zes_mem_type_t; - -typedef enum _zes_mem_loc_t { - ZES_MEM_LOC_SYSTEM = 0, - ZES_MEM_LOC_DEVICE = 1, - ZES_MEM_LOC_FORCE_UINT32 = 0x7fffffff -} zes_mem_loc_t; - -typedef enum _zes_mem_health_t { - ZES_MEM_HEALTH_FORCE_UINT32 = 0x7fffffff -} zes_mem_health_t; - -typedef struct _ze_device_uuid_t { - uint8_t id[ZE_MAX_DEVICE_UUID_SIZE]; -} ze_device_uuid_t; - -typedef struct _zes_uuid_t { - uint8_t id[ZE_MAX_DEVICE_UUID_SIZE]; -} zes_uuid_t; - -typedef enum _ze_device_type_t { - ZE_DEVICE_TYPE_GPU = 1, - ZE_DEVICE_TYPE_CPU = 2, - ZE_DEVICE_TYPE_FPGA = 3, - ZE_DEVICE_TYPE_MCA = 4, - ZE_DEVICE_TYPE_VPU = 5, - ZE_DEVICE_TYPE_FORCE_UINT32 = 0x7fffffff -} ze_device_type_t; - -typedef enum _zes_device_type_t { - ZES_DEVICE_TYPE_GPU = 1, - ZES_DEVICE_TYPE_CPU = 2, - ZES_DEVICE_TYPE_FPGA = 3, - ZES_DEVICE_TYPE_MCA = 4, - ZES_DEVICE_TYPE_VPU = 5, - ZES_DEVICE_TYPE_FORCE_UINT32 = 0x7fffffff -} zes_device_type_t; - -typedef uint32_t ze_device_property_flags_t; -typedef enum _ze_device_property_flag_t { - ZE_DEVICE_PROPERTY_FLAG_INTEGRATED = ZE_BIT(0), - ZE_DEVICE_PROPERTY_FLAG_SUBDEVICE = ZE_BIT(1), - ZE_DEVICE_PROPERTY_FLAG_ECC = ZE_BIT(2), - ZE_DEVICE_PROPERTY_FLAG_ONDEMANDPAGING = ZE_BIT(3), - ZE_DEVICE_PROPERTY_FLAG_FORCE_UINT32 = 0x7fffffff -} ze_device_property_flag_t; - -typedef uint32_t zes_device_property_flags_t; -typedef enum _zes_device_property_flag_t { - ZES_DEVICE_PROPERTY_FLAG_INTEGRATED = ZE_BIT(0), - ZES_DEVICE_PROPERTY_FLAG_SUBDEVICE = ZE_BIT(1), - ZES_DEVICE_PROPERTY_FLAG_ECC = ZE_BIT(2), - ZES_DEVICE_PROPERTY_FLAG_ONDEMANDPAGING = ZE_BIT(3), - ZES_DEVICE_PROPERTY_FLAG_FORCE_UINT32 = 0x7fffffff -} zes_device_property_flag_t; - -typedef struct _ze_device_properties_t { - ze_structure_type_t stype; - void *pNext; - ze_device_type_t type; - uint32_t vendorId; - uint32_t deviceId; - ze_device_property_flags_t flags; - uint32_t subdeviceId; - uint32_t coreClockRate; - uint64_t maxMemAllocSize; - uint32_t maxHardwareContexts; - uint32_t maxCommandQueuePriority; - uint32_t numThreadsPerEU; - uint32_t physicalEUSimdWidth; - uint32_t numEUsPerSubslice; - uint32_t numSubslicesPerSlice; - uint32_t numSlices; - uint64_t timerResolution; - uint32_t timestampValidBits; - uint32_t kernelTimestampValidBits; - ze_device_uuid_t uuid; - char name[ZE_MAX_DEVICE_NAME]; -} ze_device_properties_t; - -typedef struct _zes_device_properties_t { - zes_structure_type_t stype; - void *pNext; - ze_device_properties_t core; - uint32_t numSubdevices; - char serialNumber[ZES_STRING_PROPERTY_SIZE]; - char boardNumber[ZES_STRING_PROPERTY_SIZE]; - char brandName[ZES_STRING_PROPERTY_SIZE]; - char modelName[ZES_STRING_PROPERTY_SIZE]; - char vendorName[ZES_STRING_PROPERTY_SIZE]; - char driverVersion[ZES_STRING_PROPERTY_SIZE]; -} zes_device_properties_t; - -typedef struct _zes_device_ext_properties_t { - zes_structure_type_t stype; - void *pNext; - zes_uuid_t uuid; - zes_device_type_t type; - zes_device_property_flags_t flags; -} zes_device_ext_properties_t; - -typedef struct _zes_mem_properties_t { - zes_structure_type_t stype; - void *pNext; - zes_mem_type_t type; - ze_bool_t onSubdevice; - uint32_t subdeviceId; - zes_mem_loc_t location; - uint64_t physicalSize; - int32_t busWidth; - int32_t numChannels; -} zes_mem_properties_t; - -typedef struct _zes_mem_state_t { - zes_structure_type_t stype; - const void *pNext; - zes_mem_health_t health; - uint64_t free; - uint64_t size; -} zes_mem_state_t; - -typedef struct oneapi_handle { - void *handle; - uint16_t verbose; - - uint32_t num_drivers; - zes_driver_handle_t *drivers; - uint32_t *num_devices; - zes_device_handle_t **devices; - - // TODO Driver major, minor information - // int driver_major; - // int driver_minor; - - ze_result_t (*zesInit)(int); - ze_result_t (*zesDriverGet)(uint32_t *pCount, zes_driver_handle_t *phDrivers); - ze_result_t (*zesDeviceGet)(zes_driver_handle_t hDriver, uint32_t *pCount, - zes_device_handle_t *phDevices); - ze_result_t (*zesDeviceGetProperties)(zes_device_handle_t hDevice, - zes_device_properties_t *pProperties); - ze_result_t (*zesDeviceEnumMemoryModules)(zes_device_handle_t hDevice, - uint32_t *pCount, - zes_mem_handle_t *phMemory); - ze_result_t (*zesMemoryGetProperties)(zes_mem_handle_t hMemory, - zes_mem_properties_t *pProperties); - ze_result_t (*zesMemoryGetState)(zes_mem_handle_t hMemory, - zes_mem_state_t *pState); - -} oneapi_handle_t; - -typedef struct oneapi_init_resp { - char *err; // If err is non-null handle is invalid - oneapi_handle_t oh; -} oneapi_init_resp_t; - -typedef struct oneapi_version_resp { - ze_result_t status; - char *str; // Contains version or error string if status != 0 -} oneapi_version_resp_t; - -void oneapi_init(char *oneapi_lib_path, oneapi_init_resp_t *resp); -void oneapi_check_vram(oneapi_handle_t h, int driver, int device, - mem_info_t *resp); -void oneapi_release(oneapi_handle_t h); -int oneapi_get_device_count(oneapi_handle_t h, int driver); - -#endif // __GPU_INFO_INTEL_H__ -#endif // __APPLE__ diff --git a/discover/gpu_test.go b/discover/gpu_test.go deleted file mode 100644 index 0c6ef7bad..000000000 --- a/discover/gpu_test.go +++ /dev/null @@ -1,60 +0,0 @@ -package discover - -import ( - "runtime" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestBasicGetGPUInfo(t *testing.T) { - info := GetGPUInfo() - assert.NotEmpty(t, len(info)) - assert.Contains(t, "cuda rocm cpu metal", info[0].Library) - if info[0].Library != "cpu" { - assert.Greater(t, info[0].TotalMemory, uint64(0)) - assert.Greater(t, info[0].FreeMemory, uint64(0)) - } -} - -func TestCPUMemInfo(t *testing.T) { - info, err := GetCPUMem() - require.NoError(t, err) - switch runtime.GOOS { - case "darwin": - t.Skip("CPU memory not populated on darwin") - case "linux", "windows": - assert.Greater(t, info.TotalMemory, uint64(0)) - assert.Greater(t, info.FreeMemory, uint64(0)) - default: - return - } -} - -func TestByLibrary(t *testing.T) { - type testCase struct { - input []GpuInfo - expect int - } - - testCases := map[string]*testCase{ - "empty": {input: []GpuInfo{}, expect: 0}, - "cpu": {input: []GpuInfo{{Library: "cpu"}}, expect: 1}, - "cpu + GPU": {input: []GpuInfo{{Library: "cpu"}, {Library: "cuda"}}, expect: 2}, - "cpu + 2 GPU no variant": {input: []GpuInfo{{Library: "cpu"}, {Library: "cuda"}, {Library: "cuda"}}, expect: 2}, - "cpu + 2 GPU same variant": {input: []GpuInfo{{Library: "cpu"}, {Library: "cuda", Variant: "v11"}, {Library: "cuda", Variant: "v11"}}, expect: 2}, - "cpu + 2 GPU diff variant": {input: []GpuInfo{{Library: "cpu"}, {Library: "cuda", Variant: "v11"}, {Library: "cuda", Variant: "v12"}}, expect: 3}, - } - - for k, v := range testCases { - t.Run(k, func(t *testing.T) { - resp := (GpuInfoList)(v.input).ByLibrary() - if len(resp) != v.expect { - t.Fatalf("expected length %d, got %d => %+v", v.expect, len(resp), resp) - } - }) - } -} - -// TODO - add some logic to figure out card type through other means and actually verify we got back what we expected diff --git a/discover/runner.go b/discover/runner.go new file mode 100644 index 000000000..4c0bce75b --- /dev/null +++ b/discover/runner.go @@ -0,0 +1,543 @@ +package discover + +// Runner based GPU discovery + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "log/slog" + "math/rand" + "net" + "net/http" + "os" + "os/exec" + "path/filepath" + "runtime" + "sort" + "strconv" + "strings" + "sync" + "time" + + "github.com/ollama/ollama/envconfig" + "github.com/ollama/ollama/format" + "github.com/ollama/ollama/logutil" + "github.com/ollama/ollama/ml" +) + +var ( + deviceMu sync.Mutex + devices []ml.DeviceInfo + libDirs map[string]struct{} + rocmDir string + exe string + bootstrapped bool +) + +func GPUDevices(ctx context.Context, runners []FilteredRunnerDiscovery) []ml.DeviceInfo { + deviceMu.Lock() + defer deviceMu.Unlock() + startDiscovery := time.Now() + msg := "overall device VRAM discovery took" + defer func() { + slog.Debug(msg, "duration", time.Since(startDiscovery)) + }() + + if !bootstrapped { + msg = "GPU bootstrap discovery took" + libDirs = make(map[string]struct{}) + var err error + exe, err = os.Executable() + if err != nil { + slog.Error("unable to lookup executable path", "error", err) + return nil + } + if eval, err := filepath.EvalSymlinks(exe); err == nil { + exe = eval + } + files, err := filepath.Glob(filepath.Join(LibOllamaPath, "*", "*ggml-*")) + if err != nil { + slog.Debug("unable to lookup runner library directories", "error", err) + } + for _, file := range files { + libDirs[filepath.Dir(file)] = struct{}{} + } + + // Our current packaging model places ggml-hip in the main directory + // but keeps rocm in an isolated directory. We have to add it to + // the [LD_LIBRARY_]PATH so ggml-hip will load properly + rocmDir = filepath.Join(LibOllamaPath, "rocm") + if _, err := os.Stat(rocmDir); err != nil { + rocmDir = "" + } + + if len(libDirs) == 0 { + libDirs[""] = struct{}{} + } + + // Typically bootstrapping takes < 1s, but on some systems, with devices + // in low power/idle mode, initialization can take multiple seconds. We + // set a long timeout just for bootstrap discovery to reduce the chance + // of giving up too quickly + ctx1stPass, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + slog.Info("discovering available GPUs...") + + // For our initial discovery pass, we gather all the known GPUs through + // all the libraries that were detected. This pass may include GPUs that + // are enumerated, but not actually supported. + // We run this in serial to avoid potentially initializing a GPU multiple + // times concurrently leading to memory contention + for dir := range libDirs { + var dirs []string + if dir == "" { + dirs = []string{LibOllamaPath} + } else { + dirs = []string{LibOllamaPath, dir} + } + // For this pass, we retain duplicates in case any are incompatible with some libraries + devices = append(devices, bootstrapDevices(ctx1stPass, dirs, nil)...) + } + + // In the second pass, we more deeply initialize the GPUs to weed out devices that + // aren't supported by a given library. We run this phase in parallel to speed up discovery. + slog.Debug("filtering out unsupported or overlapping GPU library combinations", "count", len(devices)) + ctx2ndPass, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + var wg sync.WaitGroup + needsDelete := make([]bool, len(devices)) + supportedMu := sync.Mutex{} + supported := make(map[string]map[string]map[string]int) // [Library][libDir][ID] = pre-deletion devices index + for i := range devices { + libDir := devices[i].LibraryPath[len(devices[i].LibraryPath)-1] + slog.Debug("verifying GPU is supported", "library", libDir, "description", devices[i].Description, "compute", devices[i].Compute(), "pci_id", devices[i].PCIID) + wg.Add(1) + go func(i int) { + defer wg.Done() + var envVar string + if devices[i].Library == "HIP" { + if runtime.GOOS != "linux" { + envVar = "HIP_VISIBLE_DEVICES" + } else { + envVar = "ROCR_VISIBLE_DEVICES" + } + } else { + envVar = "CUDA_VISIBLE_DEVICES" + } + + extraEnvs := []string{ + "GGML_CUDA_INIT=1", // force deep initialization to trigger crash on unsupported GPUs + envVar + "=" + devices[i].ID, // Filter to just this one GPU + } + if len(bootstrapDevices(ctx2ndPass, devices[i].LibraryPath, extraEnvs)) == 0 { + needsDelete[i] = true + } else { + supportedMu.Lock() + if _, ok := supported[devices[i].Library]; !ok { + supported[devices[i].Library] = make(map[string]map[string]int) + } + if _, ok := supported[devices[i].Library][libDir]; !ok { + supported[devices[i].Library][libDir] = make(map[string]int) + } + supported[devices[i].Library][libDir][devices[i].ID] = i + supportedMu.Unlock() + } + }(i) + } + wg.Wait() + logutil.Trace("supported GPU library combinations", "supported", supported) + + // Mark for deletion any overlaps - favoring the library version that can cover all GPUs if possible + filterOverlapByLibrary(supported, needsDelete) + + // TODO if we ever support multiple ROCm library versions this algorithm will need to be adjusted to keep the rocmID numeric value correct + rocmID := 0 + for i := 0; i < len(needsDelete); i++ { + if needsDelete[i] { + logutil.Trace("removing unsupported or overlapping GPU combination", "libDir", devices[i].LibraryPath[len(devices[i].LibraryPath)-1], "description", devices[i].Description, "compute", devices[i].Compute(), "pci_id", devices[i].PCIID) + devices = append(devices[:i], devices[i+1:]...) + needsDelete = append(needsDelete[:i], needsDelete[i+1:]...) + i-- + } else if devices[i].Library == "HIP" { + if _, err := strconv.Atoi(devices[i].ID); err == nil { + // Replace the numeric ID with the post-filtered IDs + devices[i].FilteredID = devices[i].ID + devices[i].ID = strconv.Itoa(rocmID) + } + rocmID++ + } + } + + // Now filter out any overlap with different libraries (favor CUDA/HIP over others) + for i := 0; i < len(devices); i++ { + for j := i + 1; j < len(devices); j++ { + // For this pass, we only drop exact duplicates + switch devices[i].Compare(devices[j]) { + case ml.SameBackendDevice: + // Same library and device, skip it + devices = append(devices[:j], devices[j+1:]...) + j-- + continue + case ml.DuplicateDevice: + // Different library, choose based on priority + var droppedDevice ml.DeviceInfo + if devices[i].Library == "CUDA" || devices[i].Library == "HIP" { + droppedDevice = devices[j] + } else { + droppedDevice = devices[i] + devices[i] = devices[j] + } + devices = append(devices[:j], devices[j+1:]...) + j-- + + typeStr := "discrete" + if droppedDevice.Integrated { + typeStr = "iGPU" + } + slog.Debug("dropping duplicate device", + "id", droppedDevice.ID, + "library", droppedDevice.Library, + "compute", droppedDevice.Compute(), + "name", droppedDevice.Name, + "description", droppedDevice.Description, + "libdirs", strings.Join(droppedDevice.LibraryPath, ","), + "driver", droppedDevice.Driver(), + "pci_id", droppedDevice.PCIID, + "type", typeStr, + "total", format.HumanBytes2(droppedDevice.TotalMemory), + "available", format.HumanBytes2(droppedDevice.FreeMemory), + ) + continue + } + } + } + + // Reset the libDirs to what we actually wind up using for future refreshes + libDirs = make(map[string]struct{}) + for _, dev := range devices { + dir := dev.LibraryPath[len(dev.LibraryPath)-1] + if dir != LibOllamaPath { + libDirs[dir] = struct{}{} + } + } + if len(libDirs) == 0 { + libDirs[""] = struct{}{} + } + + bootstrapped = true + } else { + if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" { + // metal never updates free VRAM + return devices + } + + slog.Debug("refreshing free memory") + updated := make([]bool, len(devices)) + allDone := func() bool { + allDone := true + for _, done := range updated { + if !done { + allDone = false + break + } + } + return allDone + } + + // First try to use existing runners to refresh VRAM since they're already + // active on GPU(s) + for _, runner := range runners { + if runner == nil { + continue + } + deviceIDs := runner.GetActiveDeviceIDs() + if len(deviceIDs) == 0 { + // Skip this runner since it doesn't have active GPU devices + continue + } + + // Check to see if this runner is active on any devices that need a refresh + skip := true + devCheck: + for _, dev := range deviceIDs { + for i := range devices { + if dev.ID == devices[i].ID && dev.Library == devices[i].Library { + if !updated[i] { + skip = false + break devCheck + } + } + } + } + if skip { + continue + } + + // Typical refresh on existing runner is ~500ms but allow longer if the system + // is under stress before giving up and using stale data. + ctx, cancel := context.WithTimeout(ctx, 3*time.Second) + defer cancel() + start := time.Now() + updatedDevices := runner.GetDeviceInfos(ctx) + slog.Debug("existing runner discovery took", "duration", time.Since(start)) + for _, u := range updatedDevices { + for i := range devices { + if u.Library == devices[i].Library && u.ID == devices[i].ID { + updated[i] = true + devices[i].FreeMemory = u.FreeMemory + break + } + } + } + // Short circuit if we've updated all the devices + if allDone() { + break + } + } + if !allDone() { + slog.Debug("unable to refresh all GPUs with existing runners, performing bootstrap discovery") + + // Bootstrapping may take longer in some cases (AMD windows), but we + // would rather use stale free data to get the model running sooner + ctx, cancel := context.WithTimeout(ctx, 3*time.Second) + defer cancel() + + for dir := range libDirs { + updatedDevices := bootstrapDevices(ctx, []string{LibOllamaPath, dir}, nil) + for _, u := range updatedDevices { + for i := range devices { + if u.Library == devices[i].Library && u.ID == devices[i].ID { + updated[i] = true + devices[i].FreeMemory = u.FreeMemory + break + } + } + // TODO - consider evaluating if new devices have appeared (e.g. hotplug) + } + if allDone() { + break + } + } + if !allDone() { + slog.Warn("unable to refresh free memory, using old values") + } + } + } + + return devices +} + +func filterOverlapByLibrary(supported map[string]map[string]map[string]int, needsDelete []bool) { + // For multi-GPU systems, use the newest version that supports all the GPUs + for _, byLibDirs := range supported { + libDirs := make([]string, 0, len(byLibDirs)) + for libDir := range byLibDirs { + libDirs = append(libDirs, libDir) + } + sort.Sort(sort.Reverse(sort.StringSlice(libDirs))) + anyMissing := false + var newest string + for _, newest = range libDirs { + for _, libDir := range libDirs { + if libDir == newest { + continue + } + if len(byLibDirs[newest]) != len(byLibDirs[libDir]) { + anyMissing = true + break + } + for dev := range byLibDirs[newest] { + if _, found := byLibDirs[libDir][dev]; !found { + anyMissing = true + break + } + } + } + if !anyMissing { + break + } + } + // Now we can mark overlaps for deletion + for _, libDir := range libDirs { + if libDir == newest { + continue + } + for dev, i := range byLibDirs[libDir] { + if _, found := byLibDirs[newest][dev]; found { + needsDelete[i] = true + } + } + } + } +} + +type bootstrapRunner struct { + port int + cmd *exec.Cmd +} + +func (r *bootstrapRunner) GetPort() int { + return r.port +} + +func (r *bootstrapRunner) HasExited() bool { + if r.cmd != nil && r.cmd.ProcessState != nil { + return true + } + return false +} + +func bootstrapDevices(ctx context.Context, ollamaLibDirs []string, extraEnvs []string) []ml.DeviceInfo { + // TODO DRY out with llm/server.go + slog.Debug("spawing runner with", "OLLAMA_LIBRARY_PATH", ollamaLibDirs, "extra_envs", extraEnvs) + start := time.Now() + defer func() { + slog.Debug("bootstrap discovery took", "duration", time.Since(start), "OLLAMA_LIBRARY_PATH", ollamaLibDirs, "extra_envs", extraEnvs) + }() + port := 0 + if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil { + var l *net.TCPListener + if l, err = net.ListenTCP("tcp", a); err == nil { + port = l.Addr().(*net.TCPAddr).Port + l.Close() + } + } + if port == 0 { + slog.Debug("ResolveTCPAddr failed, using random port") + port = rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range + } + params := []string{"runner", "--ollama-engine", "--port", strconv.Itoa(port)} + var pathEnv string + switch runtime.GOOS { + case "windows": + pathEnv = "PATH" + case "darwin": + pathEnv = "DYLD_LIBRARY_PATH" + default: + pathEnv = "LD_LIBRARY_PATH" + } + libraryPaths := append([]string{LibOllamaPath}, ollamaLibDirs...) + if rocmDir != "" { + libraryPaths = append(libraryPaths, rocmDir) + } + // Note: we always put our dependency paths first + // since these are the exact version we compiled/linked against + if libraryPath, ok := os.LookupEnv(pathEnv); ok { + libraryPaths = append(libraryPaths, filepath.SplitList(libraryPath)...) + } + + cmd := exec.Command(exe, params...) + cmd.Env = os.Environ() + cmd.Stdout = os.Stdout + errBuf := &bytes.Buffer{} + if envconfig.LogLevel() == slog.Level(-8) { + cmd.Stderr = os.Stderr + } else { + cmd.Stderr = errBuf + } + // cmd.SysProcAttr = llm.LlamaServerSysProcAttr // circular dependency - bring back once refactored + cmd.Env = append(cmd.Env, "OLLAMA_LIBRARY_PATH="+strings.Join(ollamaLibDirs, string(filepath.ListSeparator))) + pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator)) + pathNeeded := true + extraDone := make([]bool, len(extraEnvs)) + for i := range cmd.Env { + cmp := strings.SplitN(cmd.Env[i], "=", 2) + if strings.EqualFold(cmp[0], pathEnv) { + cmd.Env[i] = pathEnv + "=" + pathEnvVal + pathNeeded = false + } else { + for j := range extraEnvs { + if extraDone[j] { + continue + } + extra := strings.SplitN(extraEnvs[j], "=", 2) + if cmp[0] == extra[0] { + cmd.Env[i] = extraEnvs[j] + extraDone[i] = true + } + } + } + } + if pathNeeded { + cmd.Env = append(cmd.Env, pathEnv+"="+pathEnvVal) + } + for i := range extraDone { + if !extraDone[i] { + cmd.Env = append(cmd.Env, extraEnvs[i]) + } + } + slog.Log(context.TODO(), logutil.LevelTrace, "starting runner for device discovery", "env", cmd.Env, "cmd", cmd) + if err := cmd.Start(); err != nil { + slog.Warn("unable to start discovery subprocess", "cmd", cmd, "error", err) + return nil + } + go func() { + cmd.Wait() // exit status ignored + }() + + defer cmd.Process.Kill() + devices, err := GetDevicesFromRunner(ctx, &bootstrapRunner{port: port, cmd: cmd}) + if err != nil { + if cmd.ProcessState != nil && cmd.ProcessState.ExitCode() >= 0 { + // Expected during bootstrapping while we filter out unsupported AMD GPUs + slog.Log(context.TODO(), logutil.LevelTrace, "runner exited", "OLLAMA_LIBRARY_PATH", ollamaLibDirs, "extra_envs", extraEnvs, "code", cmd.ProcessState.ExitCode()) + } else { + slog.Info("failure during GPU discovery", "OLLAMA_LIBRARY_PATH", ollamaLibDirs, "extra_envs", extraEnvs, "error", err) + } + } + logutil.Trace("runner enumerated devices", "OLLAMA_LIBRARY_PATH", ollamaLibDirs, "devices", devices) + return devices +} + +func GetDevicesFromRunner(ctx context.Context, runner BaseRunner) ([]ml.DeviceInfo, error) { + var moreDevices []ml.DeviceInfo + port := runner.GetPort() + tick := time.Tick(500 * time.Millisecond) + for { + select { + case <-ctx.Done(): + return nil, fmt.Errorf("failed to finish discovery before timeout") + case <-tick: + r, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://127.0.0.1:%d/info", port), nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + r.Header.Set("Content-Type", "application/json") + + resp, err := http.DefaultClient.Do(r) + if err != nil { + // slog.Warn("failed to send request", "error", err) + if runner.HasExited() { + return nil, fmt.Errorf("runner crashed") + } + continue + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusNotFound { + // old runner, fall back to bootstrapping model + return nil, fmt.Errorf("llamarunner free vram reporting not supported") + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + slog.Warn("failed to read response", "error", err) + continue + } + if resp.StatusCode != 200 { + logutil.Trace("runner failed to discover free VRAM", "status", resp.StatusCode, "response", body) + continue + } + + if err := json.Unmarshal(body, &moreDevices); err != nil { + slog.Warn("unmarshal encode response", "error", err) + continue + } + return moreDevices, nil + } + } +} diff --git a/discover/types.go b/discover/types.go index 1027aaac2..feb8c08e0 100644 --- a/discover/types.go +++ b/discover/types.go @@ -1,10 +1,13 @@ package discover import ( - "fmt" + "context" "log/slog" + "path/filepath" + "strings" "github.com/ollama/ollama/format" + "github.com/ollama/ollama/ml" ) type memInfo struct { @@ -27,9 +30,6 @@ type GpuInfo struct { // TODO better name maybe "InferenceProcessor"? // Any extra PATH/LD_LIBRARY_PATH dependencies required for the Library to operate properly DependencyPath []string `json:"lib_path,omitempty"` - // Extra environment variables specific to the GPU as list of [key=value] - EnvWorkarounds []string `json:"envs,omitempty"` - // Set to true if we can NOT reliably discover FreeMemory. A value of true indicates // the FreeMemory is best effort, and may over or under report actual memory usage // False indicates FreeMemory can generally be trusted on this GPU @@ -37,7 +37,7 @@ type GpuInfo struct { // TODO better name maybe "InferenceProcessor"? // GPU information ID string `json:"gpu_id"` // string to use for selection of this specific GPU - filterID int //nolint:unused,nolintlint // AMD Workaround: The numeric ID of the device used to filter out other devices + filterID string // AMD Workaround: The numeric ID of the device used to filter out other devices Name string `json:"name"` // user friendly name if available Compute string `json:"compute"` // Compute Capability or gfx @@ -70,37 +70,8 @@ type CPU struct { ThreadCount int } -type CudaGPUInfo struct { - GpuInfo - OSOverhead uint64 // Memory overhead between the driver library and management library - index int //nolint:unused,nolintlint - computeMajor int //nolint:unused,nolintlint - computeMinor int //nolint:unused,nolintlint -} -type CudaGPUInfoList []CudaGPUInfo - -type RocmGPUInfo struct { - GpuInfo - usedFilepath string //nolint:unused,nolintlint - index int //nolint:unused,nolintlint -} -type RocmGPUInfoList []RocmGPUInfo - -type OneapiGPUInfo struct { - GpuInfo - driverIndex int //nolint:unused,nolintlint - gpuIndex int //nolint:unused,nolintlint -} -type OneapiGPUInfoList []OneapiGPUInfo - type GpuInfoList []GpuInfo -type UnsupportedGPUInfo struct { - GpuInfo - Reason string `json:"reason"` -} - -// Split up the set of gpu info's by Library and variant func (l GpuInfoList) ByLibrary() []GpuInfoList { resp := []GpuInfoList{} libs := []string{} @@ -125,18 +96,48 @@ func (l GpuInfoList) ByLibrary() []GpuInfoList { return resp } -// Report the GPU information into the log an Info level -func (l GpuInfoList) LogDetails() { - for _, g := range l { +func LogDetails(devices []ml.DeviceInfo) { + for _, dev := range devices { + var libs []string + for _, dir := range dev.LibraryPath { + if strings.Contains(dir, filepath.Join("lib", "ollama")) { + libs = append(libs, filepath.Base(dir)) + } + } + typeStr := "discrete" + if dev.Integrated { + typeStr = "iGPU" + } slog.Info("inference compute", - "id", g.ID, - "library", g.Library, - "variant", g.Variant, - "compute", g.Compute, - "driver", fmt.Sprintf("%d.%d", g.DriverMajor, g.DriverMinor), - "name", g.Name, - "total", format.HumanBytes2(g.TotalMemory), - "available", format.HumanBytes2(g.FreeMemory), + "id", dev.ID, + "library", dev.Library, + "compute", dev.Compute(), + "name", dev.Name, + "description", dev.Description, + "libdirs", strings.Join(libs, ","), + "driver", dev.Driver(), + "pci_id", dev.PCIID, + "type", typeStr, + "total", format.HumanBytes2(dev.TotalMemory), + "available", format.HumanBytes2(dev.FreeMemory), + ) + } + // CPU inference + if len(devices) == 0 { + dev, _ := GetCPUMem() + // TODO more details about CPU + slog.Info("inference compute", + "id", "cpu", + "library", "cpu", + "compute", "", + "name", "cpu", + "description", "cpu", + "libdirs", "ollama", + "driver", "", + "pci_id", "", + "type", "", + "total", format.HumanBytes2(dev.TotalMemory), + "available", format.HumanBytes2(dev.FreeMemory), ) } } @@ -149,10 +150,8 @@ func (a ByFreeMemory) Swap(i, j int) { a[i], a[j] = a[j], a[i] } func (a ByFreeMemory) Less(i, j int) bool { return a[i].FreeMemory < a[j].FreeMemory } type SystemInfo struct { - System CPUInfo `json:"system"` - GPUs []GpuInfo `json:"gpus"` - UnsupportedGPUs []UnsupportedGPUInfo `json:"unsupported_gpus"` - DiscoveryErrors []string `json:"discovery_errors"` + System CPUInfo `json:"system"` + GPUs []GpuInfo `json:"gpus"` } // Return the optimal number of threads to use for inference @@ -173,9 +172,9 @@ func (si SystemInfo) GetOptimalThreadCount() int { func (l GpuInfoList) FlashAttentionSupported() bool { for _, gpu := range l { supportsFA := gpu.Library == "cpu" || - gpu.Library == "metal" || - (gpu.Library == "cuda" && gpu.DriverMajor >= 7) || - gpu.Library == "rocm" + gpu.Name == "Metal" || + (gpu.Library == "CUDA" && gpu.DriverMajor >= 7) || + gpu.Library == "HIP" if !supportsFA { return false @@ -183,3 +182,31 @@ func (l GpuInfoList) FlashAttentionSupported() bool { } return true } + +type BaseRunner interface { + // GetPort returns the localhost port number the runner is running on + GetPort() int + + // HasExited indicates if the runner is no longer running. This can be used during + // bootstrap to detect if a given filtered device is incompatible and triggered an assert + HasExited() bool +} + +type RunnerDiscovery interface { + BaseRunner + + // GetDeviceInfos will perform a query of the underlying device libraries + // for device identification and free VRAM information + // During bootstrap scenarios, this routine may take seconds to complete + GetDeviceInfos(ctx context.Context) []ml.DeviceInfo +} + +type FilteredRunnerDiscovery interface { + RunnerDiscovery + + // GetActiveDeviceIDs returns the filtered set of devices actively in + // use by this runner for running models. If the runner is a bootstrap runner, no devices + // will be active yet so no device IDs are returned. + // This routine will not query the underlying device and will return immediately + GetActiveDeviceIDs() []ml.DeviceID +} diff --git a/docs/gpu.md b/docs/gpu.md index 464788ccb..ec5c9ccc3 100644 --- a/docs/gpu.md +++ b/docs/gpu.md @@ -65,6 +65,9 @@ With ROCm v6.1, the following GPUs are supported on Windows. | AMD Radeon RX | `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` | | AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` | +### Known Workarounds + +- The RX Vega 56 requires `HSA_ENABLE_SDMA=0` to disable SDMA ### Overrides on Linux Ollama leverages the AMD ROCm library, which does not support all AMD GPUs. In diff --git a/llama/patches/0026-GPU-discovery-enhancements.patch b/llama/patches/0026-GPU-discovery-enhancements.patch new file mode 100644 index 000000000..96d7ba20e --- /dev/null +++ b/llama/patches/0026-GPU-discovery-enhancements.patch @@ -0,0 +1,860 @@ +From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 +From: Daniel Hiltgen +Date: Tue, 26 Aug 2025 12:48:29 -0700 +Subject: [PATCH] GPU discovery enhancements + +Expose more information about the devices through backend props, and leverage +management libraries for more accurate VRAM usage reporting if available. +--- + ggml/include/ggml-backend.h | 9 + + ggml/src/CMakeLists.txt | 2 + + ggml/src/ggml-cuda/ggml-cuda.cu | 79 +++++- + ggml/src/ggml-cuda/vendors/hip.h | 1 + + ggml/src/ggml-impl.h | 8 + + ggml/src/mem_hip.cpp | 449 +++++++++++++++++++++++++++++++ + ggml/src/mem_nvml.cpp | 172 ++++++++++++ + 7 files changed, 719 insertions(+), 1 deletion(-) + create mode 100644 ggml/src/mem_hip.cpp + create mode 100644 ggml/src/mem_nvml.cpp + +diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h +index fda5ceb24..7c2d86703 100644 +--- a/ggml/include/ggml-backend.h ++++ b/ggml/include/ggml-backend.h +@@ -158,6 +158,15 @@ extern "C" { + size_t memory_total; + enum ggml_backend_dev_type type; + struct ggml_backend_dev_caps caps; ++ int driver_major; ++ int driver_minor; ++ int compute_major; ++ int compute_minor; ++ int integrated; ++ int pci_bus_id; ++ int pci_device_id; ++ int pci_domain_id; ++ const char *library; + }; + + GGML_API const char * ggml_backend_dev_name(ggml_backend_dev_t device); +diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt +index 5158acd6a..3a428a22d 100644 +--- a/ggml/src/CMakeLists.txt ++++ b/ggml/src/CMakeLists.txt +@@ -203,6 +203,8 @@ add_library(ggml-base + ggml-threading.h + ggml-quants.c + ggml-quants.h ++ mem_hip.cpp ++ mem_nvml.cpp + gguf.cpp) + + target_include_directories(ggml-base PRIVATE .) +diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu +index e43fde523..352dae85d 100644 +--- a/ggml/src/ggml-cuda/ggml-cuda.cu ++++ b/ggml/src/ggml-cuda/ggml-cuda.cu +@@ -279,6 +279,16 @@ static ggml_cuda_device_info ggml_cuda_init() { + for (int id = 0; id < info.device_count; ++id) { + int device_vmm = 0; + ++#if defined(GGML_USE_HIP) ++ if (std::getenv("GGML_CUDA_INIT") != NULL) { ++ GGML_LOG_INFO("%s: initializing rocBLAS on device %d\n", __func__, id); ++ CUDA_CHECK(cudaSetDevice(id)); ++ // rocblas_initialize will SIGABRT if the GPU isn't supported ++ rocblas_initialize(); ++ GGML_LOG_INFO("%s: rocBLAS initialized on device %d\n", __func__, id); ++ } ++#endif ++ + #if defined(GGML_USE_VMM) + CUdevice device; + CU_CHECK(cuDeviceGet(&device, id)); +@@ -332,9 +342,15 @@ static ggml_cuda_device_info ggml_cuda_init() { + #else + info.devices[id].smpbo = prop.sharedMemPerBlockOptin; + info.devices[id].cc = 100*prop.major + 10*prop.minor; ++#ifdef __CUDA_ARCH_LIST__ ++ if (std::getenv("GGML_CUDA_INIT") != NULL) { ++ GGML_ASSERT(ggml_cuda_has_arch(info.devices[id].cc) && "ggml was not compiled with support for this arch"); ++ } ++#endif // defined(__CUDA_ARCH_LIST__) + GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s, ID: %s\n", + id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no", + ggml_cuda_parse_uuid(prop, id).c_str()); ++ + #endif // defined(GGML_USE_HIP) + } + +@@ -3215,6 +3231,14 @@ struct ggml_backend_cuda_device_context { + std::string name; + std::string description; + std::string id; ++ int major; ++ int minor; ++ int driver_major; ++ int driver_minor; ++ int integrated; ++ int pci_bus_id; ++ int pci_device_id; ++ int pci_domain_id; + }; + + static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) { +@@ -3235,6 +3259,28 @@ static const char * ggml_backend_cuda_device_get_id(ggml_backend_dev_t dev) { + static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { + ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; + ggml_cuda_set_device(ctx->device); ++ ++#if defined(GGML_USE_HIP) ++ if (ggml_hip_mgmt_init() == 0) { ++ int status = ggml_hip_get_device_memory(ctx->pci_bus_id, ctx->pci_device_id, free, total); ++ if (status == 0) { ++ GGML_LOG_DEBUG("%s utilizing ADLX memory reporting free: %zu total: %zu\n", __func__, *free, *total); ++ ggml_hip_mgmt_release(); ++ return; ++ } ++ ggml_hip_mgmt_release(); ++ } ++#else ++ if (ggml_nvml_init() == 0) { ++ int status = ggml_nvml_get_device_memory(ctx->id.c_str(), free, total); ++ if (status == 0) { ++ GGML_LOG_DEBUG("%s utilizing NVML memory reporting free: %zu total: %zu\n", __func__, *free, *total); ++ ggml_nvml_release(); ++ return; ++ } ++ ggml_nvml_release(); ++ } ++#endif + CUDA_CHECK(cudaMemGetInfo(free, total)); + } + +@@ -3243,6 +3289,7 @@ static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend + return GGML_BACKEND_DEVICE_TYPE_GPU; + } + ++#define GGML_HIP_NAME "HIP" + static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) { + props->name = ggml_backend_cuda_device_get_name(dev); + props->description = ggml_backend_cuda_device_get_description(dev); +@@ -3253,6 +3300,27 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back + // If you need the memory data, call ggml_backend_dev_memory() explicitly. + props->memory_total = props->memory_free = 0; + ++ ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; ++#if defined(GGML_USE_HIP) ++ int cc = ggml_cuda_info().devices[ctx->device].cc - GGML_CUDA_CC_OFFSET_AMD; ++ props->compute_major = cc / 0x100; ++ props->compute_minor = cc - (props->compute_major * 0x100); ++#else ++ props->compute_major = ctx->major; ++ props->compute_minor = ctx->minor; ++#endif ++ props->driver_major = ctx->driver_major; ++ props->driver_minor = ctx->driver_minor; ++ props->integrated = ctx->integrated; ++ props->pci_bus_id = ctx->pci_bus_id; ++ props->pci_device_id = ctx->pci_device_id; ++ props->pci_domain_id = ctx->pci_domain_id; ++#if defined(GGML_USE_HIP) ++ props->library = GGML_HIP_NAME; ++#else ++ props->library = GGML_CUDA_NAME; ++#endif ++ + bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr; + #ifdef GGML_CUDA_NO_PEER_COPY + bool events = false; +@@ -3843,6 +3911,8 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { + std::lock_guard lock(mutex); + if (!initialized) { + ggml_backend_cuda_reg_context * ctx = new ggml_backend_cuda_reg_context; ++ int driverVersion = 0; ++ CUDA_CHECK(cudaDriverGetVersion(&driverVersion)); + + for (int i = 0; i < ggml_cuda_info().device_count; i++) { + ggml_backend_cuda_device_context * dev_ctx = new ggml_backend_cuda_device_context; +@@ -3853,7 +3923,14 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { + CUDA_CHECK(cudaGetDeviceProperties(&prop, i)); + dev_ctx->description = prop.name; + dev_ctx->id = ggml_cuda_parse_uuid(prop, i); +- ++ dev_ctx->major = prop.major; ++ dev_ctx->minor = prop.minor; ++ dev_ctx->driver_major = driverVersion / 1000; ++ dev_ctx->driver_minor = (driverVersion - (dev_ctx->driver_major * 1000)) / 10; ++ dev_ctx->integrated = prop.integrated; ++ dev_ctx->pci_bus_id = prop.pciBusID; ++ dev_ctx->pci_device_id = prop.pciDeviceID; ++ dev_ctx->pci_domain_id = prop.pciDomainID; + ggml_backend_dev_t dev = new ggml_backend_device { + /* .iface = */ ggml_backend_cuda_device_interface, + /* .reg = */ ®, +diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h +index cf22e60d2..957a795f2 100644 +--- a/ggml/src/ggml-cuda/vendors/hip.h ++++ b/ggml/src/ggml-cuda/vendors/hip.h +@@ -42,6 +42,7 @@ + #define cudaDeviceProp hipDeviceProp_t + #define cudaDeviceReset hipDeviceReset + #define cudaDeviceSynchronize hipDeviceSynchronize ++#define cudaDriverGetVersion hipDriverGetVersion + #define cudaError_t hipError_t + #define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled + #define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled +diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h +index 19a7adb2d..b9b102a5e 100644 +--- a/ggml/src/ggml-impl.h ++++ b/ggml/src/ggml-impl.h +@@ -602,6 +602,14 @@ static inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx + return true; + } + ++// Management libraries for fetching more accurate free VRAM data ++GGML_API int ggml_nvml_init(); ++GGML_API int ggml_nvml_get_device_memory(const char *uuid, size_t *free, size_t *total); ++GGML_API void ggml_nvml_release(); ++GGML_API int ggml_hip_mgmt_init(); ++GGML_API int ggml_hip_get_device_memory(int pci_bus_id, int pci_device_id, size_t *free, size_t *total); ++GGML_API void ggml_hip_mgmt_release(); ++ + #ifdef __cplusplus + } + #endif +diff --git a/ggml/src/mem_hip.cpp b/ggml/src/mem_hip.cpp +new file mode 100644 +index 000000000..8ef19b8cf +--- /dev/null ++++ b/ggml/src/mem_hip.cpp +@@ -0,0 +1,449 @@ ++#include "ggml.h" ++ ++#ifdef _WIN32 ++// AMD Device Library eXtra (ADLX) ++// ++// https://github.com/GPUOpen-LibrariesAndSDKs/ADLX ++// ++// This Windows-only library provides accurate VRAM reporting for AMD GPUs. ++// The runtime DLL is installed with every AMD Driver on Windows, however ++// the SDK isn't a part of the HIP SDK packaging. As such, we avoid including ++// the headers from the SDK to simplify building from source. ++// ++// ADLX relies heavily on function pointer tables. ++// Only the minimal set of types are defined below to facilitate ++// finding the target AMD GPU(s) and querying their current VRAM usage ++// Unused function parameters are commented out to avoid unnecessary type ++// definitions. ++ ++#include "ggml-impl.h" ++#include ++#include ++ ++#define WIN32_LEAN_AND_MEAN ++#ifndef NOMINMAX ++# define NOMINMAX ++#endif ++#include ++ ++namespace fs = std::filesystem; ++ ++#include ++#include ++ ++// Begin minimal ADLX definitions - derived from tag v1.0 (Dec 2022) ++typedef uint64_t adlx_uint64; ++typedef uint32_t adlx_uint32; ++typedef int32_t adlx_int32; ++typedef adlx_int32 adlx_int; ++typedef adlx_uint32 adlx_uint; ++typedef long adlx_long; ++typedef uint8_t adlx_uint8; ++typedef enum ++{ ++ ADLX_OK = 0, /**< @ENG_START_DOX This result indicates success. @ENG_END_DOX */ ++ ADLX_ALREADY_ENABLED, /**< @ENG_START_DOX This result indicates that the asked action is already enabled. @ENG_END_DOX */ ++ ADLX_ALREADY_INITIALIZED, /**< @ENG_START_DOX This result indicates that ADLX has a unspecified type of initialization. @ENG_END_DOX */ ++ ADLX_FAIL, /**< @ENG_START_DOX This result indicates an unspecified failure. @ENG_END_DOX */ ++ ADLX_INVALID_ARGS, /**< @ENG_START_DOX This result indicates that the arguments are invalid. @ENG_END_DOX */ ++ ADLX_BAD_VER, /**< @ENG_START_DOX This result indicates that the asked version is incompatible with the current version. @ENG_END_DOX */ ++ ADLX_UNKNOWN_INTERFACE, /**< @ENG_START_DOX This result indicates that an unknown interface was asked. @ENG_END_DOX */ ++ ADLX_TERMINATED, /**< @ENG_START_DOX This result indicates that the calls were made in an interface after ADLX was terminated. @ENG_END_DOX */ ++ ADLX_ADL_INIT_ERROR, /**< @ENG_START_DOX This result indicates that the ADL initialization failed. @ENG_END_DOX */ ++ ADLX_NOT_FOUND, /**< @ENG_START_DOX This result indicates that the item is not found. @ENG_END_DOX */ ++ ADLX_INVALID_OBJECT, /**< @ENG_START_DOX This result indicates that the method was called into an invalid object. @ENG_END_DOX */ ++ ADLX_ORPHAN_OBJECTS, /**< @ENG_START_DOX This result indicates that ADLX was terminated with outstanding ADLX objects. Any interface obtained from ADLX points to invalid memory and calls in their methods will result in unexpected behavior. @ENG_END_DOX */ ++ ADLX_NOT_SUPPORTED, /**< @ENG_START_DOX This result indicates that the asked feature is not supported. @ENG_END_DOX */ ++ ADLX_PENDING_OPERATION, /**< @ENG_START_DOX This result indicates a failure due to an operation currently in progress. @ENG_END_DOX */ ++ ADLX_GPU_INACTIVE /**< @ENG_START_DOX This result indicates that the GPU is inactive. @ENG_END_DOX */ ++} ADLX_RESULT; ++#define ADLX_SUCCEEDED(x) (ADLX_OK == (x) || ADLX_ALREADY_ENABLED == (x) || ADLX_ALREADY_INITIALIZED == (x)) ++#define ADLX_FAILED(x) (ADLX_OK != (x) && ADLX_ALREADY_ENABLED != (x) && ADLX_ALREADY_INITIALIZED != (x)) ++#define ADLX_VER_MAJOR 1 ++#define ADLX_VER_MINOR 0 ++#define ADLX_VER_RELEASE 5 ++#define ADLX_VER_BUILD_NUM 30 ++#define ADLX_MAKE_FULL_VER(VERSION_MAJOR, VERSION_MINOR, VERSION_RELEASE, VERSION_BUILD_NUM) ( ((adlx_uint64)(VERSION_MAJOR) << 48ull) | ((adlx_uint64)(VERSION_MINOR) << 32ull) | ((adlx_uint64)(VERSION_RELEASE) << 16ull) | (adlx_uint64)(VERSION_BUILD_NUM)) ++#define ADLX_FULL_VERSION ADLX_MAKE_FULL_VER(ADLX_VER_MAJOR, ADLX_VER_MINOR, ADLX_VER_RELEASE, ADLX_VER_BUILD_NUM) ++#define ADLX_CORE_LINK __declspec(dllexport) ++#define ADLX_STD_CALL __stdcall ++#define ADLX_CDECL_CALL __cdecl ++#define ADLX_FAST_CALL __fastcall ++#define ADLX_INLINE __inline ++#define ADLX_FORCEINLINE __forceinline ++#define ADLX_NO_VTABLE __declspec(novtable) ++ ++#if defined(__cplusplus) ++typedef bool adlx_bool; ++#else ++typedef adlx_uint8 adlx_bool; ++#define true 1 ++#define false 0 ++#endif ++ ++typedef struct IADLXSystem IADLXSystem; ++typedef struct IADLXGPUList IADLXGPUList; ++typedef struct IADLXGPU IADLXGPU; ++typedef struct IADLXInterface IADLXInterface; ++typedef struct IADLXPerformanceMonitoringServices IADLXPerformanceMonitoringServices; ++typedef struct IADLXGPUMetrics IADLXGPUMetrics; ++typedef struct IADLXGPUMetricsSupport IADLXGPUMetricsSupport; ++ ++typedef struct IADLXSystemVtbl ++{ ++ // IADLXSystem interface ++ ADLX_RESULT (ADLX_STD_CALL *GetHybridGraphicsType)(/* IADLXSystem* pThis, ADLX_HG_TYPE* hgType */); ++ ADLX_RESULT (ADLX_STD_CALL *GetGPUs)(IADLXSystem* pThis, IADLXGPUList** ppGPUs); // Used ++ ADLX_RESULT (ADLX_STD_CALL *QueryInterface)(/* IADLXSystem* pThis, const wchar_t* interfaceId, void** ppInterface */); ++ ADLX_RESULT (ADLX_STD_CALL *GetDisplaysServices)(/* IADLXSystem* pThis, IADLXDisplayServices** ppDispServices */); ++ ADLX_RESULT (ADLX_STD_CALL *GetDesktopsServices)(/* IADLXSystem* pThis, IADLXDesktopServices** ppDeskServices */); ++ ADLX_RESULT (ADLX_STD_CALL *GetGPUsChangedHandling)(/* IADLXSystem* pThis, IADLXGPUsChangedHandling** ppGPUsChangedHandling */); ++ ADLX_RESULT (ADLX_STD_CALL *EnableLog)(/* IADLXSystem* pThis, ADLX_LOG_DESTINATION mode, ADLX_LOG_SEVERITY severity, IADLXLog* pLogger, const wchar_t* fileName */); ++ ADLX_RESULT (ADLX_STD_CALL *Get3DSettingsServices)(/* IADLXSystem* pThis, IADLX3DSettingsServices** pp3DSettingsServices */); ++ ADLX_RESULT (ADLX_STD_CALL *GetGPUTuningServices)(/* IADLXSystem* pThis, IADLXGPUTuningServices** ppGPUTuningServices */); ++ ADLX_RESULT (ADLX_STD_CALL *GetPerformanceMonitoringServices)(IADLXSystem* pThis, IADLXPerformanceMonitoringServices** ppPerformanceMonitoringServices); // Used ++ ADLX_RESULT (ADLX_STD_CALL *TotalSystemRAM)(/* IADLXSystem* pThis, adlx_uint* ramMB */); ++ ADLX_RESULT (ADLX_STD_CALL *GetI2C)(/* IADLXSystem* pThis, IADLXGPU* pGPU, IADLXI2C** ppI2C */); ++} IADLXSystemVtbl; ++struct IADLXSystem { const IADLXSystemVtbl *pVtbl; }; ++ ++typedef struct IADLXGPUVtbl ++{ ++ //IADLXInterface ++ adlx_long (ADLX_STD_CALL *Acquire)(/* IADLXGPU* pThis */); ++ adlx_long (ADLX_STD_CALL *Release)(IADLXGPU* pThis); // Used ++ ADLX_RESULT (ADLX_STD_CALL *QueryInterface)(/* IADLXGPU* pThis, const wchar_t* interfaceId, void** ppInterface */); ++ ++ //IADLXGPU ++ ADLX_RESULT (ADLX_STD_CALL *VendorId)(/* IADLXGPU* pThis, const char** vendorId */); ++ ADLX_RESULT (ADLX_STD_CALL *ASICFamilyType)(/* IADLXGPU* pThis, ADLX_ASIC_FAMILY_TYPE* asicFamilyType */); ++ ADLX_RESULT (ADLX_STD_CALL *Type)(/* IADLXGPU* pThis, ADLX_GPU_TYPE* gpuType */); ++ ADLX_RESULT (ADLX_STD_CALL *IsExternal)(/* IADLXGPU* pThis, adlx_bool* isExternal */); ++ ADLX_RESULT (ADLX_STD_CALL *Name)(/* IADLXGPU* pThis, const char** gpuName */); ++ ADLX_RESULT (ADLX_STD_CALL *DriverPath)(/* IADLXGPU* pThis, const char** driverPath */); ++ ADLX_RESULT (ADLX_STD_CALL *PNPString)(/* IADLXGPU* pThis, const char** pnpString */); ++ ADLX_RESULT (ADLX_STD_CALL *HasDesktops)(/* IADLXGPU* pThis, adlx_bool* hasDesktops */); ++ ADLX_RESULT (ADLX_STD_CALL *TotalVRAM)(IADLXGPU* pThis, adlx_uint* vramMB); // Used ++ ADLX_RESULT (ADLX_STD_CALL *VRAMType)(/* IADLXGPU* pThis, const char** type */); ++ ADLX_RESULT (ADLX_STD_CALL *BIOSInfo)(/* IADLXGPU* pThis, const char** partNumber, const char** version, const char** date */); ++ ADLX_RESULT (ADLX_STD_CALL *DeviceId)(/* IADLXGPU* pThis, const char** deviceId */); ++ ADLX_RESULT (ADLX_STD_CALL *RevisionId)(/* IADLXGPU* pThis, const char** revisionId */); ++ ADLX_RESULT (ADLX_STD_CALL *SubSystemId)(/* IADLXGPU* pThis, const char** subSystemId */); ++ ADLX_RESULT (ADLX_STD_CALL *SubSystemVendorId)(/* IADLXGPU* pThis, const char** subSystemVendorId */); ++ ADLX_RESULT (ADLX_STD_CALL *UniqueId)(IADLXGPU* pThis, adlx_int* uniqueId); // Used ++} IADLXGPUVtbl; ++struct IADLXGPU { const IADLXGPUVtbl *pVtbl; }; ++ ++typedef struct IADLXGPUListVtbl ++{ ++ //IADLXInterface ++ adlx_long (ADLX_STD_CALL *Acquire)(/* IADLXGPUList* pThis */); ++ adlx_long (ADLX_STD_CALL *Release)(IADLXGPUList* pThis); // Used ++ ADLX_RESULT (ADLX_STD_CALL *QueryInterface)(/* IADLXGPUList* pThis, const wchar_t* interfaceId, void** ppInterface */); ++ ++ //IADLXList ++ adlx_uint (ADLX_STD_CALL *Size)(/* IADLXGPUList* pThis */); ++ adlx_uint8 (ADLX_STD_CALL *Empty)(/* IADLXGPUList* pThis */); ++ adlx_uint (ADLX_STD_CALL *Begin)(IADLXGPUList* pThis); // Used ++ adlx_uint (ADLX_STD_CALL *End)(IADLXGPUList* pThis); // Used ++ ADLX_RESULT (ADLX_STD_CALL *At)(/* IADLXGPUList* pThis, const adlx_uint location, IADLXInterface** ppItem */); ++ ADLX_RESULT (ADLX_STD_CALL *Clear)(/* IADLXGPUList* pThis */); ++ ADLX_RESULT (ADLX_STD_CALL *Remove_Back)(/* IADLXGPUList* pThis */); ++ ADLX_RESULT (ADLX_STD_CALL *Add_Back)(/* IADLXGPUList* pThis, IADLXInterface* pItem */); ++ ++ //IADLXGPUList ++ ADLX_RESULT (ADLX_STD_CALL *At_GPUList)(IADLXGPUList* pThis, const adlx_uint location, IADLXGPU** ppItem); // Used ++ ADLX_RESULT (ADLX_STD_CALL *Add_Back_GPUList)(/* IADLXGPUList* pThis, IADLXGPU* pItem */); ++ ++} IADLXGPUListVtbl; ++struct IADLXGPUList { const IADLXGPUListVtbl *pVtbl; }; ++ ++typedef struct IADLXPerformanceMonitoringServicesVtbl ++{ ++ //IADLXInterface ++ adlx_long (ADLX_STD_CALL *Acquire)(/* IADLXPerformanceMonitoringServices* pThis */); ++ adlx_long (ADLX_STD_CALL *Release)(IADLXPerformanceMonitoringServices* pThis); // Used ++ ADLX_RESULT (ADLX_STD_CALL *QueryInterface)(/* IADLXPerformanceMonitoringServices* pThis, const wchar_t* interfaceId, void** ppInterface */); ++ ++ //IADLXPerformanceMonitoringServices ++ ADLX_RESULT (ADLX_STD_CALL *GetSamplingIntervalRange)(/* IADLXPerformanceMonitoringServices* pThis, ADLX_IntRange* range */); ++ ADLX_RESULT (ADLX_STD_CALL *SetSamplingInterval)(/* IADLXPerformanceMonitoringServices* pThis, adlx_int intervalMs */); ++ ADLX_RESULT (ADLX_STD_CALL *GetSamplingInterval)(/* IADLXPerformanceMonitoringServices* pThis, adlx_int* intervalMs */); ++ ADLX_RESULT (ADLX_STD_CALL *GetMaxPerformanceMetricsHistorySizeRange)(/* IADLXPerformanceMonitoringServices* pThis, ADLX_IntRange* range */); ++ ADLX_RESULT (ADLX_STD_CALL *SetMaxPerformanceMetricsHistorySize)(/* IADLXPerformanceMonitoringServices* pThis, adlx_int sizeSec */); ++ ADLX_RESULT (ADLX_STD_CALL *GetMaxPerformanceMetricsHistorySize)(/* IADLXPerformanceMonitoringServices* pThis, adlx_int* sizeSec */); ++ ADLX_RESULT (ADLX_STD_CALL *ClearPerformanceMetricsHistory)(/* IADLXPerformanceMonitoringServices* pThis */); ++ ADLX_RESULT (ADLX_STD_CALL *GetCurrentPerformanceMetricsHistorySize)(/* IADLXPerformanceMonitoringServices* pThis, adlx_int* sizeSec */); ++ ADLX_RESULT (ADLX_STD_CALL *StartPerformanceMetricsTracking)(/* IADLXPerformanceMonitoringServices* pThis */); ++ ADLX_RESULT (ADLX_STD_CALL *StopPerformanceMetricsTracking)(/* IADLXPerformanceMonitoringServices* pThis */); ++ ADLX_RESULT (ADLX_STD_CALL *GetAllMetricsHistory)(/* IADLXPerformanceMonitoringServices* pThis, adlx_int startMs, adlx_int stopMs, IADLXAllMetricsList** ppMetricsList */); ++ ADLX_RESULT (ADLX_STD_CALL *GetGPUMetricsHistory)(/* IADLXPerformanceMonitoringServices* pThis, IADLXGPU* pGPU, adlx_int startMs, adlx_int stopMs, IADLXGPUMetricsList** ppMetricsList */); ++ ADLX_RESULT (ADLX_STD_CALL *GetSystemMetricsHistory)(/* IADLXPerformanceMonitoringServices* pThis, adlx_int startMs, adlx_int stopMs, IADLXSystemMetricsList** ppMetricsList */); ++ ADLX_RESULT (ADLX_STD_CALL *GetFPSHistory)(/* IADLXPerformanceMonitoringServices* pThis, adlx_int startMs, adlx_int stopMs, IADLXFPSList** ppMetricsList */); ++ ADLX_RESULT (ADLX_STD_CALL *GetCurrentAllMetrics)(/* IADLXPerformanceMonitoringServices* pThis, IADLXAllMetrics** ppMetrics */); ++ ADLX_RESULT (ADLX_STD_CALL *GetCurrentGPUMetrics)(IADLXPerformanceMonitoringServices* pThis, IADLXGPU* pGPU, IADLXGPUMetrics** ppMetrics); // Used ++ ADLX_RESULT (ADLX_STD_CALL *GetCurrentSystemMetrics)(/* IADLXPerformanceMonitoringServices* pThis, IADLXSystemMetrics** ppMetrics */); ++ ADLX_RESULT (ADLX_STD_CALL *GetCurrentFPS)(/* IADLXPerformanceMonitoringServices* pThis, IADLXFPS** ppMetrics */); ++ ADLX_RESULT (ADLX_STD_CALL *GetSupportedGPUMetrics)(IADLXPerformanceMonitoringServices* pThis, IADLXGPU* pGPU, IADLXGPUMetricsSupport** ppMetricsSupported); // Used ++ ADLX_RESULT (ADLX_STD_CALL *GetSupportedSystemMetrics)(/* IADLXPerformanceMonitoringServices* pThis, IADLXSystemMetricsSupport** ppMetricsSupported */); ++}IADLXPerformanceMonitoringServicesVtbl; ++struct IADLXPerformanceMonitoringServices { const IADLXPerformanceMonitoringServicesVtbl *pVtbl; }; ++ ++typedef struct IADLXGPUMetricsSupportVtbl ++{ ++ //IADLXInterface ++ adlx_long (ADLX_STD_CALL* Acquire)(/* IADLXGPUMetricsSupport* pThis */); ++ adlx_long (ADLX_STD_CALL* Release)(IADLXGPUMetricsSupport* pThis); // Used ++ ADLX_RESULT (ADLX_STD_CALL* QueryInterface)(/* IADLXGPUMetricsSupport* pThis, const wchar_t* interfaceId, void** ppInterface */); ++ ++ //IADLXGPUMetricsSupport ++ ADLX_RESULT (ADLX_STD_CALL* IsSupportedGPUUsage)(/* IADLXGPUMetricsSupport* pThis, adlx_bool* supported */); ++ ADLX_RESULT (ADLX_STD_CALL* IsSupportedGPUClockSpeed)(/* IADLXGPUMetricsSupport* pThis, adlx_bool* supported */); ++ ADLX_RESULT (ADLX_STD_CALL* IsSupportedGPUVRAMClockSpeed)(/* IADLXGPUMetricsSupport* pThis, adlx_bool* supported */); ++ ADLX_RESULT (ADLX_STD_CALL* IsSupportedGPUTemperature)(/* IADLXGPUMetricsSupport* pThis, adlx_bool* supported */); ++ ADLX_RESULT (ADLX_STD_CALL* IsSupportedGPUHotspotTemperature)(/* IADLXGPUMetricsSupport* pThis, adlx_bool* supported */); ++ ADLX_RESULT (ADLX_STD_CALL* IsSupportedGPUPower)(/* IADLXGPUMetricsSupport* pThis, adlx_bool* supported */); ++ ADLX_RESULT (ADLX_STD_CALL* IsSupportedGPUTotalBoardPower)(/* IADLXGPUMetricsSupport* pThis, adlx_bool* supported */); ++ ADLX_RESULT (ADLX_STD_CALL* IsSupportedGPUFanSpeed)(/* IADLXGPUMetricsSupport* pThis, adlx_bool* supported */); ++ ADLX_RESULT (ADLX_STD_CALL* IsSupportedGPUVRAM)(IADLXGPUMetricsSupport* pThis, adlx_bool* supported); // Used ++ ADLX_RESULT (ADLX_STD_CALL* IsSupportedGPUVoltage)(/* IADLXGPUMetricsSupport* pThis, adlx_bool* supported */); ++ ++ ADLX_RESULT (ADLX_STD_CALL* GetGPUUsageRange)(/* IADLXGPUMetricsSupport* pThis, adlx_int* minValue, adlx_int* maxValue */); ++ ADLX_RESULT (ADLX_STD_CALL* GetGPUClockSpeedRange)(/* IADLXGPUMetricsSupport* pThis, adlx_int* minValue, adlx_int* maxValue */); ++ ADLX_RESULT (ADLX_STD_CALL* GetGPUVRAMClockSpeedRange)(/* IADLXGPUMetricsSupport* pThis, adlx_int* minValue, adlx_int* maxValue */); ++ ADLX_RESULT (ADLX_STD_CALL* GetGPUTemperatureRange)(/* IADLXGPUMetricsSupport* pThis, adlx_int* minValue, adlx_int* maxValue */); ++ ADLX_RESULT (ADLX_STD_CALL* GetGPUHotspotTemperatureRange)(/* IADLXGPUMetricsSupport* pThis, adlx_int* minValue, adlx_int* maxValue */); ++ ADLX_RESULT (ADLX_STD_CALL* GetGPUPowerRange)(/* IADLXGPUMetricsSupport* pThis, adlx_int* minValue, adlx_int* maxValue */); ++ ADLX_RESULT (ADLX_STD_CALL* GetGPUFanSpeedRange)(/* IADLXGPUMetricsSupport* pThis, adlx_int* minValue, adlx_int* maxValue */); ++ ADLX_RESULT (ADLX_STD_CALL* GetGPUVRAMRange)(/* IADLXGPUMetricsSupport* pThis, adlx_int* minValue, adlx_int* maxValue */); ++ ADLX_RESULT (ADLX_STD_CALL* GetGPUVoltageRange)(/* IADLXGPUMetricsSupport* pThis, adlx_int* minValue, adlx_int* maxValue */); ++ ADLX_RESULT (ADLX_STD_CALL* GetGPUTotalBoardPowerRange)(/* IADLXGPUMetricsSupport* pThis, adlx_int* minValue, adlx_int* maxValue */); ++} IADLXGPUMetricsSupportVtbl; ++struct IADLXGPUMetricsSupport { const IADLXGPUMetricsSupportVtbl *pVtbl; }; ++ ++typedef struct IADLXGPUMetricsVtbl ++{ ++ //IADLXInterface ++ adlx_long (ADLX_STD_CALL* Acquire)(/* IADLXGPUMetrics* pThis */); ++ adlx_long (ADLX_STD_CALL* Release)(IADLXGPUMetrics* pThis); // Used ++ ADLX_RESULT (ADLX_STD_CALL* QueryInterface)(/* IADLXGPUMetrics* pThis, const wchar_t* interfaceId, void** ppInterface */); ++ ++ //IADLXGPUMetrics ++ ADLX_RESULT (ADLX_STD_CALL* TimeStamp)(/* IADLXGPUMetrics* pThis, adlx_int64* ms */); ++ ADLX_RESULT (ADLX_STD_CALL* GPUUsage)(/* IADLXGPUMetrics* pThis, adlx_double* data */); ++ ADLX_RESULT (ADLX_STD_CALL* GPUClockSpeed)(/* IADLXGPUMetrics* pThis, adlx_int* data */); ++ ADLX_RESULT (ADLX_STD_CALL* GPUVRAMClockSpeed)(/* IADLXGPUMetrics* pThis, adlx_int* data */); ++ ADLX_RESULT (ADLX_STD_CALL* GPUTemperature)(/* IADLXGPUMetrics* pThis, adlx_double* data */); ++ ADLX_RESULT (ADLX_STD_CALL* GPUHotspotTemperature)(/* IADLXGPUMetrics* pThis, adlx_double* data */); ++ ADLX_RESULT (ADLX_STD_CALL* GPUPower)(/* IADLXGPUMetrics* pThis, adlx_double* data */); ++ ADLX_RESULT (ADLX_STD_CALL* GPUTotalBoardPower)(/* IADLXGPUMetrics* pThis, adlx_double* data */); ++ ADLX_RESULT (ADLX_STD_CALL* GPUFanSpeed)(/* IADLXGPUMetrics* pThis, adlx_int* data */); ++ ADLX_RESULT (ADLX_STD_CALL* GPUVRAM)(IADLXGPUMetrics* pThis, adlx_int* data); // Used ++ ADLX_RESULT (ADLX_STD_CALL* GPUVoltage)(/* IADLXGPUMetrics* pThis, adlx_int* data */); ++} IADLXGPUMetricsVtbl; ++struct IADLXGPUMetrics { const IADLXGPUMetricsVtbl *pVtbl; }; ++ ++struct { ++ void *handle; ++ ADLX_RESULT (*ADLXInitialize)(adlx_uint64 version, IADLXSystem** ppSystem); ++ ADLX_RESULT (*ADLXInitializeWithIncompatibleDriver)(adlx_uint64 version, IADLXSystem** ppSystem); ++ ADLX_RESULT (*ADLXQueryVersion)(const char** version); ++ ADLX_RESULT (*ADLXTerminate)(); ++ IADLXSystem *sys; ++} adlx { NULL, NULL, NULL, NULL, NULL, NULL }; ++static std::mutex ggml_adlx_lock; ++ ++extern "C" { ++ ++int ggml_hip_mgmt_init() { ++ std::lock_guard lock(ggml_adlx_lock); ++ if (adlx.handle != NULL) { ++ // Already initialized ++ return 0; ++ } ++ DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS); ++ SetErrorMode(old_mode | SEM_FAILCRITICALERRORS); ++ fs::path libPath = fs::path("\\Windows") / fs::path("System32") / fs::path("amdadlx64.dll"); ++ ++ adlx.handle = (void*)LoadLibraryW(libPath.wstring().c_str()); ++ if (adlx.handle == NULL) { ++ return ADLX_NOT_FOUND; ++ } ++ ++ adlx.ADLXInitialize = (ADLX_RESULT (*)(adlx_uint64 version, IADLXSystem **ppSystem)) GetProcAddress((HMODULE)(adlx.handle), "ADLXInitialize"); ++ adlx.ADLXInitializeWithIncompatibleDriver = (ADLX_RESULT (*)(adlx_uint64 version, IADLXSystem **ppSystem)) GetProcAddress((HMODULE)(adlx.handle), "ADLXInitializeWithIncompatibleDriver"); ++ adlx.ADLXTerminate = (ADLX_RESULT (*)()) GetProcAddress((HMODULE)(adlx.handle), "ADLXTerminate"); ++ adlx.ADLXQueryVersion = (ADLX_RESULT (*)(const char **version)) GetProcAddress((HMODULE)(adlx.handle), "ADLXQueryVersion"); ++ if (adlx.ADLXInitialize == NULL || adlx.ADLXInitializeWithIncompatibleDriver == NULL || adlx.ADLXTerminate == NULL) { ++ GGML_LOG_INFO("%s unable to locate required symbols in amdadlx64.dll, falling back to hip free memory reporting", __func__); ++ FreeLibrary((HMODULE)(adlx.handle)); ++ adlx.handle = NULL; ++ return ADLX_NOT_FOUND; ++ } ++ ++ SetErrorMode(old_mode); ++ ++ // Aid in troubleshooting... ++ if (adlx.ADLXQueryVersion != NULL) { ++ const char *version = NULL; ++ ADLX_RESULT status = adlx.ADLXQueryVersion(&version); ++ if (ADLX_SUCCEEDED(status)) { ++ GGML_LOG_DEBUG("%s located ADLX version %s\n", __func__, version); ++ } ++ } ++ ++ ADLX_RESULT status = adlx.ADLXInitialize(ADLX_FULL_VERSION, &adlx.sys); ++ if (ADLX_FAILED(status)) { ++ // GGML_LOG_DEBUG("%s failed to initialize ADLX error=%d - attempting with incompatible driver...\n", __func__, status); ++ // Try with the incompatible driver ++ status = adlx.ADLXInitializeWithIncompatibleDriver(ADLX_FULL_VERSION, &adlx.sys); ++ if (ADLX_FAILED(status)) { ++ GGML_LOG_INFO("%s failed to initialize ADLX error=%d\n", __func__, status); ++ FreeLibrary((HMODULE)(adlx.handle)); ++ adlx.handle = NULL; ++ adlx.sys = NULL; ++ return status; ++ } ++ // GGML_LOG_DEBUG("%s initialized ADLX with incpomatible driver\n", __func__); ++ } ++ return ADLX_OK; ++} ++ ++void ggml_hip_mgmt_release() { ++ std::lock_guard lock(ggml_adlx_lock); ++ if (adlx.handle == NULL) { ++ // Already free ++ return; ++ } ++ ADLX_RESULT status = adlx.ADLXTerminate(); ++ if (ADLX_FAILED(status)) { ++ GGML_LOG_INFO("%s failed to terminate Adlx %d\n", __func__, status); ++ // Unload anyway... ++ } ++ FreeLibrary((HMODULE)(adlx.handle)); ++ adlx.handle = NULL; ++} ++ ++#define adlx_gdm_cleanup \ ++ if (gpuMetricsSupport != NULL) gpuMetricsSupport->pVtbl->Release(gpuMetricsSupport); \ ++ if (gpuMetrics != NULL) gpuMetrics->pVtbl->Release(gpuMetrics); \ ++ if (perfMonitoringServices != NULL) perfMonitoringServices->pVtbl->Release(perfMonitoringServices); \ ++ if (gpus != NULL) gpus->pVtbl->Release(gpus); \ ++ if (gpu != NULL) gpu->pVtbl->Release(gpu) ++ ++int ggml_hip_get_device_memory(int pci_bus_id, int pci_device_id, size_t *free, size_t *total) { ++ std::lock_guard lock(ggml_adlx_lock); ++ if (adlx.handle == NULL) { ++ GGML_LOG_INFO("%s ADLX was not initialized\n", __func__); ++ return ADLX_ADL_INIT_ERROR; ++ } ++ IADLXGPUMetricsSupport *gpuMetricsSupport = NULL; ++ IADLXPerformanceMonitoringServices *perfMonitoringServices = NULL; ++ IADLXGPUList* gpus = NULL; ++ IADLXGPU* gpu = NULL; ++ IADLXGPUMetrics *gpuMetrics = NULL; ++ ADLX_RESULT status; ++ // The "UniqueID" exposed in ADLX is the PCI Bus and Device IDs ++ adlx_int target = (pci_bus_id << 8) | (pci_device_id & 0xff); ++ ++ status = adlx.sys->pVtbl->GetPerformanceMonitoringServices(adlx.sys, &perfMonitoringServices); ++ if (ADLX_FAILED(status)) { ++ GGML_LOG_INFO("%s GetPerformanceMonitoringServices failed %d\n", __func__, status); ++ return status; ++ } ++ ++ status = adlx.sys->pVtbl->GetGPUs(adlx.sys, &gpus); ++ if (ADLX_FAILED(status)) { ++ GGML_LOG_INFO("%s GetGPUs failed %d\n", __func__, status); ++ adlx_gdm_cleanup; ++ return status; ++ } ++ ++ // Get GPU list ++ for (adlx_uint crt = gpus->pVtbl->Begin(gpus); crt != gpus->pVtbl->End(gpus); ++crt) ++ { ++ status = gpus->pVtbl->At_GPUList(gpus, crt, &gpu); ++ if (ADLX_FAILED(status)) ++ { ++ GGML_LOG_INFO("%s %d] At_GPUList failed %d\n", __func__, crt, status); ++ continue; ++ } ++ adlx_int id; ++ status = gpu->pVtbl->UniqueId(gpu, &id); ++ if (ADLX_FAILED(status)) { ++ GGML_LOG_INFO("%s %d] UniqueId lookup failed %d\n", __func__, crt, status); ++ gpu->pVtbl->Release(gpu); ++ gpu = NULL; ++ continue; ++ } ++ if (id != target) { ++ GGML_LOG_DEBUG("%s %d] GPU UniqueId: %x does not match target %02x %02x\n", __func__, crt, id, pci_bus_id, pci_device_id); ++ gpu->pVtbl->Release(gpu); ++ gpu = NULL; ++ continue; ++ } ++ // Any failures at this point should cause a fall-back to other APIs ++ status = perfMonitoringServices->pVtbl->GetSupportedGPUMetrics(perfMonitoringServices, gpu, &gpuMetricsSupport); ++ if (ADLX_FAILED(status)) { ++ GGML_LOG_INFO("%s GetSupportedGPUMetrics failed %d\n", __func__, status); ++ adlx_gdm_cleanup; ++ return status; ++ } ++ status = perfMonitoringServices->pVtbl->GetCurrentGPUMetrics(perfMonitoringServices, gpu, &gpuMetrics); ++ if (ADLX_FAILED(status)) { ++ GGML_LOG_INFO("%s GetCurrentGPUMetrics failed %d\n", __func__, status); ++ adlx_gdm_cleanup; ++ return status; ++ } ++ ++ adlx_bool supported = false; ++ status = gpuMetricsSupport->pVtbl->IsSupportedGPUVRAM(gpuMetricsSupport, &supported); ++ if (ADLX_FAILED(status)) { ++ GGML_LOG_INFO("%s IsSupportedGPUVRAM failed %d\n", __func__, status); ++ adlx_gdm_cleanup; ++ return status; ++ } ++ ++ adlx_uint totalVRAM = 0; ++ status = gpu->pVtbl->TotalVRAM(gpu, &totalVRAM); ++ if (ADLX_FAILED(status)) { ++ GGML_LOG_INFO("%s TotalVRAM failed %d\n", __func__, status); ++ adlx_gdm_cleanup; ++ return status; ++ } ++ ++ adlx_int usedVRAM = 0; ++ status = gpuMetrics->pVtbl->GPUVRAM(gpuMetrics, &usedVRAM); ++ if (ADLX_FAILED(status)) { ++ GGML_LOG_INFO("%s GPUVRAM failed %d\n", __func__, status); ++ adlx_gdm_cleanup; ++ return status; ++ } ++ *total = size_t(totalVRAM) * 1024 * 1024; ++ *free = size_t(totalVRAM-usedVRAM) * 1024 * 1024; ++ ++ adlx_gdm_cleanup; ++ return ADLX_OK; ++ } ++ adlx_gdm_cleanup; ++ return ADLX_NOT_FOUND; ++} ++ ++} // extern "C" ++ ++#else // #ifdef _WIN32 ++ ++extern "C" { ++ ++// TODO Linux implementation of accurate VRAM reporting ++int ggml_hip_mgmt_init() { ++ return -1; ++} ++void ggml_hip_mgmt_release() {} ++int ggml_hip_get_device_memory(int pci_bus_id, int pci_device_id, size_t *free, size_t *total) { ++ return -1; ++} ++ ++} // extern "C" ++ ++#endif // #ifdef _WIN32 +\ No newline at end of file +diff --git a/ggml/src/mem_nvml.cpp b/ggml/src/mem_nvml.cpp +new file mode 100644 +index 000000000..aa05e9dc1 +--- /dev/null ++++ b/ggml/src/mem_nvml.cpp +@@ -0,0 +1,172 @@ ++// NVIDIA Management Library (NVML) ++// ++// https://developer.nvidia.com/management-library-nvml ++// ++// This library provides accurate VRAM reporting for NVIDIA GPUs, particularly ++// on Windows, where the cuda library provides inaccurate VRAM usage metrics. The ++// runtime DLL is installed with every driver on Windows, and most Linux ++// systems, and the headers are included in the standard CUDA SDK install. As ++// such, we can include the header here to simplify the code. ++ ++ ++#include "ggml-impl.h" ++#include ++#include ++ ++#ifdef _WIN32 ++# define WIN32_LEAN_AND_MEAN ++# ifndef NOMINMAX ++# define NOMINMAX ++# endif ++# include ++#else ++# include ++# include ++#endif ++ ++namespace fs = std::filesystem; ++ ++// Minimal definitions to avoid including the nvml.h header ++typedef enum nvmlReturn_enum ++{ ++ // cppcheck-suppress * ++ NVML_SUCCESS = 0, //!< The operation was successful ++ NVML_ERROR_UNINITIALIZED = 1, //!< NVML was not first initialized with nvmlInit() ++ NVML_ERROR_INVALID_ARGUMENT = 2, //!< A supplied argument is invalid ++ NVML_ERROR_NOT_SUPPORTED = 3, //!< The requested operation is not available on target device ++ NVML_ERROR_NO_PERMISSION = 4, //!< The current user does not have permission for operation ++ NVML_ERROR_ALREADY_INITIALIZED = 5, //!< Deprecated: Multiple initializations are now allowed through ref counting ++ NVML_ERROR_NOT_FOUND = 6, //!< A query to find an object was unsuccessful ++ NVML_ERROR_INSUFFICIENT_SIZE = 7, //!< An input argument is not large enough ++ NVML_ERROR_INSUFFICIENT_POWER = 8, //!< A device's external power cables are not properly attached ++ NVML_ERROR_DRIVER_NOT_LOADED = 9, //!< NVIDIA driver is not loaded ++ NVML_ERROR_TIMEOUT = 10, //!< User provided timeout passed ++ NVML_ERROR_IRQ_ISSUE = 11, //!< NVIDIA Kernel detected an interrupt issue with a GPU ++ NVML_ERROR_LIBRARY_NOT_FOUND = 12, //!< NVML Shared Library couldn't be found or loaded ++ NVML_ERROR_FUNCTION_NOT_FOUND = 13, //!< Local version of NVML doesn't implement this function ++ NVML_ERROR_CORRUPTED_INFOROM = 14, //!< infoROM is corrupted ++ NVML_ERROR_GPU_IS_LOST = 15, //!< The GPU has fallen off the bus or has otherwise become inaccessible ++ NVML_ERROR_RESET_REQUIRED = 16, //!< The GPU requires a reset before it can be used again ++ NVML_ERROR_OPERATING_SYSTEM = 17, //!< The GPU control device has been blocked by the operating system/cgroups ++ NVML_ERROR_LIB_RM_VERSION_MISMATCH = 18, //!< RM detects a driver/library version mismatch ++ NVML_ERROR_IN_USE = 19, //!< An operation cannot be performed because the GPU is currently in use ++ NVML_ERROR_MEMORY = 20, //!< Insufficient memory ++ NVML_ERROR_NO_DATA = 21, //!< No data ++ NVML_ERROR_VGPU_ECC_NOT_SUPPORTED = 22, //!< The requested vgpu operation is not available on target device, becasue ECC is enabled ++ NVML_ERROR_INSUFFICIENT_RESOURCES = 23, //!< Ran out of critical resources, other than memory ++ NVML_ERROR_FREQ_NOT_SUPPORTED = 24, //!< Ran out of critical resources, other than memory ++ NVML_ERROR_ARGUMENT_VERSION_MISMATCH = 25, //!< The provided version is invalid/unsupported ++ NVML_ERROR_DEPRECATED = 26, //!< The requested functionality has been deprecated ++ NVML_ERROR_NOT_READY = 27, //!< The system is not ready for the request ++ NVML_ERROR_GPU_NOT_FOUND = 28, //!< No GPUs were found ++ NVML_ERROR_INVALID_STATE = 29, //!< Resource not in correct state to perform requested operation ++ NVML_ERROR_UNKNOWN = 999 //!< An internal driver error occurred ++} nvmlReturn_t; ++typedef struct nvmlDevice_st* nvmlDevice_t; ++typedef struct nvmlMemory_st ++{ ++ unsigned long long total; //!< Total physical device memory (in bytes) ++ unsigned long long free; //!< Unallocated device memory (in bytes) ++ unsigned long long used; //!< Sum of Reserved and Allocated device memory (in bytes). ++ //!< Note that the driver/GPU always sets aside a small amount of memory for bookkeeping ++} nvmlMemory_t; ++// end nvml.h definitions ++ ++struct { ++ void *handle; ++ nvmlReturn_t (*nvmlInit_v2)(void); ++ nvmlReturn_t (*nvmlShutdown)(void); ++ nvmlReturn_t (*nvmlDeviceGetHandleByUUID)(const char *, nvmlDevice_t *); ++ nvmlReturn_t (*nvmlDeviceGetMemoryInfo)(nvmlDevice_t, nvmlMemory_t *); ++} nvml { NULL, NULL, NULL, NULL, NULL }; ++static std::mutex ggml_nvml_lock; ++ ++extern "C" { ++ ++int ggml_nvml_init() { ++ std::lock_guard lock(ggml_nvml_lock); ++ if (nvml.handle != NULL) { ++ // Already initialized ++ return 0; ++ } ++#ifdef _WIN32 ++ DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS); ++ SetErrorMode(old_mode | SEM_FAILCRITICALERRORS); ++ fs::path libPath[2]; ++ const char * programDir = std::getenv("ProgramW6432"); ++ if (programDir == NULL) { ++ libPath[0] = fs::path("Program Files") / fs::path("NVIDIA Corporation") / fs::path("NVSMI") / fs::path("NVML.dll"); ++ } else { ++ libPath[0] = fs::path(programDir) / fs::path("NVIDIA Corporation") / fs::path("NVSMI") / fs::path("NVML.dll"); ++ } ++ libPath[1] = fs::path("\\Windows") / fs::path("System32") / fs::path("NVML.dll"); ++ ++ for (int i = 0; i < 2; i++) { ++ nvml.handle = (void*)LoadLibraryW(libPath[i].wstring().c_str()); ++ if (nvml.handle != NULL) { ++ break; ++ } ++ } ++ if (nvml.handle == NULL) { ++ return NVML_ERROR_NOT_FOUND; ++ } ++ ++ nvml.nvmlInit_v2 = (nvmlReturn_enum (*)()) GetProcAddress((HMODULE)(nvml.handle), "nvmlInit_v2"); ++ nvml.nvmlShutdown = (nvmlReturn_enum (*)()) GetProcAddress((HMODULE)(nvml.handle), "nvmlShutdown"); ++ nvml.nvmlDeviceGetHandleByUUID = (nvmlReturn_t (*)(const char *, nvmlDevice_t *)) GetProcAddress((HMODULE)(nvml.handle), "nvmlDeviceGetHandleByUUID"); ++ nvml.nvmlDeviceGetMemoryInfo = (nvmlReturn_t (*)(nvmlDevice_t, nvmlMemory_t *)) GetProcAddress((HMODULE)(nvml.handle), "nvmlDeviceGetMemoryInfo"); ++ if (nvml.nvmlInit_v2 == NULL || nvml.nvmlShutdown == NULL || nvml.nvmlDeviceGetHandleByUUID == NULL || nvml.nvmlDeviceGetMemoryInfo == NULL) { ++ GGML_LOG_INFO("%s unable to locate required symbols in NVML.dll", __func__); ++ FreeLibrary((HMODULE)(nvml.handle)); ++ nvml.handle = NULL; ++ return NVML_ERROR_NOT_FOUND; ++ } ++ ++ SetErrorMode(old_mode); ++ ++#else ++ // Not currently wired up on Linux ++ return NVML_ERROR_NOT_SUPPORTED; ++#endif ++ int status = nvml.nvmlInit_v2(); ++ return NVML_SUCCESS; ++} ++ ++void ggml_nvml_release() { ++ std::lock_guard lock(ggml_nvml_lock); ++ if (nvml.handle == NULL) { ++ // Already free ++ return; ++ } ++ nvmlReturn_enum status = nvml.nvmlShutdown(); ++ if (status != NVML_SUCCESS) { ++ GGML_LOG_INFO("%s failed to shutdown NVML: %d\n", __func__, status); ++ } ++#ifdef _WIN32 ++ FreeLibrary((HMODULE)(nvml.handle)); ++ nvml.handle = NULL; ++#else ++ // Not currently wired up on Linux ++#endif ++} ++ ++int ggml_nvml_get_device_memory(const char *uuid, size_t *free, size_t *total) { ++ std::lock_guard lock(ggml_nvml_lock); ++ if (nvml.handle == NULL) { ++ return NVML_ERROR_UNINITIALIZED; ++ } ++ nvmlDevice_t device; ++ auto status = nvml.nvmlDeviceGetHandleByUUID(uuid, &device); ++ if (status != NVML_SUCCESS) { ++ return status; ++ } ++ nvmlMemory_t memInfo = {0}; ++ status = nvml.nvmlDeviceGetMemoryInfo(device, &memInfo); ++ if (status == NVML_SUCCESS) { ++ *free = memInfo.free; ++ *total = memInfo.total; ++ } ++ return status; ++} ++ ++} +\ No newline at end of file diff --git a/llm/memory.go b/llm/memory.go index 7a87b28fe..4c6003183 100644 --- a/llm/memory.go +++ b/llm/memory.go @@ -196,7 +196,7 @@ func estimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin } useFlashAttention := (envconfig.FlashAttention() || f.FlashAttention()) && - discover.GetGPUInfo().FlashAttentionSupported() && + (discover.GpuInfoList)(gpus).FlashAttentionSupported() && f.SupportsFlashAttention() var kvct string diff --git a/llm/server.go b/llm/server.go index 75f049bc0..638970299 100644 --- a/llm/server.go +++ b/llm/server.go @@ -78,6 +78,8 @@ type LlamaServer interface { TotalSize() uint64 VRAMByGPU(gpuID string) uint64 Pid() int + GetDeviceInfos(ctx context.Context) []ml.DeviceInfo + HasExited() bool } // llmServer is an instance of a runner hosting a single model @@ -361,12 +363,8 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a s.cmd.Env = append(s.cmd.Env, "OLLAMA_LIBRARY_PATH="+strings.Join(ggmlPaths, string(filepath.ListSeparator))) - envWorkarounds := []string{} - for _, gpu := range gpus { - envWorkarounds = append(envWorkarounds, gpu.EnvWorkarounds...) - } // Always filter down the set of GPUs in case there are any unsupported devices that might crash - envWorkarounds = append(envWorkarounds, gpus.GetVisibleDevicesEnv()...) + envWorkarounds := gpus.GetVisibleDevicesEnv() pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator)) // Update or add the path variable with our adjusted version @@ -524,7 +522,7 @@ func (s *llamaServer) Load(ctx context.Context, gpus discover.GpuInfoList, requi s.options.NumGPU = 0 case gpus[0].Library != "metal" && s.estimate.Layers == 0: // Don't bother loading into the GPU if no layers can fit - gpus = discover.GetCPUInfo() + gpus = discover.GpuInfoList{discover.GetCPUInfo()} case s.options.NumGPU < 0 && s.estimate.Layers > 0 && gpus[0].Library != "cpu": s.options.NumGPU = s.estimate.Layers } @@ -1312,6 +1310,30 @@ func (s *llmServer) Pid() int { return -1 } +func (s *llmServer) GetPort() int { + return s.port +} + +func (s *llmServer) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo { + // llama engine does not currently support VRAM query, short circuit + if s.textProcessor == nil { + slog.Debug("llamarunner free vram reporting not supported") + return nil + } + devices, err := discover.GetDevicesFromRunner(ctx, s) + if err != nil { + slog.Debug("failure refreshing GPU information", "error", err) + } + return devices +} + +func (s *llmServer) HasExited() bool { + if s.cmd != nil && s.cmd.ProcessState != nil && s.cmd.ProcessState.ExitCode() >= 0 { + return true + } + return false +} + var grammarJSON = ` root ::= object value ::= object | array | string | number | ("true" | "false" | "null") ws diff --git a/ml/backend.go b/ml/backend.go index 455715b0d..43ef2c1d9 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -9,6 +9,7 @@ import ( "log/slog" "math" "slices" + "sort" "strconv" "strings" @@ -29,6 +30,9 @@ type Backend interface { Get(name string) Tensor NewContext() Context NewContextSize(size int) Context + + // Enumerate the devices available for inference via this backend + BackendDevices() []DeviceInfo } // BackendCacheConfig should be implemented by backends that need special output @@ -301,6 +305,116 @@ func sumMemory(mem []Memory) uint64 { return sum } +// Minimal unique device identification +type DeviceID struct { + // ID is an identifier for the device for matching with system + // management libraries. + // This ID represents a "post filtered" view of the enumerated devices + // if the ID is numeric + ID string `json:"id"` + + // Library identifies which library is used for the device (e.g. CUDA, HIP, etc.) + Library string `json:"backend,omitempty"` +} + +type DeviceInfo struct { + DeviceID + + // Name is the name of the device as labeled by the backend. It + // may not be persistent across instances of the runner. + Name string `json:"name"` + + // Description is the longer user-friendly identification of the device + Description string `json:"description"` + + // FilterID is populated with the unfiltered device ID if a numeric ID is used + // so the device can be included. + FilteredID string `json:"filtered_id,omitempty"` + + // Integrated is set true for integrated GPUs, false for Discrete GPUs + Integrated bool `json:"integration,omitempty"` + + // PCIID is the bus, device and domain ID of the device for deduplication + // when discovered by multiple backends + PCIID string `json:"pci_id,omitempty"` + + // TotalMemory is the total amount of memory the device can use for loading models + TotalMemory uint64 `json:"total_memory"` + + // FreeMemory is the amount of memory currently available on the device for loading models + FreeMemory uint64 `json:"free_memory,omitempty"` + + // ComputeMajor is the major version of capabilities of the device + // if unsupported by the backend, -1 will be returned + ComputeMajor int + + // ComputeMinor is the minor version of capabilities of the device + // if unsupported by the backend, -1 will be returned + ComputeMinor int + + // Driver Information + DriverMajor int `json:"driver_major,omitempty"` + DriverMinor int `json:"driver_minor,omitempty"` + + // Where backends were loaded from + LibraryPath []string +} + +func (d DeviceInfo) Compute() string { + // AMD gfx is encoded into the major minor in hex form + if strings.EqualFold(d.Library, "HIP") { + return fmt.Sprintf("gfx%x%02x", d.ComputeMajor, d.ComputeMinor) + } + return strconv.Itoa(d.ComputeMajor) + "." + strconv.Itoa(d.ComputeMinor) +} + +func (d DeviceInfo) Driver() string { + return strconv.Itoa(d.DriverMajor) + "." + strconv.Itoa(d.DriverMinor) +} + +type DeviceComparison int + +const ( + UniqueDevice DeviceComparison = iota + SameBackendDevice // The device is the same, and the library/backend is the same + DuplicateDevice // The same physical device but different library/backend (overlapping device) +) + +func (a DeviceInfo) Compare(b DeviceInfo) DeviceComparison { + if a.PCIID != b.PCIID { + return UniqueDevice + } + if a.Library == b.Library { + return SameBackendDevice + } + return DuplicateDevice +} + +// For a SameBackendDevice, return true if b is better than a +// e.g. newer GPU library version +func (a DeviceInfo) IsBetter(b DeviceInfo) bool { + aLib := a.LibraryPath[len(a.LibraryPath)-1] + bLib := b.LibraryPath[len(b.LibraryPath)-1] + if aLib == bLib { + return false + } + aLibSplit := strings.SplitN(aLib, "_", 2) + bLibSplit := strings.SplitN(bLib, "_", 2) + if len(aLibSplit) < 2 || len(bLibSplit) < 2 { + return false + } + if aLibSplit[0] != bLibSplit[0] { + slog.Debug("unexpected libraries", "a", aLib, "b", bLib) + return false + } + if aLibSplit[1] == bLibSplit[1] { + return false + } + cmp := []string{aLibSplit[1], bLibSplit[1]} + sort.Sort(sort.Reverse(sort.StringSlice(cmp))) + return cmp[0] == bLibSplit[1] +} + // Log prints a high level summary of the memory (allocated or not) func (m BackendMemory) Log(level slog.Level) { var total uint64 diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 49dc3e1ab..9e6f148e3 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -1,5 +1,7 @@ package ggml +// #cgo linux LDFLAGS: -lrt -lpthread -ldl -lstdc++ -lm +// #cgo windows LDFLAGS: -lpthread // #cgo CPPFLAGS: -I${SRCDIR}/ggml/include // #include // #include @@ -16,6 +18,7 @@ import ( "log/slog" "maps" "os" + "path/filepath" "runtime" "slices" "strconv" @@ -696,6 +699,77 @@ func (b *Backend) CacheConfig() ml.CacheConfig { } } +func (b *Backend) BackendDevices() []ml.DeviceInfo { + // TODO DRY out with ./ggml/src/ggml.go + exe, err := os.Executable() + if err != nil { + slog.Warn("failed to get executable path", "error", err) + exe = "." + } + + var value string + switch runtime.GOOS { + case "darwin": + value = filepath.Dir(exe) + case "windows": + value = filepath.Join(filepath.Dir(exe), "lib", "ollama") + default: + value = filepath.Join(filepath.Dir(exe), "..", "lib", "ollama") + } + + paths, ok := os.LookupEnv("OLLAMA_LIBRARY_PATH") + if !ok { + slog.Debug("OLLAMA_LIBRARY_PATH not set, falling back to default", "search", value) + paths = value + } + + split := filepath.SplitList(paths) + + deviceInfos := []ml.DeviceInfo{} + + for _, dev := range gpus { + // If we have a model loaded, and it's only loaded on a subset of the devices + // skip idle/unused devices to avoid initializing them and causing VRAM allocations + if b.allocMemory { + idleDev := true + for _, backend := range b.schedBackends { + if dev == C.ggml_backend_get_device(backend) { + idleDev = false + break + } + } + if idleDev { + slog.Debug("skipping unused backend device", "description", C.GoString(C.ggml_backend_dev_description(dev))) + continue + } + } + + info := ml.DeviceInfo{} + props := C.struct_ggml_backend_dev_props{} + C.ggml_backend_dev_get_props(dev, &props) + info.Name = C.GoString(props.name) + info.Description = C.GoString(props.description) + info.ID = C.GoString(props.id) + info.ComputeMajor = (int)(props.compute_major) + info.ComputeMinor = (int)(props.compute_minor) + info.DriverMajor = (int)(props.driver_major) + info.DriverMinor = (int)(props.driver_minor) + info.Integrated = props.integrated != 0 + if props.library != nil { + info.Library = C.GoString(props.library) + } + info.PCIID = fmt.Sprintf("%02x:%02x.%x", props.pci_bus_id, props.pci_device_id, props.pci_domain_id) + info.LibraryPath = split + + C.ggml_backend_dev_memory(dev, &props.memory_free, &props.memory_total) + info.TotalMemory = (uint64)(props.memory_total) + info.FreeMemory = (uint64)(props.memory_free) + + deviceInfos = append(deviceInfos, info) + } + return deviceInfos +} + type Context struct { b *Backend diff --git a/ml/backend/ggml/ggml/include/ggml-backend.h b/ml/backend/ggml/ggml/include/ggml-backend.h index fda5ceb24..7c2d86703 100644 --- a/ml/backend/ggml/ggml/include/ggml-backend.h +++ b/ml/backend/ggml/ggml/include/ggml-backend.h @@ -158,6 +158,15 @@ extern "C" { size_t memory_total; enum ggml_backend_dev_type type; struct ggml_backend_dev_caps caps; + int driver_major; + int driver_minor; + int compute_major; + int compute_minor; + int integrated; + int pci_bus_id; + int pci_device_id; + int pci_domain_id; + const char *library; }; GGML_API const char * ggml_backend_dev_name(ggml_backend_dev_t device); diff --git a/ml/backend/ggml/ggml/src/CMakeLists.txt b/ml/backend/ggml/ggml/src/CMakeLists.txt index 5158acd6a..3a428a22d 100644 --- a/ml/backend/ggml/ggml/src/CMakeLists.txt +++ b/ml/backend/ggml/ggml/src/CMakeLists.txt @@ -203,6 +203,8 @@ add_library(ggml-base ggml-threading.h ggml-quants.c ggml-quants.h + mem_hip.cpp + mem_nvml.cpp gguf.cpp) target_include_directories(ggml-base PRIVATE .) diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu b/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu index e43fde523..352dae85d 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu @@ -279,6 +279,16 @@ static ggml_cuda_device_info ggml_cuda_init() { for (int id = 0; id < info.device_count; ++id) { int device_vmm = 0; +#if defined(GGML_USE_HIP) + if (std::getenv("GGML_CUDA_INIT") != NULL) { + GGML_LOG_INFO("%s: initializing rocBLAS on device %d\n", __func__, id); + CUDA_CHECK(cudaSetDevice(id)); + // rocblas_initialize will SIGABRT if the GPU isn't supported + rocblas_initialize(); + GGML_LOG_INFO("%s: rocBLAS initialized on device %d\n", __func__, id); + } +#endif + #if defined(GGML_USE_VMM) CUdevice device; CU_CHECK(cuDeviceGet(&device, id)); @@ -332,9 +342,15 @@ static ggml_cuda_device_info ggml_cuda_init() { #else info.devices[id].smpbo = prop.sharedMemPerBlockOptin; info.devices[id].cc = 100*prop.major + 10*prop.minor; +#ifdef __CUDA_ARCH_LIST__ + if (std::getenv("GGML_CUDA_INIT") != NULL) { + GGML_ASSERT(ggml_cuda_has_arch(info.devices[id].cc) && "ggml was not compiled with support for this arch"); + } +#endif // defined(__CUDA_ARCH_LIST__) GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s, ID: %s\n", id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no", ggml_cuda_parse_uuid(prop, id).c_str()); + #endif // defined(GGML_USE_HIP) } @@ -3215,6 +3231,14 @@ struct ggml_backend_cuda_device_context { std::string name; std::string description; std::string id; + int major; + int minor; + int driver_major; + int driver_minor; + int integrated; + int pci_bus_id; + int pci_device_id; + int pci_domain_id; }; static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) { @@ -3235,6 +3259,28 @@ static const char * ggml_backend_cuda_device_get_id(ggml_backend_dev_t dev) { static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; ggml_cuda_set_device(ctx->device); + +#if defined(GGML_USE_HIP) + if (ggml_hip_mgmt_init() == 0) { + int status = ggml_hip_get_device_memory(ctx->pci_bus_id, ctx->pci_device_id, free, total); + if (status == 0) { + GGML_LOG_DEBUG("%s utilizing ADLX memory reporting free: %zu total: %zu\n", __func__, *free, *total); + ggml_hip_mgmt_release(); + return; + } + ggml_hip_mgmt_release(); + } +#else + if (ggml_nvml_init() == 0) { + int status = ggml_nvml_get_device_memory(ctx->id.c_str(), free, total); + if (status == 0) { + GGML_LOG_DEBUG("%s utilizing NVML memory reporting free: %zu total: %zu\n", __func__, *free, *total); + ggml_nvml_release(); + return; + } + ggml_nvml_release(); + } +#endif CUDA_CHECK(cudaMemGetInfo(free, total)); } @@ -3243,6 +3289,7 @@ static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend return GGML_BACKEND_DEVICE_TYPE_GPU; } +#define GGML_HIP_NAME "HIP" static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) { props->name = ggml_backend_cuda_device_get_name(dev); props->description = ggml_backend_cuda_device_get_description(dev); @@ -3253,6 +3300,27 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back // If you need the memory data, call ggml_backend_dev_memory() explicitly. props->memory_total = props->memory_free = 0; + ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; +#if defined(GGML_USE_HIP) + int cc = ggml_cuda_info().devices[ctx->device].cc - GGML_CUDA_CC_OFFSET_AMD; + props->compute_major = cc / 0x100; + props->compute_minor = cc - (props->compute_major * 0x100); +#else + props->compute_major = ctx->major; + props->compute_minor = ctx->minor; +#endif + props->driver_major = ctx->driver_major; + props->driver_minor = ctx->driver_minor; + props->integrated = ctx->integrated; + props->pci_bus_id = ctx->pci_bus_id; + props->pci_device_id = ctx->pci_device_id; + props->pci_domain_id = ctx->pci_domain_id; +#if defined(GGML_USE_HIP) + props->library = GGML_HIP_NAME; +#else + props->library = GGML_CUDA_NAME; +#endif + bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr; #ifdef GGML_CUDA_NO_PEER_COPY bool events = false; @@ -3843,6 +3911,8 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { std::lock_guard lock(mutex); if (!initialized) { ggml_backend_cuda_reg_context * ctx = new ggml_backend_cuda_reg_context; + int driverVersion = 0; + CUDA_CHECK(cudaDriverGetVersion(&driverVersion)); for (int i = 0; i < ggml_cuda_info().device_count; i++) { ggml_backend_cuda_device_context * dev_ctx = new ggml_backend_cuda_device_context; @@ -3853,7 +3923,14 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { CUDA_CHECK(cudaGetDeviceProperties(&prop, i)); dev_ctx->description = prop.name; dev_ctx->id = ggml_cuda_parse_uuid(prop, i); - + dev_ctx->major = prop.major; + dev_ctx->minor = prop.minor; + dev_ctx->driver_major = driverVersion / 1000; + dev_ctx->driver_minor = (driverVersion - (dev_ctx->driver_major * 1000)) / 10; + dev_ctx->integrated = prop.integrated; + dev_ctx->pci_bus_id = prop.pciBusID; + dev_ctx->pci_device_id = prop.pciDeviceID; + dev_ctx->pci_domain_id = prop.pciDomainID; ggml_backend_dev_t dev = new ggml_backend_device { /* .iface = */ ggml_backend_cuda_device_interface, /* .reg = */ ®, diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/vendors/hip.h b/ml/backend/ggml/ggml/src/ggml-cuda/vendors/hip.h index cf22e60d2..957a795f2 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/vendors/hip.h +++ b/ml/backend/ggml/ggml/src/ggml-cuda/vendors/hip.h @@ -42,6 +42,7 @@ #define cudaDeviceProp hipDeviceProp_t #define cudaDeviceReset hipDeviceReset #define cudaDeviceSynchronize hipDeviceSynchronize +#define cudaDriverGetVersion hipDriverGetVersion #define cudaError_t hipError_t #define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled #define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled diff --git a/ml/backend/ggml/ggml/src/ggml-impl.h b/ml/backend/ggml/ggml/src/ggml-impl.h index 19a7adb2d..b9b102a5e 100644 --- a/ml/backend/ggml/ggml/src/ggml-impl.h +++ b/ml/backend/ggml/ggml/src/ggml-impl.h @@ -602,6 +602,14 @@ static inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx return true; } +// Management libraries for fetching more accurate free VRAM data +GGML_API int ggml_nvml_init(); +GGML_API int ggml_nvml_get_device_memory(const char *uuid, size_t *free, size_t *total); +GGML_API void ggml_nvml_release(); +GGML_API int ggml_hip_mgmt_init(); +GGML_API int ggml_hip_get_device_memory(int pci_bus_id, int pci_device_id, size_t *free, size_t *total); +GGML_API void ggml_hip_mgmt_release(); + #ifdef __cplusplus } #endif diff --git a/ml/backend/ggml/ggml/src/mem_hip.cpp b/ml/backend/ggml/ggml/src/mem_hip.cpp new file mode 100644 index 000000000..8ef19b8cf --- /dev/null +++ b/ml/backend/ggml/ggml/src/mem_hip.cpp @@ -0,0 +1,449 @@ +#include "ggml.h" + +#ifdef _WIN32 +// AMD Device Library eXtra (ADLX) +// +// https://github.com/GPUOpen-LibrariesAndSDKs/ADLX +// +// This Windows-only library provides accurate VRAM reporting for AMD GPUs. +// The runtime DLL is installed with every AMD Driver on Windows, however +// the SDK isn't a part of the HIP SDK packaging. As such, we avoid including +// the headers from the SDK to simplify building from source. +// +// ADLX relies heavily on function pointer tables. +// Only the minimal set of types are defined below to facilitate +// finding the target AMD GPU(s) and querying their current VRAM usage +// Unused function parameters are commented out to avoid unnecessary type +// definitions. + +#include "ggml-impl.h" +#include +#include + +#define WIN32_LEAN_AND_MEAN +#ifndef NOMINMAX +# define NOMINMAX +#endif +#include + +namespace fs = std::filesystem; + +#include +#include + +// Begin minimal ADLX definitions - derived from tag v1.0 (Dec 2022) +typedef uint64_t adlx_uint64; +typedef uint32_t adlx_uint32; +typedef int32_t adlx_int32; +typedef adlx_int32 adlx_int; +typedef adlx_uint32 adlx_uint; +typedef long adlx_long; +typedef uint8_t adlx_uint8; +typedef enum +{ + ADLX_OK = 0, /**< @ENG_START_DOX This result indicates success. @ENG_END_DOX */ + ADLX_ALREADY_ENABLED, /**< @ENG_START_DOX This result indicates that the asked action is already enabled. @ENG_END_DOX */ + ADLX_ALREADY_INITIALIZED, /**< @ENG_START_DOX This result indicates that ADLX has a unspecified type of initialization. @ENG_END_DOX */ + ADLX_FAIL, /**< @ENG_START_DOX This result indicates an unspecified failure. @ENG_END_DOX */ + ADLX_INVALID_ARGS, /**< @ENG_START_DOX This result indicates that the arguments are invalid. @ENG_END_DOX */ + ADLX_BAD_VER, /**< @ENG_START_DOX This result indicates that the asked version is incompatible with the current version. @ENG_END_DOX */ + ADLX_UNKNOWN_INTERFACE, /**< @ENG_START_DOX This result indicates that an unknown interface was asked. @ENG_END_DOX */ + ADLX_TERMINATED, /**< @ENG_START_DOX This result indicates that the calls were made in an interface after ADLX was terminated. @ENG_END_DOX */ + ADLX_ADL_INIT_ERROR, /**< @ENG_START_DOX This result indicates that the ADL initialization failed. @ENG_END_DOX */ + ADLX_NOT_FOUND, /**< @ENG_START_DOX This result indicates that the item is not found. @ENG_END_DOX */ + ADLX_INVALID_OBJECT, /**< @ENG_START_DOX This result indicates that the method was called into an invalid object. @ENG_END_DOX */ + ADLX_ORPHAN_OBJECTS, /**< @ENG_START_DOX This result indicates that ADLX was terminated with outstanding ADLX objects. Any interface obtained from ADLX points to invalid memory and calls in their methods will result in unexpected behavior. @ENG_END_DOX */ + ADLX_NOT_SUPPORTED, /**< @ENG_START_DOX This result indicates that the asked feature is not supported. @ENG_END_DOX */ + ADLX_PENDING_OPERATION, /**< @ENG_START_DOX This result indicates a failure due to an operation currently in progress. @ENG_END_DOX */ + ADLX_GPU_INACTIVE /**< @ENG_START_DOX This result indicates that the GPU is inactive. @ENG_END_DOX */ +} ADLX_RESULT; +#define ADLX_SUCCEEDED(x) (ADLX_OK == (x) || ADLX_ALREADY_ENABLED == (x) || ADLX_ALREADY_INITIALIZED == (x)) +#define ADLX_FAILED(x) (ADLX_OK != (x) && ADLX_ALREADY_ENABLED != (x) && ADLX_ALREADY_INITIALIZED != (x)) +#define ADLX_VER_MAJOR 1 +#define ADLX_VER_MINOR 0 +#define ADLX_VER_RELEASE 5 +#define ADLX_VER_BUILD_NUM 30 +#define ADLX_MAKE_FULL_VER(VERSION_MAJOR, VERSION_MINOR, VERSION_RELEASE, VERSION_BUILD_NUM) ( ((adlx_uint64)(VERSION_MAJOR) << 48ull) | ((adlx_uint64)(VERSION_MINOR) << 32ull) | ((adlx_uint64)(VERSION_RELEASE) << 16ull) | (adlx_uint64)(VERSION_BUILD_NUM)) +#define ADLX_FULL_VERSION ADLX_MAKE_FULL_VER(ADLX_VER_MAJOR, ADLX_VER_MINOR, ADLX_VER_RELEASE, ADLX_VER_BUILD_NUM) +#define ADLX_CORE_LINK __declspec(dllexport) +#define ADLX_STD_CALL __stdcall +#define ADLX_CDECL_CALL __cdecl +#define ADLX_FAST_CALL __fastcall +#define ADLX_INLINE __inline +#define ADLX_FORCEINLINE __forceinline +#define ADLX_NO_VTABLE __declspec(novtable) + +#if defined(__cplusplus) +typedef bool adlx_bool; +#else +typedef adlx_uint8 adlx_bool; +#define true 1 +#define false 0 +#endif + +typedef struct IADLXSystem IADLXSystem; +typedef struct IADLXGPUList IADLXGPUList; +typedef struct IADLXGPU IADLXGPU; +typedef struct IADLXInterface IADLXInterface; +typedef struct IADLXPerformanceMonitoringServices IADLXPerformanceMonitoringServices; +typedef struct IADLXGPUMetrics IADLXGPUMetrics; +typedef struct IADLXGPUMetricsSupport IADLXGPUMetricsSupport; + +typedef struct IADLXSystemVtbl +{ + // IADLXSystem interface + ADLX_RESULT (ADLX_STD_CALL *GetHybridGraphicsType)(/* IADLXSystem* pThis, ADLX_HG_TYPE* hgType */); + ADLX_RESULT (ADLX_STD_CALL *GetGPUs)(IADLXSystem* pThis, IADLXGPUList** ppGPUs); // Used + ADLX_RESULT (ADLX_STD_CALL *QueryInterface)(/* IADLXSystem* pThis, const wchar_t* interfaceId, void** ppInterface */); + ADLX_RESULT (ADLX_STD_CALL *GetDisplaysServices)(/* IADLXSystem* pThis, IADLXDisplayServices** ppDispServices */); + ADLX_RESULT (ADLX_STD_CALL *GetDesktopsServices)(/* IADLXSystem* pThis, IADLXDesktopServices** ppDeskServices */); + ADLX_RESULT (ADLX_STD_CALL *GetGPUsChangedHandling)(/* IADLXSystem* pThis, IADLXGPUsChangedHandling** ppGPUsChangedHandling */); + ADLX_RESULT (ADLX_STD_CALL *EnableLog)(/* IADLXSystem* pThis, ADLX_LOG_DESTINATION mode, ADLX_LOG_SEVERITY severity, IADLXLog* pLogger, const wchar_t* fileName */); + ADLX_RESULT (ADLX_STD_CALL *Get3DSettingsServices)(/* IADLXSystem* pThis, IADLX3DSettingsServices** pp3DSettingsServices */); + ADLX_RESULT (ADLX_STD_CALL *GetGPUTuningServices)(/* IADLXSystem* pThis, IADLXGPUTuningServices** ppGPUTuningServices */); + ADLX_RESULT (ADLX_STD_CALL *GetPerformanceMonitoringServices)(IADLXSystem* pThis, IADLXPerformanceMonitoringServices** ppPerformanceMonitoringServices); // Used + ADLX_RESULT (ADLX_STD_CALL *TotalSystemRAM)(/* IADLXSystem* pThis, adlx_uint* ramMB */); + ADLX_RESULT (ADLX_STD_CALL *GetI2C)(/* IADLXSystem* pThis, IADLXGPU* pGPU, IADLXI2C** ppI2C */); +} IADLXSystemVtbl; +struct IADLXSystem { const IADLXSystemVtbl *pVtbl; }; + +typedef struct IADLXGPUVtbl +{ + //IADLXInterface + adlx_long (ADLX_STD_CALL *Acquire)(/* IADLXGPU* pThis */); + adlx_long (ADLX_STD_CALL *Release)(IADLXGPU* pThis); // Used + ADLX_RESULT (ADLX_STD_CALL *QueryInterface)(/* IADLXGPU* pThis, const wchar_t* interfaceId, void** ppInterface */); + + //IADLXGPU + ADLX_RESULT (ADLX_STD_CALL *VendorId)(/* IADLXGPU* pThis, const char** vendorId */); + ADLX_RESULT (ADLX_STD_CALL *ASICFamilyType)(/* IADLXGPU* pThis, ADLX_ASIC_FAMILY_TYPE* asicFamilyType */); + ADLX_RESULT (ADLX_STD_CALL *Type)(/* IADLXGPU* pThis, ADLX_GPU_TYPE* gpuType */); + ADLX_RESULT (ADLX_STD_CALL *IsExternal)(/* IADLXGPU* pThis, adlx_bool* isExternal */); + ADLX_RESULT (ADLX_STD_CALL *Name)(/* IADLXGPU* pThis, const char** gpuName */); + ADLX_RESULT (ADLX_STD_CALL *DriverPath)(/* IADLXGPU* pThis, const char** driverPath */); + ADLX_RESULT (ADLX_STD_CALL *PNPString)(/* IADLXGPU* pThis, const char** pnpString */); + ADLX_RESULT (ADLX_STD_CALL *HasDesktops)(/* IADLXGPU* pThis, adlx_bool* hasDesktops */); + ADLX_RESULT (ADLX_STD_CALL *TotalVRAM)(IADLXGPU* pThis, adlx_uint* vramMB); // Used + ADLX_RESULT (ADLX_STD_CALL *VRAMType)(/* IADLXGPU* pThis, const char** type */); + ADLX_RESULT (ADLX_STD_CALL *BIOSInfo)(/* IADLXGPU* pThis, const char** partNumber, const char** version, const char** date */); + ADLX_RESULT (ADLX_STD_CALL *DeviceId)(/* IADLXGPU* pThis, const char** deviceId */); + ADLX_RESULT (ADLX_STD_CALL *RevisionId)(/* IADLXGPU* pThis, const char** revisionId */); + ADLX_RESULT (ADLX_STD_CALL *SubSystemId)(/* IADLXGPU* pThis, const char** subSystemId */); + ADLX_RESULT (ADLX_STD_CALL *SubSystemVendorId)(/* IADLXGPU* pThis, const char** subSystemVendorId */); + ADLX_RESULT (ADLX_STD_CALL *UniqueId)(IADLXGPU* pThis, adlx_int* uniqueId); // Used +} IADLXGPUVtbl; +struct IADLXGPU { const IADLXGPUVtbl *pVtbl; }; + +typedef struct IADLXGPUListVtbl +{ + //IADLXInterface + adlx_long (ADLX_STD_CALL *Acquire)(/* IADLXGPUList* pThis */); + adlx_long (ADLX_STD_CALL *Release)(IADLXGPUList* pThis); // Used + ADLX_RESULT (ADLX_STD_CALL *QueryInterface)(/* IADLXGPUList* pThis, const wchar_t* interfaceId, void** ppInterface */); + + //IADLXList + adlx_uint (ADLX_STD_CALL *Size)(/* IADLXGPUList* pThis */); + adlx_uint8 (ADLX_STD_CALL *Empty)(/* IADLXGPUList* pThis */); + adlx_uint (ADLX_STD_CALL *Begin)(IADLXGPUList* pThis); // Used + adlx_uint (ADLX_STD_CALL *End)(IADLXGPUList* pThis); // Used + ADLX_RESULT (ADLX_STD_CALL *At)(/* IADLXGPUList* pThis, const adlx_uint location, IADLXInterface** ppItem */); + ADLX_RESULT (ADLX_STD_CALL *Clear)(/* IADLXGPUList* pThis */); + ADLX_RESULT (ADLX_STD_CALL *Remove_Back)(/* IADLXGPUList* pThis */); + ADLX_RESULT (ADLX_STD_CALL *Add_Back)(/* IADLXGPUList* pThis, IADLXInterface* pItem */); + + //IADLXGPUList + ADLX_RESULT (ADLX_STD_CALL *At_GPUList)(IADLXGPUList* pThis, const adlx_uint location, IADLXGPU** ppItem); // Used + ADLX_RESULT (ADLX_STD_CALL *Add_Back_GPUList)(/* IADLXGPUList* pThis, IADLXGPU* pItem */); + +} IADLXGPUListVtbl; +struct IADLXGPUList { const IADLXGPUListVtbl *pVtbl; }; + +typedef struct IADLXPerformanceMonitoringServicesVtbl +{ + //IADLXInterface + adlx_long (ADLX_STD_CALL *Acquire)(/* IADLXPerformanceMonitoringServices* pThis */); + adlx_long (ADLX_STD_CALL *Release)(IADLXPerformanceMonitoringServices* pThis); // Used + ADLX_RESULT (ADLX_STD_CALL *QueryInterface)(/* IADLXPerformanceMonitoringServices* pThis, const wchar_t* interfaceId, void** ppInterface */); + + //IADLXPerformanceMonitoringServices + ADLX_RESULT (ADLX_STD_CALL *GetSamplingIntervalRange)(/* IADLXPerformanceMonitoringServices* pThis, ADLX_IntRange* range */); + ADLX_RESULT (ADLX_STD_CALL *SetSamplingInterval)(/* IADLXPerformanceMonitoringServices* pThis, adlx_int intervalMs */); + ADLX_RESULT (ADLX_STD_CALL *GetSamplingInterval)(/* IADLXPerformanceMonitoringServices* pThis, adlx_int* intervalMs */); + ADLX_RESULT (ADLX_STD_CALL *GetMaxPerformanceMetricsHistorySizeRange)(/* IADLXPerformanceMonitoringServices* pThis, ADLX_IntRange* range */); + ADLX_RESULT (ADLX_STD_CALL *SetMaxPerformanceMetricsHistorySize)(/* IADLXPerformanceMonitoringServices* pThis, adlx_int sizeSec */); + ADLX_RESULT (ADLX_STD_CALL *GetMaxPerformanceMetricsHistorySize)(/* IADLXPerformanceMonitoringServices* pThis, adlx_int* sizeSec */); + ADLX_RESULT (ADLX_STD_CALL *ClearPerformanceMetricsHistory)(/* IADLXPerformanceMonitoringServices* pThis */); + ADLX_RESULT (ADLX_STD_CALL *GetCurrentPerformanceMetricsHistorySize)(/* IADLXPerformanceMonitoringServices* pThis, adlx_int* sizeSec */); + ADLX_RESULT (ADLX_STD_CALL *StartPerformanceMetricsTracking)(/* IADLXPerformanceMonitoringServices* pThis */); + ADLX_RESULT (ADLX_STD_CALL *StopPerformanceMetricsTracking)(/* IADLXPerformanceMonitoringServices* pThis */); + ADLX_RESULT (ADLX_STD_CALL *GetAllMetricsHistory)(/* IADLXPerformanceMonitoringServices* pThis, adlx_int startMs, adlx_int stopMs, IADLXAllMetricsList** ppMetricsList */); + ADLX_RESULT (ADLX_STD_CALL *GetGPUMetricsHistory)(/* IADLXPerformanceMonitoringServices* pThis, IADLXGPU* pGPU, adlx_int startMs, adlx_int stopMs, IADLXGPUMetricsList** ppMetricsList */); + ADLX_RESULT (ADLX_STD_CALL *GetSystemMetricsHistory)(/* IADLXPerformanceMonitoringServices* pThis, adlx_int startMs, adlx_int stopMs, IADLXSystemMetricsList** ppMetricsList */); + ADLX_RESULT (ADLX_STD_CALL *GetFPSHistory)(/* IADLXPerformanceMonitoringServices* pThis, adlx_int startMs, adlx_int stopMs, IADLXFPSList** ppMetricsList */); + ADLX_RESULT (ADLX_STD_CALL *GetCurrentAllMetrics)(/* IADLXPerformanceMonitoringServices* pThis, IADLXAllMetrics** ppMetrics */); + ADLX_RESULT (ADLX_STD_CALL *GetCurrentGPUMetrics)(IADLXPerformanceMonitoringServices* pThis, IADLXGPU* pGPU, IADLXGPUMetrics** ppMetrics); // Used + ADLX_RESULT (ADLX_STD_CALL *GetCurrentSystemMetrics)(/* IADLXPerformanceMonitoringServices* pThis, IADLXSystemMetrics** ppMetrics */); + ADLX_RESULT (ADLX_STD_CALL *GetCurrentFPS)(/* IADLXPerformanceMonitoringServices* pThis, IADLXFPS** ppMetrics */); + ADLX_RESULT (ADLX_STD_CALL *GetSupportedGPUMetrics)(IADLXPerformanceMonitoringServices* pThis, IADLXGPU* pGPU, IADLXGPUMetricsSupport** ppMetricsSupported); // Used + ADLX_RESULT (ADLX_STD_CALL *GetSupportedSystemMetrics)(/* IADLXPerformanceMonitoringServices* pThis, IADLXSystemMetricsSupport** ppMetricsSupported */); +}IADLXPerformanceMonitoringServicesVtbl; +struct IADLXPerformanceMonitoringServices { const IADLXPerformanceMonitoringServicesVtbl *pVtbl; }; + +typedef struct IADLXGPUMetricsSupportVtbl +{ + //IADLXInterface + adlx_long (ADLX_STD_CALL* Acquire)(/* IADLXGPUMetricsSupport* pThis */); + adlx_long (ADLX_STD_CALL* Release)(IADLXGPUMetricsSupport* pThis); // Used + ADLX_RESULT (ADLX_STD_CALL* QueryInterface)(/* IADLXGPUMetricsSupport* pThis, const wchar_t* interfaceId, void** ppInterface */); + + //IADLXGPUMetricsSupport + ADLX_RESULT (ADLX_STD_CALL* IsSupportedGPUUsage)(/* IADLXGPUMetricsSupport* pThis, adlx_bool* supported */); + ADLX_RESULT (ADLX_STD_CALL* IsSupportedGPUClockSpeed)(/* IADLXGPUMetricsSupport* pThis, adlx_bool* supported */); + ADLX_RESULT (ADLX_STD_CALL* IsSupportedGPUVRAMClockSpeed)(/* IADLXGPUMetricsSupport* pThis, adlx_bool* supported */); + ADLX_RESULT (ADLX_STD_CALL* IsSupportedGPUTemperature)(/* IADLXGPUMetricsSupport* pThis, adlx_bool* supported */); + ADLX_RESULT (ADLX_STD_CALL* IsSupportedGPUHotspotTemperature)(/* IADLXGPUMetricsSupport* pThis, adlx_bool* supported */); + ADLX_RESULT (ADLX_STD_CALL* IsSupportedGPUPower)(/* IADLXGPUMetricsSupport* pThis, adlx_bool* supported */); + ADLX_RESULT (ADLX_STD_CALL* IsSupportedGPUTotalBoardPower)(/* IADLXGPUMetricsSupport* pThis, adlx_bool* supported */); + ADLX_RESULT (ADLX_STD_CALL* IsSupportedGPUFanSpeed)(/* IADLXGPUMetricsSupport* pThis, adlx_bool* supported */); + ADLX_RESULT (ADLX_STD_CALL* IsSupportedGPUVRAM)(IADLXGPUMetricsSupport* pThis, adlx_bool* supported); // Used + ADLX_RESULT (ADLX_STD_CALL* IsSupportedGPUVoltage)(/* IADLXGPUMetricsSupport* pThis, adlx_bool* supported */); + + ADLX_RESULT (ADLX_STD_CALL* GetGPUUsageRange)(/* IADLXGPUMetricsSupport* pThis, adlx_int* minValue, adlx_int* maxValue */); + ADLX_RESULT (ADLX_STD_CALL* GetGPUClockSpeedRange)(/* IADLXGPUMetricsSupport* pThis, adlx_int* minValue, adlx_int* maxValue */); + ADLX_RESULT (ADLX_STD_CALL* GetGPUVRAMClockSpeedRange)(/* IADLXGPUMetricsSupport* pThis, adlx_int* minValue, adlx_int* maxValue */); + ADLX_RESULT (ADLX_STD_CALL* GetGPUTemperatureRange)(/* IADLXGPUMetricsSupport* pThis, adlx_int* minValue, adlx_int* maxValue */); + ADLX_RESULT (ADLX_STD_CALL* GetGPUHotspotTemperatureRange)(/* IADLXGPUMetricsSupport* pThis, adlx_int* minValue, adlx_int* maxValue */); + ADLX_RESULT (ADLX_STD_CALL* GetGPUPowerRange)(/* IADLXGPUMetricsSupport* pThis, adlx_int* minValue, adlx_int* maxValue */); + ADLX_RESULT (ADLX_STD_CALL* GetGPUFanSpeedRange)(/* IADLXGPUMetricsSupport* pThis, adlx_int* minValue, adlx_int* maxValue */); + ADLX_RESULT (ADLX_STD_CALL* GetGPUVRAMRange)(/* IADLXGPUMetricsSupport* pThis, adlx_int* minValue, adlx_int* maxValue */); + ADLX_RESULT (ADLX_STD_CALL* GetGPUVoltageRange)(/* IADLXGPUMetricsSupport* pThis, adlx_int* minValue, adlx_int* maxValue */); + ADLX_RESULT (ADLX_STD_CALL* GetGPUTotalBoardPowerRange)(/* IADLXGPUMetricsSupport* pThis, adlx_int* minValue, adlx_int* maxValue */); +} IADLXGPUMetricsSupportVtbl; +struct IADLXGPUMetricsSupport { const IADLXGPUMetricsSupportVtbl *pVtbl; }; + +typedef struct IADLXGPUMetricsVtbl +{ + //IADLXInterface + adlx_long (ADLX_STD_CALL* Acquire)(/* IADLXGPUMetrics* pThis */); + adlx_long (ADLX_STD_CALL* Release)(IADLXGPUMetrics* pThis); // Used + ADLX_RESULT (ADLX_STD_CALL* QueryInterface)(/* IADLXGPUMetrics* pThis, const wchar_t* interfaceId, void** ppInterface */); + + //IADLXGPUMetrics + ADLX_RESULT (ADLX_STD_CALL* TimeStamp)(/* IADLXGPUMetrics* pThis, adlx_int64* ms */); + ADLX_RESULT (ADLX_STD_CALL* GPUUsage)(/* IADLXGPUMetrics* pThis, adlx_double* data */); + ADLX_RESULT (ADLX_STD_CALL* GPUClockSpeed)(/* IADLXGPUMetrics* pThis, adlx_int* data */); + ADLX_RESULT (ADLX_STD_CALL* GPUVRAMClockSpeed)(/* IADLXGPUMetrics* pThis, adlx_int* data */); + ADLX_RESULT (ADLX_STD_CALL* GPUTemperature)(/* IADLXGPUMetrics* pThis, adlx_double* data */); + ADLX_RESULT (ADLX_STD_CALL* GPUHotspotTemperature)(/* IADLXGPUMetrics* pThis, adlx_double* data */); + ADLX_RESULT (ADLX_STD_CALL* GPUPower)(/* IADLXGPUMetrics* pThis, adlx_double* data */); + ADLX_RESULT (ADLX_STD_CALL* GPUTotalBoardPower)(/* IADLXGPUMetrics* pThis, adlx_double* data */); + ADLX_RESULT (ADLX_STD_CALL* GPUFanSpeed)(/* IADLXGPUMetrics* pThis, adlx_int* data */); + ADLX_RESULT (ADLX_STD_CALL* GPUVRAM)(IADLXGPUMetrics* pThis, adlx_int* data); // Used + ADLX_RESULT (ADLX_STD_CALL* GPUVoltage)(/* IADLXGPUMetrics* pThis, adlx_int* data */); +} IADLXGPUMetricsVtbl; +struct IADLXGPUMetrics { const IADLXGPUMetricsVtbl *pVtbl; }; + +struct { + void *handle; + ADLX_RESULT (*ADLXInitialize)(adlx_uint64 version, IADLXSystem** ppSystem); + ADLX_RESULT (*ADLXInitializeWithIncompatibleDriver)(adlx_uint64 version, IADLXSystem** ppSystem); + ADLX_RESULT (*ADLXQueryVersion)(const char** version); + ADLX_RESULT (*ADLXTerminate)(); + IADLXSystem *sys; +} adlx { NULL, NULL, NULL, NULL, NULL, NULL }; +static std::mutex ggml_adlx_lock; + +extern "C" { + +int ggml_hip_mgmt_init() { + std::lock_guard lock(ggml_adlx_lock); + if (adlx.handle != NULL) { + // Already initialized + return 0; + } + DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS); + SetErrorMode(old_mode | SEM_FAILCRITICALERRORS); + fs::path libPath = fs::path("\\Windows") / fs::path("System32") / fs::path("amdadlx64.dll"); + + adlx.handle = (void*)LoadLibraryW(libPath.wstring().c_str()); + if (adlx.handle == NULL) { + return ADLX_NOT_FOUND; + } + + adlx.ADLXInitialize = (ADLX_RESULT (*)(adlx_uint64 version, IADLXSystem **ppSystem)) GetProcAddress((HMODULE)(adlx.handle), "ADLXInitialize"); + adlx.ADLXInitializeWithIncompatibleDriver = (ADLX_RESULT (*)(adlx_uint64 version, IADLXSystem **ppSystem)) GetProcAddress((HMODULE)(adlx.handle), "ADLXInitializeWithIncompatibleDriver"); + adlx.ADLXTerminate = (ADLX_RESULT (*)()) GetProcAddress((HMODULE)(adlx.handle), "ADLXTerminate"); + adlx.ADLXQueryVersion = (ADLX_RESULT (*)(const char **version)) GetProcAddress((HMODULE)(adlx.handle), "ADLXQueryVersion"); + if (adlx.ADLXInitialize == NULL || adlx.ADLXInitializeWithIncompatibleDriver == NULL || adlx.ADLXTerminate == NULL) { + GGML_LOG_INFO("%s unable to locate required symbols in amdadlx64.dll, falling back to hip free memory reporting", __func__); + FreeLibrary((HMODULE)(adlx.handle)); + adlx.handle = NULL; + return ADLX_NOT_FOUND; + } + + SetErrorMode(old_mode); + + // Aid in troubleshooting... + if (adlx.ADLXQueryVersion != NULL) { + const char *version = NULL; + ADLX_RESULT status = adlx.ADLXQueryVersion(&version); + if (ADLX_SUCCEEDED(status)) { + GGML_LOG_DEBUG("%s located ADLX version %s\n", __func__, version); + } + } + + ADLX_RESULT status = adlx.ADLXInitialize(ADLX_FULL_VERSION, &adlx.sys); + if (ADLX_FAILED(status)) { + // GGML_LOG_DEBUG("%s failed to initialize ADLX error=%d - attempting with incompatible driver...\n", __func__, status); + // Try with the incompatible driver + status = adlx.ADLXInitializeWithIncompatibleDriver(ADLX_FULL_VERSION, &adlx.sys); + if (ADLX_FAILED(status)) { + GGML_LOG_INFO("%s failed to initialize ADLX error=%d\n", __func__, status); + FreeLibrary((HMODULE)(adlx.handle)); + adlx.handle = NULL; + adlx.sys = NULL; + return status; + } + // GGML_LOG_DEBUG("%s initialized ADLX with incpomatible driver\n", __func__); + } + return ADLX_OK; +} + +void ggml_hip_mgmt_release() { + std::lock_guard lock(ggml_adlx_lock); + if (adlx.handle == NULL) { + // Already free + return; + } + ADLX_RESULT status = adlx.ADLXTerminate(); + if (ADLX_FAILED(status)) { + GGML_LOG_INFO("%s failed to terminate Adlx %d\n", __func__, status); + // Unload anyway... + } + FreeLibrary((HMODULE)(adlx.handle)); + adlx.handle = NULL; +} + +#define adlx_gdm_cleanup \ + if (gpuMetricsSupport != NULL) gpuMetricsSupport->pVtbl->Release(gpuMetricsSupport); \ + if (gpuMetrics != NULL) gpuMetrics->pVtbl->Release(gpuMetrics); \ + if (perfMonitoringServices != NULL) perfMonitoringServices->pVtbl->Release(perfMonitoringServices); \ + if (gpus != NULL) gpus->pVtbl->Release(gpus); \ + if (gpu != NULL) gpu->pVtbl->Release(gpu) + +int ggml_hip_get_device_memory(int pci_bus_id, int pci_device_id, size_t *free, size_t *total) { + std::lock_guard lock(ggml_adlx_lock); + if (adlx.handle == NULL) { + GGML_LOG_INFO("%s ADLX was not initialized\n", __func__); + return ADLX_ADL_INIT_ERROR; + } + IADLXGPUMetricsSupport *gpuMetricsSupport = NULL; + IADLXPerformanceMonitoringServices *perfMonitoringServices = NULL; + IADLXGPUList* gpus = NULL; + IADLXGPU* gpu = NULL; + IADLXGPUMetrics *gpuMetrics = NULL; + ADLX_RESULT status; + // The "UniqueID" exposed in ADLX is the PCI Bus and Device IDs + adlx_int target = (pci_bus_id << 8) | (pci_device_id & 0xff); + + status = adlx.sys->pVtbl->GetPerformanceMonitoringServices(adlx.sys, &perfMonitoringServices); + if (ADLX_FAILED(status)) { + GGML_LOG_INFO("%s GetPerformanceMonitoringServices failed %d\n", __func__, status); + return status; + } + + status = adlx.sys->pVtbl->GetGPUs(adlx.sys, &gpus); + if (ADLX_FAILED(status)) { + GGML_LOG_INFO("%s GetGPUs failed %d\n", __func__, status); + adlx_gdm_cleanup; + return status; + } + + // Get GPU list + for (adlx_uint crt = gpus->pVtbl->Begin(gpus); crt != gpus->pVtbl->End(gpus); ++crt) + { + status = gpus->pVtbl->At_GPUList(gpus, crt, &gpu); + if (ADLX_FAILED(status)) + { + GGML_LOG_INFO("%s %d] At_GPUList failed %d\n", __func__, crt, status); + continue; + } + adlx_int id; + status = gpu->pVtbl->UniqueId(gpu, &id); + if (ADLX_FAILED(status)) { + GGML_LOG_INFO("%s %d] UniqueId lookup failed %d\n", __func__, crt, status); + gpu->pVtbl->Release(gpu); + gpu = NULL; + continue; + } + if (id != target) { + GGML_LOG_DEBUG("%s %d] GPU UniqueId: %x does not match target %02x %02x\n", __func__, crt, id, pci_bus_id, pci_device_id); + gpu->pVtbl->Release(gpu); + gpu = NULL; + continue; + } + // Any failures at this point should cause a fall-back to other APIs + status = perfMonitoringServices->pVtbl->GetSupportedGPUMetrics(perfMonitoringServices, gpu, &gpuMetricsSupport); + if (ADLX_FAILED(status)) { + GGML_LOG_INFO("%s GetSupportedGPUMetrics failed %d\n", __func__, status); + adlx_gdm_cleanup; + return status; + } + status = perfMonitoringServices->pVtbl->GetCurrentGPUMetrics(perfMonitoringServices, gpu, &gpuMetrics); + if (ADLX_FAILED(status)) { + GGML_LOG_INFO("%s GetCurrentGPUMetrics failed %d\n", __func__, status); + adlx_gdm_cleanup; + return status; + } + + adlx_bool supported = false; + status = gpuMetricsSupport->pVtbl->IsSupportedGPUVRAM(gpuMetricsSupport, &supported); + if (ADLX_FAILED(status)) { + GGML_LOG_INFO("%s IsSupportedGPUVRAM failed %d\n", __func__, status); + adlx_gdm_cleanup; + return status; + } + + adlx_uint totalVRAM = 0; + status = gpu->pVtbl->TotalVRAM(gpu, &totalVRAM); + if (ADLX_FAILED(status)) { + GGML_LOG_INFO("%s TotalVRAM failed %d\n", __func__, status); + adlx_gdm_cleanup; + return status; + } + + adlx_int usedVRAM = 0; + status = gpuMetrics->pVtbl->GPUVRAM(gpuMetrics, &usedVRAM); + if (ADLX_FAILED(status)) { + GGML_LOG_INFO("%s GPUVRAM failed %d\n", __func__, status); + adlx_gdm_cleanup; + return status; + } + *total = size_t(totalVRAM) * 1024 * 1024; + *free = size_t(totalVRAM-usedVRAM) * 1024 * 1024; + + adlx_gdm_cleanup; + return ADLX_OK; + } + adlx_gdm_cleanup; + return ADLX_NOT_FOUND; +} + +} // extern "C" + +#else // #ifdef _WIN32 + +extern "C" { + +// TODO Linux implementation of accurate VRAM reporting +int ggml_hip_mgmt_init() { + return -1; +} +void ggml_hip_mgmt_release() {} +int ggml_hip_get_device_memory(int pci_bus_id, int pci_device_id, size_t *free, size_t *total) { + return -1; +} + +} // extern "C" + +#endif // #ifdef _WIN32 \ No newline at end of file diff --git a/ml/backend/ggml/ggml/src/mem_nvml.cpp b/ml/backend/ggml/ggml/src/mem_nvml.cpp new file mode 100644 index 000000000..aa05e9dc1 --- /dev/null +++ b/ml/backend/ggml/ggml/src/mem_nvml.cpp @@ -0,0 +1,172 @@ +// NVIDIA Management Library (NVML) +// +// https://developer.nvidia.com/management-library-nvml +// +// This library provides accurate VRAM reporting for NVIDIA GPUs, particularly +// on Windows, where the cuda library provides inaccurate VRAM usage metrics. The +// runtime DLL is installed with every driver on Windows, and most Linux +// systems, and the headers are included in the standard CUDA SDK install. As +// such, we can include the header here to simplify the code. + + +#include "ggml-impl.h" +#include +#include + +#ifdef _WIN32 +# define WIN32_LEAN_AND_MEAN +# ifndef NOMINMAX +# define NOMINMAX +# endif +# include +#else +# include +# include +#endif + +namespace fs = std::filesystem; + +// Minimal definitions to avoid including the nvml.h header +typedef enum nvmlReturn_enum +{ + // cppcheck-suppress * + NVML_SUCCESS = 0, //!< The operation was successful + NVML_ERROR_UNINITIALIZED = 1, //!< NVML was not first initialized with nvmlInit() + NVML_ERROR_INVALID_ARGUMENT = 2, //!< A supplied argument is invalid + NVML_ERROR_NOT_SUPPORTED = 3, //!< The requested operation is not available on target device + NVML_ERROR_NO_PERMISSION = 4, //!< The current user does not have permission for operation + NVML_ERROR_ALREADY_INITIALIZED = 5, //!< Deprecated: Multiple initializations are now allowed through ref counting + NVML_ERROR_NOT_FOUND = 6, //!< A query to find an object was unsuccessful + NVML_ERROR_INSUFFICIENT_SIZE = 7, //!< An input argument is not large enough + NVML_ERROR_INSUFFICIENT_POWER = 8, //!< A device's external power cables are not properly attached + NVML_ERROR_DRIVER_NOT_LOADED = 9, //!< NVIDIA driver is not loaded + NVML_ERROR_TIMEOUT = 10, //!< User provided timeout passed + NVML_ERROR_IRQ_ISSUE = 11, //!< NVIDIA Kernel detected an interrupt issue with a GPU + NVML_ERROR_LIBRARY_NOT_FOUND = 12, //!< NVML Shared Library couldn't be found or loaded + NVML_ERROR_FUNCTION_NOT_FOUND = 13, //!< Local version of NVML doesn't implement this function + NVML_ERROR_CORRUPTED_INFOROM = 14, //!< infoROM is corrupted + NVML_ERROR_GPU_IS_LOST = 15, //!< The GPU has fallen off the bus or has otherwise become inaccessible + NVML_ERROR_RESET_REQUIRED = 16, //!< The GPU requires a reset before it can be used again + NVML_ERROR_OPERATING_SYSTEM = 17, //!< The GPU control device has been blocked by the operating system/cgroups + NVML_ERROR_LIB_RM_VERSION_MISMATCH = 18, //!< RM detects a driver/library version mismatch + NVML_ERROR_IN_USE = 19, //!< An operation cannot be performed because the GPU is currently in use + NVML_ERROR_MEMORY = 20, //!< Insufficient memory + NVML_ERROR_NO_DATA = 21, //!< No data + NVML_ERROR_VGPU_ECC_NOT_SUPPORTED = 22, //!< The requested vgpu operation is not available on target device, becasue ECC is enabled + NVML_ERROR_INSUFFICIENT_RESOURCES = 23, //!< Ran out of critical resources, other than memory + NVML_ERROR_FREQ_NOT_SUPPORTED = 24, //!< Ran out of critical resources, other than memory + NVML_ERROR_ARGUMENT_VERSION_MISMATCH = 25, //!< The provided version is invalid/unsupported + NVML_ERROR_DEPRECATED = 26, //!< The requested functionality has been deprecated + NVML_ERROR_NOT_READY = 27, //!< The system is not ready for the request + NVML_ERROR_GPU_NOT_FOUND = 28, //!< No GPUs were found + NVML_ERROR_INVALID_STATE = 29, //!< Resource not in correct state to perform requested operation + NVML_ERROR_UNKNOWN = 999 //!< An internal driver error occurred +} nvmlReturn_t; +typedef struct nvmlDevice_st* nvmlDevice_t; +typedef struct nvmlMemory_st +{ + unsigned long long total; //!< Total physical device memory (in bytes) + unsigned long long free; //!< Unallocated device memory (in bytes) + unsigned long long used; //!< Sum of Reserved and Allocated device memory (in bytes). + //!< Note that the driver/GPU always sets aside a small amount of memory for bookkeeping +} nvmlMemory_t; +// end nvml.h definitions + +struct { + void *handle; + nvmlReturn_t (*nvmlInit_v2)(void); + nvmlReturn_t (*nvmlShutdown)(void); + nvmlReturn_t (*nvmlDeviceGetHandleByUUID)(const char *, nvmlDevice_t *); + nvmlReturn_t (*nvmlDeviceGetMemoryInfo)(nvmlDevice_t, nvmlMemory_t *); +} nvml { NULL, NULL, NULL, NULL, NULL }; +static std::mutex ggml_nvml_lock; + +extern "C" { + +int ggml_nvml_init() { + std::lock_guard lock(ggml_nvml_lock); + if (nvml.handle != NULL) { + // Already initialized + return 0; + } +#ifdef _WIN32 + DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS); + SetErrorMode(old_mode | SEM_FAILCRITICALERRORS); + fs::path libPath[2]; + const char * programDir = std::getenv("ProgramW6432"); + if (programDir == NULL) { + libPath[0] = fs::path("Program Files") / fs::path("NVIDIA Corporation") / fs::path("NVSMI") / fs::path("NVML.dll"); + } else { + libPath[0] = fs::path(programDir) / fs::path("NVIDIA Corporation") / fs::path("NVSMI") / fs::path("NVML.dll"); + } + libPath[1] = fs::path("\\Windows") / fs::path("System32") / fs::path("NVML.dll"); + + for (int i = 0; i < 2; i++) { + nvml.handle = (void*)LoadLibraryW(libPath[i].wstring().c_str()); + if (nvml.handle != NULL) { + break; + } + } + if (nvml.handle == NULL) { + return NVML_ERROR_NOT_FOUND; + } + + nvml.nvmlInit_v2 = (nvmlReturn_enum (*)()) GetProcAddress((HMODULE)(nvml.handle), "nvmlInit_v2"); + nvml.nvmlShutdown = (nvmlReturn_enum (*)()) GetProcAddress((HMODULE)(nvml.handle), "nvmlShutdown"); + nvml.nvmlDeviceGetHandleByUUID = (nvmlReturn_t (*)(const char *, nvmlDevice_t *)) GetProcAddress((HMODULE)(nvml.handle), "nvmlDeviceGetHandleByUUID"); + nvml.nvmlDeviceGetMemoryInfo = (nvmlReturn_t (*)(nvmlDevice_t, nvmlMemory_t *)) GetProcAddress((HMODULE)(nvml.handle), "nvmlDeviceGetMemoryInfo"); + if (nvml.nvmlInit_v2 == NULL || nvml.nvmlShutdown == NULL || nvml.nvmlDeviceGetHandleByUUID == NULL || nvml.nvmlDeviceGetMemoryInfo == NULL) { + GGML_LOG_INFO("%s unable to locate required symbols in NVML.dll", __func__); + FreeLibrary((HMODULE)(nvml.handle)); + nvml.handle = NULL; + return NVML_ERROR_NOT_FOUND; + } + + SetErrorMode(old_mode); + +#else + // Not currently wired up on Linux + return NVML_ERROR_NOT_SUPPORTED; +#endif + int status = nvml.nvmlInit_v2(); + return NVML_SUCCESS; +} + +void ggml_nvml_release() { + std::lock_guard lock(ggml_nvml_lock); + if (nvml.handle == NULL) { + // Already free + return; + } + nvmlReturn_enum status = nvml.nvmlShutdown(); + if (status != NVML_SUCCESS) { + GGML_LOG_INFO("%s failed to shutdown NVML: %d\n", __func__, status); + } +#ifdef _WIN32 + FreeLibrary((HMODULE)(nvml.handle)); + nvml.handle = NULL; +#else + // Not currently wired up on Linux +#endif +} + +int ggml_nvml_get_device_memory(const char *uuid, size_t *free, size_t *total) { + std::lock_guard lock(ggml_nvml_lock); + if (nvml.handle == NULL) { + return NVML_ERROR_UNINITIALIZED; + } + nvmlDevice_t device; + auto status = nvml.nvmlDeviceGetHandleByUUID(uuid, &device); + if (status != NVML_SUCCESS) { + return status; + } + nvmlMemory_t memInfo = {0}; + status = nvml.nvmlDeviceGetMemoryInfo(device, &memInfo); + if (status == NVML_SUCCESS) { + *free = memInfo.free; + *total = memInfo.total; + } + return status; +} + +} \ No newline at end of file diff --git a/ml/nn/pooling/pooling_test.go b/ml/nn/pooling/pooling_test.go index c80019459..e27727462 100644 --- a/ml/nn/pooling/pooling_test.go +++ b/ml/nn/pooling/pooling_test.go @@ -3,11 +3,9 @@ package pooling_test import ( "bytes" "os" - "slices" "testing" "github.com/google/go-cmp/cmp" - "github.com/ollama/ollama/discover" fsggml "github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/backend/ggml" @@ -32,20 +30,7 @@ func setup(tb testing.TB, n int) ml.Backend { tb.Fatal(err) } - var gpuLayers ml.GPULayersList - if gpus := discover.GetGPUInfo(); len(gpus) > 0 { - gpuLayers = append(gpuLayers, ml.GPULayers{ - ID: gpus[0].ID, - Layers: slices.Collect(func(yield func(int) bool) { - for i := range n { - if !yield(i) { - return - } - } - }), - }) - } - b, err := ggml.New(f.Name(), ml.BackendParams{AllocMemory: true, GPULayers: gpuLayers}) + b, err := ggml.New(f.Name(), ml.BackendParams{AllocMemory: true}) if err != nil { tb.Fatal(err) } diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 480cfc19b..c86d3c2b9 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -28,6 +28,7 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/envconfig" + "github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/llm" "github.com/ollama/ollama/logutil" "github.com/ollama/ollama/ml" @@ -1235,6 +1236,46 @@ func (s *Server) load(w http.ResponseWriter, r *http.Request) { } } +// load is the handler called by the Ollama server to process different +// load operations +func (s *Server) info(w http.ResponseWriter, r *http.Request) { + s.loadMu.Lock() + defer s.loadMu.Unlock() + + w.Header().Set("Content-Type", "application/json") + + m := s.model + + if m == nil { + // Dummy load to get the backend wired up + f, err := os.CreateTemp("", "*.bin") + if err != nil { + http.Error(w, fmt.Sprintf("failed to initialize baackend: %v", err), http.StatusInternalServerError) + return + } + defer f.Close() + + if err := ggml.WriteGGUF(f, ggml.KV{ + "general.architecture": "llama", + "tokenizer.ggml.model": "gpt2", + }, nil); err != nil { + http.Error(w, fmt.Sprintf("failed to initialize baackend: %v", err), http.StatusInternalServerError) + return + } + + m, err = model.New(f.Name(), ml.BackendParams{AllocMemory: false, GPULayers: ml.GPULayersList{{}}}) + if err != nil { + http.Error(w, fmt.Sprintf("failed to initialize baackend: %v", err), http.StatusInternalServerError) + return + } + } + + infos := m.Backend().BackendDevices() + if err := json.NewEncoder(w).Encode(&infos); err != nil { + http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) + } +} + func Execute(args []string) error { fs := flag.NewFlagSet("runner", flag.ExitOnError) mpath := fs.String("model", "", "Path to model binary file") @@ -1275,6 +1316,7 @@ func Execute(args []string) error { mux := http.NewServeMux() // TODO: support embeddings + mux.HandleFunc("GET /info", server.info) mux.HandleFunc("POST /load", server.load) mux.HandleFunc("POST /embedding", server.embeddings) mux.HandleFunc("POST /completion", server.completion) diff --git a/scripts/build_windows.ps1 b/scripts/build_windows.ps1 index 37fe87961..b4a4b4235 100644 --- a/scripts/build_windows.ps1 +++ b/scripts/build_windows.ps1 @@ -279,7 +279,7 @@ function distZip() { write-host "Generating stand-alone distribution zip file ${script:SRC_DIR}\dist\ollama-windows-amd64.zip" Compress-Archive -CompressionLevel Optimal -Path "${script:SRC_DIR}\dist\windows-amd64\*" -DestinationPath "${script:SRC_DIR}\dist\ollama-windows-amd64.zip" -Force if (Test-Path -Path "${script:SRC_DIR}\dist\windows-amd64-rocm") { - Move-Item -destination "${script:SRC_DIR}\dist\windows-amd64\lib\ollama\rocm" -path "${script:SRC_DIR}\dist\windows-amd64-rocm\lib\ollama" + Move-Item -destination "${script:SRC_DIR}\dist\windows-amd64\lib\ollama\rocm" -path "${script:SRC_DIR}\dist\windows-amd64-rocm\lib\ollama\rocm" } } diff --git a/server/routes.go b/server/routes.go index a08a72898..a073fb76d 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1548,8 +1548,8 @@ func Serve(ln net.Listener) error { // At startup we retrieve GPU information so we can get log messages before loading a model // This will log warnings to the log in case we have problems with detected GPUs - gpus := discover.GetGPUInfo() - gpus.LogDetails() + gpus := discover.GPUDevices(ctx, nil) + discover.LogDetails(gpus) var totalVRAM uint64 for _, gpu := range gpus { diff --git a/server/routes_debug_test.go b/server/routes_debug_test.go index 6507284ef..cc3522109 100644 --- a/server/routes_debug_test.go +++ b/server/routes_debug_test.go @@ -36,8 +36,8 @@ func TestGenerateDebugRenderOnly(t *testing.T) { unloadedCh: make(chan any, 1), loaded: make(map[string]*runnerRef), newServerFn: newMockServer(&mock), - getGpuFn: discover.GetGPUInfo, - getCpuFn: discover.GetCPUInfo, + getGpuFn: getGpuFn, + getCpuFn: getCpuFn, reschedDelay: 250 * time.Millisecond, loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool { // add small delay to simulate loading @@ -229,8 +229,8 @@ func TestChatDebugRenderOnly(t *testing.T) { unloadedCh: make(chan any, 1), loaded: make(map[string]*runnerRef), newServerFn: newMockServer(&mock), - getGpuFn: discover.GetGPUInfo, - getCpuFn: discover.GetCPUInfo, + getGpuFn: getGpuFn, + getCpuFn: getCpuFn, reschedDelay: 250 * time.Millisecond, loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool { // add small delay to simulate loading diff --git a/server/routes_generate_test.go b/server/routes_generate_test.go index a3b83fc1a..8385cb17b 100644 --- a/server/routes_generate_test.go +++ b/server/routes_generate_test.go @@ -74,8 +74,8 @@ func TestGenerateChat(t *testing.T) { unloadedCh: make(chan any, 1), loaded: make(map[string]*runnerRef), newServerFn: newMockServer(&mock), - getGpuFn: discover.GetGPUInfo, - getCpuFn: discover.GetCPUInfo, + getGpuFn: getGpuFn, + getCpuFn: getCpuFn, reschedDelay: 250 * time.Millisecond, loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool { // add small delay to simulate loading @@ -618,8 +618,8 @@ func TestGenerate(t *testing.T) { unloadedCh: make(chan any, 1), loaded: make(map[string]*runnerRef), newServerFn: newMockServer(&mock), - getGpuFn: discover.GetGPUInfo, - getCpuFn: discover.GetCPUInfo, + getGpuFn: getGpuFn, + getCpuFn: getCpuFn, reschedDelay: 250 * time.Millisecond, loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool { // add small delay to simulate loading @@ -994,8 +994,8 @@ func TestChatWithPromptEndingInThinkTag(t *testing.T) { unloadedCh: make(chan any, 1), loaded: make(map[string]*runnerRef), newServerFn: newMockServer(mock), - getGpuFn: discover.GetGPUInfo, - getCpuFn: discover.GetCPUInfo, + getGpuFn: getGpuFn, + getCpuFn: getCpuFn, reschedDelay: 250 * time.Millisecond, loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool { time.Sleep(time.Millisecond) diff --git a/server/routes_harmony_streaming_test.go b/server/routes_harmony_streaming_test.go index b1ede4e39..caadcb872 100644 --- a/server/routes_harmony_streaming_test.go +++ b/server/routes_harmony_streaming_test.go @@ -274,8 +274,8 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) { unloadedCh: make(chan any, 1), loaded: make(map[string]*runnerRef), newServerFn: newMockServer(&mock), - getGpuFn: discover.GetGPUInfo, - getCpuFn: discover.GetCPUInfo, + getGpuFn: getGpuFn, + getCpuFn: getCpuFn, reschedDelay: 100 * time.Millisecond, loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool { req.successCh <- &runnerRef{ @@ -425,8 +425,8 @@ func TestChatHarmonyParserStreamingSimple(t *testing.T) { unloadedCh: make(chan any, 1), loaded: make(map[string]*runnerRef), newServerFn: newMockServer(&mock), - getGpuFn: discover.GetGPUInfo, - getCpuFn: discover.GetCPUInfo, + getGpuFn: getGpuFn, + getCpuFn: getCpuFn, reschedDelay: 100 * time.Millisecond, loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool { req.successCh <- &runnerRef{ @@ -607,8 +607,8 @@ func TestChatHarmonyParserStreaming(t *testing.T) { unloadedCh: make(chan any, 1), loaded: make(map[string]*runnerRef), newServerFn: newMockServer(&mock), - getGpuFn: discover.GetGPUInfo, - getCpuFn: discover.GetCPUInfo, + getGpuFn: getGpuFn, + getCpuFn: getCpuFn, reschedDelay: 250 * time.Millisecond, loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool { req.successCh <- &runnerRef{ diff --git a/server/sched.go b/server/sched.go index 74aa406af..b0b33a911 100644 --- a/server/sched.go +++ b/server/sched.go @@ -21,6 +21,7 @@ import ( "github.com/ollama/ollama/format" "github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/llm" + "github.com/ollama/ollama/ml" "github.com/ollama/ollama/types/model" ) @@ -52,8 +53,8 @@ type Scheduler struct { loadFn func(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoList, requireFull bool) bool newServerFn func(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) - getGpuFn func() discover.GpuInfoList - getCpuFn func() discover.GpuInfoList + getGpuFn func(ctx context.Context, runners []discover.FilteredRunnerDiscovery) discover.GpuInfoList + getCpuFn func() discover.GpuInfo reschedDelay time.Duration } @@ -148,7 +149,12 @@ func (s *Scheduler) processPending(ctx context.Context) { s.loadedMu.Lock() runner := s.loaded[pending.model.ModelPath] loadedCount := len(s.loaded) + runnersSnapshot := make([]discover.FilteredRunnerDiscovery, 0, len(s.loaded)) + for _, r := range s.loaded { + runnersSnapshot = append(runnersSnapshot, r) + } s.loadedMu.Unlock() + if runner != nil { if runner.needsReload(ctx, pending) { slog.Debug("reloading", "runner", runner) @@ -166,9 +172,9 @@ func (s *Scheduler) processPending(ctx context.Context) { // Get a refreshed GPU list var gpus discover.GpuInfoList if pending.opts.NumGPU == 0 { - gpus = s.getCpuFn() + gpus = discover.GpuInfoList{s.getCpuFn()} } else { - gpus = s.getGpuFn() + gpus = s.getGpuFn(ctx, runnersSnapshot) } if envconfig.MaxRunners() <= 0 { @@ -343,7 +349,11 @@ func (s *Scheduler) processCompleted(ctx context.Context) { runner.refMu.Unlock() } else { slog.Debug("starting background wait for VRAM recovery", "runner", runner) - finished := runner.waitForVRAMRecovery() + runnersSnapshot := make([]discover.FilteredRunnerDiscovery, 0, len(s.loaded)) + for _, r := range s.loaded { + runnersSnapshot = append(runnersSnapshot, r) + } + finished := runner.waitForVRAMRecovery(runnersSnapshot) runner.unload() delete(s.loaded, runner.modelPath) s.loadedMu.Unlock() @@ -571,7 +581,6 @@ func (runner *runnerRef) unload() { runner.llama.Close() } runner.model = nil - runner.llama = nil runner.Options = nil runner.gpus = nil } @@ -618,7 +627,7 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool // a before and after GPU memory allocation. The returned channel // will be notified when we're done waiting, or have timed out and should // proceed anyway -func (runner *runnerRef) waitForVRAMRecovery() chan any { +func (runner *runnerRef) waitForVRAMRecovery(runners []discover.FilteredRunnerDiscovery) chan any { finished := make(chan any, 1) // CPU or Metal don't need checking, so no waiting required @@ -633,7 +642,7 @@ func (runner *runnerRef) waitForVRAMRecovery() chan any { start := time.Now() // Establish a baseline before we unload - gpusBefore := discover.GetGPUInfo() + gpusBefore := discover.GetGPUInfo(context.Background(), runners) var totalMemoryBefore, freeMemoryBefore uint64 for _, gpu := range gpusBefore { totalMemoryBefore += gpu.TotalMemory @@ -651,7 +660,7 @@ func (runner *runnerRef) waitForVRAMRecovery() chan any { } // Query GPUs, look for free to go back up - gpusNow := discover.GetGPUInfo() + gpusNow := discover.GetGPUInfo(context.Background(), runners) var totalMemoryNow, freeMemoryNow uint64 for _, gpu := range gpusNow { totalMemoryNow += gpu.TotalMemory @@ -695,6 +704,37 @@ func (runner *runnerRef) LogValue() slog.Value { return slog.GroupValue(attrs...) } +// Implements discover.RunnerDiscovery +func (runner *runnerRef) GetPort() int { + if runner.llama != nil { + return runner.llama.Pid() + } + return -1 +} + +func (runner *runnerRef) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo { + if runner.llama != nil { + return runner.llama.GetDeviceInfos(ctx) + } + return nil +} + +func (runner *runnerRef) GetActiveDeviceIDs() []ml.DeviceID { + devIDs := make([]ml.DeviceID, len(runner.gpus)) + for i := range devIDs { + devIDs[i].ID = runner.gpus[i].ID + devIDs[i].Library = runner.gpus[i].Library + } + return devIDs +} + +func (runner *runnerRef) HasExited() bool { + if runner.llama != nil { + return runner.llama.HasExited() + } + return true +} + type ByDurationAndName []*runnerRef func (a ByDurationAndName) Len() int { return len(a) } diff --git a/server/sched_test.go b/server/sched_test.go index 0acd59118..3fabfec32 100644 --- a/server/sched_test.go +++ b/server/sched_test.go @@ -17,6 +17,7 @@ import ( "github.com/ollama/ollama/format" "github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/llm" + "github.com/ollama/ollama/ml" ) func TestMain(m *testing.M) { @@ -150,18 +151,18 @@ func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, vra return b } -func getGpuFn() discover.GpuInfoList { +func getGpuFn(ctx context.Context, runners []discover.FilteredRunnerDiscovery) discover.GpuInfoList { g := discover.GpuInfo{Library: "metal"} g.TotalMemory = 24 * format.GigaByte g.FreeMemory = 12 * format.GigaByte return []discover.GpuInfo{g} } -func getCpuFn() discover.GpuInfoList { +func getCpuFn() discover.GpuInfo { g := discover.GpuInfo{Library: "cpu"} g.TotalMemory = 32 * format.GigaByte g.FreeMemory = 26 * format.GigaByte - return []discover.GpuInfo{g} + return g } func TestRequestsSameModelSameRequest(t *testing.T) { @@ -460,7 +461,7 @@ func TestPrematureExpired(t *testing.T) { // Same model, same request scenario1a := newScenarioRequest(t, ctx, "ollama-model-1a", 10, nil) s := InitScheduler(ctx) - s.getGpuFn = func() discover.GpuInfoList { + s.getGpuFn = func(ctx context.Context, runners []discover.FilteredRunnerDiscovery) discover.GpuInfoList { g := discover.GpuInfo{Library: "metal"} g.TotalMemory = 24 * format.GigaByte g.FreeMemory = 12 * format.GigaByte @@ -732,7 +733,11 @@ func (s *mockLlm) Close() error { s.closeCalled = true return s.closeResp } -func (s *mockLlm) VRAMSize() uint64 { return s.vramSize } -func (s *mockLlm) TotalSize() uint64 { return s.totalSize } -func (s *mockLlm) VRAMByGPU(gpuid string) uint64 { return s.vramByGPU[gpuid] } -func (s *mockLlm) Pid() int { return -1 } +func (s *mockLlm) VRAMSize() uint64 { return s.vramSize } +func (s *mockLlm) TotalSize() uint64 { return s.totalSize } +func (s *mockLlm) VRAMByGPU(gpuid string) uint64 { return s.vramByGPU[gpuid] } +func (s *mockLlm) Pid() int { return -1 } +func (s *mockLlm) GetPort() int { return -1 } +func (s *mockLlm) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo { return nil } +func (s *mockLlm) HasExited() bool { return false } +func (s *mockLlm) GetActiveDeviceIDs() []ml.DeviceID { return nil } From 3566fe0e7b87824769f9a97988d00a077b8e63c3 Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Fri, 19 Sep 2025 12:57:01 -0700 Subject: [PATCH 116/172] timing info for runner --- discover/runner.go | 4 ++-- runner/ollamarunner/runner.go | 5 +++++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/discover/runner.go b/discover/runner.go index 4c0bce75b..5e4e05f95 100644 --- a/discover/runner.go +++ b/discover/runner.go @@ -496,7 +496,7 @@ func bootstrapDevices(ctx context.Context, ollamaLibDirs []string, extraEnvs []s func GetDevicesFromRunner(ctx context.Context, runner BaseRunner) ([]ml.DeviceInfo, error) { var moreDevices []ml.DeviceInfo port := runner.GetPort() - tick := time.Tick(500 * time.Millisecond) + tick := time.Tick(10 * time.Millisecond) for { select { case <-ctx.Done(): @@ -530,7 +530,7 @@ func GetDevicesFromRunner(ctx context.Context, runner BaseRunner) ([]ml.DeviceIn } if resp.StatusCode != 200 { logutil.Trace("runner failed to discover free VRAM", "status", resp.StatusCode, "response", body) - continue + return nil, fmt.Errorf("runner error: %s", string(body)) } if err := json.Unmarshal(body, &moreDevices); err != nil { diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index c86d3c2b9..a97ef7c18 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -1247,6 +1247,8 @@ func (s *Server) info(w http.ResponseWriter, r *http.Request) { m := s.model if m == nil { + startLoad := time.Now() + // Dummy load to get the backend wired up f, err := os.CreateTemp("", "*.bin") if err != nil { @@ -1268,9 +1270,12 @@ func (s *Server) info(w http.ResponseWriter, r *http.Request) { http.Error(w, fmt.Sprintf("failed to initialize baackend: %v", err), http.StatusInternalServerError) return } + slog.Debug("dummy model load took", "duration", time.Since(startLoad)) } + startDevices := time.Now() infos := m.Backend().BackendDevices() + slog.Debug("gathering device infos took", "duration", time.Since(startDevices)) if err := json.NewEncoder(w).Encode(&infos); err != nil { http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) } From c86af47ac0a8788a187c602377fc3911e9ce630f Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Fri, 5 Sep 2025 08:25:03 -0700 Subject: [PATCH 117/172] WIP - wire up Vulkan with the new engine based discovery Not a complete implementation - free VRAM is better, but not accurate on windows --- CMakeLists.txt | 12 +- Dockerfile | 20 +- discover/gpu.go | 38 +- discover/gpu_info_vulkan.c | 241 ----------- discover/gpu_info_vulkan.h | 394 ------------------ discover/runner.go | 5 +- discover/types.go | 12 +- llm/server.go | 3 +- .../ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp | 148 ++++++- scripts/build_windows.ps1 | 7 +- 10 files changed, 192 insertions(+), 688 deletions(-) delete mode 100644 discover/gpu_info_vulkan.c delete mode 100644 discover/gpu_info_vulkan.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 29fbd00cd..94114a709 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -114,7 +114,6 @@ if(CMAKE_HIP_COMPILER) target_compile_definitions(ggml-hip PRIVATE GGML_HIP_NO_VMM) - set(OLLAMA_HIP_INSTALL_DIR ${OLLAMA_INSTALL_DIR}/rocm) install(TARGETS ggml-hip RUNTIME_DEPENDENCY_SET rocm RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT HIP @@ -125,13 +124,13 @@ if(CMAKE_HIP_COMPILER) PRE_INCLUDE_REGEXES hipblas rocblas amdhip64 rocsolver amd_comgr hsa-runtime64 rocsparse tinfo rocprofiler-register drm drm_amdgpu numa elf PRE_EXCLUDE_REGEXES ".*" POST_EXCLUDE_REGEXES "system32" - RUNTIME DESTINATION ${OLLAMA_HIP_INSTALL_DIR} COMPONENT HIP - LIBRARY DESTINATION ${OLLAMA_HIP_INSTALL_DIR} COMPONENT HIP + RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT HIP + LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT HIP ) foreach(HIP_LIB_BIN_INSTALL_DIR IN ITEMS ${HIP_BIN_INSTALL_DIR} ${HIP_LIB_INSTALL_DIR}) if(EXISTS ${HIP_LIB_BIN_INSTALL_DIR}/rocblas) - install(DIRECTORY ${HIP_LIB_BIN_INSTALL_DIR}/rocblas DESTINATION ${OLLAMA_HIP_INSTALL_DIR} COMPONENT HIP) + install(DIRECTORY ${HIP_LIB_BIN_INSTALL_DIR}/rocblas DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT HIP) break() endif() endforeach() @@ -141,12 +140,11 @@ endif() find_package(Vulkan) if(Vulkan_FOUND) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan) - set(OLLAMA_VULKAN_INSTALL_DIR ${OLLAMA_INSTALL_DIR}/vulkan) install(TARGETS ggml-vulkan RUNTIME_DEPENDENCIES PRE_INCLUDE_REGEXES vulkan PRE_EXCLUDE_REGEXES ".*" - RUNTIME DESTINATION ${OLLAMA_VULKAN_INSTALL_DIR} COMPONENT Vulkan - LIBRARY DESTINATION ${OLLAMA_VULKAN_INSTALL_DIR} COMPONENT Vulkan + RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan + LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan ) endif() diff --git a/Dockerfile b/Dockerfile index aeab5947f..7478fbd95 100644 --- a/Dockerfile +++ b/Dockerfile @@ -7,7 +7,7 @@ ARG ROCMVERSION=6.3.3 ARG JETPACK5VERSION=r35.4.1 ARG JETPACK6VERSION=r36.4.0 ARG CMAKEVERSION=3.31.2 -ARG VULKANVERSION=1.4.313.2 +ARG VULKANVERSION=1.4.321.1 # We require gcc v10 minimum. v10.3 has regressions, so the rockylinux 8.5 AppStream has the latest compatible version FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base-amd64 @@ -88,7 +88,7 @@ FROM base AS rocm-6 ENV PATH=/opt/rocm/hcc/bin:/opt/rocm/hip/bin:/opt/rocm/bin:/opt/rocm/hcc/bin:$PATH ARG PARALLEL RUN --mount=type=cache,target=/root/.ccache \ - cmake --preset 'ROCm 6' \ + cmake --preset 'ROCm 6' -DOLLAMA_RUNNER_DIR="rocm" \ && cmake --build --parallel ${PARALLEL} --preset 'ROCm 6' \ && cmake --install build --component HIP --strip --parallel ${PARALLEL} @@ -100,7 +100,7 @@ COPY CMakeLists.txt CMakePresets.json . COPY ml/backend/ggml/ggml ml/backend/ggml/ggml ARG PARALLEL RUN --mount=type=cache,target=/root/.ccache \ - cmake --preset 'JetPack 5' \ + cmake --preset 'JetPack 5' -DOLLAMA_RUNNER_DIR="cuda_jetpack5" \ && cmake --build --parallel ${PARALLEL} --preset 'JetPack 5' \ && cmake --install build --component CUDA --strip --parallel ${PARALLEL} @@ -112,13 +112,13 @@ COPY CMakeLists.txt CMakePresets.json . COPY ml/backend/ggml/ggml ml/backend/ggml/ggml ARG PARALLEL RUN --mount=type=cache,target=/root/.ccache \ - cmake --preset 'JetPack 6' \ + cmake --preset 'JetPack 6' -DOLLAMA_RUNNER_DIR="cuda_jetpack6" \ && cmake --build --parallel ${PARALLEL} --preset 'JetPack 6' \ && cmake --install build --component CUDA --strip --parallel ${PARALLEL} FROM base AS vulkan RUN --mount=type=cache,target=/root/.ccache \ - cmake --preset 'Vulkan' \ + cmake --preset 'Vulkan' -DOLLAMA_RUNNER_DIR="vulkan" \ && cmake --build --parallel --preset 'Vulkan' \ && cmake --install build --component Vulkan --strip --parallel 8 @@ -140,15 +140,15 @@ RUN --mount=type=cache,target=/root/.cache/go-build \ FROM --platform=linux/amd64 scratch AS amd64 # COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/ COPY --from=cuda-12 dist/lib/ollama /lib/ollama/ -COPY --from=cuda-13 dist/lib/ollama/ /lib/ollama/ -COPY --from=vulkan dist/lib/ollama/vulkan /lib/ollama/vulkan +COPY --from=cuda-13 dist/lib/ollama /lib/ollama/ +COPY --from=vulkan dist/lib/ollama /lib/ollama/ FROM --platform=linux/arm64 scratch AS arm64 # COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/ COPY --from=cuda-12 dist/lib/ollama /lib/ollama/ -COPY --from=cuda-13 dist/lib/ollama/ /lib/ollama/ -COPY --from=jetpack-5 dist/lib/ollama /lib/ollama/cuda_jetpack5 -COPY --from=jetpack-6 dist/lib/ollama /lib/ollama/cuda_jetpack6 +COPY --from=cuda-13 dist/lib/ollama /lib/ollama/ +COPY --from=jetpack-5 dist/lib/ollama /lib/ollama/ +COPY --from=jetpack-6 dist/lib/ollama /lib/ollama/ FROM scratch AS rocm COPY --from=rocm-6 dist/lib/ollama /lib/ollama diff --git a/discover/gpu.go b/discover/gpu.go index 872b06c64..0cae79005 100644 --- a/discover/gpu.go +++ b/discover/gpu.go @@ -71,11 +71,9 @@ func devInfoToInfoList(devs []ml.DeviceInfo) GpuInfoList { } else { info.Compute = fmt.Sprintf("%d.%d", dev.ComputeMajor, dev.ComputeMinor) } + // TODO any special processing of Vulkan devices? resp = append(resp, info) } - for _, gpu := range vulkanGPUs { - resp = append(resp, gpu.GpuInfo) - } if len(resp) == 0 { mem, err := GetCPUMem() if err != nil { @@ -93,18 +91,20 @@ func devInfoToInfoList(devs []ml.DeviceInfo) GpuInfoList { // Given the list of GPUs this instantiation is targeted for, // figure out the visible devices environment variable -// -// # If different libraries are detected, the first one is what we use -// -// TODO once we're purely running on the new runner, this level of device -// filtering will no longer be necessary. Instead the runner can be told which -// of the set of GPUs to utilize and handle filtering itself, instead of relying -// on the env var to hide devices from the underlying GPU libraries func (l GpuInfoList) GetVisibleDevicesEnv() []string { if len(l) == 0 { return nil } - return []string{rocmGetVisibleDevicesEnv(l)} + res := []string{} + envVar := rocmGetVisibleDevicesEnv(l) + if envVar != "" { + res = append(res, envVar) + } + envVar = vkGetVisibleDevicesEnv(l) + if envVar != "" { + res = append(res, envVar) + } + return res } func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) string { @@ -134,6 +134,22 @@ func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) string { return envVar + strings.Join(ids, ",") } +func vkGetVisibleDevicesEnv(gpuInfo []GpuInfo) string { + ids := []string{} + for _, info := range gpuInfo { + if info.Library != "VULKAN" { + continue + } + ids = append(ids, info.ID) + + } + if len(ids) == 0 { + return "" + } + envVar := "GGML_VK_VISIBLE_DEVICES=" + return envVar + strings.Join(ids, ",") +} + // GetSystemInfo returns the last cached state of the GPUs on the system func GetSystemInfo() SystemInfo { deviceMu.Lock() diff --git a/discover/gpu_info_vulkan.c b/discover/gpu_info_vulkan.c deleted file mode 100644 index 65033ad8a..000000000 --- a/discover/gpu_info_vulkan.c +++ /dev/null @@ -1,241 +0,0 @@ -#ifndef __APPLE__ -#include "gpu_info_vulkan.h" - -#include - -int is_extension_supported(vk_handle_t* rh, VkPhysicalDevice device, char* extension) { - VkPhysicalDeviceProperties properties = {}; - (*rh->vkGetPhysicalDeviceProperties)(device, &properties); - - uint32_t extensionCount; - (*rh->vkEnumerateDeviceExtensionProperties)(device, NULL, &extensionCount, NULL); - - if (extensionCount == 0) { - return 0; - } - - VkExtensionProperties* extensions = malloc(extensionCount * sizeof(VkExtensionProperties)); - if (extensions == NULL) { - return 0; - } - - (*rh->vkEnumerateDeviceExtensionProperties)(device, NULL, &extensionCount, extensions); - - for (int j = 0; j < extensionCount; j++) { - if (strcmp(extensions[j].extensionName, extension) == 0) { - free(extensions); - return 1; - } - } - - free(extensions); - return 0; -} - -void vk_init(char* vk_lib_path, vk_init_resp_t *resp) { - const int buflen = 256; - char buf[buflen + 1]; - int i; - - struct lookup { - char *s; - void **p; - } l[] = { - {"vkGetPhysicalDeviceProperties", (void *)&resp->ch.vkGetPhysicalDeviceProperties}, - {"vkGetPhysicalDeviceProperties2", (void *)&resp->ch.vkGetPhysicalDeviceProperties2}, - {"vkEnumerateDeviceExtensionProperties", (void *)&resp->ch.vkEnumerateDeviceExtensionProperties}, - {"vkCreateInstance", (void *)&resp->ch.vkCreateInstance}, - {"vkEnumeratePhysicalDevices", (void *)&resp->ch.vkEnumeratePhysicalDevices}, - {"vkGetPhysicalDeviceMemoryProperties2", (void *)&resp->ch.vkGetPhysicalDeviceMemoryProperties2}, - {"vkDestroyInstance", (void *)&resp->ch.vkDestroyInstance}, - {NULL, NULL}, - }; - - resp->ch.vk_handle = LOAD_LIBRARY(vk_lib_path, RTLD_LAZY); - if (!resp->ch.vk_handle) { - char *msg = LOAD_ERR(); - LOG(resp->ch.verbose, "library %s load err: %s\n", vk_lib_path, msg); - snprintf(buf, buflen, - "Unable to load %s library to query for Vulkan GPUs: %s", - vk_lib_path, msg); - free(msg); - resp->err = strdup(buf); - return; - } - - for (i = 0; l[i].s != NULL; i++) { - *l[i].p = LOAD_SYMBOL(resp->ch.vk_handle, l[i].s); - if (!*l[i].p) { - char *msg = LOAD_ERR(); - LOG(resp->ch.verbose, "dlerr: %s\n", msg); - UNLOAD_LIBRARY(resp->ch.vk_handle); - resp->ch.vk_handle = NULL; - snprintf(buf, buflen, "symbol lookup for %s failed: %s", l[i].s, - msg); - free(msg); - resp->err = strdup(buf); - return; - } - } - - VkInstance instance; - - VkApplicationInfo appInfo = {}; - appInfo.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO; - appInfo.pNext = NULL; - appInfo.pApplicationName = "Ollama"; - appInfo.applicationVersion = VK_MAKE_VERSION(1, 0, 0); - appInfo.pEngineName = "No Engine"; - appInfo.engineVersion = VK_MAKE_VERSION(1, 0, 0); - appInfo.apiVersion = VK_API_VERSION_1_2; - - VkInstanceCreateInfo createInfo = {}; - createInfo.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO; - createInfo.pNext = NULL; - createInfo.flags = 0; - createInfo.enabledExtensionCount = 1; - const char* extensions[] = { VK_KHR_GET_PHYSICAL_DEVICE_PROPERTIES_2_EXTENSION_NAME }; - createInfo.ppEnabledExtensionNames = extensions; - createInfo.pApplicationInfo = &appInfo; - - VkResult result = (*resp->ch.vkCreateInstance)(&createInfo, NULL, &instance); - if (result != VK_SUCCESS) { - resp->err = strdup("failed to create instance"); - return; - } - - uint32_t deviceCount; - result = (*resp->ch.vkEnumeratePhysicalDevices)(instance, &deviceCount, NULL); - if (result != VK_SUCCESS) { - resp->err = strdup("failed to enumerate physical devices"); - return; - } - - resp->err = NULL; - resp->ch.vk = instance; - resp->ch.num_devices = deviceCount; - resp->num_devices = deviceCount; -} - -int vk_check_flash_attention(vk_handle_t rh, int i) { - VkInstance instance = rh.vk; - uint32_t deviceCount = rh.num_devices; - - VkPhysicalDevice* devices = malloc(deviceCount * sizeof(VkPhysicalDevice)); - if (devices == NULL) { - return 0; - } - - VkResult result = (*rh.vkEnumeratePhysicalDevices)(instance, &deviceCount, devices); - if (result != VK_SUCCESS) { - free(devices); - return 0; - } - - VkPhysicalDeviceProperties properties = {}; - (*rh.vkGetPhysicalDeviceProperties)(devices[i], &properties); - - int supports_nv_coopmat2 = is_extension_supported(&rh, devices[i], VK_NV_COOPERATIVE_MATRIX_2_EXTENSION_NAME); - if (!supports_nv_coopmat2) { - free(devices); - return 1; - } - - free(devices); - return 0; -} - -void vk_check_vram(vk_handle_t rh, int i, mem_info_t *resp) { - VkInstance instance = rh.vk; - uint32_t deviceCount = rh.num_devices; - - VkPhysicalDevice* devices = malloc(deviceCount * sizeof(VkPhysicalDevice)); - if (devices == NULL) { - resp->err = strdup("memory allocation failed for devices array"); - return; - } - - VkResult result = (*rh.vkEnumeratePhysicalDevices)(instance, &deviceCount, devices); - if (result != VK_SUCCESS) { - free(devices); - resp->err = strdup("failed to enumerate physical devices"); - return; - } - - VkPhysicalDeviceProperties properties = {}; - (*rh.vkGetPhysicalDeviceProperties)(devices[i], &properties); - - int supports_budget = is_extension_supported(&rh, devices[i], VK_EXT_MEMORY_BUDGET_EXTENSION_NAME); - if (!supports_budget) { - free(devices); - resp->err = strdup("device does not support memory budget"); - return; - } - - if (properties.deviceType == VK_PHYSICAL_DEVICE_TYPE_CPU) { - free(devices); - resp->err = strdup("device is a CPU"); - return; - } - - VkPhysicalDeviceProperties2 device_props2 = {}; - device_props2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2; - - VkPhysicalDeviceIDProperties id_props = {}; - id_props.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_ID_PROPERTIES; - - device_props2.pNext = &id_props; - (*rh.vkGetPhysicalDeviceProperties2)(devices[i], &device_props2); - - VkPhysicalDeviceMemoryBudgetPropertiesEXT physical_device_memory_budget_properties = {}; - physical_device_memory_budget_properties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MEMORY_BUDGET_PROPERTIES_EXT; - physical_device_memory_budget_properties.pNext = NULL; - - VkPhysicalDeviceMemoryProperties2 device_memory_properties = {}; - device_memory_properties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MEMORY_PROPERTIES_2; - device_memory_properties.pNext = &physical_device_memory_budget_properties; - - (*rh.vkGetPhysicalDeviceMemoryProperties2)(devices[i], &device_memory_properties); - - VkDeviceSize device_memory_total_size = 0; - VkDeviceSize device_memory_heap_budget = 0; - - for (uint32_t j = 0; j < device_memory_properties.memoryProperties.memoryHeapCount; j++) { - VkMemoryHeap heap = device_memory_properties.memoryProperties.memoryHeaps[j]; - if (heap.flags & VK_MEMORY_HEAP_DEVICE_LOCAL_BIT) { - device_memory_total_size += heap.size; - device_memory_heap_budget += physical_device_memory_budget_properties.heapBudget[j]; - } - } - - free(devices); - - resp->err = NULL; - snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", i); - resp->gpu_name[GPU_NAME_LEN - 1] = '\0'; - strncpy(&resp->gpu_name[0], properties.deviceName, GPU_NAME_LEN - 1); - resp->gpu_name[GPU_NAME_LEN - 1] = '\0'; - const uint8_t *uuid = id_props.deviceUUID; - snprintf(&resp->gpu_id[0], GPU_ID_LEN, - "GPU-%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x", - uuid[0], uuid[1], uuid[2], uuid[3], - uuid[4], uuid[5], - uuid[6], uuid[7], - uuid[8], uuid[9], - uuid[10], uuid[11], uuid[12], uuid[13], uuid[14], uuid[15] - ); - resp->total = (uint64_t) device_memory_total_size; - resp->free = (uint64_t) device_memory_heap_budget; - resp->major = VK_API_VERSION_MAJOR(properties.apiVersion); - resp->minor = VK_API_VERSION_MINOR(properties.apiVersion); - resp->patch = VK_API_VERSION_PATCH(properties.apiVersion); -} - -void vk_release(vk_handle_t rh) { - LOG(rh.verbose, "releasing vulkan library\n"); - (*rh.vkDestroyInstance)(rh.vk, NULL); - UNLOAD_LIBRARY(rh.vk_handle); - rh.vk_handle = NULL; -} - -#endif // __APPLE__ diff --git a/discover/gpu_info_vulkan.h b/discover/gpu_info_vulkan.h deleted file mode 100644 index 42e4b1610..000000000 --- a/discover/gpu_info_vulkan.h +++ /dev/null @@ -1,394 +0,0 @@ -#ifndef __APPLE__ -#ifndef __GPU_INFO_VULKAN_H__ -#define __GPU_INFO_VULKAN_H__ - -#include "gpu_info.h" - -#define VK_DEFINE_HANDLE(object) typedef struct object##_T* object; -VK_DEFINE_HANDLE(VkInstance) -VK_DEFINE_HANDLE(VkPhysicalDevice) - -#define VK_MAX_EXTENSION_NAME_SIZE 256U -#define VK_MAX_DESCRIPTION_SIZE 256U -#define VK_LUID_SIZE 8U -#define VK_UUID_SIZE 16U -#define VK_MAX_MEMORY_TYPES 32U -#define VK_MAX_MEMORY_HEAPS 16U -#define VK_MAX_PHYSICAL_DEVICE_NAME_SIZE 256U - -#define VK_MAKE_VERSION(major, minor, patch) \ - ((((uint32_t)(major)) << 22U) | (((uint32_t)(minor)) << 12U) | ((uint32_t)(patch))) - -#define VK_MAKE_API_VERSION(variant, major, minor, patch) \ - ((((uint32_t)(variant)) << 29U) | (((uint32_t)(major)) << 22U) | (((uint32_t)(minor)) << 12U) | ((uint32_t)(patch))) - -#define VK_API_VERSION_1_0 VK_MAKE_API_VERSION(0, 1, 0, 0)// Patch version should always be set to 0 -#define VK_API_VERSION_1_1 VK_MAKE_API_VERSION(0, 1, 1, 0)// Patch version should always be set to 0 -#define VK_API_VERSION_1_2 VK_MAKE_API_VERSION(0, 1, 2, 0)// Patch version should always be set to 0 -#define VK_API_VERSION_1_3 VK_MAKE_API_VERSION(0, 1, 3, 0)// Patch version should always be set to 0 -#define VK_API_VERSION_MAJOR(version) (((uint32_t)(version) >> 22U) & 0x7FU) -#define VK_API_VERSION_MINOR(version) (((uint32_t)(version) >> 12U) & 0x3FFU) -#define VK_API_VERSION_PATCH(version) ((uint32_t)(version) & 0xFFFU) - -#define VK_KHR_GET_PHYSICAL_DEVICE_PROPERTIES_2_EXTENSION_NAME "VK_KHR_get_physical_device_properties2" -#define VK_NV_COOPERATIVE_MATRIX_2_EXTENSION_NAME "VK_NV_cooperative_matrix2" -#define VK_EXT_MEMORY_BUDGET_EXTENSION_NAME "VK_EXT_memory_budget" - -typedef uint32_t VkFlags; -typedef uint32_t VkBool32; -typedef uint64_t VkDeviceSize; -typedef uint32_t VkSampleMask; -typedef VkFlags VkSampleCountFlags; -typedef VkFlags VkMemoryPropertyFlags; -typedef VkFlags VkMemoryHeapFlags; -typedef VkFlags VkInstanceCreateFlags; - -typedef enum VkResult { - VK_SUCCESS = 0, - VK_NOT_READY = 1, - VK_TIMEOUT = 2, - VK_EVENT_SET = 3, - VK_EVENT_RESET = 4, - VK_INCOMPLETE = 5, - VK_ERROR_OUT_OF_HOST_MEMORY = -1, - VK_ERROR_OUT_OF_DEVICE_MEMORY = -2, - VK_ERROR_INITIALIZATION_FAILED = -3, - VK_ERROR_DEVICE_LOST = -4, - VK_ERROR_MEMORY_MAP_FAILED = -5, - VK_ERROR_LAYER_NOT_PRESENT = -6, - VK_ERROR_EXTENSION_NOT_PRESENT = -7, - VK_ERROR_FEATURE_NOT_PRESENT = -8, - VK_ERROR_INCOMPATIBLE_DRIVER = -9, - VK_ERROR_TOO_MANY_OBJECTS = -10, - VK_ERROR_FORMAT_NOT_SUPPORTED = -11, - VK_ERROR_FRAGMENTED_POOL = -12, - VK_ERROR_UNKNOWN = -13, - VK_ERROR_OUT_OF_POOL_MEMORY = -1000069000, - VK_ERROR_INVALID_EXTERNAL_HANDLE = -1000072003, - VK_ERROR_FRAGMENTATION = -1000168000, - VK_ERROR_INVALID_OPAQUE_CAPTURE_ADDRESS = -1000257000, - VK_PIPELINE_COMPILE_REQUIRED = 1000297000, - VK_ERROR_SURFACE_LOST_KHR = -1000000000, - VK_ERROR_NATIVE_WINDOW_IN_USE_KHR = -1000000001, - VK_SUBOPTIMAL_KHR = 1000001003, - VK_ERROR_OUT_OF_DATE_KHR = -1000001004, - VK_ERROR_INCOMPATIBLE_DISPLAY_KHR = -1000003001, - VK_ERROR_VALIDATION_FAILED_EXT = -1000011001, - VK_ERROR_INVALID_SHADER_NV = -1000012000, - VK_ERROR_IMAGE_USAGE_NOT_SUPPORTED_KHR = -1000158000, - VK_ERROR_VIDEO_PICTURE_LAYOUT_NOT_SUPPORTED_KHR = -1000158001, - VK_ERROR_VIDEO_PROFILE_OPERATION_NOT_SUPPORTED_KHR = -1000158002, - VK_ERROR_VIDEO_PROFILE_FORMAT_NOT_SUPPORTED_KHR = -1000158003, - VK_ERROR_VIDEO_PROFILE_CODEC_NOT_SUPPORTED_KHR = -1000158004, - VK_ERROR_VIDEO_STD_VERSION_NOT_SUPPORTED_KHR = -1000158005, - VK_ERROR_INVALID_DRM_FORMAT_MODIFIER_PLANE_LAYOUT_EXT = -1000158006, - VK_ERROR_NOT_PERMITTED_KHR = -1000174001, - VK_ERROR_FULL_SCREEN_EXCLUSIVE_MODE_LOST_EXT = -1000255000, - VK_THREAD_IDLE_KHR = 1000268000, - VK_THREAD_DONE_KHR = 1000268001, - VK_OPERATION_DEFERRED_KHR = 1000268002, - VK_OPERATION_NOT_DEFERRED_KHR = 1000268003, - VK_ERROR_COMPRESSION_EXHAUSTED_EXT = -1000338000, - VK_RESULT_MAX_ENUM = 0x7FFFFFFF -} VkResult; - -typedef enum VkStructureType { - VK_STRUCTURE_TYPE_APPLICATION_INFO = 0, - VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO = 1, - VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2 = 1000059001, - VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MEMORY_PROPERTIES_2 = 1000059006, - VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_ID_PROPERTIES = 1000071004, - VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MEMORY_BUDGET_PROPERTIES_EXT = 1000237000, - VK_STRUCTURE_TYPE_MAX_ENUM = 0x7FFFFFFF -} VkStructureType; - -typedef enum VkPhysicalDeviceType { - VK_PHYSICAL_DEVICE_TYPE_OTHER = 0, - VK_PHYSICAL_DEVICE_TYPE_INTEGRATED_GPU = 1, - VK_PHYSICAL_DEVICE_TYPE_DISCRETE_GPU = 2, - VK_PHYSICAL_DEVICE_TYPE_VIRTUAL_GPU = 3, - VK_PHYSICAL_DEVICE_TYPE_CPU = 4, - VK_PHYSICAL_DEVICE_TYPE_MAX_ENUM = 0x7FFFFFFF -} VkPhysicalDeviceType; - -typedef enum VkSystemAllocationScope { - VK_SYSTEM_ALLOCATION_SCOPE_COMMAND = 0, - VK_SYSTEM_ALLOCATION_SCOPE_OBJECT = 1, - VK_SYSTEM_ALLOCATION_SCOPE_CACHE = 2, - VK_SYSTEM_ALLOCATION_SCOPE_DEVICE = 3, - VK_SYSTEM_ALLOCATION_SCOPE_INSTANCE = 4, - VK_SYSTEM_ALLOCATION_SCOPE_MAX_ENUM = 0x7FFFFFFF -} VkSystemAllocationScope; - -typedef enum VkInternalAllocationType { - VK_INTERNAL_ALLOCATION_TYPE_EXECUTABLE = 0, - VK_INTERNAL_ALLOCATION_TYPE_NON_EXECUTABLE = 1, - VK_INTERNAL_ALLOCATION_TYPE_MAX_ENUM = 0x7FFFFFFF -} VkInternalAllocationType; - -typedef enum VkMemoryHeapFlagBits { - VK_MEMORY_HEAP_DEVICE_LOCAL_BIT = 0x00000001, - VK_MEMORY_HEAP_MULTI_INSTANCE_BIT = 0x00000002, - VK_MEMORY_HEAP_TILE_MEMORY_BIT_QCOM = 0x00000008, - VK_MEMORY_HEAP_MULTI_INSTANCE_BIT_KHR = VK_MEMORY_HEAP_MULTI_INSTANCE_BIT, - VK_MEMORY_HEAP_FLAG_BITS_MAX_ENUM = 0x7FFFFFFF -} VkMemoryHeapFlagBits; - -typedef struct VkExtensionProperties { - char extensionName[VK_MAX_EXTENSION_NAME_SIZE]; - uint32_t specVersion; -} VkExtensionProperties; - -typedef struct VkPhysicalDeviceLimits { - uint32_t maxImageDimension1D; - uint32_t maxImageDimension2D; - uint32_t maxImageDimension3D; - uint32_t maxImageDimensionCube; - uint32_t maxImageArrayLayers; - uint32_t maxTexelBufferElements; - uint32_t maxUniformBufferRange; - uint32_t maxStorageBufferRange; - uint32_t maxPushConstantsSize; - uint32_t maxMemoryAllocationCount; - uint32_t maxSamplerAllocationCount; - VkDeviceSize bufferImageGranularity; - VkDeviceSize sparseAddressSpaceSize; - uint32_t maxBoundDescriptorSets; - uint32_t maxPerStageDescriptorSamplers; - uint32_t maxPerStageDescriptorUniformBuffers; - uint32_t maxPerStageDescriptorStorageBuffers; - uint32_t maxPerStageDescriptorSampledImages; - uint32_t maxPerStageDescriptorStorageImages; - uint32_t maxPerStageDescriptorInputAttachments; - uint32_t maxPerStageResources; - uint32_t maxDescriptorSetSamplers; - uint32_t maxDescriptorSetUniformBuffers; - uint32_t maxDescriptorSetUniformBuffersDynamic; - uint32_t maxDescriptorSetStorageBuffers; - uint32_t maxDescriptorSetStorageBuffersDynamic; - uint32_t maxDescriptorSetSampledImages; - uint32_t maxDescriptorSetStorageImages; - uint32_t maxDescriptorSetInputAttachments; - uint32_t maxVertexInputAttributes; - uint32_t maxVertexInputBindings; - uint32_t maxVertexInputAttributeOffset; - uint32_t maxVertexInputBindingStride; - uint32_t maxVertexOutputComponents; - uint32_t maxTessellationGenerationLevel; - uint32_t maxTessellationPatchSize; - uint32_t maxTessellationControlPerVertexInputComponents; - uint32_t maxTessellationControlPerVertexOutputComponents; - uint32_t maxTessellationControlPerPatchOutputComponents; - uint32_t maxTessellationControlTotalOutputComponents; - uint32_t maxTessellationEvaluationInputComponents; - uint32_t maxTessellationEvaluationOutputComponents; - uint32_t maxGeometryShaderInvocations; - uint32_t maxGeometryInputComponents; - uint32_t maxGeometryOutputComponents; - uint32_t maxGeometryOutputVertices; - uint32_t maxGeometryTotalOutputComponents; - uint32_t maxFragmentInputComponents; - uint32_t maxFragmentOutputAttachments; - uint32_t maxFragmentDualSrcAttachments; - uint32_t maxFragmentCombinedOutputResources; - uint32_t maxComputeSharedMemorySize; - uint32_t maxComputeWorkGroupCount[3]; - uint32_t maxComputeWorkGroupInvocations; - uint32_t maxComputeWorkGroupSize[3]; - uint32_t subPixelPrecisionBits; - uint32_t subTexelPrecisionBits; - uint32_t mipmapPrecisionBits; - uint32_t maxDrawIndexedIndexValue; - uint32_t maxDrawIndirectCount; - float maxSamplerLodBias; - float maxSamplerAnisotropy; - uint32_t maxViewports; - uint32_t maxViewportDimensions[2]; - float viewportBoundsRange[2]; - uint32_t viewportSubPixelBits; - size_t minMemoryMapAlignment; - VkDeviceSize minTexelBufferOffsetAlignment; - VkDeviceSize minUniformBufferOffsetAlignment; - VkDeviceSize minStorageBufferOffsetAlignment; - int32_t minTexelOffset; - uint32_t maxTexelOffset; - int32_t minTexelGatherOffset; - uint32_t maxTexelGatherOffset; - float minInterpolationOffset; - float maxInterpolationOffset; - uint32_t subPixelInterpolationOffsetBits; - uint32_t maxFramebufferWidth; - uint32_t maxFramebufferHeight; - uint32_t maxFramebufferLayers; - VkSampleCountFlags framebufferColorSampleCounts; - VkSampleCountFlags framebufferDepthSampleCounts; - VkSampleCountFlags framebufferStencilSampleCounts; - VkSampleCountFlags framebufferNoAttachmentsSampleCounts; - uint32_t maxColorAttachments; - VkSampleCountFlags sampledImageColorSampleCounts; - VkSampleCountFlags sampledImageIntegerSampleCounts; - VkSampleCountFlags sampledImageDepthSampleCounts; - VkSampleCountFlags sampledImageStencilSampleCounts; - VkSampleCountFlags storageImageSampleCounts; - uint32_t maxSampleMaskWords; - VkBool32 timestampComputeAndGraphics; - float timestampPeriod; - uint32_t maxClipDistances; - uint32_t maxCullDistances; - uint32_t maxCombinedClipAndCullDistances; - uint32_t discreteQueuePriorities; - float pointSizeRange[2]; - float lineWidthRange[2]; - float pointSizeGranularity; - float lineWidthGranularity; - VkBool32 strictLines; - VkBool32 standardSampleLocations; - VkDeviceSize optimalBufferCopyOffsetAlignment; - VkDeviceSize optimalBufferCopyRowPitchAlignment; - VkDeviceSize nonCoherentAtomSize; -} VkPhysicalDeviceLimits; - -typedef struct VkPhysicalDeviceSparseProperties { - VkBool32 residencyStandard2DBlockShape; - VkBool32 residencyStandard2DMultisampleBlockShape; - VkBool32 residencyStandard3DBlockShape; - VkBool32 residencyAlignedMipSize; - VkBool32 residencyNonResidentStrict; -} VkPhysicalDeviceSparseProperties; - -typedef struct VkPhysicalDeviceProperties { - uint32_t apiVersion; - uint32_t driverVersion; - uint32_t vendorID; - uint32_t deviceID; - uint32_t deviceType; - char deviceName[VK_MAX_PHYSICAL_DEVICE_NAME_SIZE]; - uint8_t pipelineCacheUUID[VK_UUID_SIZE]; - VkPhysicalDeviceLimits limits; - VkPhysicalDeviceSparseProperties sparseProperties; -} VkPhysicalDeviceProperties; - -typedef struct VkPhysicalDeviceProperties2 { - VkStructureType sType; - void* pNext; - VkPhysicalDeviceProperties properties; -} VkPhysicalDeviceProperties2; - -typedef struct VkPhysicalDeviceIDProperties { - VkStructureType sType; - void* pNext; - uint8_t deviceUUID[VK_UUID_SIZE]; - uint8_t driverUUID[VK_UUID_SIZE]; - uint8_t deviceLUID[VK_LUID_SIZE]; - uint32_t deviceNodeMask; - VkBool32 deviceLUIDValid; -} VkPhysicalDeviceIDProperties; - -typedef struct VkMemoryType { - VkMemoryPropertyFlags propertyFlags; - uint32_t heapIndex; -} VkMemoryType; - -typedef struct VkMemoryHeap { - VkDeviceSize size; - VkMemoryHeapFlags flags; -} VkMemoryHeap; - -typedef struct VkPhysicalDeviceMemoryProperties { - uint32_t memoryTypeCount; - VkMemoryType memoryTypes[VK_MAX_MEMORY_TYPES]; - uint32_t memoryHeapCount; - VkMemoryHeap memoryHeaps[VK_MAX_MEMORY_HEAPS]; -} VkPhysicalDeviceMemoryProperties; - -typedef struct VkPhysicalDeviceMemoryProperties2 { - VkStructureType sType; - void* pNext; - VkPhysicalDeviceMemoryProperties memoryProperties; -} VkPhysicalDeviceMemoryProperties2; - -typedef struct VkPhysicalDeviceMemoryBudgetPropertiesEXT { - VkStructureType sType; - void* pNext; - VkDeviceSize heapBudget[VK_MAX_MEMORY_HEAPS]; - VkDeviceSize heapUsage[VK_MAX_MEMORY_HEAPS]; -} VkPhysicalDeviceMemoryBudgetPropertiesEXT; - -typedef struct VkApplicationInfo { - VkStructureType sType; - const void* pNext; - const char* pApplicationName; - uint32_t applicationVersion; - const char* pEngineName; - uint32_t engineVersion; - uint32_t apiVersion; -} VkApplicationInfo; - -typedef struct VkInstanceCreateInfo { - VkStructureType sType; - const void* pNext; - VkInstanceCreateFlags flags; - const VkApplicationInfo* pApplicationInfo; - uint32_t enabledLayerCount; - const char* const* ppEnabledLayerNames; - uint32_t enabledExtensionCount; - const char* const* ppEnabledExtensionNames; -} VkInstanceCreateInfo; - -typedef struct VkAllocationCallbacks { - void* pUserData; - void* (*pfnAllocation)(void* pUserData, size_t size, size_t alignment, VkSystemAllocationScope allocationScope); - void* (*pfnReallocation)(void* pUserData, void* pOriginal, size_t size, size_t alignment, VkSystemAllocationScope allocationScope); - void (*pfnFree)(void* pUserData, void* pMemory); - void (*pfnInternalAllocation)(void* pUserData, size_t size, VkInternalAllocationType allocationType, VkSystemAllocationScope allocationScope); - void (*pfnInternalFree)(void* pUserData, size_t size, VkInternalAllocationType allocationType, VkSystemAllocationScope allocationScope); -} VkAllocationCallbacks; - -typedef struct { - void* vk_handle; - uint16_t verbose; - - VkInstance vk; - int num_devices; - - void (*vkGetPhysicalDeviceProperties)( - VkPhysicalDevice physicalDevice, - VkPhysicalDeviceProperties* pProperties); - void (*vkGetPhysicalDeviceProperties2)( - VkPhysicalDevice physicalDevice, - VkPhysicalDeviceProperties2* pProperties); - VkResult (*vkEnumerateDeviceExtensionProperties)( - VkPhysicalDevice physicalDevice, - const char* pLayerName, - uint32_t* pPropertyCount, - VkExtensionProperties* pProperties); - VkResult (*vkCreateInstance)( - const VkInstanceCreateInfo* pCreateInfo, - const VkAllocationCallbacks* pAllocator, - VkInstance* pInstance); - VkResult (*vkEnumeratePhysicalDevices)( - VkInstance instance, - uint32_t* pPhysicalDeviceCount, - VkPhysicalDevice* pPhysicalDevices); - void (*vkGetPhysicalDeviceMemoryProperties2)( - VkPhysicalDevice physicalDevice, - VkPhysicalDeviceMemoryProperties2* pMemoryProperties); - void (*vkDestroyInstance)( - VkInstance instance, - const VkAllocationCallbacks* pAllocator); -} vk_handle_t; - -typedef struct vk_init_resp -{ - char *err; // If err is non-null handle is invalid - int num_devices; - vk_handle_t ch; -} vk_init_resp_t; - -void vk_init(char* vk_lib_path, vk_init_resp_t *resp); -void vk_check_vram(vk_handle_t rh, int i, mem_info_t *resp); -int vk_check_flash_attention(vk_handle_t rh, int i); -void vk_release(vk_handle_t rh); - -#endif -#endif \ No newline at end of file diff --git a/discover/runner.go b/discover/runner.go index 5e4e05f95..c9bbef7f5 100644 --- a/discover/runner.go +++ b/discover/runner.go @@ -92,6 +92,7 @@ func GPUDevices(ctx context.Context, runners []FilteredRunnerDiscovery) []ml.Dev // are enumerated, but not actually supported. // We run this in serial to avoid potentially initializing a GPU multiple // times concurrently leading to memory contention + // TODO refactor so we group the lib dirs and do serial per version, but parallel for different libs for dir := range libDirs { var dirs []string if dir == "" { @@ -125,8 +126,10 @@ func GPUDevices(ctx context.Context, runners []FilteredRunnerDiscovery) []ml.Dev } else { envVar = "ROCR_VISIBLE_DEVICES" } - } else { + } else if devices[i].Library == "CUDA" { envVar = "CUDA_VISIBLE_DEVICES" + } else if devices[i].Library == "VULKAN" { + envVar = "GGML_VK_VISIBLE_DEVICES" } extraEnvs := []string{ diff --git a/discover/types.go b/discover/types.go index feb8c08e0..5a9ce1865 100644 --- a/discover/types.go +++ b/discover/types.go @@ -36,10 +36,11 @@ type GpuInfo struct { // TODO better name maybe "InferenceProcessor"? UnreliableFreeMemory bool // GPU information - ID string `json:"gpu_id"` // string to use for selection of this specific GPU - filterID string // AMD Workaround: The numeric ID of the device used to filter out other devices - Name string `json:"name"` // user friendly name if available - Compute string `json:"compute"` // Compute Capability or gfx + ID string `json:"gpu_id"` // string to use for selection of this specific GPU + filterID string // AMD Workaround: The numeric ID of the device used to filter out other devices + Name string `json:"name"` // user friendly name if available + Compute string `json:"compute"` // Compute Capability or gfx + FlashAttention bool `json:"flash_attention"` // is flash attention supported // Driver Information - TODO no need to put this on each GPU DriverMajor int `json:"driver_major,omitempty"` @@ -174,7 +175,8 @@ func (l GpuInfoList) FlashAttentionSupported() bool { supportsFA := gpu.Library == "cpu" || gpu.Name == "Metal" || (gpu.Library == "CUDA" && gpu.DriverMajor >= 7) || - gpu.Library == "HIP" + gpu.Library == "HIP" || + gpu.Library == "VULKAN" if !supportsFA { return false diff --git a/llm/server.go b/llm/server.go index d3438e6c2..8757d4d56 100644 --- a/llm/server.go +++ b/llm/server.go @@ -561,10 +561,11 @@ func (s *llamaServer) Load(ctx context.Context, gpus discover.GpuInfoList, requi // Windows CUDA should not use mmap for best performance // Linux with a model larger than free space, mmap leads to thrashing // For CPU loads we want the memory to be allocated, not FS cache - if (runtime.GOOS == "windows" && gpus[0].Library == "cuda" && s.options.UseMMap == nil) || + if (runtime.GOOS == "windows" && gpus[0].Library == "CUDA" && s.options.UseMMap == nil) || (runtime.GOOS == "linux" && systemInfo.System.FreeMemory < s.estimate.TotalSize && s.options.UseMMap == nil) || (gpus[0].Library == "vulkan" && s.options.UseMMap == nil) || (gpus[0].Library == "cpu" && s.options.UseMMap == nil) || + (gpus[0].Library == "VULKAN" && s.options.UseMMap == nil) || (s.options.UseMMap != nil && !*s.options.UseMMap) { s.loadRequest.UseMmap = false } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp index d73cdf176..3b0a0891e 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -4123,7 +4123,6 @@ static void ggml_vk_instance_init() { } } else { std::vector devices = vk_instance.instance.enumeratePhysicalDevices(); - // If no vulkan devices are found, return early if (devices.empty()) { GGML_LOG_INFO("ggml_vulkan: No devices found.\n"); @@ -10821,14 +10820,90 @@ std::string ggml_backend_vk_get_device_id(int device) { return ggml_vk_get_device_id(dev_idx); } -void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total) { - GGML_ASSERT(device < (int) vk_instance.device_indices.size()); +////////////////////////// - vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device]]; +struct ggml_backend_vk_device_context { + size_t device; + std::string name; + std::string description; + std::string id; + std::string uuid; + int major; + int minor; + int driver_major; + int driver_minor; + int integrated; + int pci_bus_id; + int pci_device_id; + int pci_domain_id; +}; + +void ggml_backend_vk_get_device_memory(ggml_backend_vk_device_context *ctx, size_t * free, size_t * total) { + GGML_ASSERT(ctx->device < (int) vk_instance.device_indices.size()); + + vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[ctx->device]]; vk::PhysicalDeviceMemoryProperties memprops = vkdev.getMemoryProperties(); + vk::PhysicalDeviceProperties2 props2; + vkdev.getProperties2(&props2); - for (const vk::MemoryHeap& heap : memprops.memoryHeaps) { + // Use vendor specific management libraries for best VRAM reporting if available + switch (props2.properties.vendorID) { + case VK_VENDOR_ID_AMD: + if (ggml_hip_mgmt_init() == 0) { + int status = ggml_hip_get_device_memory(ctx->pci_bus_id, ctx->pci_device_id, free, total); + if (status == 0) { + GGML_LOG_DEBUG("%s utilizing ADLX memory reporting free: %zu total: %zu\n", __func__, *free, *total); + ggml_hip_mgmt_release(); + return; + } + ggml_hip_mgmt_release(); + } + break; + case VK_VENDOR_ID_NVIDIA: + if (ggml_nvml_init() == 0) { + int status = ggml_nvml_get_device_memory(ctx->uuid.c_str(), free, total); + if (status == 0) { + GGML_LOG_DEBUG("%s utilizing NVML memory reporting free: %zu total: %zu\n", __func__, *free, *total); + ggml_nvml_release(); + return; + } + ggml_nvml_release(); + } + break; + } + // else fallback to memory budget if supported + + *total = 0; + *free = 0; + vk::PhysicalDeviceMemoryBudgetPropertiesEXT mem_budget_props; + vk::PhysicalDeviceMemoryProperties2 memprops2; + memprops2.pNext = &mem_budget_props; + vkdev.getMemoryProperties2(&memprops2); + for (int i = 0; i < memprops2.memoryProperties.memoryHeapCount; i++) { + if (memprops2.memoryProperties.memoryHeaps[i].flags & vk::MemoryHeapFlagBits::eDeviceLocal) { + *total += memprops2.memoryProperties.memoryHeaps[i].size; + } else if (ctx->integrated) { + // Include shared memory on iGPUs + *total += memprops2.memoryProperties.memoryHeaps[i].size; + } + } + for (int i = 0; i < memprops2.memoryProperties.memoryHeapCount; i++) { + if (memprops2.memoryProperties.memoryHeaps[i].flags & vk::MemoryHeapFlagBits::eDeviceLocal) { + *free += mem_budget_props.heapBudget[i]; + } else if (ctx->integrated) { + *free += mem_budget_props.heapBudget[i]; + } + } + if (*total > 0 && *free > 0) { + return; + } else if (*total > 0) { + *free = *total; + return; + } + + // else just report the physical memory + for (const vk::MemoryHeap& heap : memprops2.memoryProperties.memoryHeaps) { if (heap.flags & vk::MemoryHeapFlagBits::eDeviceLocal) { *total = heap.size; *free = heap.size; @@ -10837,14 +10912,6 @@ void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total } } -////////////////////////// - -struct ggml_backend_vk_device_context { - size_t device; - std::string name; - std::string description; - std::string id; -}; static const char * ggml_backend_vk_device_get_name(ggml_backend_dev_t dev) { ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; @@ -10863,7 +10930,7 @@ static const char * ggml_backend_vk_device_get_id(ggml_backend_dev_t dev) { static void ggml_backend_vk_device_get_memory(ggml_backend_dev_t device, size_t * free, size_t * total) { ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)device->context; - ggml_backend_vk_get_device_memory(ctx->device, free, total); + ggml_backend_vk_get_device_memory(ctx, free, total); } static ggml_backend_buffer_type_t ggml_backend_vk_device_get_buffer_type(ggml_backend_dev_t dev) { @@ -10881,6 +10948,7 @@ static enum ggml_backend_dev_type ggml_backend_vk_device_get_type(ggml_backend_d return GGML_BACKEND_DEVICE_TYPE_GPU; } +#define GGML_VULKAN_NAME "VULKAN" static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) { props->name = ggml_backend_vk_device_get_name(dev); props->description = ggml_backend_vk_device_get_description(dev); @@ -10893,6 +10961,18 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml /* .buffer_from_host_ptr = */ false, /* .events = */ false, }; + + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + props->id = ctx->id.c_str(); + props->compute_major = ctx->major; + props->compute_minor = ctx->minor; + props->driver_major = ctx->driver_major; + props->driver_minor = ctx->driver_minor; + props->integrated = ctx->integrated; + props->pci_bus_id = ctx->pci_bus_id; + props->pci_device_id = ctx->pci_device_id; + props->pci_domain_id = ctx->pci_domain_id; + props->library = GGML_VULKAN_NAME; } static ggml_backend_t ggml_backend_vk_device_init(ggml_backend_dev_t dev, const char * params) { @@ -11296,6 +11376,8 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, static std::mutex mutex; std::lock_guard lock(mutex); if (!initialized) { + std::vector vk_devices = vk_instance.instance.enumeratePhysicalDevices(); + for (int i = 0; i < ggml_backend_vk_get_device_count(); i++) { ggml_backend_vk_device_context * ctx = new ggml_backend_vk_device_context; char desc[256]; @@ -11309,6 +11391,44 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, /* .reg = */ reg, /* .context = */ ctx, }); + + // Gather additional information about the device + int dev_idx = vk_instance.device_indices[i]; + vk::PhysicalDeviceProperties props1; + vk_devices[dev_idx].getProperties(&props1); + vk::PhysicalDeviceProperties2 props2; + vk::PhysicalDeviceIDProperties device_id_props; + vk::PhysicalDevicePCIBusInfoPropertiesEXT pci_bus_props; + vk::PhysicalDeviceDriverProperties driver_props; + props2.pNext = &device_id_props; + device_id_props.pNext = &pci_bus_props; + pci_bus_props.pNext = &driver_props; + vk_devices[dev_idx].getProperties2(&props2); + std::ostringstream oss; + oss << std::hex << std::setfill('0'); + oss << "GPU-"; + int byteIdx = 0; + for (int i = 0; i < 16; ++i, ++byteIdx) { + oss << std::setw(2) << static_cast(device_id_props.deviceUUID[i]); + if (byteIdx == 3 || byteIdx == 5 || byteIdx == 7 || byteIdx == 9) { + oss << '-'; + } + } + ctx->uuid = oss.str(); + ctx->pci_bus_id = pci_bus_props.pciBus; + ctx->pci_device_id = pci_bus_props.pciDevice; + ctx->pci_domain_id = pci_bus_props.pciDomain; + ctx->id = std::to_string(i); + if (props1.deviceType == vk::PhysicalDeviceType::eIntegratedGpu) { + ctx->integrated = 1; + } else { + ctx->integrated = 0; + } + ctx->major = 0; + ctx->minor = 0; + // TODO regex parse driver_props.driverInfo for a X.Y or X.Y.Z version string + ctx->driver_major = 0; + ctx->driver_minor = 0; } initialized = true; } diff --git a/scripts/build_windows.ps1 b/scripts/build_windows.ps1 index 3ca25a13c..0a3d7c888 100644 --- a/scripts/build_windows.ps1 +++ b/scripts/build_windows.ps1 @@ -165,12 +165,11 @@ function buildROCm() { $env:HIPCXX="${env:HIP_PATH}\bin\clang++.exe" $env:HIP_PLATFORM="amd" $env:CMAKE_PREFIX_PATH="${env:HIP_PATH}" - & cmake --fresh --preset "ROCm 6" -G Ninja ` + & cmake --fresh --preset "ROCm 6" -G Ninja --install-prefix $script:DIST_DIR -DOLLAMA_RUNNER_DIR="rocm" ` -DCMAKE_C_COMPILER=clang ` -DCMAKE_CXX_COMPILER=clang++ ` -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" ` - -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" ` - --install-prefix $script:DIST_DIR + -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} $env:HIPCXX="" $env:HIP_PLATFORM="" @@ -186,7 +185,7 @@ function buildROCm() { function buildVulkan(){ if ($env:VULKAN_SDK) { write-host "Building Vulkan backend libraries" - & cmake --fresh --preset Vulkan --install-prefix $script:DIST_DIR + & cmake --fresh --preset Vulkan --install-prefix $script:DIST_DIR -DOLLAMA_RUNNER_DIR="vulkan" if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} & cmake --build --preset Vulkan --config Release --parallel $script:JOBS if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} From c57cd59be7e4946ac3fbeece08794bb450f63b3f Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Tue, 23 Sep 2025 11:49:56 -0700 Subject: [PATCH 118/172] fix - trust the library paths from discovery when starting runner --- llm/server.go | 316 ++++++++++++++++++++++---------------------------- 1 file changed, 137 insertions(+), 179 deletions(-) diff --git a/llm/server.go b/llm/server.go index 8757d4d56..08861165b 100644 --- a/llm/server.go +++ b/llm/server.go @@ -238,35 +238,29 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a } var gpuLibs []string + newLibs := func(libDirs []string) (ret []string) { + for _, d := range libDirs { + found := false + for _, t := range gpuLibs { + if d == t { + found = true + break + } + } + if !found { + ret = append(ret, d) + } + } + return + } for _, gpu := range gpus { - gpuLibs = append(gpuLibs, gpu.RunnerName()) + gpuLibs = append(gpuLibs, newLibs(gpu.DependencyPath)...) } requested := envconfig.LLMLibrary() if availableLibs[requested] != "" { slog.Info("using requested gpu library", "requested", requested) - gpuLibs = []string{requested} - } - - var compatible []string - for _, gpuLib := range gpuLibs { - var matchingLibs []string - for k := range availableLibs { - // exact match first - if k == gpuLib { - matchingLibs = append([]string{k}, matchingLibs...) - continue - } - - // then match the family (e.g. 'cuda') - if strings.Split(k, "_")[0] == strings.Split(gpuLib, "_")[0] { - matchingLibs = append(matchingLibs, k) - } - } - - if len(matchingLibs) > 0 { - compatible = append(compatible, matchingLibs[0]) - } + gpuLibs = []string{availableLibs[requested]} } exe, err := os.Executable() @@ -278,164 +272,128 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a exe = eval } - // iterate through compatible GPU libraries such as 'cuda_v12', 'rocm', etc. - // adding each library's respective path to the LD_LIBRARY_PATH, until finally running - // without any LD_LIBRARY_PATH flags - for { - port := 0 - if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil { - var l *net.TCPListener - if l, err = net.ListenTCP("tcp", a); err == nil { - port = l.Addr().(*net.TCPAddr).Port - l.Close() - } - } - if port == 0 { - slog.Debug("ResolveTCPAddr failed, using random port") - port = rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range - } - params := []string{"runner"} - if textProcessor != nil { - // New engine - // TODO - if we have failure to load scenarios, add logic to retry with the old runner - params = append(params, "--ollama-engine") - } - params = append(params, "--model", modelPath) - params = append(params, "--port", strconv.Itoa(port)) - - var pathEnv string - switch runtime.GOOS { - case "windows": - pathEnv = "PATH" - case "darwin": - pathEnv = "DYLD_LIBRARY_PATH" - default: - pathEnv = "LD_LIBRARY_PATH" - } - - // Note: we always put our dependency paths first - // since these are the exact version we compiled/linked against - libraryPaths := []string{discover.LibOllamaPath} - if libraryPath, ok := os.LookupEnv(pathEnv); ok { - libraryPaths = append(libraryPaths, filepath.SplitList(libraryPath)...) - } - - ggmlPaths := []string{discover.LibOllamaPath} - for _, c := range compatible { - if libpath, ok := availableLibs[c]; ok { - slog.Debug("adding gpu library", "path", libpath) - libraryPaths = append([]string{libpath}, libraryPaths...) - ggmlPaths = append(ggmlPaths, libpath) - } - } - - for _, gpu := range gpus { - if gpu.DependencyPath != nil { - slog.Debug("adding gpu dependency paths", "paths", gpu.DependencyPath) - libraryPaths = append(gpu.DependencyPath, libraryPaths...) - } - } - - // finally, add the root library path - libraryPaths = append(libraryPaths, discover.LibOllamaPath) - - s := llmServer{ - port: port, - cmd: exec.Command(exe, params...), - status: NewStatusWriter(os.Stderr), - options: opts, - modelPath: modelPath, - loadRequest: loadRequest, - llamaModel: llamaModel, - llamaModelLock: &sync.Mutex{}, - textProcessor: textProcessor, - numParallel: numParallel, - sem: semaphore.NewWeighted(int64(numParallel)), - totalLayers: f.KV().BlockCount() + 1, - loadStart: time.Now(), - done: make(chan error, 1), - } - - s.cmd.Env = os.Environ() - s.cmd.Stdout = os.Stdout - s.cmd.Stderr = s.status - s.cmd.SysProcAttr = LlamaServerSysProcAttr - - s.cmd.Env = append(s.cmd.Env, "OLLAMA_LIBRARY_PATH="+strings.Join(ggmlPaths, string(filepath.ListSeparator))) - - // Always filter down the set of GPUs in case there are any unsupported devices that might crash - envWorkarounds := gpus.GetVisibleDevicesEnv() - pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator)) - - // Update or add the path variable with our adjusted version - pathNeeded := true - envWorkaroundDone := make([]bool, len(envWorkarounds)) - for i := range s.cmd.Env { - cmp := strings.SplitN(s.cmd.Env[i], "=", 2) - if strings.EqualFold(cmp[0], pathEnv) { - s.cmd.Env[i] = pathEnv + "=" + pathEnvVal - pathNeeded = false - } else if len(envWorkarounds) != 0 { - for j, kv := range envWorkarounds { - tmp := strings.SplitN(kv, "=", 2) - if strings.EqualFold(cmp[0], tmp[0]) { - s.cmd.Env[i] = kv - envWorkaroundDone[j] = true - } - } - } - } - if pathNeeded { - s.cmd.Env = append(s.cmd.Env, pathEnv+"="+pathEnvVal) - } - for i, done := range envWorkaroundDone { - if !done { - s.cmd.Env = append(s.cmd.Env, envWorkarounds[i]) - } - } - - slog.Info("starting runner", "cmd", s.cmd) - slog.Debug("subprocess", "", filteredEnv(s.cmd.Env)) - - if err = s.cmd.Start(); err != nil { - var msg string - if s.status != nil && s.status.LastErrMsg != "" { - msg = s.status.LastErrMsg - } - err := fmt.Errorf("error starting runner: %v %s", err, msg) - if len(compatible) == 0 { - if llamaModel != nil { - llama.FreeModel(llamaModel) - } - return nil, err - } - - slog.Warn("unable to start runner with compatible gpu", "error", err, "compatible", compatible) - compatible = compatible[1:] - continue - } - - // reap subprocess when it exits - go func() { - err := s.cmd.Wait() - // Favor a more detailed message over the process exit status - if err != nil && s.status != nil && s.status.LastErrMsg != "" { - slog.Error("llama runner terminated", "error", err) - if strings.Contains(s.status.LastErrMsg, "unknown model") { - s.status.LastErrMsg = "this model is not supported by your version of Ollama. You may need to upgrade" - } - s.done <- errors.New(s.status.LastErrMsg) - } else { - s.done <- err - } - }() - - if textProcessor != nil { - return &ollamaServer{llmServer: s}, nil - } else { - return &llamaServer{llmServer: s, ggml: f}, nil + port := 0 + if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil { + var l *net.TCPListener + if l, err = net.ListenTCP("tcp", a); err == nil { + port = l.Addr().(*net.TCPAddr).Port + l.Close() } } + if port == 0 { + slog.Debug("ResolveTCPAddr failed, using random port") + port = rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range + } + params := []string{"runner"} + if textProcessor != nil { + // New engine + // TODO - if we have failure to load scenarios, add logic to retry with the old runner + params = append(params, "--ollama-engine") + } + params = append(params, "--model", modelPath) + params = append(params, "--port", strconv.Itoa(port)) + + var pathEnv string + switch runtime.GOOS { + case "windows": + pathEnv = "PATH" + case "darwin": + pathEnv = "DYLD_LIBRARY_PATH" + default: + pathEnv = "LD_LIBRARY_PATH" + } + + s := llmServer{ + port: port, + cmd: exec.Command(exe, params...), + status: NewStatusWriter(os.Stderr), + options: opts, + modelPath: modelPath, + loadRequest: loadRequest, + llamaModel: llamaModel, + llamaModelLock: &sync.Mutex{}, + textProcessor: textProcessor, + numParallel: numParallel, + sem: semaphore.NewWeighted(int64(numParallel)), + totalLayers: f.KV().BlockCount() + 1, + loadStart: time.Now(), + done: make(chan error, 1), + } + + s.cmd.Env = os.Environ() + s.cmd.Stdout = os.Stdout + s.cmd.Stderr = s.status + s.cmd.SysProcAttr = LlamaServerSysProcAttr + + s.cmd.Env = append(s.cmd.Env, "OLLAMA_LIBRARY_PATH="+strings.Join(gpuLibs, string(filepath.ListSeparator))) + + // Always filter down the set of GPUs in case there are any unsupported devices that might crash + envWorkarounds := gpus.GetVisibleDevicesEnv() + pathEnvVal := strings.Join(gpuLibs, string(filepath.ListSeparator)) + + // Update or add the path variable with our adjusted version + pathNeeded := true + envWorkaroundDone := make([]bool, len(envWorkarounds)) + for i := range s.cmd.Env { + cmp := strings.SplitN(s.cmd.Env[i], "=", 2) + if strings.EqualFold(cmp[0], pathEnv) { + s.cmd.Env[i] = pathEnv + "=" + pathEnvVal + pathNeeded = false + } else if len(envWorkarounds) != 0 { + for j, kv := range envWorkarounds { + tmp := strings.SplitN(kv, "=", 2) + if strings.EqualFold(cmp[0], tmp[0]) { + s.cmd.Env[i] = kv + envWorkaroundDone[j] = true + } + } + } + } + if pathNeeded { + s.cmd.Env = append(s.cmd.Env, pathEnv+"="+pathEnvVal) + } + for i, done := range envWorkaroundDone { + if !done { + s.cmd.Env = append(s.cmd.Env, envWorkarounds[i]) + } + } + + slog.Info("starting runner", "cmd", s.cmd) + slog.Debug("subprocess", "", filteredEnv(s.cmd.Env)) + + if err = s.cmd.Start(); err != nil { + var msg string + if s.status != nil && s.status.LastErrMsg != "" { + msg = s.status.LastErrMsg + } + err := fmt.Errorf("error starting runner: %v %s", err, msg) + if llamaModel != nil { + llama.FreeModel(llamaModel) + } + return nil, err + } + + // reap subprocess when it exits + go func() { + err := s.cmd.Wait() + // Favor a more detailed message over the process exit status + if err != nil && s.status != nil && s.status.LastErrMsg != "" { + slog.Error("llama runner terminated", "error", err) + if strings.Contains(s.status.LastErrMsg, "unknown model") { + s.status.LastErrMsg = "this model is not supported by your version of Ollama. You may need to upgrade" + } + s.done <- errors.New(s.status.LastErrMsg) + } else { + s.done <- err + } + }() + + if textProcessor != nil { + return &ollamaServer{llmServer: s}, nil + } else { + return &llamaServer{llmServer: s, ggml: f}, nil + } + } func (s *llmServer) ModelPath() string { From 26893578908e9ce0275e15fdca8506af2186aa86 Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Wed, 24 Sep 2025 12:22:46 -0700 Subject: [PATCH 119/172] fix index bug --- discover/runner.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/discover/runner.go b/discover/runner.go index c9bbef7f5..610b97f22 100644 --- a/discover/runner.go +++ b/discover/runner.go @@ -460,7 +460,7 @@ func bootstrapDevices(ctx context.Context, ollamaLibDirs []string, extraEnvs []s extra := strings.SplitN(extraEnvs[j], "=", 2) if cmp[0] == extra[0] { cmd.Env[i] = extraEnvs[j] - extraDone[i] = true + extraDone[j] = true } } } From 5c18fb456cac3a364e5b1a57d728d6d94e164199 Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Wed, 24 Sep 2025 15:48:35 -0700 Subject: [PATCH 120/172] fix vulkan ids to be underlying --- ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 3b0a0891e..f694f1cd8 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -10828,6 +10828,7 @@ struct ggml_backend_vk_device_context { std::string description; std::string id; std::string uuid; + std::string dev_idx; int major; int minor; int driver_major; @@ -10952,7 +10953,7 @@ static enum ggml_backend_dev_type ggml_backend_vk_device_get_type(ggml_backend_d static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) { props->name = ggml_backend_vk_device_get_name(dev); props->description = ggml_backend_vk_device_get_description(dev); - props->id = ggml_backend_vk_device_get_id(dev); + // props->id = ggml_backend_vk_device_get_id(dev); props->type = ggml_backend_vk_device_get_type(dev); ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total); props->caps = { @@ -10963,7 +10964,8 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml }; ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; - props->id = ctx->id.c_str(); + // Use the unfiltered ID so round-trip through env var works + props->id = ctx->dev_idx.c_str(); props->compute_major = ctx->major; props->compute_minor = ctx->minor; props->driver_major = ctx->driver_major; @@ -11394,6 +11396,7 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, // Gather additional information about the device int dev_idx = vk_instance.device_indices[i]; + ctx->dev_idx = std::to_string(dev_idx); vk::PhysicalDeviceProperties props1; vk_devices[dev_idx].getProperties(&props1); vk::PhysicalDeviceProperties2 props2; From 5f9f312bdb8abb468901b6165c575240a3257301 Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Wed, 24 Sep 2025 16:25:56 -0700 Subject: [PATCH 121/172] fix - give bootstrapping more time on slow systems --- discover/runner.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/discover/runner.go b/discover/runner.go index 610b97f22..4c78b0bfe 100644 --- a/discover/runner.go +++ b/discover/runner.go @@ -78,13 +78,6 @@ func GPUDevices(ctx context.Context, runners []FilteredRunnerDiscovery) []ml.Dev libDirs[""] = struct{}{} } - // Typically bootstrapping takes < 1s, but on some systems, with devices - // in low power/idle mode, initialization can take multiple seconds. We - // set a long timeout just for bootstrap discovery to reduce the chance - // of giving up too quickly - ctx1stPass, cancel := context.WithTimeout(ctx, 30*time.Second) - defer cancel() - slog.Info("discovering available GPUs...") // For our initial discovery pass, we gather all the known GPUs through @@ -100,6 +93,13 @@ func GPUDevices(ctx context.Context, runners []FilteredRunnerDiscovery) []ml.Dev } else { dirs = []string{LibOllamaPath, dir} } + // Typically bootstrapping takes < 1s, but on some systems, with devices + // in low power/idle mode, initialization can take multiple seconds. We + // set a long timeout just for bootstrap discovery to reduce the chance + // of giving up too quickly + ctx1stPass, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + // For this pass, we retain duplicates in case any are incompatible with some libraries devices = append(devices, bootstrapDevices(ctx1stPass, dirs, nil)...) } From 3a45922c018781d09c8836888eef308119abb697 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Thu, 25 Sep 2025 03:22:01 +0200 Subject: [PATCH 122/172] Test if Vulkan device is supported --- discover/gpu.go | 9 ++++ discover/gpu_info_vulkan.c | 33 +++++++++++++++ discover/gpu_info_vulkan.h | 87 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 129 insertions(+) diff --git a/discover/gpu.go b/discover/gpu.go index f6152bf0d..2cb77e1e5 100644 --- a/discover/gpu.go +++ b/discover/gpu.go @@ -466,6 +466,15 @@ func GetGPUInfo() GpuInfoList { continue } + if C.vk_device_is_supported(*vHandles.vulkan, C.int(i)) == 0 { + unsupportedGPUs = append(unsupportedGPUs, + UnsupportedGPUInfo{ + GpuInfo: gpuInfo.GpuInfo, + }) + slog.Info(fmt.Sprintf("[%d] Vulkan GPU does not support required Vulkan features. (StorageBuffer16BitAccess)", i)) + continue + } + gpuInfo.TotalMemory = uint64(memInfo.total) gpuInfo.FreeMemory = uint64(memInfo.free) gpuInfo.ID = C.GoString(&memInfo.gpu_id[0]) diff --git a/discover/gpu_info_vulkan.c b/discover/gpu_info_vulkan.c index 65033ad8a..7179ec9a3 100644 --- a/discover/gpu_info_vulkan.c +++ b/discover/gpu_info_vulkan.c @@ -48,6 +48,7 @@ void vk_init(char* vk_lib_path, vk_init_resp_t *resp) { {"vkEnumeratePhysicalDevices", (void *)&resp->ch.vkEnumeratePhysicalDevices}, {"vkGetPhysicalDeviceMemoryProperties2", (void *)&resp->ch.vkGetPhysicalDeviceMemoryProperties2}, {"vkDestroyInstance", (void *)&resp->ch.vkDestroyInstance}, + {"vkGetPhysicalDeviceFeatures2", (void *)&resp->ch.vkGetPhysicalDeviceFeatures2}, {NULL, NULL}, }; @@ -117,6 +118,38 @@ void vk_init(char* vk_lib_path, vk_init_resp_t *resp) { resp->num_devices = deviceCount; } +int vk_device_is_supported(vk_handle_t rh, int i) { + VkInstance instance = rh.vk; + uint32_t deviceCount = rh.num_devices; + + VkPhysicalDevice* devices = malloc(deviceCount * sizeof(VkPhysicalDevice)); + if (devices == NULL) { + return 0; + } + + VkResult result = (*rh.vkEnumeratePhysicalDevices)(instance, &deviceCount, devices); + if (result != VK_SUCCESS) { + free(devices); + return 0; + } + + VkPhysicalDeviceVulkan11Features vk11_features = {}; + vk11_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_1_FEATURES; + vk11_features.pNext = NULL; + + VkPhysicalDeviceFeatures2 device_features2 = {}; + device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2; + device_features2.pNext = &vk11_features; + + // make sure you have the right function pointer from your loader + (*rh.vkGetPhysicalDeviceFeatures2)(devices[i], &device_features2); + + int supported = vk11_features.storageBuffer16BitAccess ? 1 : 0; + + free(devices); + return supported; +} + int vk_check_flash_attention(vk_handle_t rh, int i) { VkInstance instance = rh.vk; uint32_t deviceCount = rh.num_devices; diff --git a/discover/gpu_info_vulkan.h b/discover/gpu_info_vulkan.h index 42e4b1610..26d00d601 100644 --- a/discover/gpu_info_vulkan.h +++ b/discover/gpu_info_vulkan.h @@ -95,6 +95,8 @@ typedef enum VkResult { typedef enum VkStructureType { VK_STRUCTURE_TYPE_APPLICATION_INFO = 0, VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO = 1, + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_1_FEATURES = 49, + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2 = 1000059000, VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2 = 1000059001, VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MEMORY_PROPERTIES_2 = 1000059006, VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_ID_PROPERTIES = 1000071004, @@ -284,6 +286,87 @@ typedef struct VkPhysicalDeviceIDProperties { VkBool32 deviceLUIDValid; } VkPhysicalDeviceIDProperties; +typedef struct VkPhysicalDeviceFeatures { + VkBool32 robustBufferAccess; + VkBool32 fullDrawIndexUint32; + VkBool32 imageCubeArray; + VkBool32 independentBlend; + VkBool32 geometryShader; + VkBool32 tessellationShader; + VkBool32 sampleRateShading; + VkBool32 dualSrcBlend; + VkBool32 logicOp; + VkBool32 multiDrawIndirect; + VkBool32 drawIndirectFirstInstance; + VkBool32 depthClamp; + VkBool32 depthBiasClamp; + VkBool32 fillModeNonSolid; + VkBool32 depthBounds; + VkBool32 wideLines; + VkBool32 largePoints; + VkBool32 alphaToOne; + VkBool32 multiViewport; + VkBool32 samplerAnisotropy; + VkBool32 textureCompressionETC2; + VkBool32 textureCompressionASTC_LDR; + VkBool32 textureCompressionBC; + VkBool32 occlusionQueryPrecise; + VkBool32 pipelineStatisticsQuery; + VkBool32 vertexPipelineStoresAndAtomics; + VkBool32 fragmentStoresAndAtomics; + VkBool32 shaderTessellationAndGeometryPointSize; + VkBool32 shaderImageGatherExtended; + VkBool32 shaderStorageImageExtendedFormats; + VkBool32 shaderStorageImageMultisample; + VkBool32 shaderStorageImageReadWithoutFormat; + VkBool32 shaderStorageImageWriteWithoutFormat; + VkBool32 shaderUniformBufferArrayDynamicIndexing; + VkBool32 shaderSampledImageArrayDynamicIndexing; + VkBool32 shaderStorageBufferArrayDynamicIndexing; + VkBool32 shaderStorageImageArrayDynamicIndexing; + VkBool32 shaderClipDistance; + VkBool32 shaderCullDistance; + VkBool32 shaderFloat64; + VkBool32 shaderInt64; + VkBool32 shaderInt16; + VkBool32 shaderResourceResidency; + VkBool32 shaderResourceMinLod; + VkBool32 sparseBinding; + VkBool32 sparseResidencyBuffer; + VkBool32 sparseResidencyImage2D; + VkBool32 sparseResidencyImage3D; + VkBool32 sparseResidency2Samples; + VkBool32 sparseResidency4Samples; + VkBool32 sparseResidency8Samples; + VkBool32 sparseResidency16Samples; + VkBool32 sparseResidencyAliased; + VkBool32 variableMultisampleRate; + VkBool32 inheritedQueries; +} VkPhysicalDeviceFeatures; + +typedef struct VkPhysicalDeviceFeatures2 { + VkStructureType sType; + void* pNext; + VkPhysicalDeviceFeatures features; +} VkPhysicalDeviceFeatures2; + +typedef struct VkPhysicalDeviceVulkan11Features { + VkStructureType sType; + void* pNext; + VkBool32 storageBuffer16BitAccess; + VkBool32 uniformAndStorageBuffer16BitAccess; + VkBool32 storagePushConstant16; + VkBool32 storageInputOutput16; + VkBool32 multiview; + VkBool32 multiviewGeometryShader; + VkBool32 multiviewTessellationShader; + VkBool32 variablePointersStorageBuffer; + VkBool32 variablePointers; + VkBool32 protectedMemory; + VkBool32 samplerYcbcrConversion; + VkBool32 shaderDrawParameters; +} VkPhysicalDeviceVulkan11Features; + typedef struct VkMemoryType { VkMemoryPropertyFlags propertyFlags; uint32_t heapIndex; @@ -376,6 +459,9 @@ typedef struct { void (*vkDestroyInstance)( VkInstance instance, const VkAllocationCallbacks* pAllocator); + void (*vkGetPhysicalDeviceFeatures2)( + VkPhysicalDevice physicalDevice, + VkPhysicalDeviceFeatures2* pFeatures); } vk_handle_t; typedef struct vk_init_resp @@ -388,6 +474,7 @@ typedef struct vk_init_resp void vk_init(char* vk_lib_path, vk_init_resp_t *resp); void vk_check_vram(vk_handle_t rh, int i, mem_info_t *resp); int vk_check_flash_attention(vk_handle_t rh, int i); +int vk_device_is_supported(vk_handle_t rh, int i); void vk_release(vk_handle_t rh); #endif From a7e2d21f598ce218f0a8073bfe17be5df1f46fd1 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Thu, 25 Sep 2025 06:33:15 +0200 Subject: [PATCH 123/172] vk_check_flash_attention is not needed (coompat2 coopmapt and scalar implementation exist) --- discover/gpu_info_vulkan.c | 28 ---------------------------- discover/gpu_info_vulkan.h | 1 - 2 files changed, 29 deletions(-) diff --git a/discover/gpu_info_vulkan.c b/discover/gpu_info_vulkan.c index 7179ec9a3..0929fdee5 100644 --- a/discover/gpu_info_vulkan.c +++ b/discover/gpu_info_vulkan.c @@ -150,34 +150,6 @@ int vk_device_is_supported(vk_handle_t rh, int i) { return supported; } -int vk_check_flash_attention(vk_handle_t rh, int i) { - VkInstance instance = rh.vk; - uint32_t deviceCount = rh.num_devices; - - VkPhysicalDevice* devices = malloc(deviceCount * sizeof(VkPhysicalDevice)); - if (devices == NULL) { - return 0; - } - - VkResult result = (*rh.vkEnumeratePhysicalDevices)(instance, &deviceCount, devices); - if (result != VK_SUCCESS) { - free(devices); - return 0; - } - - VkPhysicalDeviceProperties properties = {}; - (*rh.vkGetPhysicalDeviceProperties)(devices[i], &properties); - - int supports_nv_coopmat2 = is_extension_supported(&rh, devices[i], VK_NV_COOPERATIVE_MATRIX_2_EXTENSION_NAME); - if (!supports_nv_coopmat2) { - free(devices); - return 1; - } - - free(devices); - return 0; -} - void vk_check_vram(vk_handle_t rh, int i, mem_info_t *resp) { VkInstance instance = rh.vk; uint32_t deviceCount = rh.num_devices; diff --git a/discover/gpu_info_vulkan.h b/discover/gpu_info_vulkan.h index 26d00d601..3cd8b0b39 100644 --- a/discover/gpu_info_vulkan.h +++ b/discover/gpu_info_vulkan.h @@ -473,7 +473,6 @@ typedef struct vk_init_resp void vk_init(char* vk_lib_path, vk_init_resp_t *resp); void vk_check_vram(vk_handle_t rh, int i, mem_info_t *resp); -int vk_check_flash_attention(vk_handle_t rh, int i); int vk_device_is_supported(vk_handle_t rh, int i); void vk_release(vk_handle_t rh); From 05bdfedb56c61f331d263898cdf0d31c20c272e7 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Thu, 25 Sep 2025 08:23:13 +0200 Subject: [PATCH 124/172] Handle GGML_VK_VISIBLE_DEVICES --- discover/gpu_info_vulkan.c | 112 +++++++++++++++++++++++++++++++++++++ discover/gpu_info_vulkan.h | 3 + 2 files changed, 115 insertions(+) diff --git a/discover/gpu_info_vulkan.c b/discover/gpu_info_vulkan.c index 0929fdee5..8bb0a9d3a 100644 --- a/discover/gpu_info_vulkan.c +++ b/discover/gpu_info_vulkan.c @@ -2,6 +2,85 @@ #include "gpu_info_vulkan.h" #include +#include +#include + +#define INITIAL_ARRAY_SIZE 10 + +// Function to parse an environment variable into a list of int values. +// Returns a pointer to the allocated array, and stores the count in out_count. +// Returns NULL in case of any error. +int* parse_envvar_to_int_list(const char* envvar_name, size_t *out_count) { + char *env_str = getenv(envvar_name); + if (env_str == NULL) { + *out_count = 0; + return NULL; + } + + // Duplicate the string since strtok modifies it. + char *tmp = strdup(env_str); + if (!tmp) { + *out_count = 0; + return NULL; + } + + size_t capacity = INITIAL_ARRAY_SIZE; + size_t count = 0; + int *list = malloc(capacity * sizeof(uint32_t)); + if (!list) { + free(tmp); + *out_count = 0; + return NULL; + } + + char *token = strtok(tmp, ","); + while (token != NULL) { + char *endptr = NULL; + errno = 0; + unsigned long val = strtoul(token, &endptr, 10); + if (errno != 0 || endptr == token) { + free(list); + free(tmp); + *out_count = 0; + return NULL; + } + // Optional: Check trailing characters. + while (*endptr != '\0') { + if (!isspace((unsigned char)*endptr)) { + free(list); + free(tmp); + *out_count = 0; + return NULL; + } + endptr++; + } + if (val > UINT32_MAX) { + free(list); + free(tmp); + *out_count = 0; + return NULL; + } + + // Save the value, reallocating if necessary. + if (count == capacity) { + capacity *= 2; + int *temp = realloc(list, capacity * sizeof(uint32_t)); + if (!temp) { + free(list); + free(tmp); + *out_count = 0; + return NULL; + } + list = temp; + } + list[count++] = (int)val; + token = strtok(NULL, ","); + } + + free(tmp); + *out_count = count; + return list; +} int is_extension_supported(vk_handle_t* rh, VkPhysicalDevice device, char* extension) { VkPhysicalDeviceProperties properties = {}; @@ -112,10 +191,21 @@ void vk_init(char* vk_lib_path, vk_init_resp_t *resp) { return; } + size_t visDevIdCount; + int* visDevIds = parse_envvar_to_int_list("GGML_VK_VISIBLE_DEVICES", &visDevIdCount); + resp->err = NULL; resp->ch.vk = instance; resp->ch.num_devices = deviceCount; resp->num_devices = deviceCount; + + if (visDevIds && visDevIdCount > 0) { + resp->ch.num_visible_devices = visDevIdCount; + resp->ch.visible_devices = visDevIds; + } else { + resp->ch.num_visible_devices = -1; + resp->ch.visible_devices = NULL; + } } int vk_device_is_supported(vk_handle_t rh, int i) { @@ -192,6 +282,24 @@ void vk_check_vram(vk_handle_t rh, int i, mem_info_t *resp) { device_props2.pNext = &id_props; (*rh.vkGetPhysicalDeviceProperties2)(devices[i], &device_props2); + if (rh.num_visible_devices > 0) { + LOG(rh.verbose, "Checking if device %d is visible\n", i); + int is_visible = 0; + for (uint32_t visDevId = 0; visDevId < rh.num_visible_devices; visDevId++) { + if (i == rh.visible_devices[visDevId]) { + LOG(rh.verbose, "Device %d is visible!\n", i); + is_visible = 1; + break; + } + } + if (!is_visible) { + LOG(rh.verbose, "Device %d is NOT visible!\n", i); + free(devices); + resp->err = strdup("device is hidden with GGML_VK_VISIBLE_DEVICES"); + return; + } + } + VkPhysicalDeviceMemoryBudgetPropertiesEXT physical_device_memory_budget_properties = {}; physical_device_memory_budget_properties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MEMORY_BUDGET_PROPERTIES_EXT; physical_device_memory_budget_properties.pNext = NULL; @@ -241,6 +349,10 @@ void vk_release(vk_handle_t rh) { (*rh.vkDestroyInstance)(rh.vk, NULL); UNLOAD_LIBRARY(rh.vk_handle); rh.vk_handle = NULL; + + if (rh.visible_devices) { + free(rh.visible_devices); + } } #endif // __APPLE__ diff --git a/discover/gpu_info_vulkan.h b/discover/gpu_info_vulkan.h index 3cd8b0b39..c249d9855 100644 --- a/discover/gpu_info_vulkan.h +++ b/discover/gpu_info_vulkan.h @@ -434,6 +434,9 @@ typedef struct { VkInstance vk; int num_devices; + int num_visible_devices; + int* visible_devices; + void (*vkGetPhysicalDeviceProperties)( VkPhysicalDevice physicalDevice, VkPhysicalDeviceProperties* pProperties); From 82f0c7e6a518a975dc103d75932f9cf1d1285349 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Thu, 25 Sep 2025 08:47:04 +0200 Subject: [PATCH 125/172] ask for supported first --- discover/gpu.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/discover/gpu.go b/discover/gpu.go index 2cb77e1e5..aba6dae0a 100644 --- a/discover/gpu.go +++ b/discover/gpu.go @@ -459,13 +459,6 @@ func GetGPUInfo() GpuInfoList { index: i, } - C.vk_check_vram(*vHandles.vulkan, C.int(i), &memInfo) - if memInfo.err != nil { - slog.Info("error looking up vulkan GPU memory", "error", C.GoString(memInfo.err)) - C.free(unsafe.Pointer(memInfo.err)) - continue - } - if C.vk_device_is_supported(*vHandles.vulkan, C.int(i)) == 0 { unsupportedGPUs = append(unsupportedGPUs, UnsupportedGPUInfo{ @@ -475,6 +468,13 @@ func GetGPUInfo() GpuInfoList { continue } + C.vk_check_vram(*vHandles.vulkan, C.int(i), &memInfo) + if memInfo.err != nil { + slog.Info("error looking up vulkan GPU memory", "error", C.GoString(memInfo.err)) + C.free(unsafe.Pointer(memInfo.err)) + continue + } + gpuInfo.TotalMemory = uint64(memInfo.total) gpuInfo.FreeMemory = uint64(memInfo.free) gpuInfo.ID = C.GoString(&memInfo.gpu_id[0]) From 936c6d6be189fcfd6e4e55eb0b52ae68d571f128 Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Thu, 25 Sep 2025 08:13:24 -0700 Subject: [PATCH 126/172] win: fix CPU query buffer handling Try in a short loop until we get the size right. --- discover/cpu_windows.go | 33 ++++++++++++++------------------- 1 file changed, 14 insertions(+), 19 deletions(-) diff --git a/discover/cpu_windows.go b/discover/cpu_windows.go index ee308805e..5f516b5d1 100644 --- a/discover/cpu_windows.go +++ b/discover/cpu_windows.go @@ -99,27 +99,22 @@ func (pkg *winPackage) IsMember(target *GROUP_AFFINITY) bool { } func getLogicalProcessorInformationEx() ([]byte, error) { - buf := make([]byte, 1) + buf := make([]byte, 1024) bufSize := len(buf) - ret, _, err := GetLogicalProcessorInformationEx.Call( - uintptr(RelationAll), - uintptr(unsafe.Pointer(&buf[0])), - uintptr(unsafe.Pointer(&bufSize)), - ) - if ret != 0 { - return nil, fmt.Errorf("failed to determine size info ret:%d %w", ret, err) + var err error + for range 3 { + var ret uintptr + ret, _, err = GetLogicalProcessorInformationEx.Call( + uintptr(RelationAll), + uintptr(unsafe.Pointer(&buf[0])), + uintptr(unsafe.Pointer(&bufSize)), + ) + if ret == 1 && bufSize <= len(buf) { + return buf, nil + } + buf = make([]byte, bufSize) } - - buf = make([]byte, bufSize) - ret, _, err = GetLogicalProcessorInformationEx.Call( - uintptr(RelationAll), - uintptr(unsafe.Pointer(&buf[0])), - uintptr(unsafe.Pointer(&bufSize)), - ) - if ret == 0 { - return nil, fmt.Errorf("failed to gather processor information ret:%d buflen:%d %w", ret, bufSize, err) - } - return buf, nil + return nil, fmt.Errorf("unable to determine CPU details: %w", err) } func processSystemLogicalProcessorInforationList(buf []byte) []*winPackage { From 5647ac91b297f77938cb25f02bdd308f2f9278a4 Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Thu, 25 Sep 2025 09:53:22 -0700 Subject: [PATCH 127/172] test: harden integration tests for slow start If the server takes a while to start up, block tests from starting until it's online to avoid setting large timeouts in individual test cases. --- integration/utils_test.go | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/integration/utils_test.go b/integration/utils_test.go index 7901fed3f..bfbae19e7 100644 --- a/integration/utils_test.go +++ b/integration/utils_test.go @@ -440,6 +440,24 @@ func InitServerConnection(ctx context.Context, t *testing.T) (*api.Client, strin t.Fatal(err) } } + // Make sure server is online and healthy before returning + listCtx, cancel := context.WithDeadlineCause( + ctx, + time.Now().Add(120*time.Second), + fmt.Errorf("list models took too long"), + ) + defer cancel() + models, err := client.ListRunning(listCtx) + if err != nil { + t.Fatal(err) + } + if len(models.Models) > 0 { + names := make([]string, len(models.Models)) + for i, m := range models.Models { + names[i] = m.Name + } + slog.Info("currently loaded", "models", names) + } return client, testEndpoint, func() { if os.Getenv("OLLAMA_TEST_EXISTING") == "" { From a7ddd0e2aebb7744eb2ea64d9cc0aef4949acae2 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Fri, 26 Sep 2025 22:15:58 +0200 Subject: [PATCH 128/172] gofumpt fix --- discover/gpu.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/discover/gpu.go b/discover/gpu.go index aba6dae0a..9cdd48de0 100644 --- a/discover/gpu.go +++ b/discover/gpu.go @@ -486,7 +486,7 @@ func GetGPUInfo() GpuInfoList { gpuInfo.DriverMinor = int(memInfo.minor) // TODO potentially sort on our own algorithm instead of what the underlying GPU library does... - var backend = gpuInfoExistsInOtherBackends(gpuInfo) + backend := gpuInfoExistsInOtherBackends(gpuInfo) if backend != "" { unsupportedGPUs = append(unsupportedGPUs, UnsupportedGPUInfo{ From f567cc59d442f88397536818b234195dc7180d7b Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sat, 4 Oct 2025 15:08:18 +0200 Subject: [PATCH 129/172] fix build --- ml/backend.go | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/ml/backend.go b/ml/backend.go index dd51c5d59..a3dbe04e1 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -7,7 +7,6 @@ import ( "fmt" "math" "slices" - "sort" "strconv" "strings" @@ -226,9 +225,9 @@ type ScaledDotProductAttention interface { type number interface { ~int | ~int8 | ~int16 | ~int32 | ~int64 | - ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | - ~float32 | ~float64 | - ~complex64 | ~complex128 + ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | + ~float32 | ~float64 | + ~complex64 | ~complex128 } func mul[T number](s ...T) T { From 294b1796885c07b579a3d6fcf5e1b6060d2fdea4 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sat, 4 Oct 2025 15:20:33 +0200 Subject: [PATCH 130/172] merge fixes --- discover/runner.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/discover/runner.go b/discover/runner.go index 15fac2f17..d173959ad 100644 --- a/discover/runner.go +++ b/discover/runner.go @@ -123,13 +123,15 @@ func GPUDevices(ctx context.Context, runners []FilteredRunnerDiscovery) []ml.Dev go func(i int) { defer wg.Done() var envVar string + if devices[i].Library == "ROCm" { if runtime.GOOS != "linux" { envVar = "HIP_VISIBLE_DEVICES" } else { envVar = "ROCR_VISIBLE_DEVICES" } + } else if devices[i].Library == "CUDA" { envVar = "CUDA_VISIBLE_DEVICES" - } else if devices[i].Library == "VULKAN" { + } else if devices[i].Library == "Vulkan" { envVar = "GGML_VK_VISIBLE_DEVICES" } From c4d8c75e542324dab82b939a4e75b6e9d3de9820 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sat, 4 Oct 2025 15:27:52 +0200 Subject: [PATCH 131/172] merge fixes --- discover/runner.go | 1 - discover/types.go | 2 +- ml/backend/ggml/ggml.go | 1 - 3 files changed, 1 insertion(+), 3 deletions(-) diff --git a/discover/runner.go b/discover/runner.go index d173959ad..e9d34ba8d 100644 --- a/discover/runner.go +++ b/discover/runner.go @@ -3,7 +3,6 @@ package discover // Runner based GPU discovery import ( - "bytes" "context" "encoding/json" "fmt" diff --git a/discover/types.go b/discover/types.go index bbd22a13d..a294f26b7 100644 --- a/discover/types.go +++ b/discover/types.go @@ -174,7 +174,7 @@ func (l GpuInfoList) FlashAttentionSupported() bool { supportsFA := gpu.Library == "cpu" || gpu.Name == "Metal" || gpu.Library == "Metal" || (gpu.Library == "CUDA" && gpu.DriverMajor >= 7) || - gpu.Library == "ROCm" + gpu.Library == "ROCm" || gpu.Library == "Vulkan" if !supportsFA { diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index ccdee03c0..dc71c8de4 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -18,7 +18,6 @@ import ( "log/slog" "maps" "os" - "path/filepath" "runtime" "slices" "strconv" From 1e46db874855b0c8928d96f288578cebfb612a0c Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sat, 4 Oct 2025 15:44:23 +0200 Subject: [PATCH 132/172] fixed build --- discover/runner.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/discover/runner.go b/discover/runner.go index e9d34ba8d..7a659367c 100644 --- a/discover/runner.go +++ b/discover/runner.go @@ -441,8 +441,6 @@ func bootstrapDevices(ctx context.Context, ollamaLibDirs []string, extraEnvs []s if envconfig.LogLevel() == logutil.LevelTrace { cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr - } else { - cmd.Stderr = errBuf } // cmd.SysProcAttr = llm.LlamaServerSysProcAttr // circular dependency - bring back once refactored cmd.Env = append(cmd.Env, "OLLAMA_LIBRARY_PATH="+strings.Join(ollamaLibDirs, string(filepath.ListSeparator))) From 75f65bcdbfd22226b5ed6874d75722984a6b78af Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sat, 4 Oct 2025 16:11:34 +0200 Subject: [PATCH 133/172] merge fixes --- llm/server.go | 240 ++++++++++++++++++++++++++++---------------------- 1 file changed, 133 insertions(+), 107 deletions(-) diff --git a/llm/server.go b/llm/server.go index a7f156490..3fb1bd905 100644 --- a/llm/server.go +++ b/llm/server.go @@ -235,29 +235,35 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a } var gpuLibs []string - newLibs := func(libDirs []string) (ret []string) { - for _, d := range libDirs { - found := false - for _, t := range gpuLibs { - if d == t { - found = true - break - } - } - if !found { - ret = append(ret, d) - } - } - return - } for _, gpu := range gpus { - gpuLibs = append(gpuLibs, newLibs(gpu.DependencyPath)...) + gpuLibs = append(gpuLibs, gpu.RunnerName()) } requested := envconfig.LLMLibrary() if availableLibs[requested] != "" { slog.Info("using requested gpu library", "requested", requested) - gpuLibs = []string{availableLibs[requested]} + gpuLibs = []string{requested} + } + + var compatible []string + for _, gpuLib := range gpuLibs { + var matchingLibs []string + for k := range availableLibs { + // exact match first + if k == gpuLib { + matchingLibs = append([]string{k}, matchingLibs...) + continue + } + + // then match the family (e.g. 'cuda') + if strings.Split(k, "_")[0] == strings.Split(gpuLib, "_")[0] { + matchingLibs = append(matchingLibs, k) + } + } + + if len(matchingLibs) > 0 { + compatible = append(compatible, matchingLibs[0]) + } } exe, err := os.Executable() @@ -269,36 +275,40 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a exe = eval } - port := 0 - if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil { - var l *net.TCPListener - if l, err = net.ListenTCP("tcp", a); err == nil { - port = l.Addr().(*net.TCPAddr).Port - l.Close() + // iterate through compatible GPU libraries such as 'cuda_v12', 'rocm', etc. + // adding each library's respective path to the LD_LIBRARY_PATH, until finally running + // without any LD_LIBRARY_PATH flags + for { + port := 0 + if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil { + var l *net.TCPListener + if l, err = net.ListenTCP("tcp", a); err == nil { + port = l.Addr().(*net.TCPAddr).Port + l.Close() + } } - } - if port == 0 { - slog.Debug("ResolveTCPAddr failed, using random port") - port = rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range - } - params := []string{"runner"} - if textProcessor != nil { - // New engine - // TODO - if we have failure to load scenarios, add logic to retry with the old runner - params = append(params, "--ollama-engine") - } - params = append(params, "--model", modelPath) - params = append(params, "--port", strconv.Itoa(port)) + if port == 0 { + slog.Debug("ResolveTCPAddr failed, using random port") + port = rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range + } + params := []string{"runner"} + if textProcessor != nil { + // New engine + // TODO - if we have failure to load scenarios, add logic to retry with the old runner + params = append(params, "--ollama-engine") + } + params = append(params, "--model", modelPath) + params = append(params, "--port", strconv.Itoa(port)) - var pathEnv string - switch runtime.GOOS { - case "windows": - pathEnv = "PATH" - case "darwin": - pathEnv = "DYLD_LIBRARY_PATH" - default: - pathEnv = "LD_LIBRARY_PATH" - } + var pathEnv string + switch runtime.GOOS { + case "windows": + pathEnv = "PATH" + case "darwin": + pathEnv = "DYLD_LIBRARY_PATH" + default: + pathEnv = "LD_LIBRARY_PATH" + } // Note: we always put our dependency paths first // since these are the exact version we compiled/linked against @@ -344,80 +354,86 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a done: make(chan error, 1), } - s.cmd.Env = os.Environ() - s.cmd.Stdout = os.Stdout - s.cmd.Stderr = s.status - s.cmd.SysProcAttr = LlamaServerSysProcAttr + s.cmd.Env = os.Environ() + s.cmd.Stdout = os.Stdout + s.cmd.Stderr = s.status + s.cmd.SysProcAttr = LlamaServerSysProcAttr - s.cmd.Env = append(s.cmd.Env, "OLLAMA_LIBRARY_PATH="+strings.Join(gpuLibs, string(filepath.ListSeparator))) + s.cmd.Env = append(s.cmd.Env, "OLLAMA_LIBRARY_PATH="+strings.Join(ggmlPaths, string(filepath.ListSeparator))) // Always filter down the set of GPUs in case there are any unsupported devices that might crash envWorkarounds := gpus.GetVisibleDevicesEnv() pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator)) - // Update or add the path variable with our adjusted version - pathNeeded := true - envWorkaroundDone := make([]bool, len(envWorkarounds)) - for i := range s.cmd.Env { - cmp := strings.SplitN(s.cmd.Env[i], "=", 2) - if strings.EqualFold(cmp[0], pathEnv) { - s.cmd.Env[i] = pathEnv + "=" + pathEnvVal - pathNeeded = false - } else if len(envWorkarounds) != 0 { - for j, kv := range envWorkarounds { - tmp := strings.SplitN(kv, "=", 2) - if strings.EqualFold(cmp[0], tmp[0]) { - s.cmd.Env[i] = kv - envWorkaroundDone[j] = true + // Update or add the path variable with our adjusted version + pathNeeded := true + envWorkaroundDone := make([]bool, len(envWorkarounds)) + for i := range s.cmd.Env { + cmp := strings.SplitN(s.cmd.Env[i], "=", 2) + if strings.EqualFold(cmp[0], pathEnv) { + s.cmd.Env[i] = pathEnv + "=" + pathEnvVal + pathNeeded = false + } else if len(envWorkarounds) != 0 { + for j, kv := range envWorkarounds { + tmp := strings.SplitN(kv, "=", 2) + if strings.EqualFold(cmp[0], tmp[0]) { + s.cmd.Env[i] = kv + envWorkaroundDone[j] = true + } } } } - } - if pathNeeded { - s.cmd.Env = append(s.cmd.Env, pathEnv+"="+pathEnvVal) - } - for i, done := range envWorkaroundDone { - if !done { - s.cmd.Env = append(s.cmd.Env, envWorkarounds[i]) + if pathNeeded { + s.cmd.Env = append(s.cmd.Env, pathEnv+"="+pathEnvVal) } - } - - slog.Info("starting runner", "cmd", s.cmd) - slog.Debug("subprocess", "", filteredEnv(s.cmd.Env)) - - if err = s.cmd.Start(); err != nil { - var msg string - if s.status != nil && s.status.LastErrMsg != "" { - msg = s.status.LastErrMsg - } - err := fmt.Errorf("error starting runner: %v %s", err, msg) - if llamaModel != nil { - llama.FreeModel(llamaModel) - } - return nil, err - } - - // reap subprocess when it exits - go func() { - err := s.cmd.Wait() - // Favor a more detailed message over the process exit status - if err != nil && s.status != nil && s.status.LastErrMsg != "" { - slog.Error("llama runner terminated", "error", err) - if strings.Contains(s.status.LastErrMsg, "unknown model") { - s.status.LastErrMsg = "this model is not supported by your version of Ollama. You may need to upgrade" + for i, done := range envWorkaroundDone { + if !done { + s.cmd.Env = append(s.cmd.Env, envWorkarounds[i]) } - s.done <- errors.New(s.status.LastErrMsg) - } else { - s.done <- err } - }() - if textProcessor != nil { - return &ollamaServer{llmServer: s}, nil - } else { - return &llamaServer{llmServer: s, ggml: f}, nil + slog.Info("starting runner", "cmd", s.cmd) + slog.Debug("subprocess", "", filteredEnv(s.cmd.Env)) + + if err = s.cmd.Start(); err != nil { + var msg string + if s.status != nil && s.status.LastErrMsg != "" { + msg = s.status.LastErrMsg + } + err := fmt.Errorf("error starting runner: %v %s", err, msg) + if len(compatible) == 0 { + if llamaModel != nil { + llama.FreeModel(llamaModel) + } + return nil, err + } + + slog.Warn("unable to start runner with compatible gpu", "error", err, "compatible", compatible) + compatible = compatible[1:] + continue + } + + // reap subprocess when it exits + go func() { + err := s.cmd.Wait() + // Favor a more detailed message over the process exit status + if err != nil && s.status != nil && s.status.LastErrMsg != "" { + slog.Error("llama runner terminated", "error", err) + if strings.Contains(s.status.LastErrMsg, "unknown model") { + s.status.LastErrMsg = "this model is not supported by your version of Ollama. You may need to upgrade" + } + s.done <- errors.New(s.status.LastErrMsg) + } else { + s.done <- err + } + }() + + if textProcessor != nil { + return &ollamaServer{llmServer: s}, nil + } else { + return &llamaServer{llmServer: s, ggml: f}, nil + } } - } func (s *llmServer) ModelPath() string { @@ -545,9 +561,8 @@ func (s *llamaServer) Load(ctx context.Context, gpus discover.GpuInfoList, requi // For CPU loads we want the memory to be allocated, not FS cache if (runtime.GOOS == "windows" && gpus[0].Library == "CUDA" && s.options.UseMMap == nil) || (runtime.GOOS == "linux" && systemInfo.System.FreeMemory < s.estimate.TotalSize && s.options.UseMMap == nil) || - (gpus[0].Library == "vulkan" && s.options.UseMMap == nil) || (gpus[0].Library == "cpu" && s.options.UseMMap == nil) || - (gpus[0].Library == "VULKAN" && s.options.UseMMap == nil) || + (gpus[0].Library == "Vulkan" && s.options.UseMMap == nil) || (s.options.UseMMap != nil && !*s.options.UseMMap) { s.loadRequest.UseMmap = false } @@ -1313,6 +1328,17 @@ func (s *llmServer) Pid() int { return -1 } +func (s *llmServer) GetPort() int { + return s.port +} + +func (s *llmServer) HasExited() bool { + if s.cmd != nil && s.cmd.ProcessState != nil && s.cmd.ProcessState.ExitCode() >= 0 { + return true + } + return false +} + var grammarJSON = ` root ::= object value ::= object | array | string | number | ("true" | "false" | "null") ws From 06528d66aa98e4434fc7237aa42c18b56fedecd8 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sat, 4 Oct 2025 16:22:55 +0200 Subject: [PATCH 134/172] fixing build --- server/sched.go | 32 +++++++++++++++++++++++++++++--- 1 file changed, 29 insertions(+), 3 deletions(-) diff --git a/server/sched.go b/server/sched.go index 2de622ea6..5f8c61e09 100644 --- a/server/sched.go +++ b/server/sched.go @@ -623,7 +623,7 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool // a before and after GPU memory allocation. The returned channel // will be notified when we're done waiting, or have timed out and should // proceed anyway -func (runner *runnerRef) waitForVRAMRecovery() chan any { +func (s *Scheduler) waitForVRAMRecovery(runner *runnerRef, runners []discover.FilteredRunnerDiscovery) chan any { finished := make(chan any, 1) // CPU or Metal don't need checking, so no waiting required @@ -638,7 +638,7 @@ func (runner *runnerRef) waitForVRAMRecovery() chan any { start := time.Now() // Establish a baseline before we unload - gpusBefore := discover.GetGPUInfo() + gpusBefore := s.getGpuFn(context.Background(), runners) var totalMemoryBefore, freeMemoryBefore uint64 for _, gpu := range gpusBefore { totalMemoryBefore += gpu.TotalMemory @@ -656,7 +656,7 @@ func (runner *runnerRef) waitForVRAMRecovery() chan any { } // Query GPUs, look for free to go back up - gpusNow := discover.GetGPUInfo() + gpusNow := s.getGpuFn(context.Background(), runners) var totalMemoryNow, freeMemoryNow uint64 for _, gpu := range gpusNow { totalMemoryNow += gpu.TotalMemory @@ -699,6 +699,32 @@ func (runner *runnerRef) LogValue() slog.Value { return slog.GroupValue(attrs...) } +// Implements discover.RunnerDiscovery +func (runner *runnerRef) GetPort() int { + if runner.llama != nil { + return runner.llama.GetPort() + } + return -1 +} + +func (runner *runnerRef) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo { + if runner.llama != nil { + return runner.llama.GetDeviceInfos(ctx) + } + return nil +} + +func (runner *runnerRef) GetActiveDeviceIDs() []ml.DeviceID { + return runner.gpus +} + +func (runner *runnerRef) HasExited() bool { + if runner.llama != nil { + return runner.llama.HasExited() + } + return true +} + type ByDurationAndName []*runnerRef func (a ByDurationAndName) Len() int { return len(a) } From b2aba4ea83e2b0cc5c7a0ecb5087f86e3ed0e521 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sat, 4 Oct 2025 16:26:03 +0200 Subject: [PATCH 135/172] fixed build --- server/sched_test.go | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/server/sched_test.go b/server/sched_test.go index 1341a7083..fd6309e33 100644 --- a/server/sched_test.go +++ b/server/sched_test.go @@ -497,8 +497,8 @@ func TestPrematureExpired(t *testing.T) { // Same model, same request scenario1a := newScenarioRequest(t, ctx, "ollama-model-1a", 10, nil, nil) s := InitScheduler(ctx) - s.getGpuFn = func() discover.GpuInfoList { - g := discover.GpuInfo{Library: "metal"} + s.getGpuFn = func(ctx context.Context, runners []discover.FilteredRunnerDiscovery) discover.GpuInfoList { + g := discover.GpuInfo{DeviceID: ml.DeviceID{Library: "metal"}} g.TotalMemory = 24 * format.GigaByte g.FreeMemory = 12 * format.GigaByte return []discover.GpuInfo{g} @@ -783,7 +783,11 @@ func (s *mockLlm) Close() error { s.closeCalled = true return s.closeResp } -func (s *mockLlm) VRAMSize() uint64 { return s.vramSize } -func (s *mockLlm) TotalSize() uint64 { return s.totalSize } -func (s *mockLlm) VRAMByGPU(gpuid string) uint64 { return s.vramByGPU[gpuid] } -func (s *mockLlm) Pid() int { return -1 } +func (s *mockLlm) VRAMSize() uint64 { return s.vramSize } +func (s *mockLlm) TotalSize() uint64 { return s.totalSize } +func (s *mockLlm) VRAMByGPU(id ml.DeviceID) uint64 { return s.vramByGPU[id] } +func (s *mockLlm) Pid() int { return -1 } +func (s *mockLlm) GetPort() int { return -1 } +func (s *mockLlm) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo { return nil } +func (s *mockLlm) HasExited() bool { return false } +func (s *mockLlm) GetActiveDeviceIDs() []ml.DeviceID { return nil } From 9ac9f3a9524eb1dfad7438a76b223ca57a742952 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sat, 4 Oct 2025 16:32:39 +0200 Subject: [PATCH 136/172] fixed formatting --- discover/gpu.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/discover/gpu.go b/discover/gpu.go index 028014ac2..5b8d3fbe7 100644 --- a/discover/gpu.go +++ b/discover/gpu.go @@ -72,7 +72,7 @@ func devInfoToInfoList(devs []ml.DeviceInfo) GpuInfoList { } else { info.Compute = fmt.Sprintf("%d.%d", dev.ComputeMajor, dev.ComputeMinor) } - // TODO any special processing of Vulkan devices? + // TODO any special processing of Vulkan devices? resp = append(resp, info) } if len(resp) == 0 { From 96e562f982f68bb175e853e02834cb13791b8b44 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sat, 4 Oct 2025 16:35:04 +0200 Subject: [PATCH 137/172] fixed build --- discover/gpu.go | 1 - ml/backend.go | 6 +++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/discover/gpu.go b/discover/gpu.go index 5b8d3fbe7..9102bd65b 100644 --- a/discover/gpu.go +++ b/discover/gpu.go @@ -146,7 +146,6 @@ func vkGetVisibleDevicesEnv(gpuInfo []GpuInfo) string { continue } ids = append(ids, info.ID) - } if len(ids) == 0 { return "" diff --git a/ml/backend.go b/ml/backend.go index a3dbe04e1..351942d54 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -225,9 +225,9 @@ type ScaledDotProductAttention interface { type number interface { ~int | ~int8 | ~int16 | ~int32 | ~int64 | - ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | - ~float32 | ~float64 | - ~complex64 | ~complex128 + ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | + ~float32 | ~float64 | + ~complex64 | ~complex128 } func mul[T number](s ...T) T { From 163f62fcb69c48b97205f607eedc5e6c92f47a71 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sat, 4 Oct 2025 18:56:38 +0200 Subject: [PATCH 138/172] fix vulkan gpu id patch --- ...026-vulkan-get-GPU-ID-ollama-v0.11.5.patch | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/llama/patches/0026-vulkan-get-GPU-ID-ollama-v0.11.5.patch b/llama/patches/0026-vulkan-get-GPU-ID-ollama-v0.11.5.patch index 928e85d5e..be83b6371 100644 --- a/llama/patches/0026-vulkan-get-GPU-ID-ollama-v0.11.5.patch +++ b/llama/patches/0026-vulkan-get-GPU-ID-ollama-v0.11.5.patch @@ -9,10 +9,10 @@ Signed-off-by: Xiaodong Ye 1 file changed, 37 insertions(+) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp -index 4070e248..671323ad 100644 +index 061cd078..adea7783 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp -@@ -10194,6 +10194,29 @@ static void ggml_vk_get_device_description(int device, char * description, size_ +@@ -11588,6 +11588,29 @@ static void ggml_vk_get_device_description(int device, char * description, size_ snprintf(description, description_size, "%s", props.deviceName.data()); } @@ -42,7 +42,7 @@ index 4070e248..671323ad 100644 // backend interface #define UNUSED GGML_UNUSED -@@ -10790,6 +10813,12 @@ void ggml_backend_vk_get_device_description(int device, char * description, size +@@ -12394,6 +12417,12 @@ void ggml_backend_vk_get_device_description(int device, char * description, size ggml_vk_get_device_description(dev_idx, description, description_size); } @@ -54,16 +54,16 @@ index 4070e248..671323ad 100644 + void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total) { GGML_ASSERT(device < (int) vk_instance.device_indices.size()); - -@@ -10812,6 +10841,7 @@ struct ggml_backend_vk_device_context { - size_t device; - std::string name; + GGML_ASSERT(device < (int) vk_instance.device_supports_membudget.size()); +@@ -12481,6 +12510,7 @@ struct ggml_backend_vk_device_context { std::string description; + bool is_integrated_gpu; + std::string pci_bus_id; + std::string id; }; static const char * ggml_backend_vk_device_get_name(ggml_backend_dev_t dev) { -@@ -10824,6 +10854,11 @@ static const char * ggml_backend_vk_device_get_description(ggml_backend_dev_t de +@@ -12493,6 +12523,11 @@ static const char * ggml_backend_vk_device_get_description(ggml_backend_dev_t de return ctx->description.c_str(); } @@ -75,18 +75,18 @@ index 4070e248..671323ad 100644 static void ggml_backend_vk_device_get_memory(ggml_backend_dev_t device, size_t * free, size_t * total) { ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)device->context; ggml_backend_vk_get_device_memory(ctx->device, free, total); -@@ -10847,6 +10882,7 @@ static enum ggml_backend_dev_type ggml_backend_vk_device_get_type(ggml_backend_d - static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) { +@@ -12519,6 +12554,7 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml + props->name = ggml_backend_vk_device_get_name(dev); props->description = ggml_backend_vk_device_get_description(dev); + props->id = ggml_backend_vk_device_get_id(dev); props->type = ggml_backend_vk_device_get_type(dev); + props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str(); ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total); - props->caps = { -@@ -11265,6 +11301,7 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, - ctx->device = i; - ctx->name = GGML_VK_NAME + std::to_string(i); +@@ -12965,6 +13001,7 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, ctx->description = desc; + ctx->is_integrated_gpu = ggml_backend_vk_get_device_type(i) == vk::PhysicalDeviceType::eIntegratedGpu; + ctx->pci_bus_id = ggml_backend_vk_get_device_pci_id(i); + ctx->id = ggml_backend_vk_get_device_id(i); devices.push_back(new ggml_backend_device { /* .iface = */ ggml_backend_vk_device_i, From 93d7126ce5b11c8bd26797bfc5938287a17254e7 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sat, 4 Oct 2025 19:02:57 +0200 Subject: [PATCH 139/172] sync llama.cpp vulkan code --- .../ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp | 3325 ++++++++++++----- .../src/ggml-vulkan/vulkan-shaders/add.comp | 42 +- .../ggml-vulkan/vulkan-shaders/argmax.comp | 17 +- .../ggml-vulkan/vulkan-shaders/argsort.comp | 68 +- .../ggml-vulkan/vulkan-shaders/conv2d_mm.comp | 24 +- .../vulkan-shaders/copy_to_quant.comp | 11 +- .../vulkan-shaders/dequant_funcs.comp | 136 + .../vulkan-shaders/dequant_iq2_s.comp | 2 +- .../vulkan-shaders/dequant_iq2_xxs.comp | 3 +- .../vulkan-shaders/dequant_iq3_s.comp | 11 +- .../vulkan-shaders/dequant_iq3_xxs.comp | 6 +- .../src/ggml-vulkan/vulkan-shaders/exp.comp | 21 + .../vulkan-shaders/flash_attn.comp | 23 +- .../vulkan-shaders/flash_attn_base.comp | 44 +- .../vulkan-shaders/flash_attn_cm1.comp | 59 +- .../vulkan-shaders/flash_attn_cm2.comp | 40 +- .../flash_attn_split_k_reduce.comp | 4 + .../vulkan-shaders/generic_binary_head.comp | 19 +- .../ggml-vulkan/vulkan-shaders/get_rows.comp | 29 +- .../vulkan-shaders/get_rows_quant.comp | 40 +- .../vulkan-shaders/hardsigmoid.comp | 22 + .../ggml-vulkan/vulkan-shaders/hardswish.comp | 22 + .../ggml-vulkan/vulkan-shaders/im2col.comp | 21 +- .../ggml-vulkan/vulkan-shaders/im2col_3d.comp | 126 + .../vulkan-shaders/mul_mat_vec_base.comp | 66 +- .../vulkan-shaders/mul_mat_vecq.comp | 140 + .../ggml-vulkan/vulkan-shaders/mul_mm.comp | 738 +--- .../vulkan-shaders/mul_mm_cm2.comp | 233 +- .../vulkan-shaders/mul_mm_funcs.comp | 556 +++ .../ggml-vulkan/vulkan-shaders/mul_mmq.comp | 17 +- .../vulkan-shaders/mul_mmq_funcs.comp | 18 +- .../ggml-vulkan/vulkan-shaders/multi_add.comp | 111 + .../vulkan-shaders/opt_step_sgd.comp | 22 + .../src/ggml-vulkan/vulkan-shaders/pad.comp | 27 +- .../vulkan-shaders/quantize_q8_1.comp | 56 +- .../ggml-vulkan/vulkan-shaders/rms_norm.comp | 60 +- .../vulkan-shaders/rms_norm_partials.comp | 65 + .../vulkan-shaders/soft_max_back.comp | 4 + .../src/ggml-vulkan/vulkan-shaders/sqrt.comp | 17 + .../ggml-vulkan/vulkan-shaders/sum_rows.comp | 43 +- .../vulkan-shaders/timestep_embedding.comp | 7 +- .../src/ggml-vulkan/vulkan-shaders/types.comp | 55 +- .../src/ggml-vulkan/vulkan-shaders/utils.comp | 25 + .../vulkan-shaders/vulkan-shaders-gen.cpp | 320 +- 44 files changed, 4941 insertions(+), 1754 deletions(-) create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp index f694f1cd8..adea7783d 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -5,6 +5,14 @@ #include "ggml-cpu.h" #endif +// See https://github.com/KhronosGroup/Vulkan-Hpp?tab=readme-ov-file#extensions--per-device-function-pointers- +#define VULKAN_HPP_DISPATCH_LOADER_DYNAMIC 1 +// We use VULKAN_HPP_DEFAULT_DISPATCHER, but not VULKAN_HPP_DEFAULT_DISPATCH_LOADER_DYNAMIC_STORAGE +// to avoid conflicts with applications or other libraries who might use it. +namespace vk::detail { class DispatchLoaderDynamic; } +vk::detail::DispatchLoaderDynamic & ggml_vk_default_dispatcher(); +#define VULKAN_HPP_DEFAULT_DISPATCHER ggml_vk_default_dispatcher() + #include #include @@ -102,7 +110,9 @@ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; } struct ggml_backend_vk_context; -#define MAX_PARAMETER_COUNT 8 +#define MAX_PARAMETER_COUNT 12 +// Max number of adds that can be fused without exceeding MAX_PARAMETER_COUNT. +#define MAX_FUSED_ADDS (MAX_PARAMETER_COUNT - 3) struct vk_pipeline_struct { std::string name; @@ -113,10 +123,14 @@ struct vk_pipeline_struct { uint32_t parameter_count; std::array wg_denoms; uint32_t align; + // true if fields have been set by ggml_vk_create_pipeline + bool initialized {}; // set to true to request the pipeline is compiled after the dryrun bool needed {}; // set to true when the shader has been compiled bool compiled {}; + // number of registers used, extracted from pipeline executable properties + uint32_t register_count {}; }; typedef std::shared_ptr vk_pipeline; @@ -225,21 +239,6 @@ enum vk_device_architecture { NVIDIA_PRE_TURING, }; -// HSK x HSV -enum FaHeadSizes { - FA_HEAD_SIZE_64, - FA_HEAD_SIZE_80, - FA_HEAD_SIZE_96, - FA_HEAD_SIZE_112, - FA_HEAD_SIZE_128, - FA_HEAD_SIZE_192, - FA_HEAD_SIZE_192_128, - FA_HEAD_SIZE_256, - FA_HEAD_SIZE_576_512, - FA_HEAD_SIZE_UNSUPPORTED, - FA_HEAD_SIZE_COUNT = FA_HEAD_SIZE_UNSUPPORTED, -}; - static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) { vk::PhysicalDeviceProperties props = device.getProperties(); @@ -343,6 +342,44 @@ enum vk_conv_shapes { CONV_SHAPE_COUNT, }; +enum dmmv_wg_sizes { + DMMV_WG_SIZE_SUBGROUP, + DMMV_WG_SIZE_LARGE, + DMMV_WG_SIZE_COUNT, +}; + +enum FaCodePath { + FA_SCALAR, + FA_COOPMAT1, + FA_COOPMAT2, +}; + +struct vk_fa_pipeline_state { + vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, FaCodePath path, bool aligned, bool f32acc) + : HSK(HSK), HSV(HSV), small_rows(small_rows), path(path), aligned(aligned), f32acc(f32acc) {} + + uint32_t HSK, HSV; + bool small_rows; + FaCodePath path; + bool aligned; + bool f32acc; + + bool operator<(const vk_fa_pipeline_state &b) const { + return std::tie(HSK, HSV, small_rows, path, aligned, f32acc) < + std::tie(b.HSK, b.HSV, b.small_rows, b.path, b.aligned, b.f32acc); + } +}; + +enum shader_reduction_mode { + SHADER_REDUCTION_MODE_SHMEM, + SHADER_REDUCTION_MODE_HYBRID, + SHADER_REDUCTION_MODE_SUBGROUP, + SHADER_REDUCTION_MODE_COUNT, +}; + +static constexpr uint32_t num_argsort_pipelines = 11; +static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1); + struct vk_device_struct { std::recursive_mutex mutex; @@ -366,10 +403,20 @@ struct vk_device_struct { bool uma; bool prefer_host_memory; bool float_controls_rte_fp16; - bool subgroup_add; + bool subgroup_arithmetic; bool subgroup_shuffle; + bool subgroup_ballot; + bool subgroup_clustered; + bool multi_add; + bool shader_int64; + bool buffer_device_address; + + bool add_rms_fusion; + uint32_t partials_binding_alignment; bool integer_dot_product; + // 0: default, 1: force mmvq, -1: disable mmvq + int32_t mmvq_mode; bool subgroup_size_control; uint32_t subgroup_min_size; @@ -394,6 +441,8 @@ struct vk_device_struct { bool coopmat2; + bool pipeline_executable_properties_support {}; + size_t idx; bool mul_mat_l[GGML_TYPE_COUNT]; @@ -427,12 +476,15 @@ struct vk_device_struct { vk_pipeline pipeline_matmul_split_k_reduce; vk_pipeline pipeline_quantize_q8_1; + vk_pipeline pipeline_quantize_q8_1_x4; vk_pipeline pipeline_dequant[GGML_TYPE_COUNT]; - vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols]; - vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols]; + vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols]; + vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols]; vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_COUNT]; + vk_pipeline pipeline_dequant_mul_mat_vec_q8_1_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols]; + vk_pipeline pipeline_mul_mat_vec_p021_f16_f32[p021_max_gqa_ratio]; vk_pipeline pipeline_mul_mat_vec_nc_f16_f32; vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT]; @@ -448,6 +500,12 @@ struct vk_device_struct { vk_pipeline pipeline_mul_norepeat[2][2][2]; vk_pipeline pipeline_div[2][2][2]; vk_pipeline pipeline_div_norepeat[2][2][2]; + vk_pipeline pipeline_add_rms[2][2][2]; + vk_pipeline pipeline_add_rms_norepeat[2][2][2]; + + // indexed by num_additional_fused_ops == num_adds - 1 + vk_pipeline pipeline_multi_add[MAX_FUSED_ADDS]; + vk_pipeline pipeline_multi_add_rms[MAX_FUSED_ADDS]; vk_pipeline pipeline_add_id_f32; @@ -455,25 +513,30 @@ struct vk_device_struct { vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bilinear_ac_f32; vk_pipeline pipeline_scale_f32; vk_pipeline pipeline_sqr_f32; + vk_pipeline pipeline_sqrt_f32; vk_pipeline pipeline_sin_f32; vk_pipeline pipeline_cos_f32; vk_pipeline pipeline_clamp_f32; vk_pipeline pipeline_pad_f32; vk_pipeline pipeline_roll_f32; vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32; - vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16; - vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16; + vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16, pipeline_cpy_f32_i32, pipeline_cpy_i32_f32; + vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16, pipeline_contig_cpy_f32_i32, pipeline_contig_cpy_i32_f32; vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT]; vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT]; - vk_pipeline pipeline_set_rows[GGML_TYPE_COUNT]; + vk_pipeline pipeline_set_rows_i32[GGML_TYPE_COUNT]; + vk_pipeline pipeline_set_rows_i64[GGML_TYPE_COUNT]; vk_pipeline pipeline_norm_f32; vk_pipeline pipeline_group_norm_f32; vk_pipeline pipeline_rms_norm_f32; vk_pipeline pipeline_rms_norm_mul_f32; + vk_pipeline pipeline_rms_norm_partials_f32; + vk_pipeline pipeline_rms_norm_mul_partials_f32; vk_pipeline pipeline_rms_norm_back_f32; vk_pipeline pipeline_l2_norm_f32; // [src/dst 0=fp32,1=fp16] + vk_pipeline pipeline_exp[2]; vk_pipeline pipeline_gelu[2]; vk_pipeline pipeline_gelu_erf[2]; vk_pipeline pipeline_gelu_quick[2]; @@ -481,6 +544,8 @@ struct vk_device_struct { vk_pipeline pipeline_relu[2]; vk_pipeline pipeline_tanh[2]; vk_pipeline pipeline_sigmoid[2]; + vk_pipeline pipeline_hardsigmoid[2]; + vk_pipeline pipeline_hardswish[2]; vk_pipeline pipeline_geglu[2]; vk_pipeline pipeline_reglu[2]; @@ -499,32 +564,31 @@ struct vk_device_struct { vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16; vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16; vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16; - vk_pipeline pipeline_argsort_f32; + vk_pipeline pipeline_argsort_f32[num_argsort_pipelines]; vk_pipeline pipeline_sum_rows_f32; vk_pipeline pipeline_argmax_f32; vk_pipeline pipeline_count_equal_i32; vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16; + vk_pipeline pipeline_im2col_3d_f32, pipeline_im2col_3d_f32_f16; vk_pipeline pipeline_timestep_embedding_f32; vk_pipeline pipeline_conv_transpose_1d_f32; vk_pipeline pipeline_pool2d_f32; vk_pipeline pipeline_rwkv_wkv6_f32; vk_pipeline pipeline_rwkv_wkv7_f32; vk_pipeline pipeline_opt_step_adamw_f32; + vk_pipeline pipeline_opt_step_sgd_f32; vk_pipeline pipeline_conv2d_f32[CONV_SHAPE_COUNT]; vk_pipeline pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT]; - vk_pipeline pipeline_conv2d_dw_whcn_f32; - vk_pipeline pipeline_conv2d_dw_cwhn_f32; + vk_pipeline pipeline_conv_transpose_2d_f32[CONV_SHAPE_COUNT]; + vk_pipeline pipeline_conv_transpose_2d_f16_f32[CONV_SHAPE_COUNT]; + vk_pipeline pipeline_conv2d_dw_whcn_f32, pipeline_conv2d_dw_whcn_f16_f32; + vk_pipeline pipeline_conv2d_dw_cwhn_f32, pipeline_conv2d_dw_cwhn_f16_f32; - // [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned} - vk_pipeline pipeline_flash_attn_f32_f16_cm2[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2]; - - vk_pipeline pipeline_flash_attn_f32_f16_cm1[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2]; - - vk_pipeline pipeline_flash_attn_f32_f16[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2]; + std::map pipeline_flash_attn_f32_f16[GGML_TYPE_COUNT]; vk_pipeline pipeline_flash_attn_split_k_reduce; - std::unordered_map pipelines; + std::vector all_pipelines; std::vector> pinned_memory; @@ -535,6 +599,8 @@ struct vk_device_struct { bool disable_fusion; bool disable_host_visible_vidmem; + bool allow_sysmem_fallback; + bool disable_graph_optimize; #ifdef GGML_VULKAN_MEMORY_DEBUG std::unique_ptr memory_logger; @@ -555,15 +621,15 @@ struct vk_device_struct { compute_queue.cmd_pool.destroy(device); transfer_queue.cmd_pool.destroy(device); - for (auto& pipeline : pipelines) { - if (pipeline.second.expired()) { + for (auto& pipeline : all_pipelines) { + if (pipeline.expired()) { continue; } - vk_pipeline pl = pipeline.second.lock(); + vk_pipeline pl = pipeline.lock(); ggml_vk_destroy_pipeline(device, pl); } - pipelines.clear(); + all_pipelines.clear(); device.destroyDescriptorSetLayout(dsl); @@ -591,6 +657,7 @@ struct vk_buffer_struct { vk::MemoryPropertyFlags memory_property_flags; void * ptr; size_t size = 0; + vk::DeviceAddress bda_addr {}; vk_device device; @@ -756,6 +823,57 @@ static vk_op_unary_push_constants vk_op_unary_push_constants_init(const ggml_ten p.nb12 = (uint32_t)(dst->nb[2] / dst_tsize); p.nb13 = (uint32_t)(dst->nb[3] / dst_tsize); + return p; // offsets are initialized later in ggml_vk_op +} + +struct vk_op_pad_push_constants { + uint32_t ne; + uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03; + uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13; + uint32_t misalign_offsets; + + uint32_t lp0; uint32_t rp0; + uint32_t lp1; uint32_t rp1; + uint32_t lp2; uint32_t rp2; + uint32_t lp3; uint32_t rp3; +}; + +static vk_op_pad_push_constants vk_op_pad_push_constants_init(const ggml_tensor * src0, const ggml_tensor * dst) { + int64_t ne = ggml_nelements(dst); + GGML_ASSERT(ne <= (int64_t)std::numeric_limits::max()); + + vk_op_pad_push_constants p{}; + p.ne = (uint32_t)ne; + + size_t src0_tsize = ggml_type_size(src0->type); + p.ne00 = (uint32_t)src0->ne[0]; + p.ne01 = (uint32_t)src0->ne[1]; + p.ne02 = (uint32_t)src0->ne[2]; + p.ne03 = (uint32_t)src0->ne[3]; + p.nb00 = (uint32_t)(src0->nb[0] / src0_tsize); + p.nb01 = (uint32_t)(src0->nb[1] / src0_tsize); + p.nb02 = (uint32_t)(src0->nb[2] / src0_tsize); + p.nb03 = (uint32_t)(src0->nb[3] / src0_tsize); + + size_t dst_tsize = ggml_type_size(dst->type); + p.ne10 = (uint32_t)dst->ne[0]; + p.ne11 = (uint32_t)dst->ne[1]; + p.ne12 = (uint32_t)dst->ne[2]; + p.ne13 = (uint32_t)dst->ne[3]; + p.nb10 = (uint32_t)(dst->nb[0] / dst_tsize); + p.nb11 = (uint32_t)(dst->nb[1] / dst_tsize); + p.nb12 = (uint32_t)(dst->nb[2] / dst_tsize); + p.nb13 = (uint32_t)(dst->nb[3] / dst_tsize); + + p.lp0 = dst->op_params[0]; + p.rp0 = dst->op_params[1]; + p.lp1 = dst->op_params[2]; + p.rp1 = dst->op_params[3]; + p.lp2 = dst->op_params[4]; + p.rp2 = dst->op_params[5]; + p.lp3 = dst->op_params[6]; + p.rp3 = dst->op_params[7]; + return p; // fastdiv values and offsets are initialized later in ggml_vk_op } @@ -800,6 +918,19 @@ struct vk_op_binary_push_constants { float param1; float param2; int32_t param3; }; +struct vk_op_multi_add_push_constants { + // shape for dst + uint32_t ne20; uint32_t ne21; uint32_t ne22; uint32_t ne23; + + // strides for srcs+dst + uint32_t nb[MAX_PARAMETER_COUNT][4]; + + uint32_t rms_partials; +}; +// update multi_add.comp if this changes +static_assert(MAX_PARAMETER_COUNT == 12); +static_assert(sizeof(vk_op_multi_add_push_constants) <= 256); + struct vk_op_add_id_push_constants { uint32_t ne0; uint32_t ne1; @@ -855,11 +986,11 @@ struct vk_op_soft_max_push_constants { struct vk_op_argsort_push_constants { uint32_t ncols; - uint32_t ncols_pad; int32_t order; }; struct vk_op_im2col_push_constants { + uint64_t dst_addr; uint32_t batch_offset; uint32_t offset_delta; uint32_t IC; uint32_t IW; uint32_t IH; @@ -872,6 +1003,38 @@ struct vk_op_im2col_push_constants { int32_t d0; int32_t d1; }; +struct vk_op_im2col_3d_push_constants { + uint64_t dst_addr; + uint32_t nb10; + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + uint32_t s0; + uint32_t s1; + uint32_t s2; + uint32_t p0; + uint32_t p1; + uint32_t p2; + uint32_t d0; + uint32_t d1; + uint32_t d2; + uint32_t IW; + uint32_t IH; + uint32_t ID; + uint32_t IC; + uint32_t KW; + uint32_t OH; + uint32_t KD_KH_KW; + uint32_t KH_KW; + uint32_t IC_KD_KH_KW; + uint32_t N_OD_OH; + uint32_t OD_OH; + uint32_t OD_OH_OW_IC_KD_KH_KW; + uint32_t OH_OW_IC_KD_KH_KW; + uint32_t OW_IC_KD_KH_KW; + uint32_t misalign_offsets; +}; + struct vk_op_timestep_embedding_push_constants { uint32_t nb1; uint32_t dim; @@ -964,6 +1127,56 @@ template <> void init_pushconst_fastdiv(vk_op_conv2d_push_constants &p) { init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL); } +struct vk_op_conv_transpose_2d_push_constants { + uint32_t Cout; + uint32_t Cin; + uint32_t N; + + uint32_t KW; + uint32_t KH; + uint32_t W; + uint32_t H; + uint32_t OW; + uint32_t OH; + + uint32_t s0; + uint32_t s1; + uint32_t p0; + uint32_t p1; + uint32_t d0; + uint32_t d1; + + uint32_t nb01; + uint32_t nb02; + uint32_t nb03; + + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + + uint32_t nb1; + uint32_t nb2; + uint32_t nb3; + + // init_fastdiv_values constants for dividing by KW, KW*KH, OW, OW*OH, s0, s1 + uint32_t KWmp; uint32_t KWL; + uint32_t KWKHmp; uint32_t KWKHL; + uint32_t OWmp; uint32_t OWL; + uint32_t OWOHmp; uint32_t OWOHL; + uint32_t s0mp; uint32_t s0L; + uint32_t s1mp; uint32_t s1L; +}; + +template <> void init_pushconst_fastdiv(vk_op_conv_transpose_2d_push_constants &p) { + // Compute magic values to divide by KW, KW*KH, OW, OW*OH, s0, s1 + init_fastdiv_values(p.KW, p.KWmp, p.KWL); + init_fastdiv_values(p.KW*p.KH, p.KWKHmp, p.KWKHL); + init_fastdiv_values(p.OW, p.OWmp, p.OWL); + init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL); + init_fastdiv_values(p.s0, p.s0mp, p.s0L); + init_fastdiv_values(p.s1, p.s1mp, p.s1L); +} + struct vk_op_conv2d_dw_push_constants { uint32_t ne; uint32_t batches; @@ -990,6 +1203,39 @@ struct vk_op_upscale_push_constants { float sf0; float sf1; float sf2; float sf3; }; +struct vk_op_sum_rows_push_constants +{ + uint32_t n_cols; + uint32_t ne01, ne02; + uint32_t nb01, nb02, nb03; + uint32_t nb11, nb12, nb13; + float weight; + uint32_t misalign_offsets; + uint32_t ne0_12mp, ne0_12L; + uint32_t ne0_1mp, ne0_1L; +}; + +static vk_op_sum_rows_push_constants vk_op_sum_rows_push_constants_init(const ggml_tensor * src, const ggml_tensor * dst, int64_t n_cols) { + uint32_t type_size = (uint32_t)ggml_type_size(src->type); + vk_op_sum_rows_push_constants p = {}; + p.n_cols = (uint32_t)n_cols; + p.ne01 = (uint32_t)src->ne[1]; + p.ne02 = (uint32_t)src->ne[2]; + p.nb01 = (uint32_t)src->nb[1] / type_size; + p.nb02 = (uint32_t)src->nb[2] / type_size; + p.nb03 = (uint32_t)src->nb[3] / type_size; + p.nb11 = (uint32_t)dst->nb[1] / type_size; + p.nb12 = (uint32_t)dst->nb[2] / type_size; + p.nb13 = (uint32_t)dst->nb[3] / type_size; + p.weight = 1.0f; + return p; +} + +template <> void init_pushconst_fastdiv(vk_op_sum_rows_push_constants &p) { + init_fastdiv_values(p.ne01*p.ne02, p.ne0_12mp, p.ne0_12L); + init_fastdiv_values(p.ne01, p.ne0_1mp, p.ne0_1L); +} + // Allow pre-recording command buffers struct vk_staging_memcpy { vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {} @@ -999,6 +1245,14 @@ struct vk_staging_memcpy { size_t n; }; +struct vk_staging_memset { + vk_staging_memset(void * _dst, uint32_t _val, size_t _n) : dst(_dst), val(_val), n(_n) {} + + void * dst; + uint32_t val; + size_t n; +}; + struct vk_context_struct { vk_submission * s; std::vector seqs; @@ -1007,6 +1261,7 @@ struct vk_context_struct { std::vector in_memcpys; std::vector out_memcpys; + std::vector memsets; vk_command_pool * p {}; }; @@ -1045,8 +1300,6 @@ static std::string format_size(size_t size) { return oss.str(); } -static std::mutex log_mutex; - class vk_memory_logger { public: vk_memory_logger(): total_device(0), total_host(0) {} @@ -1110,20 +1363,26 @@ class vk_perf_logger { return; } if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) { - const uint64_t m = node->src[0]->ne[1]; - const uint64_t n = node->src[1]->ne[1]; - const uint64_t k = node->src[1]->ne[0]; - std::string name = ggml_op_name(node->op); - if (n == 1) { - name += "_VEC m=" + std::to_string(m) + " k=" + std::to_string(k); - } else { - name += " m=" + std::to_string(m) + " n=" + std::to_string(n) + " k=" + std::to_string(k); + const uint64_t m = node->src[0]->ne[1]; + const uint64_t n = node->ne[1]; + const uint64_t k = node->src[1]->ne[0]; + const uint64_t batch = node->src[1]->ne[2] * node->src[1]->ne[3]; + std::string name = ggml_op_name(node->op); + if ((node->op == GGML_OP_MUL_MAT && n <= mul_mat_vec_max_cols) || + (node->op == GGML_OP_MUL_MAT_ID && node->src[2]->ne[1] == 1)) { + name += "_VEC"; + } + name += " "; + name += ggml_type_name(node->src[0]->type); + name += " m=" + std::to_string(m) + " n=" + std::to_string(n) + " k=" + std::to_string(k); + if (batch > 1) { + name += " batch=" + std::to_string(batch); } timings[name].push_back(time); - flops[name].push_back(m * n * (k + (k - 1))); + flops[name].push_back(m * n * (k + (k - 1)) * batch); return; } - if (node->op == GGML_OP_CONV_2D) { + if (node->op == GGML_OP_CONV_2D || node->op == GGML_OP_CONV_TRANSPOSE_2D) { std::string name = ggml_op_name(node->op); ggml_tensor * knl = node->src[0]; uint64_t OW = node->ne[0]; @@ -1132,7 +1391,7 @@ class vk_perf_logger { uint64_t Cout = node->ne[2]; uint64_t KW = knl->ne[0]; uint64_t KH = knl->ne[1]; - uint64_t Cin = knl->ne[2]; + uint64_t Cin = node->src[1]->ne[2]; // KxCRS @ CRSxNPQ = KxNPQ -> M=K, K=CRS, N=NPQ uint64_t size_M = Cout; uint64_t size_K = Cin * KW * KH; @@ -1144,6 +1403,12 @@ class vk_perf_logger { timings[name].push_back(time); return; } + if (node->op == GGML_OP_RMS_NORM) { + std::string name = ggml_op_name(node->op); + name += "(" + std::to_string(node->ne[0]) + "," + std::to_string(node->ne[1]) + "," + std::to_string(node->ne[2]) + "," + std::to_string(node->ne[3]) + ")"; + timings[name].push_back(time); + return; + } timings[ggml_op_name(node->op)].push_back(time); } private: @@ -1158,10 +1423,25 @@ struct ggml_backend_vk_context { size_t semaphore_idx, event_idx; ggml_vk_garbage_collector gc; - size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k; - vk_buffer prealloc_x, prealloc_y, prealloc_split_k; + size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k, prealloc_size_add_rms_partials, prealloc_size_add_rms_partials_offset; + vk_buffer prealloc_x, prealloc_y, prealloc_split_k, prealloc_add_rms_partials; vk::Fence fence, almost_ready_fence; bool almost_ready_fence_pending {}; + // Set before op_add and unset after op_rms_norm to indicate that the add should + // write partial sums to accumulate the square of the vector components + bool do_add_rms_partials; + + // Cache most recent tensor that was converted into prealloc_y, and what pipeline it used to convert. + vk_pipeline_struct * prealloc_y_last_pipeline_used {}; + const ggml_tensor * prealloc_y_last_tensor_used {}; + + // Track which nodes have been used since the last sync, and whether they were written to + std::vector unsynced_nodes_written; + std::vector unsynced_nodes_read; + // Track which prealloc buffers have pending reads that need to be synchronized. + // These are checked before writing to the buffer (and call ggml_vk_sync_buffers if set), + // and set to true after the buffer contents are consumed. + bool prealloc_x_need_sync, prealloc_y_need_sync, prealloc_split_k_need_sync; vk_buffer buffer_pool[MAX_VK_BUFFERS]; @@ -1209,6 +1489,8 @@ struct ggml_backend_vk_buffer_context { }; #ifdef GGML_VULKAN_MEMORY_DEBUG +static std::mutex log_mutex; + void vk_memory_logger::log_allocation(vk_buffer_ref buf_ref, size_t size) { std::lock_guard guard(log_mutex); vk_buffer buf = buf_ref.lock(); @@ -1253,6 +1535,7 @@ struct vk_instance_t { PFN_vkCmdInsertDebugUtilsLabelEXT pfn_vkCmdInsertDebugUtilsLabelEXT = {}; std::vector device_indices; + std::vector device_supports_membudget; vk_device devices[GGML_VK_MAX_DEVICES]; }; @@ -1370,7 +1653,9 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin } vk::ComputePipelineCreateInfo compute_pipeline_create_info( - vk::PipelineCreateFlags{}, + device->pipeline_executable_properties_support ? + vk::PipelineCreateFlagBits::eCaptureStatisticsKHR : + vk::PipelineCreateFlags{}, pipeline_shader_create_info, pipeline->layout); @@ -1399,9 +1684,23 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin vk_instance.pfn_vkSetDebugUtilsObjectNameEXT(device->device, &static_cast(duoni)); } + if (device->pipeline_executable_properties_support) { + vk::PipelineExecutableInfoKHR executableInfo; + executableInfo.pipeline = pipeline->pipeline; + + auto statistics = device->device.getPipelineExecutableStatisticsKHR(executableInfo); + for (auto & s : statistics) { + // "Register Count" is reported by NVIDIA drivers. + if (strcmp(s.name, "Register Count") == 0) { + VK_LOG_DEBUG(pipeline->name << " " << s.name << ": " << s.value.u64 << " registers"); + pipeline->register_count = (uint32_t)s.value.u64; + } + } + } + { std::lock_guard guard(device->mutex); - device->pipelines.insert({ pipeline->name, pipeline }); + device->all_pipelines.push_back(pipeline); } { @@ -1705,8 +2004,8 @@ static uint32_t find_properties(const vk::PhysicalDeviceMemoryProperties* mem_pr return UINT32_MAX; } -static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, vk::MemoryPropertyFlags req_flags, vk::MemoryPropertyFlags fallback_flags = vk::MemoryPropertyFlags(0)) { - VK_LOG_DEBUG("ggml_vk_create_buffer(" << device->name << ", " << size << ", " << to_string(req_flags) << ", " << to_string(fallback_flags) << ")"); +static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std::initializer_list & req_flags_list) { + VK_LOG_DEBUG("ggml_vk_create_buffer(" << device->name << ", " << size << ", " << to_string(req_flags_list.begin()[0]) << ", " << to_string(req_flags_list.begin()[req_flags_list.size()-1]) << ")"); if (size > device->max_memory_allocation_size) { throw vk::OutOfDeviceMemoryError("Requested buffer size exceeds device memory allocation limit"); } @@ -1718,10 +2017,17 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, vk::Memor return buf; } + vk::BufferUsageFlags usage_flags = vk::BufferUsageFlagBits::eStorageBuffer | vk::BufferUsageFlagBits::eTransferSrc | vk::BufferUsageFlagBits::eTransferDst; + vk::MemoryAllocateFlags mem_flags {}; + if (device->buffer_device_address) { + usage_flags |= vk::BufferUsageFlagBits::eShaderDeviceAddress; + mem_flags |= vk::MemoryAllocateFlagBits::eDeviceAddress; + } + vk::BufferCreateInfo buffer_create_info{ vk::BufferCreateFlags(), size, - vk::BufferUsageFlagBits::eStorageBuffer | vk::BufferUsageFlagBits::eTransferSrc | vk::BufferUsageFlagBits::eTransferDst, + usage_flags, vk::SharingMode::eExclusive, 0, nullptr, @@ -1733,42 +2039,36 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, vk::Memor vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties(); - uint32_t memory_type_index = UINT32_MAX; + const vk::MemoryAllocateFlagsInfo mem_flags_info { mem_flags }; - memory_type_index = find_properties(&mem_props, &mem_req, req_flags); - buf->memory_property_flags = req_flags; + for (auto it = req_flags_list.begin(); it != req_flags_list.end(); it++) { + const auto & req_flags = *it; - if (memory_type_index == UINT32_MAX && fallback_flags) { - memory_type_index = find_properties(&mem_props, &mem_req, fallback_flags); - buf->memory_property_flags = fallback_flags; + uint32_t memory_type_index = find_properties(&mem_props, &mem_req, req_flags); + + if (memory_type_index == UINT32_MAX) { + continue; + } + buf->memory_property_flags = req_flags; + + try { + buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index, &mem_flags_info }); + break; + } catch (const vk::SystemError& e) { + // loop and retry + // during last attempt throw the exception + if (it + 1 == req_flags_list.end()) { + device->device.destroyBuffer(buf->buffer); + throw e; + } + } } - if (memory_type_index == UINT32_MAX) { + if (!buf->device_memory) { device->device.destroyBuffer(buf->buffer); throw vk::OutOfDeviceMemoryError("No suitable memory type found"); } - try { - buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index }); - } catch (const vk::SystemError& e) { - if (buf->memory_property_flags != fallback_flags) { - // Try again with fallback flags - memory_type_index = find_properties(&mem_props, &mem_req, fallback_flags); - buf->memory_property_flags = fallback_flags; - - try { - buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index }); - } - catch (const vk::SystemError& e) { - device->device.destroyBuffer(buf->buffer); - throw e; - } - } else { - // Out of Host/Device memory, clean up buffer - device->device.destroyBuffer(buf->buffer); - throw e; - } - } buf->ptr = nullptr; if (buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) { @@ -1780,6 +2080,11 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, vk::Memor buf->device = device; buf->size = size; + if (device->buffer_device_address) { + const vk::BufferDeviceAddressInfo addressInfo(buf->buffer); + buf->bda_addr = device->device.getBufferAddress(addressInfo); + } + #ifdef GGML_VULKAN_MEMORY_DEBUG device->memory_logger->log_allocation(buf, size); #endif @@ -1789,7 +2094,7 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, vk::Memor static vk_buffer ggml_vk_create_buffer_check(vk_device& device, size_t size, vk::MemoryPropertyFlags req_flags, vk::MemoryPropertyFlags fallback_flags = vk::MemoryPropertyFlags(0)) { try { - return ggml_vk_create_buffer(device, size, req_flags, fallback_flags); + return ggml_vk_create_buffer(device, size, {req_flags, fallback_flags}); } catch (const vk::SystemError& e) { std::cerr << "ggml_vulkan: Memory allocation of size " << size << " failed." << std::endl; std::cerr << "ggml_vulkan: " << e.what() << std::endl; @@ -1801,15 +2106,29 @@ static vk_buffer ggml_vk_create_buffer_device(vk_device& device, size_t size) { vk_buffer buf; try { if (device->prefer_host_memory) { - buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent, vk::MemoryPropertyFlagBits::eDeviceLocal); + buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent, + vk::MemoryPropertyFlagBits::eDeviceLocal}); } else if (device->uma) { // Fall back to host memory type - buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent); + buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eDeviceLocal, + vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent}); } else if (device->disable_host_visible_vidmem) { - buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal, vk::MemoryPropertyFlagBits::eDeviceLocal); + if (device->allow_sysmem_fallback) { + buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eDeviceLocal, + vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent}); + } else { + buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eDeviceLocal}); + } } else { // use rebar if available, otherwise fallback to device only visible memory - buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal | vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent, vk::MemoryPropertyFlagBits::eDeviceLocal); + if (device->allow_sysmem_fallback) { + buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eDeviceLocal | vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent, + vk::MemoryPropertyFlagBits::eDeviceLocal, + vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent}); + } else { + buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eDeviceLocal | vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent, + vk::MemoryPropertyFlagBits::eDeviceLocal}); + } } } catch (const vk::SystemError& e) { std::cerr << "ggml_vulkan: Device memory allocation of size " << size << " failed." << std::endl; @@ -1838,14 +2157,18 @@ static vk_subbuffer ggml_vk_subbuffer(vk_buffer& buf) { return { buf, 0, VK_WHOLE_SIZE }; } -static void ggml_vk_sync_buffers(vk_context& ctx) { +static void ggml_vk_sync_buffers(ggml_backend_vk_context* ctx, vk_context& subctx) { VK_LOG_DEBUG("ggml_vk_sync_buffers()"); - const bool transfer_queue = ctx->p->q->transfer_only; + const bool transfer_queue = subctx->p->q->transfer_only; - ctx->s->buffer.pipelineBarrier( - ctx->p->q->stage_flags, - ctx->p->q->stage_flags, + if (ctx) { + ctx->prealloc_x_need_sync = ctx->prealloc_y_need_sync = ctx->prealloc_split_k_need_sync = false; + } + + subctx->s->buffer.pipelineBarrier( + subctx->p->q->stage_flags, + subctx->p->q->stage_flags, {}, { { { !transfer_queue ? (vk::AccessFlagBits::eShaderRead | vk::AccessFlagBits::eShaderWrite | vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) : (vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) }, @@ -1872,47 +2195,12 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector&& events ); } -enum FaCodePath { - FA_SCALAR, - FA_COOPMAT1, - FA_COOPMAT2, -}; - -static FaHeadSizes fa_get_head_sizes(uint32_t hsk, uint32_t hsv) { - if (hsk != 192 && hsk != 576 && hsk != hsv) { - return FA_HEAD_SIZE_UNSUPPORTED; - } - switch (hsk) { - case 64: return FA_HEAD_SIZE_64; - case 80: return FA_HEAD_SIZE_80; - case 96: return FA_HEAD_SIZE_96; - case 112: return FA_HEAD_SIZE_112; - case 128: return FA_HEAD_SIZE_128; - case 192: - if (hsv == 192) { - return FA_HEAD_SIZE_192; - } else if (hsv == 128) { - return FA_HEAD_SIZE_192_128; - } else { - return FA_HEAD_SIZE_UNSUPPORTED; - } - case 256: return FA_HEAD_SIZE_256; - case 576: - if (hsv == 512) { - return FA_HEAD_SIZE_576_512; - } else { - return FA_HEAD_SIZE_UNSUPPORTED; - } - default: return FA_HEAD_SIZE_UNSUPPORTED; - } -} - // number of rows/cols for flash attention shader static constexpr uint32_t flash_attention_num_small_rows = 32; static constexpr uint32_t scalar_flash_attention_num_small_rows = 1; static uint32_t get_fa_scalar_num_large_rows(uint32_t hsv) { - if (hsv >= 512) { + if (hsv >= 192) { return 2; } else { return 8; @@ -1942,7 +2230,13 @@ static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint3 if (small_rows) { return {scalar_flash_attention_num_small_rows, 64}; } else { - return {get_fa_scalar_num_large_rows(hsv), 32}; + if ((hsv | hsk) & 8) { + // HSV/HSK not being a multiple of 16 makes D_split smaller, which makes cols_per_iter + // larger, and Bc needs to be >= cols_per_thread. 64 is large enough, 32 is not. + return {get_fa_scalar_num_large_rows(hsv), 64}; + } else { + return {get_fa_scalar_num_large_rows(hsv), 32}; + } } } @@ -1960,8 +2254,8 @@ static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint3 } // small cols to reduce register count - if (ggml_is_quantized(type) || hsk >= 256) { - if (hsk >= 512) { + if (ggml_is_quantized(type) || hsk >= 256 || hsv >= 256) { + if (hsk >= 512 || hsv >= 512) { return {32, 32}; } else { return {64, 32}; @@ -1970,6 +2264,10 @@ static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint3 return {64, 64}; } +static uint32_t fa_align(FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type type, bool small_rows) { + return fa_rows_cols(path, hsk, hsv, 0, type, small_rows)[1]; +} + static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector& warptile, bool mul_mat_id, ggml_type src0_type) { uint32_t lut_size = 0; @@ -2008,10 +2306,11 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec const uint32_t warps = warptile[0] / warptile[10]; const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size; - const uint32_t mmid_row_ids = mul_mat_id ? (4096 * sizeof(uint32_t) + 4/*_ne1*/) : 0; + const uint32_t mmid_row_ids = mul_mat_id ? (warptile[2] * 2 * sizeof(uint16_t)) : 0; const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0; + const uint32_t ballots_sh = mul_mat_id ? (warps * 4 * sizeof(uint32_t)) : 0; - const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size; + const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size + ballots_sh; const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; VK_LOG_DEBUG("ggml_vk_matmul_shmem_support(warptile=(" << warptile[0] << "," << warptile[1] << "," << warptile[2] << "), " @@ -2095,8 +2394,17 @@ static void ggml_vk_load_shaders(vk_device& device) { const uint32_t subgroup_size_16 = std::max(device->subgroup_size, 16u); const uint32_t subgroup_size_32 = std::max(device->subgroup_size, 32u); + const uint32_t mul_mat_subgroup_size = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control) ? device->subgroup_min_size : device->subgroup_size; + const uint32_t mul_mat_subgroup_size_8 = std::max(mul_mat_subgroup_size, 8u); + const uint32_t mul_mat_subgroup_size_16 = std::max(mul_mat_subgroup_size, 16u); + const uint32_t mul_mat_subgroup_size_32 = std::max(mul_mat_subgroup_size, 32u); + + const bool subgroup_min_size_16 = (!device->subgroup_size_control && device->subgroup_size >= 16) || + (device->subgroup_size_control && device->subgroup_max_size >= 16); + // mulmat std::vector l_warptile, m_warptile, s_warptile, + l_warptile_id, m_warptile_id, s_warptile_id, l_warptile_mmq, m_warptile_mmq, s_warptile_mmq, l_warptile_mmq_int, m_warptile_mmq_int, s_warptile_mmq_int, l_warptile_mmq_k, m_warptile_mmq_k, s_warptile_mmq_k, @@ -2133,9 +2441,9 @@ static void ggml_vk_load_shaders(vk_device& device) { s_mmq_wg_denoms_k = { 32, 64, 1 }; // spec constants and tile sizes for quant matmul_id - l_warptile_mmqid = { 256, 128, 128, 16, 0 }; - m_warptile_mmqid = { 256, 128, 64, 16, 0 }; - s_warptile_mmqid = { 256, 128, 64, 16, 0 }; + l_warptile_mmqid = { 256, 128, 128, 16, 1, device->subgroup_size }; + m_warptile_mmqid = { 256, 128, 64, 16, 0, device->subgroup_size }; + s_warptile_mmqid = { 256, 128, 64, 16, 0, device->subgroup_size }; l_mmqid_wg_denoms = { 128, 128, 1 }; m_mmqid_wg_denoms = { 128, 64, 1 }; s_mmqid_wg_denoms = { 128, 64, 1 }; @@ -2167,9 +2475,18 @@ static void ggml_vk_load_shaders(vk_device& device) { m_warptile_mmq_int = { 128, 64, 64, 32, subgroup_size_8, 32, 2, 2, 2, 1, subgroup_size_8 }; s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, subgroup_size_8 }; + l_warptile_id = { 128, 128, 128, 16, mul_mat_subgroup_size_16 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_16 }; + m_warptile_id = { 128, 64, 64, 16, mul_mat_subgroup_size_16, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_16 }; + s_warptile_id = { mul_mat_subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_16 }; + + l_warptile_mmqid = { 128, 128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_8 }; + m_warptile_mmqid = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_8 }; + s_warptile_mmqid = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_8 }; + // chip specific tuning if ((device->architecture == AMD_GCN) && (device->driver_id != vk::DriverId::eAmdProprietary)) { m_warptile_mmq = m_warptile_mmq_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 }; + m_warptile_mmqid = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 }; } l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 }; @@ -2195,14 +2512,14 @@ static void ggml_vk_load_shaders(vk_device& device) { } // Disable mul_mat_id if not enough shared memory is available - if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, true, t)) { + if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmqid, true, t)) { device->mul_mat_id_s[i] = false; device->mul_mat_id_m[i] = false; device->mul_mat_id_l[i] = false; - } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, true, t)) { + } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmqid, true, t)) { device->mul_mat_id_m[i] = false; device->mul_mat_id_l[i] = false; - } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, true, t)) { + } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmqid, true, t)) { device->mul_mat_id_l[i] = false; } } @@ -2225,7 +2542,7 @@ static void ggml_vk_load_shaders(vk_device& device) { } std::vector> compiles; - auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint, + auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const char *name, size_t spv_size, const void* spv_data, const char *entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array wg_denoms, const std::vector& specialization_constants, uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) { @@ -2235,11 +2552,14 @@ static void ggml_vk_load_shaders(vk_device& device) { if (!pipeline) { pipeline = std::make_shared(); + } + if (!pipeline->initialized) { pipeline->name = name; pipeline->parameter_count = parameter_count; pipeline->push_constant_size = push_constant_size; pipeline->wg_denoms = wg_denoms; pipeline->align = align; + pipeline->initialized = true; } if (!pipeline->needed || pipeline->compiled) { @@ -2259,6 +2579,14 @@ static void ggml_vk_load_shaders(vk_device& device) { parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size)); }; + auto const &ggml_vk_create_pipeline2 = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const char *entrypoint, + uint32_t parameter_count, uint32_t push_constant_size, std::array wg_denoms, const std::vector& specialization_constants, + uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) { + return ggml_vk_create_pipeline(device, pipeline, name.c_str(), spv_size, spv_data, entrypoint, + parameter_count, push_constant_size, wg_denoms, specialization_constants, + align, disable_robustness, require_full_subgroups, required_subgroup_size); + }; + auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::array { return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows)[0], 1, 1}; }; @@ -2285,26 +2613,30 @@ static void ggml_vk_load_shaders(vk_device& device) { return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split}; }; -#define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, HSK, HSV, HEAD_SIZES) \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ - #define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \ - CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64, 64, 64) \ - CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 80, 80, 80) \ - CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 96, 96, 96) \ - CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 112, 112, 112) \ - CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 128, 128, 128) \ - CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 192, 192, 192) \ - CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 192, 128, 192_128) \ - CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 256, 256, 256) \ - CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 576, 512, 576_512) + for (auto &fa : device->pipeline_flash_attn_f32_f16[TYPE]) { \ + uint32_t HSK = fa.first.HSK; \ + uint32_t HSV = fa.first.HSV; \ + bool small_rows = fa.first.small_rows; \ + FaCodePath path = fa.first.path; \ + bool aligned = fa.first.aligned; \ + bool f32acc = fa.first.f32acc; \ + if (path == FAPATH) { \ + if (aligned) { \ + if (f32acc) { \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + } else { \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + } \ + } else { \ + if (f32acc) { \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + } else { \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + } \ + } \ + } \ + } CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, ) CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, ) @@ -2327,7 +2659,6 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_COOPMAT2, _cm2) } #endif -#undef CREATE_FA2 #undef CREATE_FA #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) @@ -2374,32 +2705,34 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_MXFP4], matmul_mxfp4_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) - CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4) + GGML_ASSERT(device->subgroup_ballot); + + CREATE_MM2(pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4) #if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) if (device->coopmat_bf16_support) { - CREATE_MM(pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4) } #endif - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f16acc, matmul_id_iq1_s_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f16acc, matmul_id_iq1_m_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f16acc, matmul_id_mxfp4_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) #undef CREATE_MM #undef CREATE_MM2 } else @@ -2486,79 +2819,56 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); } - CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + GGML_ASSERT(device->subgroup_ballot); + + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); #if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) if (device->coopmat_bf16_support) { - CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); } #endif - if (device->coopmat_acc_f16_support) { - CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - - CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f16acc, matmul_id_iq1_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f16acc, matmul_id_iq1_m_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f16acc, matmul_id_mxfp4_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - } else { - CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - - CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f16acc, matmul_id_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f16acc, matmul_id_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f16acc, matmul_id_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - } + CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); #undef CREATE_MM2 #undef CREATE_MM } else #endif // defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) if (device->fp16) { // Create 6 variants, {s,m,l}x{unaligned,aligned} -#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ +#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \ if (device->mul_mat ## ID ## _l[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _m[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _s[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _l[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _m[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _s[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ #define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ if (device->mul_mat ## ID ## _l[TYPE]) { \ @@ -2575,38 +2885,38 @@ static void ggml_vk_load_shaders(vk_device& device) { } \ // Create 2 variants, {f16,f32} accumulator -#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ - CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ - CREATE_MM(TYPE, PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ +#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \ + CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \ + CREATE_MM(TYPE, PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \ - CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); - CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); - CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); - CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); - CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], matmul_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1], matmul_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1], matmul_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0], matmul_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], matmul_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1], matmul_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1], matmul_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0], matmul_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); - CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K], matmul_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K], matmul_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K], matmul_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K], matmul_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K], matmul_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S], matmul_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M], matmul_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS], matmul_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS], matmul_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S], matmul_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS], matmul_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S], matmul_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K], matmul_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K], matmul_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K], matmul_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K], matmul_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K], matmul_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S], matmul_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M], matmul_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS], matmul_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS], matmul_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S], matmul_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS], matmul_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S], matmul_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) if (device->integer_dot_product) { @@ -2618,51 +2928,77 @@ static void ggml_vk_load_shaders(vk_device& device) { } #endif - CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + if (device->subgroup_ballot && device->subgroup_require_full_support && subgroup_min_size_16) { + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16); + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16); - CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + } else { + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id, 0); - CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - - CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f16acc, matmul_id_iq1_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f16acc, matmul_id_iq1_m_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f16acc, matmul_id_mxfp4_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_iq1_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_iq1_m_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_iq2_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_iq2_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_iq2_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_iq3_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_iq3_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + } #undef CREATE_MM2 #undef CREATE_MMQ #undef CREATE_MM } else { // Create 6 variants, {s,m,l}x{unaligned,aligned} -#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ +#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \ if (device->mul_mat ## ID ## _l[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _m[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _s[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _l[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _m[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _s[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ #define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ if (device->mul_mat ## ID ## _l[TYPE]) \ @@ -2672,34 +3008,34 @@ static void ggml_vk_load_shaders(vk_device& device) { if (device->mul_mat ## ID ## _s[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC "_s", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ - CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); - CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); - CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); - CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f32acc, matmul_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f32acc, matmul_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f32acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f32acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f32acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f32acc, matmul_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f32acc, matmul_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f32acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f32acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f32acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) if (device->integer_dot_product) { @@ -2711,33 +3047,59 @@ static void ggml_vk_load_shaders(vk_device& device) { } #endif - CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + if (device->subgroup_ballot && device->subgroup_require_full_support && subgroup_min_size_16) { + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16); + CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_subgroup_f16, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16); + CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_subgroup_f16_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16); + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16); - CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_subgroup_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_subgroup_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_subgroup_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_subgroup_q5_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_subgroup_q8_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_subgroup_q2_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_subgroup_q3_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_subgroup_q4_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_subgroup_q5_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_subgroup_q6_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc, matmul_id_subgroup_iq1_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc, matmul_id_subgroup_iq1_m_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_subgroup_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_subgroup_iq2_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_subgroup_iq2_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_subgroup_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_subgroup_iq3_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_subgroup_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_subgroup_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_subgroup_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + } else { + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id, 0); - CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - - CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc, matmul_id_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc, matmul_id_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc, matmul_id_iq1_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc, matmul_id_iq1_m_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + } } // reusing CREATE_MM from the fp32 path if ((device->coopmat2 || device->coopmat_support) @@ -2754,8 +3116,8 @@ static void ggml_vk_load_shaders(vk_device& device) { m_wg_denoms = { 64, 64, 1 }; s_wg_denoms = { 32, 32, 1 }; - CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id, 0); } #undef CREATE_MM @@ -2773,54 +3135,90 @@ static void ggml_vk_load_shaders(vk_device& device) { rm_stdq = 2; uint32_t rm_iq = 2 * rm_kq; - for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) { - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32_"+std::to_string(i+1), mul_mat_vec_f32_f32_f32_len, mul_mat_vec_f32_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f32_f32_"+std::to_string(i+1), mul_mat_vec_f16_f32_f32_len, mul_mat_vec_f16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f32_f32_"+std::to_string(i+1), mul_mat_vec_bf16_f32_f32_len, mul_mat_vec_bf16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_1_f32_f32_len, mul_mat_vec_q5_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q8_0_f32_f32_len, mul_mat_vec_q8_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q2_k_f32_f32_len, mul_mat_vec_q2_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q3_k_f32_f32_len, mul_mat_vec_q3_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq1_s_f32_f32_len, mul_mat_vec_iq1_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq1_m_f32_f32_len, mul_mat_vec_iq1_m_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xxs_f32_f32_len, mul_mat_vec_iq2_xxs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xs_f32_f32_len, mul_mat_vec_iq2_xs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_s_f32_f32_len, mul_mat_vec_iq2_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_xxs_f32_f32_len, mul_mat_vec_iq3_xxs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f32_f32_len, mul_mat_vec_iq3_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f32_f32_len, mul_mat_vec_iq4_xs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f32_f32_"+std::to_string(i+1), mul_mat_vec_mxfp4_f32_f32_len, mul_mat_vec_mxfp4_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + const bool use_subgroups = device->subgroup_arithmetic && device->architecture != vk_device_architecture::AMD_GCN; + // Ensure a subgroup size >= 16 is available + const bool use_subgroups16 = use_subgroups && subgroup_min_size_16; - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32_"+std::to_string(i+1), mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32_"+std::to_string(i+1), mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f16_f32_"+std::to_string(i+1), mul_mat_vec_bf16_f16_f32_len, mul_mat_vec_bf16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_1_f16_f32_len, mul_mat_vec_q5_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q8_0_f16_f32_len, mul_mat_vec_q8_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q2_k_f16_f32_len, mul_mat_vec_q2_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q3_k_f16_f32_len, mul_mat_vec_q3_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq1_s_f16_f32_len, mul_mat_vec_iq1_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq1_m_f16_f32_len, mul_mat_vec_iq1_m_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xxs_f16_f32_len, mul_mat_vec_iq2_xxs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xs_f16_f32_len, mul_mat_vec_iq2_xs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_s_f16_f32_len, mul_mat_vec_iq2_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_xxs_f16_f32_len, mul_mat_vec_iq3_xxs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f16_f32_len, mul_mat_vec_iq3_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f16_f32_len, mul_mat_vec_iq4_xs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f16_f32_"+std::to_string(i+1), mul_mat_vec_mxfp4_f16_f32_len, mul_mat_vec_mxfp4_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + const uint32_t subgroup_size = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control && device->subgroup_min_size <= 16 && device->subgroup_max_size >= 16) ? 16 : device->subgroup_size; + const uint32_t subgroup_size16 = std::max(subgroup_size, 16u); + + const uint32_t force_subgroup_size = use_subgroups ? subgroup_size : 0; + const uint32_t force_subgroup_size16 = use_subgroups16 ? subgroup_size16 : 0; + + for (uint32_t w = 0; w < DMMV_WG_SIZE_COUNT; ++w) { + const uint32_t wg_size_subgroup = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size : (subgroup_size * 4); + const uint32_t wg_size_subgroup16 = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size16 : (subgroup_size16 * 4); + + const shader_reduction_mode reduc = (use_subgroups && w == DMMV_WG_SIZE_SUBGROUP) ? SHADER_REDUCTION_MODE_SUBGROUP : + (use_subgroups && w == DMMV_WG_SIZE_LARGE) ? SHADER_REDUCTION_MODE_HYBRID : + SHADER_REDUCTION_MODE_SHMEM; + + const shader_reduction_mode reduc16 = (use_subgroups16 && w == DMMV_WG_SIZE_SUBGROUP) ? SHADER_REDUCTION_MODE_SUBGROUP : + (use_subgroups16 && w == DMMV_WG_SIZE_LARGE) ? SHADER_REDUCTION_MODE_HYBRID : + SHADER_REDUCTION_MODE_SHMEM; + + for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) { + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32", arr_dmmv_f32_f32_f32_len[reduc], arr_dmmv_f32_f32_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f32_f32", arr_dmmv_f16_f32_f32_len[reduc], arr_dmmv_f16_f32_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f32_f32", arr_dmmv_bf16_f32_f32_len[reduc], arr_dmmv_bf16_f32_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f32_f32", arr_dmmv_q4_0_f32_f32_len[reduc], arr_dmmv_q4_0_f32_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f32_f32", arr_dmmv_q4_1_f32_f32_len[reduc], arr_dmmv_q4_1_f32_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f32_f32", arr_dmmv_q5_0_f32_f32_len[reduc], arr_dmmv_q5_0_f32_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f32_f32", arr_dmmv_q5_1_f32_f32_len[reduc], arr_dmmv_q5_1_f32_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f32_f32", arr_dmmv_q8_0_f32_f32_len[reduc], arr_dmmv_q8_0_f32_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f32_f32", arr_dmmv_q2_k_f32_f32_len[reduc16], arr_dmmv_q2_k_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f32_f32", arr_dmmv_q3_k_f32_f32_len[reduc16], arr_dmmv_q3_k_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f32_f32", arr_dmmv_q4_k_f32_f32_len[reduc16], arr_dmmv_q4_k_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f32_f32", arr_dmmv_q5_k_f32_f32_len[reduc16], arr_dmmv_q5_k_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f32_f32", arr_dmmv_q6_k_f32_f32_len[reduc16], arr_dmmv_q6_k_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f32_f32", arr_dmmv_iq1_s_f32_f32_len[reduc16], arr_dmmv_iq1_s_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f32_f32", arr_dmmv_iq1_m_f32_f32_len[reduc16], arr_dmmv_iq1_m_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f32_f32", arr_dmmv_iq2_xxs_f32_f32_len[reduc16], arr_dmmv_iq2_xxs_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f32_f32", arr_dmmv_iq2_xs_f32_f32_len[reduc16], arr_dmmv_iq2_xs_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f32_f32", arr_dmmv_iq2_s_f32_f32_len[reduc16], arr_dmmv_iq2_s_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f32_f32", arr_dmmv_iq3_xxs_f32_f32_len[reduc16], arr_dmmv_iq3_xxs_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f32_f32", arr_dmmv_iq3_s_f32_f32_len[reduc16], arr_dmmv_iq3_s_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f32_f32", arr_dmmv_iq4_xs_f32_f32_len[reduc16], arr_dmmv_iq4_xs_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32", arr_dmmv_iq4_nl_f32_f32_len[reduc16], arr_dmmv_iq4_nl_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f32_f32", arr_dmmv_mxfp4_f32_f32_len[reduc16], arr_dmmv_mxfp4_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32", arr_dmmv_f32_f16_f32_len[reduc], arr_dmmv_f32_f16_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32", arr_dmmv_f16_f16_f32_len[reduc], arr_dmmv_f16_f16_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f16_f32", arr_dmmv_bf16_f16_f32_len[reduc], arr_dmmv_bf16_f16_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f16_f32", arr_dmmv_q4_0_f16_f32_len[reduc], arr_dmmv_q4_0_f16_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f16_f32", arr_dmmv_q4_1_f16_f32_len[reduc], arr_dmmv_q4_1_f16_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f16_f32", arr_dmmv_q5_0_f16_f32_len[reduc], arr_dmmv_q5_0_f16_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f16_f32", arr_dmmv_q5_1_f16_f32_len[reduc], arr_dmmv_q5_1_f16_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f16_f32", arr_dmmv_q8_0_f16_f32_len[reduc], arr_dmmv_q8_0_f16_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f16_f32", arr_dmmv_q2_k_f16_f32_len[reduc16], arr_dmmv_q2_k_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f16_f32", arr_dmmv_q3_k_f16_f32_len[reduc16], arr_dmmv_q3_k_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f16_f32", arr_dmmv_q4_k_f16_f32_len[reduc16], arr_dmmv_q4_k_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f16_f32", arr_dmmv_q5_k_f16_f32_len[reduc16], arr_dmmv_q5_k_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f16_f32", arr_dmmv_q6_k_f16_f32_len[reduc16], arr_dmmv_q6_k_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f16_f32", arr_dmmv_iq1_s_f16_f32_len[reduc16], arr_dmmv_iq1_s_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f16_f32", arr_dmmv_iq1_m_f16_f32_len[reduc16], arr_dmmv_iq1_m_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f16_f32", arr_dmmv_iq2_xxs_f16_f32_len[reduc16], arr_dmmv_iq2_xxs_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f16_f32", arr_dmmv_iq2_xs_f16_f32_len[reduc16], arr_dmmv_iq2_xs_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f16_f32", arr_dmmv_iq2_s_f16_f32_len[reduc16], arr_dmmv_iq2_s_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f16_f32", arr_dmmv_iq3_xxs_f16_f32_len[reduc16], arr_dmmv_iq3_xxs_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f16_f32", arr_dmmv_iq3_s_f16_f32_len[reduc16], arr_dmmv_iq3_s_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f16_f32", arr_dmmv_iq4_xs_f16_f32_len[reduc16], arr_dmmv_iq4_xs_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32", arr_dmmv_iq4_nl_f16_f32_len[reduc16], arr_dmmv_iq4_nl_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f16_f32", arr_dmmv_mxfp4_f16_f32_len[reduc16], arr_dmmv_mxfp4_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + if (device->integer_dot_product) { + const uint32_t subgroup_size_int = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control) ? device->subgroup_min_size : device->subgroup_size; + const uint32_t wg_size_subgroup_int = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size_int : (subgroup_size_int * 4); + + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_q8_1_f32", arr_dmmv_q4_0_q8_1_f32_len[reduc], arr_dmmv_q4_0_q8_1_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_q8_1_f32", arr_dmmv_q4_1_q8_1_f32_len[reduc], arr_dmmv_q4_1_q8_1_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_q8_1_f32", arr_dmmv_q5_0_q8_1_f32_len[reduc], arr_dmmv_q5_0_q8_1_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_q8_1_f32", arr_dmmv_q5_1_q8_1_f32_len[reduc], arr_dmmv_q5_1_q8_1_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_q8_1_f32", arr_dmmv_q8_0_q8_1_f32_len[reduc], arr_dmmv_q8_0_q8_1_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int); + } +#endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT + } } ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); @@ -2879,6 +3277,11 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_0], "get_rows_q5_0", get_rows_q5_0_len, get_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_1], "get_rows_q5_1", get_rows_q5_1_len, get_rows_q5_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q8_0], "get_rows_q8_0", get_rows_q8_0_len, get_rows_q8_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q2_K], "get_rows_q2_k", get_rows_q2_k_len, get_rows_q2_k_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q3_K], "get_rows_q3_k", get_rows_q3_k_len, get_rows_q3_k_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_K], "get_rows_q4_k", get_rows_q4_k_len, get_rows_q4_k_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_K], "get_rows_q5_k", get_rows_q5_k_len, get_rows_q5_k_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q6_K], "get_rows_q6_k", get_rows_q6_k_len, get_rows_q6_k_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ1_S], "get_rows_iq1_s", get_rows_iq1_s_len, get_rows_iq1_s_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ1_M], "get_rows_iq1_m", get_rows_iq1_m_len, get_rows_iq1_m_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_XXS], "get_rows_iq2_xxs", get_rows_iq2_xxs_len, get_rows_iq2_xxs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); @@ -2898,6 +3301,11 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_0], "get_rows_q5_0_f32", get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_1], "get_rows_q5_1_f32", get_rows_q5_1_f32_len, get_rows_q5_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q8_0], "get_rows_q8_0_f32", get_rows_q8_0_f32_len, get_rows_q8_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q2_K], "get_rows_q2_k_f32", get_rows_q2_k_f32_len, get_rows_q2_k_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q3_K], "get_rows_q3_k_f32", get_rows_q3_k_f32_len, get_rows_q3_k_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_K], "get_rows_q4_k_f32", get_rows_q4_k_f32_len, get_rows_q4_k_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_K], "get_rows_q5_k_f32", get_rows_q5_k_f32_len, get_rows_q5_k_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q6_K], "get_rows_q6_k_f32", get_rows_q6_k_f32_len, get_rows_q6_k_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ1_S], "get_rows_iq1_s_f32", get_rows_iq1_s_f32_len, get_rows_iq1_s_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ1_M], "get_rows_iq1_m_f32", get_rows_iq1_m_f32_len, get_rows_iq1_m_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_XXS], "get_rows_iq2_xxs_f32", get_rows_iq2_xxs_f32_len, get_rows_iq2_xxs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); @@ -2911,21 +3319,32 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, 5 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1); + + if (device->subgroup_clustered && device->subgroup_require_full_support) { + ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_subgroup_len, quantize_q8_1_subgroup_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1, true, true); + ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_subgroup_len, quantize_q8_1_x4_subgroup_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1, true, true); + } else { + ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_len, quantize_q8_1_x4_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1); + } for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) { - if (device->subgroup_add && device->subgroup_require_full_support) { - ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_subgroup_add_len, mul_mat_vec_p021_f16_f32_subgroup_add_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true, true); + if (device->subgroup_arithmetic && device->subgroup_require_full_support) { + ggml_vk_create_pipeline2(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_subgroup_add_len, mul_mat_vec_p021_f16_f32_subgroup_add_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true, true); } else { - ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true); + ggml_vk_create_pipeline2(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true); } } ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 12 * sizeof(uint32_t), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_f32, "rms_norm_mul_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_f32, "rms_norm_mul_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_rms_norm_partials_f32, "rms_norm_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_partials_f32, "rms_norm_mul_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); @@ -2934,12 +3353,16 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f32, "cpy_f16_f32", cpy_f16_f32_len, cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_bf16,"cpy_f32_bf16",cpy_f32_bf16_len,cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_i32_f32, "cpy_i32_f32", cpy_i32_f32_len, cpy_i32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_i32, "cpy_f32_i32", cpy_f32_i32_len, cpy_f32_i32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f32, "contig_cpy_f32_f32", contig_cpy_f32_f32_len, contig_cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f32, "contig_cpy_f16_f32", contig_cpy_f16_f32_len, contig_cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_bf16,"contig_cpy_f32_bf16",contig_cpy_f32_bf16_len,contig_cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_i32_f32, "contig_cpy_i32_f32", contig_cpy_i32_f32_len, contig_cpy_i32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_i32, "contig_cpy_f32_i32", contig_cpy_f32_i32_len, contig_cpy_f32_i32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); if (device->float_controls_rte_fp16) { ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); @@ -2957,27 +3380,26 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); } +#define SET_ROWS(itype, rte) \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_F32], "set_rows_f32" #itype, set_rows_f32 ## itype ## rte ## _len, set_rows_f32 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_F16], "set_rows_f16" #itype, set_rows_f16 ## itype ## rte ## _len, set_rows_f16 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_BF16], "set_rows_bf16" #itype, set_rows_bf16 ## itype ## rte ## _len, set_rows_bf16 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q4_0], "set_rows_q4_0" #itype, set_rows_q4_0 ## itype ## rte ## _len, set_rows_q4_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q4_1], "set_rows_q4_1" #itype, set_rows_q4_1 ## itype ## rte ## _len, set_rows_q4_1 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_0], "set_rows_q5_0" #itype, set_rows_q5_0 ## itype ## rte ## _len, set_rows_q5_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_1], "set_rows_q5_1" #itype, set_rows_q5_1 ## itype ## rte ## _len, set_rows_q5_1 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q8_0], "set_rows_q8_0" #itype, set_rows_q8_0 ## itype ## rte ## _len, set_rows_q8_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_IQ4_NL], "set_rows_iq4_nl" #itype, set_rows_iq4_nl ## itype ## rte ## _len, set_rows_iq4_nl ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); + if (device->float_controls_rte_fp16) { - ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F32], "set_rows_f32", set_rows_f32_rte_len, set_rows_f32_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F16], "set_rows_f16", set_rows_f16_rte_len, set_rows_f16_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_BF16], "set_rows_bf16", set_rows_bf16_rte_len, set_rows_bf16_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_0], "set_rows_q4_0", set_rows_q4_0_rte_len, set_rows_q4_0_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_1], "set_rows_q4_1", set_rows_q4_1_rte_len, set_rows_q4_1_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_0], "set_rows_q5_0", set_rows_q5_0_rte_len, set_rows_q5_0_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_1], "set_rows_q5_1", set_rows_q5_1_rte_len, set_rows_q5_1_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q8_0], "set_rows_q8_0", set_rows_q8_0_rte_len, set_rows_q8_0_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_IQ4_NL], "set_rows_iq4_nl", set_rows_iq4_nl_rte_len, set_rows_iq4_nl_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); + SET_ROWS(_i32, _rte) + SET_ROWS(_i64, _rte) } else { - ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F32], "set_rows_f32", set_rows_f32_len, set_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F16], "set_rows_f16", set_rows_f16_len, set_rows_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_BF16], "set_rows_bf16", set_rows_bf16_len, set_rows_bf16_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_0], "set_rows_q4_0", set_rows_q4_0_len, set_rows_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_1], "set_rows_q4_1", set_rows_q4_1_len, set_rows_q4_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_0], "set_rows_q5_0", set_rows_q5_0_len, set_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_1], "set_rows_q5_1", set_rows_q5_1_len, set_rows_q5_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q8_0], "set_rows_q8_0", set_rows_q8_0_len, set_rows_q8_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_IQ4_NL], "set_rows_iq4_nl", set_rows_iq4_nl_len, set_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); + SET_ROWS(_i32, ) + SET_ROWS(_i64, ) } +#undef SET_ROWS + ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_0], "cpy_q4_0_f32", cpy_q4_0_f32_len, cpy_q4_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_1], "cpy_q4_1_f32", cpy_q4_1_f32_len, cpy_q4_1_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1); @@ -2995,22 +3417,31 @@ static void ggml_vk_load_shaders(vk_device& device) { }; bool rte = device->float_controls_rte_fp16; -#define CREATE_BINARY(name, namemod, spec) \ +#define CREATE_BINARY(name, namemod, spec, bindings) \ for (int s0 : {0,1}) for (int s1 : {0,1}) for (int d : {0,1}) \ - ggml_vk_create_pipeline(device, device->pipeline_ ## name ## namemod[s0][s1][d], \ + ggml_vk_create_pipeline2(device, device->pipeline_ ## name ## namemod[s0][s1][d], \ #name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d][rte], name ## _data[s0][s1][d][rte], \ - "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1); + "main", (bindings), sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1); - CREATE_BINARY(add, , {0}) - CREATE_BINARY(add, _norepeat, {1}) - CREATE_BINARY(sub, , {0}) - CREATE_BINARY(sub, _norepeat, {1}) - CREATE_BINARY(mul, , {0}) - CREATE_BINARY(mul, _norepeat, {1}) - CREATE_BINARY(div, , {0}) - CREATE_BINARY(div, _norepeat, {1}) + CREATE_BINARY(add, , {0}, 4) + CREATE_BINARY(add, _norepeat, {1}, 4) + CREATE_BINARY(sub, , {0}, 3) + CREATE_BINARY(sub, _norepeat, {1}, 3) + CREATE_BINARY(mul, , {0}, 3) + CREATE_BINARY(mul, _norepeat, {1}, 3) + CREATE_BINARY(div, , {0}, 3) + CREATE_BINARY(div, _norepeat, {1}, 3) + CREATE_BINARY(add_rms, , {0}, 4) + CREATE_BINARY(add_rms, _norepeat, {1}, 4) #undef CREATE_BINARY + if (device->multi_add) { + for (uint32_t i = 0; i < MAX_FUSED_ADDS; ++i) { + ggml_vk_create_pipeline2(device, device->pipeline_multi_add[i], "multi_add_f32_" + std::to_string(i+1), multi_add_f32_len, multi_add_f32_data, "main", MAX_PARAMETER_COUNT, sizeof(vk_op_multi_add_push_constants), {512, 1, 1}, {i+2}, 1); + ggml_vk_create_pipeline2(device, device->pipeline_multi_add_rms[i], "multi_add_rms_f32_" + std::to_string(i+1), multi_add_rms_f32_len, multi_add_rms_f32_data, "main", MAX_PARAMETER_COUNT, sizeof(vk_op_multi_add_push_constants), {512, 1, 1}, {i+2}, 1); + } + } + ggml_vk_create_pipeline(device, device->pipeline_add_id_f32, "add_id_f32", add_id_f32_len, add_id_f32_data, "main", 4, sizeof(vk_op_add_id_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); @@ -3026,12 +3457,13 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_sqr_f32, "sqr_f32", sqr_f32_len, sqr_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_sqrt_f32, "sqrt_f32", sqrt_f32_len, sqrt_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_sin_f32, "sin_f32", sin_f32_len, sin_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cos_f32, "cos_f32", cos_f32_len, cos_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_pad_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_roll_f32, "roll_f32", roll_f32_len, roll_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); @@ -3049,8 +3481,21 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_UNARY(relu) CREATE_UNARY(tanh) CREATE_UNARY(sigmoid) + CREATE_UNARY(hardsigmoid) + CREATE_UNARY(hardswish) #undef CREATE_UNARY +#define CREATE_UNARY_RTE(name) \ + if (device->float_controls_rte_fp16) { \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16_rte", name ## _f16_rte_len, name ## _f16_rte_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \ + } else { \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \ + } + CREATE_UNARY_RTE(exp) +#undef CREATE_UNARY_RTE + #define CREATE_GLU(name) \ if (device->float_controls_rte_fp16) { \ ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \ @@ -3077,7 +3522,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1); ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1); - ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1, true); ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); @@ -3096,19 +3541,30 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_len, rope_vision_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); } - ggml_vk_create_pipeline(device, device->pipeline_argsort_f32, "argsort_f32", argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1024, 1, 1}, {}, 1); + for (uint32_t i = 0; i < num_argsort_pipelines; ++i) { + ggml_vk_create_pipeline2(device, device->pipeline_argsort_f32[i], "argsort_f32_"+std::to_string(i), argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1u<pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); - if (device->float_controls_rte_fp16) { - ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); +#define IM2COL(bda) \ + ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32 ## bda ## _len, im2col_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32 ## bda ## _len, im2col_3d_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \ + if (device->float_controls_rte_fp16) { \ + ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte ## bda ## _len, im2col_f32_f16_rte ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_rte ## bda ## _len, im2col_3d_f32_f16_rte ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \ + } else { \ + ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16 ## bda ## _len, im2col_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16 ## bda ## _len, im2col_3d_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \ + } + if (device->shader_int64 && device->buffer_device_address) { + IM2COL(_bda) } else { - ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); + IM2COL() } ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1); @@ -3123,7 +3579,9 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); - // conv2d + ggml_vk_create_pipeline(device, device->pipeline_opt_step_sgd_f32, "opt_step_sgd_f32", opt_step_sgd_f32_len, opt_step_sgd_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + + // conv2d, conv_transpose_2d for (uint32_t s = 0; s < CONV_SHAPE_COUNT; ++s) { uint32_t conv2d_WG_SIZE = 256; uint32_t conv2d_BS_K = 128; @@ -3198,35 +3656,36 @@ static void ggml_vk_load_shaders(vk_device& device) { std::array wg_denoms = { conv2d_BS_K, conv2d_BS_NPQ, 1 }; std::vector spec_constants = { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives, conv2d_SHMEM_PAD }; +#define CREATE_CONV(name, type_suffix, spv_suffix) \ + ggml_vk_create_pipeline( \ + device, device->pipeline_##name##type_suffix[s], #name #type_suffix, \ + name##type_suffix##spv_suffix##_len, name##type_suffix##spv_suffix##_data, "main", 3, \ + sizeof(vk_op_##name##_push_constants), wg_denoms, spec_constants, 1, true, use_collectives); +#define CREATE_CONVS(spv_suffix) \ + CREATE_CONV(conv2d, _f32, spv_suffix) \ + CREATE_CONV(conv2d, _f16_f32, spv_suffix) \ + if (device->properties.limits.maxPushConstantsSize >= sizeof(vk_op_conv_transpose_2d_push_constants)) { \ + CREATE_CONV(conv_transpose_2d, _f32, spv_suffix) \ + CREATE_CONV(conv_transpose_2d, _f16_f32, spv_suffix) \ + } #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) if (device->coopmat2) { - ggml_vk_create_pipeline( - device, device->pipeline_conv2d_f32[s], "conv2d_f32", conv2d_f32_cm2_len, conv2d_f32_cm2_data, "main", 3, - sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives); - ggml_vk_create_pipeline( - device, device->pipeline_conv2d_f16_f32[s], "conv2d_f16_f32", conv2d_f16_f32_cm2_len, conv2d_f16_f32_cm2_data, "main", 3, - sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives); + CREATE_CONVS(_cm2) } else #endif if (conv2d_UNROLL) { - ggml_vk_create_pipeline( - device, device->pipeline_conv2d_f32[s], "conv2d_f32", conv2d_f32_unroll_len, conv2d_f32_unroll_data, "main", 3, - sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives); - ggml_vk_create_pipeline( - device, device->pipeline_conv2d_f16_f32[s], "conv2d_f16_f32", conv2d_f16_f32_unroll_len, conv2d_f16_f32_unroll_data, "main", 3, - sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives); + CREATE_CONVS(_unroll) } else { - ggml_vk_create_pipeline( - device, device->pipeline_conv2d_f32[s], "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3, - sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives); - ggml_vk_create_pipeline( - device, device->pipeline_conv2d_f16_f32[s], "conv2d_f16_f32", conv2d_f16_f32_len, conv2d_f16_f32_data, "main", 3, - sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives); + CREATE_CONVS( ) } +#undef CREATE_CONV +#undef CREATE_CONVS } ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f32, "conv2d_dw_cwhn_f32", conv2d_dw_cwhn_f32_len, conv2d_dw_cwhn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f16_f32, "conv2d_dw_whcn_f16_f32", conv2d_dw_whcn_f16_f32_len, conv2d_dw_whcn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); for (auto &c : compiles) { c.wait(); @@ -3271,6 +3730,12 @@ static vk_device ggml_vk_get_device(size_t idx) { const char* GGML_VK_DISABLE_HOST_VISIBLE_VIDMEM = getenv("GGML_VK_DISABLE_HOST_VISIBLE_VIDMEM"); device->disable_host_visible_vidmem = GGML_VK_DISABLE_HOST_VISIBLE_VIDMEM != nullptr; + const char* GGML_VK_ALLOW_SYSMEM_FALLBACK = getenv("GGML_VK_ALLOW_SYSMEM_FALLBACK"); + device->allow_sysmem_fallback = GGML_VK_ALLOW_SYSMEM_FALLBACK != nullptr; + + const char* GGML_VK_DISABLE_GRAPH_OPTIMIZE = getenv("GGML_VK_DISABLE_GRAPH_OPTIMIZE"); + device->disable_graph_optimize = GGML_VK_DISABLE_GRAPH_OPTIMIZE != nullptr; + bool fp16_storage = false; bool fp16_compute = false; bool maintenance4_support = false; @@ -3278,6 +3743,7 @@ static vk_device ggml_vk_get_device(size_t idx) { bool amd_shader_core_properties2 = false; bool pipeline_robustness = false; bool coopmat2_support = false; + bool pipeline_executable_properties_support = false; device->coopmat_support = false; device->integer_dot_product = false; bool bfloat16_support = false; @@ -3320,6 +3786,8 @@ static vk_device ggml_vk_get_device(size_t idx) { !getenv("GGML_VK_DISABLE_BFLOAT16")) { bfloat16_support = true; #endif + } else if (strcmp("VK_KHR_pipeline_executable_properties", properties.extensionName) == 0) { + pipeline_executable_properties_support = true; } } @@ -3409,11 +3877,21 @@ static vk_device ggml_vk_get_device(size_t idx) { } device->float_controls_rte_fp16 = vk12_props.shaderRoundingModeRTEFloat16; - device->subgroup_add = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && - (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eArithmetic); - + device->subgroup_arithmetic = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && + (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eArithmetic); +#ifdef __APPLE__ + // Workaround for subgroup arithmetic failing on MoltenVK with AMD GPUs (issue 15846) + if (device->vendor_id == VK_VENDOR_ID_AMD) { + device->subgroup_arithmetic = false; + } +#endif device->subgroup_shuffle = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eShuffle); + device->subgroup_clustered = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && + (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eClustered); + + device->subgroup_ballot = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && + (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eBallot); const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr; @@ -3536,8 +4014,18 @@ static vk_device ggml_vk_get_device(size_t idx) { device_extensions.push_back("VK_KHR_shader_integer_dot_product"); } + VkPhysicalDevicePipelineExecutablePropertiesFeaturesKHR pep_features {}; + pep_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PIPELINE_EXECUTABLE_PROPERTIES_FEATURES_KHR; + if (pipeline_executable_properties_support) { + last_struct->pNext = (VkBaseOutStructure *)&pep_features; + last_struct = (VkBaseOutStructure *)&pep_features; + device_extensions.push_back("VK_KHR_pipeline_executable_properties"); + } + vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2); + device->pipeline_executable_properties_support = pipeline_executable_properties_support; + device->fp16 = device->fp16 && vk12_features.shaderFloat16; #if defined(VK_KHR_shader_bfloat16) @@ -3548,6 +4036,15 @@ static vk_device ggml_vk_get_device(size_t idx) { device->pipeline_robustness = pl_robustness_features.pipelineRobustness; + device->multi_add = vk12_props.shaderRoundingModeRTEFloat16 && + device->properties.limits.maxPushConstantsSize >= sizeof(vk_op_multi_add_push_constants) && + vk12_features.runtimeDescriptorArray && + device->vendor_id != VK_VENDOR_ID_INTEL && + getenv("GGML_VK_DISABLE_MULTI_ADD") == nullptr; + + device->shader_int64 = device_features2.features.shaderInt64; + device->buffer_device_address = vk12_features.bufferDeviceAddress; + if (device->subgroup_size_control) { device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize; device->subgroup_max_size = subgroup_size_control_props.maxSubgroupSize; @@ -3558,9 +4055,7 @@ static vk_device ggml_vk_get_device(size_t idx) { (subgroup_size_control_props.requiredSubgroupSizeStages & vk::ShaderStageFlagBits::eCompute) && subgroup_size_control_features.subgroupSizeControl; - if (device->subgroup_size_control) { - device->subgroup_require_full_support = subgroup_size_control_features.computeFullSubgroups; - } + device->subgroup_require_full_support = subgroup_size_control_features.computeFullSubgroups; #if defined(VK_KHR_cooperative_matrix) device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix; @@ -3861,6 +4356,19 @@ static vk_device ggml_vk_get_device(size_t idx) { device->disable_fusion = getenv("GGML_VK_DISABLE_FUSION") != nullptr; + device->add_rms_fusion = !device->disable_fusion && + device->subgroup_arithmetic && + device->vendor_id != VK_VENDOR_ID_INTEL; + device->partials_binding_alignment = + std::max(4u, (uint32_t)device->properties.limits.minStorageBufferOffsetAlignment); + + device->mmvq_mode = 0; + if (getenv("GGML_VK_DISABLE_MMVQ")) { + device->mmvq_mode = -1; + } else if (getenv("GGML_VK_FORCE_MMVQ")) { + device->mmvq_mode = 1; + } + return device; } @@ -4025,10 +4533,16 @@ static void ggml_vk_print_gpu_info(size_t idx) { } } -static bool ggml_vk_instance_validation_ext_available(const std::vector& instance_extensions); +static bool ggml_vk_instance_validation_ext_available(); static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector& instance_extensions); - static bool ggml_vk_instance_debug_utils_ext_available(const std::vector & instance_extensions); +static bool ggml_vk_device_is_supported(const vk::PhysicalDevice & vkdev); + +static vk::detail::DispatchLoaderDynamic ggml_vk_default_dispatcher_instance; + +vk::detail::DispatchLoaderDynamic & ggml_vk_default_dispatcher() { + return ggml_vk_default_dispatcher_instance; +} static void ggml_vk_instance_init() { if (vk_instance_initialized) { @@ -4036,17 +4550,20 @@ static void ggml_vk_instance_init() { } VK_LOG_DEBUG("ggml_vk_instance_init()"); + // See https://github.com/KhronosGroup/Vulkan-Hpp?tab=readme-ov-file#extensions--per-device-function-pointers- + ggml_vk_default_dispatcher_instance.init(vkGetInstanceProcAddr); + uint32_t api_version = vk::enumerateInstanceVersion(); if (api_version < VK_API_VERSION_1_2) { std::cerr << "ggml_vulkan: Error: Vulkan 1.2 required." << std::endl; - GGML_ABORT("fatal error"); + throw vk::SystemError(vk::Result::eErrorFeatureNotPresent, "Vulkan 1.2 required"); } vk::ApplicationInfo app_info{ "ggml-vulkan", 1, nullptr, 0, api_version }; const std::vector instance_extensions = vk::enumerateInstanceExtensionProperties(); - const bool validation_ext = ggml_vk_instance_validation_ext_available(instance_extensions); + const bool validation_ext = ggml_vk_instance_validation_ext_available(); #ifdef __APPLE__ const bool portability_enumeration_ext = ggml_vk_instance_portability_enumeration_ext_available(instance_extensions); #endif @@ -4099,15 +4616,19 @@ static void ggml_vk_instance_init() { vk_instance.pfn_vkCmdBeginDebugUtilsLabelEXT = (PFN_vkCmdBeginDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkCmdBeginDebugUtilsLabelEXT"); vk_instance.pfn_vkCmdEndDebugUtilsLabelEXT = (PFN_vkCmdEndDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkCmdEndDebugUtilsLabelEXT"); vk_instance.pfn_vkCmdInsertDebugUtilsLabelEXT = (PFN_vkCmdInsertDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkCmdInsertDebugUtilsLabelEXT"); - } vk_perf_logger_enabled = getenv("GGML_VK_PERF_LOGGER") != nullptr; + // See https://github.com/KhronosGroup/Vulkan-Hpp?tab=readme-ov-file#extensions--per-device-function-pointers- + VULKAN_HPP_DEFAULT_DISPATCHER.init(vk_instance.instance); + + std::vector devices = vk_instance.instance.enumeratePhysicalDevices(); + // Emulate behavior of CUDA_VISIBLE_DEVICES for Vulkan char * devices_env = getenv("GGML_VK_VISIBLE_DEVICES"); if (devices_env != nullptr) { - size_t num_available_devices = vk_instance.instance.enumeratePhysicalDevices().size(); + size_t num_available_devices = devices.size(); std::string devices(devices_env); std::replace(devices.begin(), devices.end(), ',', ' '); @@ -4122,7 +4643,6 @@ static void ggml_vk_instance_init() { vk_instance.device_indices.push_back(tmp); } } else { - std::vector devices = vk_instance.instance.enumeratePhysicalDevices(); // If no vulkan devices are found, return early if (devices.empty()) { GGML_LOG_INFO("ggml_vulkan: No devices found.\n"); @@ -4138,7 +4658,7 @@ static void ggml_vk_instance_init() { new_driver.pNext = &new_id; devices[i].getProperties2(&new_props); - if (new_props.properties.deviceType == vk::PhysicalDeviceType::eDiscreteGpu) { + if ((new_props.properties.deviceType == vk::PhysicalDeviceType::eDiscreteGpu || new_props.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu) && ggml_vk_device_is_supported(devices[i])) { // Check if there are two physical devices corresponding to the same GPU auto old_device = std::find_if( vk_instance.device_indices.begin(), @@ -4208,7 +4728,7 @@ static void ggml_vk_instance_init() { } } - // If no dedicated GPUs found, fall back to the first non-CPU device. + // If no GPUs found, fall back to the first non-CPU device. // If only CPU devices are available, return without devices. if (vk_instance.device_indices.empty()) { for (size_t i = 0; i < devices.size(); i++) { @@ -4227,6 +4747,19 @@ static void ggml_vk_instance_init() { GGML_LOG_DEBUG("ggml_vulkan: Found %zu Vulkan devices:\n", vk_instance.device_indices.size()); for (size_t i = 0; i < vk_instance.device_indices.size(); i++) { + vk::PhysicalDevice vkdev = devices[vk_instance.device_indices[i]]; + std::vector extensionprops = vkdev.enumerateDeviceExtensionProperties(); + + bool membudget_supported = false; + for (const auto & ext : extensionprops) { + if (strcmp(VK_EXT_MEMORY_BUDGET_EXTENSION_NAME, ext.extensionName) == 0) { + membudget_supported = true; + break; + } + } + + vk_instance.device_supports_membudget.push_back(membudget_supported); + ggml_vk_print_gpu_info(i); } } @@ -4371,11 +4904,24 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte return (ctx->device->fp16 && prec == GGML_PREC_DEFAULT) ? ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc; } -static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type, uint32_t num_cols) { +static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type, uint32_t num_cols, uint32_t m, uint32_t k) { VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec()"); - GGML_ASSERT(b_type == GGML_TYPE_F32 || b_type == GGML_TYPE_F16); + GGML_ASSERT(b_type == GGML_TYPE_F32 || b_type == GGML_TYPE_F16 || b_type == GGML_TYPE_Q8_1); GGML_ASSERT(num_cols >= 1 && num_cols <= mul_mat_vec_max_cols); + if (b_type == GGML_TYPE_Q8_1) { + switch (a_type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + break; + default: + return nullptr; + } + } + switch (a_type) { case GGML_TYPE_F32: case GGML_TYPE_F16: @@ -4405,7 +4951,31 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * return nullptr; } - return b_type == GGML_TYPE_F32 ? ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[a_type][num_cols-1] : ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[a_type][num_cols-1]; + // heuristic to choose workgroup size + uint32_t dmmv_wg = DMMV_WG_SIZE_SUBGROUP; + if ((ctx->device->vendor_id == VK_VENDOR_ID_NVIDIA && ctx->device->architecture != vk_device_architecture::NVIDIA_PRE_TURING) || ctx->device->vendor_id == VK_VENDOR_ID_INTEL) { + // Prefer larger workgroups when M is small, to spread the work out more + // and keep more SMs busy. + // q6_k seems to prefer small workgroup size even for "medium" values of M. + if (a_type == GGML_TYPE_Q6_K) { + if (m < 4096 && k >= 1024) { + dmmv_wg = DMMV_WG_SIZE_LARGE; + } + } else { + if (m <= 8192 && k >= 1024) { + dmmv_wg = DMMV_WG_SIZE_LARGE; + } + } + } + + if (b_type == GGML_TYPE_Q8_1) { + if (ctx->device->vendor_id == VK_VENDOR_ID_INTEL) { + dmmv_wg = DMMV_WG_SIZE_SUBGROUP; + } + return ctx->device->pipeline_dequant_mul_mat_vec_q8_1_f32[dmmv_wg][a_type][num_cols-1]; + } + + return b_type == GGML_TYPE_F32 ? ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[dmmv_wg][a_type][num_cols-1] : ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[dmmv_wg][a_type][num_cols-1]; } static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type, ggml_prec prec) { @@ -4460,11 +5030,21 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co return nullptr; } - return ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f32acc; + // XXX TODO 'prec' is not actually allowed in mul_mat_id. + bool prefer_fp16acc = ctx->device->fp16 /*&& prec == GGML_PREC_DEFAULT*/; + bool support_fp16acc = ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f16acc != nullptr; + bool support_fp32acc = ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f32acc != nullptr; + + if (support_fp16acc && (prefer_fp16acc || !support_fp32acc)) { + return ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f16acc; + } else { + GGML_ASSERT(support_fp32acc); + return ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f32acc; + } } static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type) { - VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec()"); + VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec_id()"); GGML_ASSERT(b_type == GGML_TYPE_F32); switch (a_type) { @@ -4567,8 +5147,8 @@ static vk_buffer ggml_vk_create_buffer_temp(ggml_backend_vk_context * ctx, size_ static void * ggml_vk_host_malloc(vk_device& device, size_t size) { VK_LOG_MEMORY("ggml_vk_host_malloc(" << size << ")"); vk_buffer buf = ggml_vk_create_buffer(device, size, - vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached, - vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent); + {vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached, + vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent}); if(!(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible)) { fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory\n", @@ -4732,6 +5312,14 @@ static void deferred_memcpy(void * dst, const void * src, size_t size, std::vect } } +static void deferred_memset(void * dst, uint32_t val, size_t size, std::vector* memsets = nullptr) { + if (memsets == nullptr) { + memset(dst, val, size); + } else { + memsets->emplace_back(dst, val, size); + } +} + static void ggml_vk_ensure_sync_staging_buffer(vk_device& device, size_t size) { if (device->sync_staging == nullptr || device->sync_staging->size < size) { VK_LOG_MEMORY("ggml_vk_ensure_sync_staging_buffer(" << size << ")"); @@ -4799,7 +5387,7 @@ static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_cont } } - ggml_vk_sync_buffers(subctx); + ggml_vk_sync_buffers(ctx, subctx); subctx->s->buffer.copyBuffer(buf->buffer, dst->buffer, slices); return; } @@ -4814,7 +5402,7 @@ static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_cont ggml_vk_ensure_sync_staging_buffer(ctx->device, copy_size); VkBufferCopy buf_copy{ 0, offset, copy_size }; - ggml_vk_sync_buffers(subctx); + ggml_vk_sync_buffers(ctx, subctx); vkCmdCopyBuffer(subctx->s->buffer, (VkBuffer)staging->buffer, (VkBuffer)dst->buffer, 1, &buf_copy); for (uint64_t i3 = 0; i3 < ne3; i3++) { @@ -4868,7 +5456,7 @@ static void ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz } } - ggml_vk_sync_buffers(subctx); + ggml_vk_sync_buffers(nullptr, subctx); subctx->s->buffer.copyBuffer(buf->buffer, dst->buffer, slices); return; } @@ -4889,7 +5477,7 @@ static void ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz offset, copy_size}; - ggml_vk_sync_buffers(subctx); + ggml_vk_sync_buffers(nullptr, subctx); vkCmdCopyBuffer(subctx->s->buffer, (VkBuffer)staging_buffer->buffer, (VkBuffer)dst->buffer, 1, &buf_copy); if (width == spitch) { @@ -4927,6 +5515,10 @@ static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void * memcpy(cpy.dst, cpy.src, cpy.n); } + for (auto& mset : subctx->memsets) { + memset(mset.dst, mset.val, mset.n); + } + ggml_vk_submit(subctx, dst->device->fence); VK_CHECK(dst->device->device.waitForFences({ dst->device->fence }, true, UINT64_MAX), "vk_buffer_write_2d waitForFences"); dst->device->device.resetFences({ dst->device->fence }); @@ -4969,7 +5561,7 @@ static void ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size if (buf != nullptr) { // Memory is pinned, use as staging buffer - ggml_vk_sync_buffers(subctx); + ggml_vk_sync_buffers(nullptr, subctx); subctx->s->buffer.copyBuffer(src->buffer, buf->buffer, slices); return; @@ -4986,7 +5578,7 @@ static void ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size vk_buffer& staging_buffer = src->device->sync_staging; - ggml_vk_sync_buffers(subctx); + ggml_vk_sync_buffers(nullptr, subctx); subctx->s->buffer.copyBuffer(src->buffer, staging_buffer->buffer, slices); deferred_memcpy(dst, staging_buffer->ptr, copy_size, &subctx->out_memcpys); @@ -5066,12 +5658,25 @@ static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& sr static void ggml_vk_buffer_memset_async(vk_context& ctx, vk_buffer& dst, size_t offset, uint32_t c, size_t size) { VK_LOG_DEBUG("ggml_vk_buffer_memset_async(" << offset << ", " << c << ", " << size << ")"); + if (dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible && + dst->device->uma) { + deferred_memset((uint8_t*)dst->ptr + offset, c, size, &ctx->memsets); + return; + } + + // Fall back to GPU fillBuffer for non-UMA or non-host-visible buffers ctx->s->buffer.fillBuffer(dst->buffer, offset, size, c); } static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, size_t size) { VK_LOG_DEBUG("ggml_vk_buffer_memset(" << offset << ", " << c << ", " << size << ")"); + if (dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible && + dst->device->uma) { + memset((uint8_t*)dst->ptr + offset, c, size); + return; + } + std::lock_guard guard(dst->device->mutex); vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool); ggml_vk_ctx_begin(dst->device, subctx); @@ -5084,8 +5689,12 @@ static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, siz ggml_vk_queue_command_pools_cleanup(dst->device); } -static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, uint32_t m, uint32_t n, uint32_t k, const vk_pipeline& pipeline) { - VK_LOG_DEBUG("ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ")"); +static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, uint32_t m, uint32_t n, uint32_t k, bool disable_split_k, const vk_pipeline& pipeline) { + VK_LOG_DEBUG("ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ", " << disable_split_k << ")"); + + if (disable_split_k) { + return 1; + } uint32_t split_k = 1; if (ctx->device->shader_core_count != 0 && m >= pipeline->wg_denoms[0] && n >= pipeline->wg_denoms[1]) { @@ -5176,13 +5785,16 @@ static void ggml_vk_matmul( uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3, uint32_t padded_n) { VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", padded_n: " << padded_n << ")"); - ggml_vk_sync_buffers(subctx); if (split_k == 1) { const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n }; ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, pc, { m, n, batch }); return; } + if (ctx->prealloc_split_k_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + GGML_ASSERT(batch_stride_d == m * n); // Round the split size up to a multiple of 256 (k-quant alignment) @@ -5192,9 +5804,10 @@ static void ggml_vk_matmul( const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k_split, ne02, ne12, broadcast2, broadcast3, padded_n }; // Make sure enough workgroups get assigned for split k to work ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch }); - ggml_vk_sync_buffers(subctx); + ggml_vk_sync_buffers(ctx, subctx); const std::array pc2 = { (uint32_t)(m * n * batch), split_k }; ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2, { m * n * batch, 1, 1 }); + ctx->prealloc_split_k_need_sync = true; } static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type) { @@ -5239,7 +5852,6 @@ static void ggml_vk_matmul_id( "m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", " << "batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", " << "n_as: " << n_as << ", nei0: " << nei0 << ", nei1: " << nei1 << ", nbi1: " << nbi1 << ", ne11: " << ne11 << ")"); - ggml_vk_sync_buffers(subctx); const vk_mat_mat_id_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, nei0, nei1, nbi1, ne11, padded_n }; ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids }, pc, { m, nei1, n_as }); @@ -5292,6 +5904,20 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_cpy_f32_bf16; } } + if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_I32) { + if (contig) { + return ctx->device->pipeline_contig_cpy_f32_i32; + } else { + return ctx->device->pipeline_cpy_f32_i32; + } + } + if (src->type == GGML_TYPE_I32 && to == GGML_TYPE_F32) { + if (contig) { + return ctx->device->pipeline_contig_cpy_i32_f32; + } else { + return ctx->device->pipeline_cpy_i32_f32; + } + } if (src->type == GGML_TYPE_F32) { switch (to) { case GGML_TYPE_Q4_0: @@ -5370,30 +5996,30 @@ static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context& 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, }; init_pushconst_fastdiv(pc); - ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, pc, elements); + ggml_vk_sync_buffers(ctx, subctx); } -static vk_pipeline ggml_vk_get_quantize_pipeline(ggml_backend_vk_context * ctx, ggml_type type) { +static vk_pipeline ggml_vk_get_quantize_pipeline(ggml_backend_vk_context * ctx, ggml_type type, bool use_x4_blocks) { switch(type) { case GGML_TYPE_Q8_1: - return ctx->device->pipeline_quantize_q8_1; + return use_x4_blocks ? ctx->device->pipeline_quantize_q8_1_x4 : ctx->device->pipeline_quantize_q8_1; default: std::cerr << "Missing quantize pipeline for type: " << ggml_type_name(type) << std::endl; GGML_ABORT("fatal error"); } } -static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& subctx, vk_subbuffer&& in, vk_subbuffer&& out, uint32_t ne) { +static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& subctx, vk_subbuffer&& in, vk_subbuffer&& out, uint32_t ne, bool use_x4_blocks = false) { VK_LOG_DEBUG("ggml_vk_quantize_q8_1(" << "buffer in size=" << in.buffer->size << ", buffer out size=" << out.buffer->size << ", " << ne << ")"); - vk_pipeline pipeline = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1); + vk_pipeline pipeline = use_x4_blocks ? ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1, true) : ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1, false); - ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, std::array{ne}, { ne, 1, 1 }); + ggml_vk_sync_buffers(ctx, subctx); } -static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { +static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool disable_split_k, bool dryrun = false) { VK_LOG_DEBUG("ggml_vk_mul_mat_q_f16((" << src0 << ", name=" << src0->name << ", type=" << ggml_type_name(src0->type) << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << ggml_type_name(src1->type) << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << ggml_type_name(dst->type) << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; @@ -5411,8 +6037,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub const uint64_t ne12 = src1->ne[2]; const uint64_t ne13 = src1->ne[3]; - const uint64_t ne20 = dst->ne[0]; const uint64_t ne21 = dst->ne[1]; + const uint32_t stride_d = dst->nb[1] / ggml_type_size(dst->type); + const uint32_t stride_batch_d = stride_d*ne21; const uint64_t r2 = ne12 / ne02; const uint64_t r3 = ne13 / ne03; @@ -5481,7 +6108,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub const int y_ne = padded_n * ne10; const int d_ne = ne11 * ne01; - const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, pipeline); + const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, disable_split_k, pipeline); const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type); const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); @@ -5507,12 +6134,15 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT if (quantize_y) { - to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1); + to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1, true); } if (dryrun) { const uint64_t x_sz_upd = x_sz * ne02 * ne03; - const uint64_t y_sz_upd = y_sz * ne12 * ne13; + uint64_t y_sz_upd = y_sz * ne12 * ne13; + if (quantize_y) { + y_sz_upd = CEIL_DIV(y_sz_upd, 144) * 144; + } const uint64_t split_k_size = split_k > 1 ? d_sz * ne12 * ne13 * split_k : 0; if ( (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) || @@ -5578,25 +6208,47 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub GGML_ASSERT(d_Y->size >= y_sz * ne12 * ne13); } else if (quantize_y) { d_Y = ctx->prealloc_y; - GGML_ASSERT(d_Y->size >= y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)); + GGML_ASSERT(d_Y->size >= CEIL_DIV(y_sz * ne12 * ne13, 144) * 144); } else { d_Y = d_Qy; y_buf_offset = qy_buf_offset; GGML_ASSERT(qy_sz == y_sz); } + if (x_non_contig || qx_needs_dequant) { + if (ctx->prealloc_x_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + } + if (x_non_contig) { ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE }); } else if (qx_needs_dequant) { const std::vector pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) }; - ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, vk_subbuffer{ d_X, 0, x_sz * ne02 * ne03 } }, pc, { (uint32_t)(x_ne * ne02 * ne03), 1, 1}); + ggml_vk_sync_buffers(ctx, subctx); } if (y_non_contig) { - ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }); + if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() || + ctx->prealloc_y_last_tensor_used != src1) { + if (ctx->prealloc_y_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }); + ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get(); + ctx->prealloc_y_last_tensor_used = src1; + } } if (quantize_y) { - ggml_vk_quantize_q8_1(ctx, subctx, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, y_ne * ne12 * ne13); + if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() || + ctx->prealloc_y_last_tensor_used != src1) { + if (ctx->prealloc_y_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + ggml_vk_quantize_q8_1(ctx, subctx, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, y_ne * ne12 * ne13, true); + ctx->prealloc_y_last_pipeline_used = to_q8_1.get(); + ctx->prealloc_y_last_tensor_used = src1; + } } uint32_t stride_batch_x = ne00*ne01; @@ -5610,15 +6262,75 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub stride_batch_y = src1->nb[0] / ggml_type_size(src1->type); } + uint32_t y_sz_total = y_sz * ne12 * ne13; + if (quantize_y) { + y_sz_total = CEIL_DIV(y_sz_total, 144) * 144; + } + + // No bounds checking is needed for dst. This is basically VK_WHOLE_SIZE but clamped to maxStorageBufferRange. + VkDeviceSize d_range = std::min(VkDeviceSize{d_D->size - d_buf_offset}, VkDeviceSize{ctx->device->properties.limits.maxStorageBufferRange}); + // compute ggml_vk_matmul( ctx, subctx, pipeline, - { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 }, - { d_D, d_buf_offset, d_sz * ne12 * ne13 }, { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k }, + { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz_total }, + { d_D, d_buf_offset, d_range }, { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k }, ne01, ne11, ne10, - ne10, ne10, ne01, stride_batch_x, stride_batch_y, ne20*ne21, + ne10, ne10, stride_d, stride_batch_x, stride_batch_y, stride_batch_d, split_k, ne12*ne13, ne02, ne12, r2, r3, padded_n ); // NOLINT + + if (x_non_contig || qx_needs_dequant) { + ctx->prealloc_x_need_sync = true; + } + if (y_non_contig || quantize_y) { + ctx->prealloc_y_need_sync = true; + } +} + +// Device tuning +static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_t n, uint32_t k, ggml_type src0_type) { + if (device->mmvq_mode == 1) { + return true; + } else if (device->mmvq_mode == -1) { + return false; + } + + // MMVQ is generally good for batches + if (n > 1) { + return true; + } + + switch (device->vendor_id) { + case VK_VENDOR_ID_NVIDIA: + switch (src0_type) { + case GGML_TYPE_Q8_0: + return device->architecture == vk_device_architecture::NVIDIA_PRE_TURING; + default: + return true; + } + case VK_VENDOR_ID_AMD: + switch (src0_type) { + case GGML_TYPE_Q8_0: + return device->architecture == vk_device_architecture::AMD_GCN; + default: + return true; + } + case VK_VENDOR_ID_INTEL: + switch (src0_type) { + // From tests on A770 Linux, may need more tuning + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q5_1: + return false; + default: + return true; + } + default: + return true; + } + + GGML_UNUSED(m); + GGML_UNUSED(k); } static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { @@ -5675,22 +6387,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& const bool y_non_contig = !ggml_vk_dim01_contiguous(src1); const bool f16_f32_kernel = src1->type == GGML_TYPE_F32; - - const bool qx_needs_dequant = x_non_contig; - const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig; - - // Not implemented - GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT - - const uint64_t x_ne = ne01 * ne00; - const uint64_t y_ne = ne11 * ne10; - const uint64_t d_ne = ne11 * ne01; - - const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment); - const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); - const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz; - const uint64_t y_sz = f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne; - const uint64_t d_sz = sizeof(float) * d_ne; + bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0 && ggml_vk_should_use_mmvq(ctx->device, ne01, ne11, ne10, src0->type); vk_pipeline to_fp16_vk_0 = nullptr; vk_pipeline to_fp16_vk_1 = nullptr; @@ -5702,14 +6399,47 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& } else { to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type); } - vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, src1->type, ne11); + + // Check for mmq first + vk_pipeline dmmv = quantize_y ? ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, GGML_TYPE_Q8_1, ne11, ne20, ne00) : nullptr; + vk_pipeline to_q8_1 = nullptr; + + if (dmmv == nullptr) { + // Fall back to f16 dequant mul mat + dmmv = ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, src1->type, ne11, ne20, ne00); + quantize_y = false; + } + + if (quantize_y) { + to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1, true); + } + + const bool qx_needs_dequant = x_non_contig; + const bool qy_needs_dequant = !quantize_y && ((src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig); + + // Not implemented + GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT + GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT GGML_ASSERT(dmmv != nullptr); + const uint64_t x_ne = ne01 * ne00; + const uint64_t y_ne = ne11 * ne10; + const uint64_t d_ne = ne11 * ne01; + + const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment); + const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); + const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz; + const uint64_t y_sz = quantize_y ? (y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : (f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne); + const uint64_t d_sz = sizeof(float) * d_ne; + if (dryrun) { const uint64_t x_sz_upd = x_sz * ne02 * ne03; - const uint64_t y_sz_upd = y_sz * ne12 * ne13; + uint64_t y_sz_upd = y_sz * ne12 * ne13; + if (quantize_y) { + y_sz_upd = CEIL_DIV(y_sz_upd, 144) * 144; + } if ( (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) || (qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size)) { @@ -5718,7 +6448,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) { ctx->prealloc_size_x = x_sz_upd; } - if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) { + if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz_upd) { ctx->prealloc_size_y = y_sz_upd; } @@ -5729,6 +6459,9 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& if (qy_needs_dequant) { ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1); } + if (quantize_y) { + ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1); + } ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1); return; } @@ -5759,6 +6492,9 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& } if (qy_needs_dequant) { d_Y = ctx->prealloc_y; + } else if (quantize_y) { + d_Y = ctx->prealloc_y; + GGML_ASSERT(d_Y->size >= CEIL_DIV(y_sz * ne12 * ne13, 144) * 144); } else { d_Y = d_Qy; y_buf_offset = qy_buf_offset; @@ -5766,12 +6502,35 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& } if (x_non_contig) { + if (ctx->prealloc_x_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment)); ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE }); } if (y_non_contig) { GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne); - ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }); + if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() || + ctx->prealloc_y_last_tensor_used != src1) { + if (ctx->prealloc_y_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }); + ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get(); + ctx->prealloc_y_last_tensor_used = src1; + } + } + if (quantize_y) { + if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() || + ctx->prealloc_y_last_tensor_used != src1) { + if (ctx->prealloc_y_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + ggml_vk_quantize_q8_1(ctx, subctx, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, y_ne * ne12 * ne13, true); + ctx->prealloc_y_last_pipeline_used = to_q8_1.get(); + ctx->prealloc_y_last_tensor_used = src1; + } } // For batch_n, the A matrix is the same for each batch, and B/D use the row stride as the batch stride @@ -5797,16 +6556,28 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& groups_x = CEIL_DIV(groups_x, groups_z); } + // TODO: Clean up this whole sz * ne_2 * ne_3 thing, it hasn't been necessary for a long time + uint32_t y_sz_total = y_sz * ne12 * ne13; + if (quantize_y) { + y_sz_total = CEIL_DIV(y_sz_total, 144) * 144; + } + // compute const vk_mat_vec_push_constants pc = { (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01, stride_batch_x, stride_batch_y, stride_batch_d, (uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3, }; - ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, - { vk_subbuffer{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, vk_subbuffer{ d_Y, y_buf_offset, y_sz * ne12 * ne13 }, vk_subbuffer{ d_D, d_buf_offset, d_sz * ne22 * ne23} }, + { vk_subbuffer{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, vk_subbuffer{ d_Y, y_buf_offset, y_sz_total }, vk_subbuffer{ d_D, d_buf_offset, d_sz * ne22 * ne23} }, pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z }); + + if (x_non_contig) { + ctx->prealloc_x_need_sync = true; + } + if (y_non_contig || quantize_y) { + ctx->prealloc_y_need_sync = true; + } } static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { @@ -5893,7 +6664,6 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c workgroups_z /= gqa_ratio; } - ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, pc, { 1, (uint32_t)ne01, workgroups_z }); } @@ -5980,14 +6750,40 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con // compute const std::array pc = { (uint32_t)ne00, (uint32_t)ne01, row_stride_x, channel_stride_x, channel_stride_y, (uint32_t)(ne12 / ne02), (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)), nb03, nb13, nb23 }; - ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, pc, { (uint32_t)ne03, (uint32_t)ne01, (uint32_t)ne12 }); } -static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { +static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { VK_LOG_DEBUG("ggml_vk_mul_mat(" << src0 << ", " << src1 << ", " << dst << ")"); - if (src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && dst->ne[1] == 1 && + + // Handle huge A matrix by splitting the M dimensions. This works well for convolution use cases + // where the M dimension is very large. + // Split_k doesn't work with M splitting. + const size_t nbytes = ggml_nbytes(src0); + const bool needs_split = nbytes > ctx->device->properties.limits.maxStorageBufferRange; + if (needs_split) { + // Choose the number of rows that can fit (and divide by two, to allow for any additional offsets) + const uint32_t M_split = ctx->device->properties.limits.maxStorageBufferRange / (2 * src0->nb[1]); + uint32_t m_offset = 0; + while (m_offset < dst->ne[0]) { + const uint32_t cur_M_size = std::min(M_split, (uint32_t)(dst->ne[0] - m_offset)); + ggml_tensor dst2 = *dst; + ggml_tensor src02 = *src0; + + dst2.view_src = dst->view_src ? dst->view_src : dst; + src02.view_src = src0->view_src ? src0->view_src : src0; + + dst2.view_offs += m_offset * dst->nb[0]; + src02.view_offs += m_offset * src0->nb[1]; + dst2.ne[0] = cur_M_size; + src02.ne[1] = cur_M_size; + + ggml_vk_mul_mat_q_f16(ctx, subctx, &src02, src1, &dst2, true, dryrun); + + m_offset += cur_M_size; + } + } else if (src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && dst->ne[1] == 1 && // detect 0213 permutation, and batch size of 1 src0->nb[0] <= src0->nb[2] && src0->nb[2] <= src0->nb[1] && @@ -6007,7 +6803,7 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || ggml_is_quantized(src0->type))) { ggml_vk_mul_mat_vec_q_f16(ctx, subctx, src0, src1, dst, dryrun); } else { - ggml_vk_mul_mat_q_f16(ctx, subctx, src0, src1, dst, dryrun); + ggml_vk_mul_mat_q_f16(ctx, subctx, src0, src1, dst, false, dryrun); } } @@ -6031,7 +6827,6 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& const uint64_t nei0 = ids->ne[0]; const uint64_t nei1 = ids->ne[1]; - GGML_ASSERT(nei0 * nei1 <= 4096); const uint32_t nbi1 = ids->nb[1]; const uint32_t nbi2 = ids->nb[2]; @@ -6192,16 +6987,30 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& GGML_ASSERT(qy_sz == y_sz); } + if (x_non_contig || qx_needs_dequant) { + if (ctx->prealloc_x_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + } + if (x_non_contig) { ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE }); } else if (qx_needs_dequant) { const std::vector pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) }; - ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, vk_subbuffer{ d_X, 0, x_sz * ne02 * ne03 } }, pc, { (uint32_t)(x_ne * ne02 * ne03), 1, 1}); + ggml_vk_sync_buffers(ctx, subctx); } if (y_non_contig) { - ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }); + if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() || + ctx->prealloc_y_last_tensor_used != src1) { + if (ctx->prealloc_y_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }); + ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get(); + ctx->prealloc_y_last_tensor_used = src1; + } } uint32_t stride_batch_x = ne00*ne01; @@ -6224,6 +7033,13 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& stride_batch_x, stride_batch_y, ne20*ne21, n_as, nei0, nei1, nbi1 / ggml_type_size(ids->type), ne11, padded_n ); // NOLINT + + if (x_non_contig || qx_needs_dequant) { + ctx->prealloc_x_need_sync = true; + } + if (y_non_contig) { + ctx->prealloc_y_need_sync = true; + } } static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, bool dryrun = false) { @@ -6383,13 +7199,27 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte GGML_ASSERT(qy_sz == y_sz); } + if (x_non_contig) { + if (ctx->prealloc_x_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + } + if (x_non_contig) { GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment)); ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE }); } if (y_non_contig) { GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne); - ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }); + if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() || + ctx->prealloc_y_last_tensor_used != src1) { + if (ctx->prealloc_y_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }); + ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get(); + ctx->prealloc_y_last_tensor_used = src1; + } } uint32_t stride_batch_y = ne10*ne11; @@ -6414,11 +7244,17 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte (uint32_t)x_ne, stride_batch_y, (uint32_t)(ne20*ne21), (uint32_t)nei0, (uint32_t)ne11, }; - ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, { vk_subbuffer{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, vk_subbuffer{ d_Y, y_buf_offset, y_sz * ne12 * ne13 }, vk_subbuffer{ d_D, d_buf_offset, d_sz * ne22 * ne23}, vk_subbuffer{ d_ids, ids_buf_offset, ids_sz } }, pc, { groups_x, (uint32_t)nei0, groups_z }); + + if (x_non_contig) { + ctx->prealloc_x_need_sync = true; + } + if (y_non_contig) { + ctx->prealloc_y_need_sync = true; + } } static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) { @@ -6426,30 +7262,7 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx if (src2->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) { ggml_vk_mul_mat_vec_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun); } else { - // Split based on number of ids, to fit in shared memory - const uint32_t nei0 = (uint32_t)src2->ne[0]; - const uint32_t nei1 = (uint32_t)src2->ne[1]; - - GGML_ASSERT(nei0 <= 4096); - const uint32_t split_size = std::min(nei1, 4096u / nei0); - - ggml_tensor src1_copy = *src1; - ggml_tensor src2_copy = *src2; - ggml_tensor dst_copy = *dst; - - for (uint32_t token_start = 0; token_start < nei1; token_start += split_size) { - const uint32_t n_tokens = std::min(split_size, nei1 - token_start); - - src1_copy.view_offs = src1->view_offs + token_start * src1_copy.nb[2]; - src2_copy.view_offs = src2->view_offs + token_start * src2_copy.nb[1]; - dst_copy.view_offs = dst->view_offs + token_start * dst_copy.nb[2]; - - src1_copy.ne[2] = n_tokens; - src2_copy.ne[1] = n_tokens; - dst_copy.ne[2] = n_tokens; - - ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, &src1_copy, &src2_copy, &dst_copy, dryrun); - } + ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun); } } @@ -6482,18 +7295,21 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co const uint32_t Br = coopmat1_flash_attention_num_large_rows; const uint32_t Bc = scalar_flash_attention_Bc; + const uint32_t hsk_pad = ROUNDUP_POW2(hsk, 16); + const uint32_t acctype = f32acc ? 4 : 2; const uint32_t f16vec4 = 8; const uint32_t tmpsh = wg_size * sizeof(float); const uint32_t tmpshv4 = wg_size * 4 * acctype; - const uint32_t Qf = Br * (hsk / 4 + 2) * f16vec4; + const uint32_t qstride = hsk_pad / 4 + 2; + const uint32_t Qf = Br * qstride * f16vec4; const uint32_t sfshstride = (hsk <= 128) ? (Br + 8) : Br; const uint32_t sfsh = Bc * sfshstride * acctype; - const uint32_t kshstride = hsk / 4 + 2; + const uint32_t kshstride = hsk_pad / 4 + 2; const uint32_t ksh = Bc * kshstride * f16vec4; const uint32_t slope = Br * sizeof(float); @@ -6604,7 +7420,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx workgroups_y /= N; } - vk_pipeline *pipelines; bool small_rows = N <= get_fa_num_small_rows(path); // coopmat1 does not actually support "small rows" (it needs 16 rows). @@ -6624,37 +7439,36 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx small_rows = true; } - bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32; - - FaHeadSizes head_sizes = fa_get_head_sizes(k->ne[0], v->ne[0]); - - switch (path) { - case FA_SCALAR: - pipelines = &ctx->device->pipeline_flash_attn_f32_f16[k->type][head_sizes][f32acc][small_rows][0]; - break; - case FA_COOPMAT1: - pipelines = &ctx->device->pipeline_flash_attn_f32_f16_cm1[k->type][head_sizes][f32acc][small_rows][0]; - break; - case FA_COOPMAT2: - pipelines = &ctx->device->pipeline_flash_attn_f32_f16_cm2[k->type][head_sizes][f32acc][small_rows][0]; - break; - default: - GGML_ASSERT(0); - } - assert(pipelines); - const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type)); const uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type)); const uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type)); - bool aligned = (KV % pipelines[1]->align) == 0 && + uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows); + bool aligned = (KV % alignment) == 0 && // the "aligned" shader variant will forcibly align strides, for performance (q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0; + // Need to use the coopmat2 variant that clamps loads when HSK/HSV aren't sufficiently aligned. + if (((HSK | HSV) % 16) != 0 && path == FA_COOPMAT2) { + aligned = false; + } // mask dim1 is padded to 64, we rely on this to avoid clamping mask loads GGML_ASSERT((nem1 % GGML_KQ_MASK_PAD) == 0); - vk_pipeline pipeline = pipelines[aligned]; + bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32; + + vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, path, aligned, f32acc); + + vk_pipeline pipeline = nullptr; + + auto &pipelines = ctx->device->pipeline_flash_attn_f32_f16[k->type]; + auto it = pipelines.find(fa_pipeline_state); + if (it != pipelines.end()) { + pipeline = it->second; + } else { + pipelines[fa_pipeline_state] = pipeline = std::make_shared(); + } + assert(pipeline); uint32_t split_kv = KV; @@ -6670,7 +7484,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx if (split_k > 1) { // Try to evenly split KV into split_k chunks, but it needs to be a multiple // of "align", so recompute split_k based on that. - split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), pipelines[1]->align); + split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), alignment); split_k = CEIL_DIV(KV, split_kv); workgroups_x = split_k; } @@ -6794,9 +7608,11 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx mask_n_head_log2, m0, m1, gqa_ratio, split_kv, split_k }; - ggml_vk_sync_buffers(subctx); - if (split_k > 1) { + if (ctx->prealloc_split_k_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE}, @@ -6812,7 +7628,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx // cancel out the divide by wg_denoms[0]. pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z }); - ggml_vk_sync_buffers(subctx); + ggml_vk_sync_buffers(ctx, subctx); const std::array pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k, (sinks != nullptr) }; ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce, { @@ -6821,6 +7637,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE}, }, pc2, { (uint32_t)ne1, HSV, (uint32_t)ne3 }); + ctx->prealloc_split_k_need_sync = true; } else { ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { @@ -6863,7 +7680,34 @@ static std::array ggml_vk_get_conv_elements(const ggml_tensor *dst) return elements; } -static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) { +static std::array ggml_vk_get_conv_transpose_2d_elements(const ggml_tensor *dst) { + const ggml_tensor *src0 = dst->src[0]; + const ggml_tensor *src1 = dst->src[1]; + + // src0 - kernel: [KW, KH, Cout, Cin] + // src1 - input: [W, H, Cin, N] + // dst - result: [OW, OH, Cout, N] + + auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t { + return (ins - 1) * s - 2 * p + (ks - 1) * d + 1; + }; + // parallelize in {OW/BS_K, OH/BS_NPQ, 1} + int64_t W = src1->ne[0]; + int64_t H = src1->ne[1]; + int64_t KW = src0->ne[0]; + int64_t KH = src0->ne[1]; + int64_t Cout = src0->ne[2]; + int64_t N = src1->ne[3]; + int64_t OH = calc_conv_output_size(H, KH, dst->op_params[0], 0, 1); + int64_t OW = calc_conv_output_size(W, KW, dst->op_params[0], 0, 1); + int64_t NPQ = N * OW * OH; + + // Tile output matrix to (K/NB_K, NPQ/NB_NPQ, 1) workgroups + std::array elements = { static_cast(Cout), static_cast(NPQ), 1 }; + return elements; +} + +static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * dst, ggml_op op) { switch (op) { case GGML_OP_GET_ROWS: GGML_ASSERT(src1->type == GGML_TYPE_I32); @@ -6891,8 +7735,20 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const switch (op) { case GGML_OP_ADD: { - auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_norepeat : ctx->device->pipeline_add; - return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16]; + if (ctx->num_additional_fused_ops > 0) { + if (ctx->do_add_rms_partials) { + return ctx->device->pipeline_multi_add_rms[ctx->num_additional_fused_ops]; + } else { + return ctx->device->pipeline_multi_add[ctx->num_additional_fused_ops]; + } + } + if (ctx->do_add_rms_partials) { + auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_rms_norepeat : ctx->device->pipeline_add_rms; + return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16]; + } else { + auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_norepeat : ctx->device->pipeline_add; + return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16]; + } } case GGML_OP_SUB: { @@ -6952,6 +7808,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_sqr_f32; } return nullptr; + case GGML_OP_SQRT: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_sqrt_f32; + } + return nullptr; case GGML_OP_SIN: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_sin_f32; @@ -6992,7 +7853,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const case GGML_OP_DUP: return ggml_vk_get_cpy_pipeline(ctx, src0, dst, dst->type); case GGML_OP_SET_ROWS: - return ctx->device->pipeline_set_rows[dst->type]; + if (src1->type == GGML_TYPE_I64) { + return ctx->device->pipeline_set_rows_i64[dst->type]; + } else { + return ctx->device->pipeline_set_rows_i32[dst->type]; + } case GGML_OP_SILU_BACK: if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_silu_back_f32; @@ -7010,7 +7875,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return nullptr; case GGML_OP_RMS_NORM: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_f32 : ctx->device->pipeline_rms_norm_f32; + if (ctx->do_add_rms_partials) { + return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_partials_f32 : ctx->device->pipeline_rms_norm_partials_f32; + } else { + return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_f32 : ctx->device->pipeline_rms_norm_f32; + } } return nullptr; case GGML_OP_RMS_NORM_BACK: @@ -7031,6 +7900,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const } switch (ggml_get_unary_op(dst)) { + case GGML_UNARY_OP_EXP: + return ctx->device->pipeline_exp[dst->type == GGML_TYPE_F16]; case GGML_UNARY_OP_SILU: return ctx->device->pipeline_silu[dst->type == GGML_TYPE_F16]; case GGML_UNARY_OP_GELU: @@ -7045,6 +7916,10 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_tanh[dst->type == GGML_TYPE_F16]; case GGML_UNARY_OP_SIGMOID: return ctx->device->pipeline_sigmoid[dst->type == GGML_TYPE_F16]; + case GGML_UNARY_OP_HARDSIGMOID: + return ctx->device->pipeline_hardsigmoid[dst->type == GGML_TYPE_F16]; + case GGML_UNARY_OP_HARDSWISH: + return ctx->device->pipeline_hardswish[dst->type == GGML_TYPE_F16]; default: break; } @@ -7135,11 +8010,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const } case GGML_OP_ARGSORT: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) { - return ctx->device->pipeline_argsort_f32; + uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0]))); + return ctx->device->pipeline_argsort_f32[idx]; } return nullptr; case GGML_OP_SUM: case GGML_OP_SUM_ROWS: + case GGML_OP_MEAN: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_sum_rows_f32; } @@ -7162,6 +8039,14 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_im2col_f32_f16; } return nullptr; + case GGML_OP_IM2COL_3D: + if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_im2col_3d_f32; + } + if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_im2col_3d_f32_f16; + } + return nullptr; case GGML_OP_TIMESTEP_EMBEDDING: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_timestep_embedding_f32; @@ -7192,15 +8077,23 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_opt_step_adamw_f32; } return nullptr; + case GGML_OP_OPT_STEP_SGD: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_opt_step_sgd_f32; + } + return nullptr; case GGML_OP_LEAKY_RELU: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_leaky_relu_f32; } return nullptr; case GGML_OP_CONV_2D: + case GGML_OP_CONV_TRANSPOSE_2D: if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) { - auto elements = ggml_vk_get_conv_elements(dst); + std::array elements; + if (op == GGML_OP_CONV_2D) elements = ggml_vk_get_conv_elements(dst); + else if (op == GGML_OP_CONV_TRANSPOSE_2D) elements = ggml_vk_get_conv_transpose_2d_elements(dst); vk_conv_shapes shape; uint32_t tiles[CONV_SHAPE_COUNT]; @@ -7220,10 +8113,18 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const shape = CONV_SHAPE_64x32; } - if (src0->type == GGML_TYPE_F32) { - return ctx->device->pipeline_conv2d_f32[shape]; - } else if (src0->type == GGML_TYPE_F16) { - return ctx->device->pipeline_conv2d_f16_f32[shape]; + if (op == GGML_OP_CONV_2D) { + if (src0->type == GGML_TYPE_F32) { + return ctx->device->pipeline_conv2d_f32[shape]; + } else if (src0->type == GGML_TYPE_F16) { + return ctx->device->pipeline_conv2d_f16_f32[shape]; + } + } else if (op == GGML_OP_CONV_TRANSPOSE_2D) { + if (src0->type == GGML_TYPE_F32) { + return ctx->device->pipeline_conv_transpose_2d_f32[shape]; + } else if (src0->type == GGML_TYPE_F16) { + return ctx->device->pipeline_conv_transpose_2d_f16_f32[shape]; + } } } return nullptr; @@ -7234,6 +8135,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const } else if (ggml_is_contiguous_channels(src1)) { return ctx->device->pipeline_conv2d_dw_cwhn_f32; } + } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { + if (ggml_is_contiguous(src1)) { + return ctx->device->pipeline_conv2d_dw_whcn_f16_f32; + } else if (ggml_is_contiguous_channels(src1)) { + return ctx->device->pipeline_conv2d_dw_cwhn_f16_f32; + } } return nullptr; default: @@ -7255,6 +8162,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) { case GGML_OP_CONCAT: case GGML_OP_UPSCALE: case GGML_OP_SQR: + case GGML_OP_SQRT: case GGML_OP_SIN: case GGML_OP_COS: case GGML_OP_CLAMP: @@ -7265,7 +8173,11 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) { case GGML_OP_RMS_NORM: case GGML_OP_CONV_2D_DW: case GGML_OP_IM2COL: + case GGML_OP_IM2COL_3D: case GGML_OP_SET_ROWS: + case GGML_OP_SUM: + case GGML_OP_SUM_ROWS: + case GGML_OP_MEAN: return true; default: return false; @@ -7300,6 +8212,36 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk GGML_UNUSED(src2); } +template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_sum_rows_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { + const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); + const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); + + p.misalign_offsets = (a_offset << 16) | d_offset; + + GGML_UNUSED(src1); + GGML_UNUSED(src2); +} + +template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_pad_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { + const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); + const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); + + p.misalign_offsets = (a_offset << 16) | d_offset; + + GGML_UNUSED(src1); + GGML_UNUSED(src2); +} + +template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_im2col_3d_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { + const uint32_t a_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type); + const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); + + p.misalign_offsets = (a_offset << 16) | d_offset; + + GGML_UNUSED(src0); + GGML_UNUSED(src2); +} + template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_binary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); const uint32_t b_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type); @@ -7450,10 +8392,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co d_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); if (op_supports_incontiguous) { - x_sz = ggml_nbytes(src0); - y_sz = use_src1 ? ggml_nbytes(src1) : 0; - z_sz = use_src2 ? ggml_nbytes(src2) : 0; - d_sz = ggml_nbytes(dst); + x_sz = ggml_nbytes(src0) + get_misalign_bytes(ctx, src0); + y_sz = use_src1 ? ggml_nbytes(src1) + get_misalign_bytes(ctx, src1) : 0; + z_sz = use_src2 ? ggml_nbytes(src2) + get_misalign_bytes(ctx, src2) : 0; + d_sz = ggml_nbytes(dst) + get_misalign_bytes(ctx, dst); if (x_buf_offset + x_sz >= d_X->size) { x_sz = VK_WHOLE_SIZE; @@ -7481,6 +8423,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX_BACK: case GGML_OP_SUM_ROWS: + case GGML_OP_MEAN: case GGML_OP_ARGMAX: { const uint32_t nr = ggml_nrows(src0); @@ -7493,7 +8436,12 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co } } break; case GGML_OP_RMS_NORM: - elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 }; + if (ctx->do_add_rms_partials) { + // Run one element per thread, 128 threads per workgroup + elements = { (uint32_t)CEIL_DIV(ne00, 128), 1, 1 }; + } else { + elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 }; + } break; case GGML_OP_SUM: @@ -7512,6 +8460,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co break; case GGML_OP_GET_ROWS: elements = { (uint32_t)ne00, (uint32_t)ne10, (uint32_t)(ne11 * ne12) }; + elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]); + elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]); break; case GGML_OP_ARGSORT: elements = { (uint32_t)ne00, (uint32_t)ggml_nrows(src0), 1 }; @@ -7532,6 +8482,26 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co elements = { OW * KW * KH, OH, batch * IC }; } break; + case GGML_OP_IM2COL_3D: + { + const uint32_t IC = ((const uint32_t *)(dst->op_params))[9]; + + const uint32_t N = ne13 / IC; + + const uint32_t KD = ne02; + const uint32_t KH = ne01; + const uint32_t KW = ne00; + + const uint32_t OD = ned3 / N; + const uint32_t OH = ned2; + const uint32_t OW = ned1; + + const uint32_t IC_KD_KH_KW = IC*KD*KH*KW; + const uint32_t N_OD_OH = N*OD*OH; + + elements = { IC_KD_KH_KW, OW, N_OD_OH }; + elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]); + } break; case GGML_OP_TIMESTEP_EMBEDDING: { const uint32_t dim = dst->op_params[0]; @@ -7554,12 +8524,17 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co { elements = ggml_vk_get_conv_elements(dst); } break; + case GGML_OP_CONV_TRANSPOSE_2D: + { + elements = ggml_vk_get_conv_transpose_2d_elements(dst); + } break; case GGML_OP_ADD: case GGML_OP_SUB: case GGML_OP_DIV: case GGML_OP_MUL: case GGML_OP_SCALE: case GGML_OP_SQR: + case GGML_OP_SQRT: case GGML_OP_SIN: case GGML_OP_COS: case GGML_OP_CLAMP: @@ -7641,7 +8616,16 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co } } - if (op == GGML_OP_GLU) { + if (op == GGML_OP_ADD || op == GGML_OP_RMS_NORM) { + vk_buffer d_A = ctx->do_add_rms_partials ? ctx->prealloc_add_rms_partials : d_X; + size_t a_buf_offset = ctx->do_add_rms_partials ? ctx->prealloc_size_add_rms_partials_offset : 0; + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, + { vk_subbuffer{ d_X, x_buf_offset, x_sz }, + vk_subbuffer{ d_Y, y_buf_offset, y_sz }, + vk_subbuffer{ d_D, d_buf_offset, d_sz }, + vk_subbuffer{ d_A, a_buf_offset, VK_WHOLE_SIZE }, + }, pc, elements); + } else if (op == GGML_OP_GLU) { // Empty src1 is possible in glu, but the shader needs a buffer vk_subbuffer subbuf_y; if (use_src1) { @@ -7650,7 +8634,6 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co subbuf_y = { d_X, 0, x_sz }; } - ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); } else if (op == GGML_OP_SOFT_MAX) { // Empty src1 and src2 is possible in soft_max, but the shader needs a buffer @@ -7668,7 +8651,6 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co subbuf_z = { d_X, 0, x_sz }; } - ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); } else if (op == GGML_OP_ROPE || op == GGML_OP_ROPE_BACK) { // Empty src2 is possible in rope, but the shader needs a buffer @@ -7679,26 +8661,27 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co subbuf_z = { d_X, 0, x_sz }; } - ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); - } else if (op == GGML_OP_IM2COL) { + } else if (op == GGML_OP_IM2COL || op == GGML_OP_IM2COL_3D) { + if (ctx->device->shader_int64 && ctx->device->buffer_device_address) { + // buffer device address path doesn't use dst buffer + d_sz = 1; + } // im2col uses only src1 and dst buffers - ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); } else if (op == GGML_OP_COUNT_EQUAL) { - ggml_vk_sync_buffers(subctx); // count_equal assumes that destination buffer is initialized with zeroes ggml_vk_buffer_memset_async(subctx, d_D, d_buf_offset, 0, d_sz); - ggml_vk_sync_buffers(subctx); + ggml_vk_sync_buffers(ctx, subctx); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); + } else if (op == GGML_OP_OPT_STEP_SGD) { + // OPT_STEP_SGD works on src0, it does not need dst + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz } }, pc, elements); } else if (use_src2) { - ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); } else if (use_src1) { - ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); } else { - ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); } } @@ -7738,6 +8721,116 @@ static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const }, dryrun); } +static void ggml_vk_multi_add(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx, bool dryrun = false) { + const ggml_tensor *first_node = cgraph->nodes[node_idx]; + const ggml_tensor *dst = cgraph->nodes[node_idx + ctx->num_additional_fused_ops]; + + // Make a list of all the tensors used by the op. + // Last element of the list is the dest tensor. + const ggml_tensor *tensors[MAX_PARAMETER_COUNT]; + uint32_t num_srcs = ctx->num_additional_fused_ops + 2; + uint32_t num_tensors = num_srcs + 1; + GGML_ASSERT(num_tensors + ctx->do_add_rms_partials <= MAX_PARAMETER_COUNT); + + tensors[0] = first_node->src[0]; + tensors[1] = first_node->src[1]; + for (int32_t i = 0; i < ctx->num_additional_fused_ops; ++i) { + // check whether the previous result is src[0] or src[1] + if (cgraph->nodes[node_idx + i] == cgraph->nodes[node_idx + i + 1]->src[0]) { + tensors[i+2] = cgraph->nodes[node_idx + i + 1]->src[1]; + } else { + tensors[i+2] = cgraph->nodes[node_idx + i + 1]->src[0]; + } + } + tensors[num_srcs] = dst; + + vk_op_multi_add_push_constants pc; + pc.ne20 = (uint32_t)dst->ne[0]; + pc.ne21 = (uint32_t)dst->ne[1]; + pc.ne22 = (uint32_t)dst->ne[2]; + pc.ne23 = (uint32_t)dst->ne[3]; + + for (uint32_t i = 0; i < num_tensors; ++i) { + const ggml_tensor *t = tensors[i]; + pc.nb[i][0] = (uint32_t)t->nb[0] / sizeof(float); + pc.nb[i][1] = (uint32_t)t->nb[1] / sizeof(float); + pc.nb[i][2] = (uint32_t)t->nb[2] / sizeof(float); + pc.nb[i][3] = (uint32_t)t->nb[3] / sizeof(float); + } + pc.rms_partials = ctx->do_add_rms_partials; + + vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, tensors[0], tensors[1], nullptr, dst, dst->op); + + if (pipeline == nullptr) { + std::cerr << "ggml_vulkan: Error: Missing multi_add"; + GGML_ABORT("fatal error"); + } + + if (dryrun) { + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + return; + } + + ggml_backend_vk_buffer_context * buf_ctx[MAX_PARAMETER_COUNT]; + vk_buffer buf[MAX_PARAMETER_COUNT]; + size_t offset[MAX_PARAMETER_COUNT]; + bool uma[MAX_PARAMETER_COUNT]; + + for (uint32_t i = 0; i < num_tensors; ++i) { + buf_ctx[i] = (ggml_backend_vk_buffer_context *)tensors[i]->buffer->context; + buf[i] = nullptr; + offset[i] = 0; + uma[i] = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, tensors[i]->data, buf[i], offset[i]); + uma[i] = buf[i] != nullptr; + } + if (!uma[i]) { + buf[i] = buf_ctx[i]->dev_buffer; + offset[i] = vk_tensor_offset(tensors[i]) + tensors[i]->view_offs; + } + GGML_ASSERT(buf[i] != nullptr); + } + // If any remaining descriptors are unused, just point them at src[0] + for (uint32_t i = num_tensors; i < MAX_PARAMETER_COUNT; ++i) { + buf[i] = buf[0]; + offset[i] = 0; + } + if (ctx->do_add_rms_partials) { + buf[num_tensors] = ctx->prealloc_add_rms_partials; + offset[num_tensors] = ctx->prealloc_size_add_rms_partials_offset; + } + + std::array elements; + + uint32_t ne = ggml_nelements(dst); + if (ne > 262144) { + elements = { 512, 512, CEIL_DIV(ne, 262144) }; + } else if (ne > 512) { + elements = { 512, CEIL_DIV(ne, 512), 1 }; + } else { + elements = { ne, 1, 1 }; + } + + static_assert(MAX_PARAMETER_COUNT == 12); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, + { + vk_subbuffer{ buf[0], offset[0], VK_WHOLE_SIZE }, + vk_subbuffer{ buf[1], offset[1], VK_WHOLE_SIZE }, + vk_subbuffer{ buf[2], offset[2], VK_WHOLE_SIZE }, + vk_subbuffer{ buf[3], offset[3], VK_WHOLE_SIZE }, + vk_subbuffer{ buf[4], offset[4], VK_WHOLE_SIZE }, + vk_subbuffer{ buf[5], offset[5], VK_WHOLE_SIZE }, + vk_subbuffer{ buf[6], offset[6], VK_WHOLE_SIZE }, + vk_subbuffer{ buf[7], offset[7], VK_WHOLE_SIZE }, + vk_subbuffer{ buf[8], offset[8], VK_WHOLE_SIZE }, + vk_subbuffer{ buf[9], offset[9], VK_WHOLE_SIZE }, + vk_subbuffer{ buf[10], offset[10], VK_WHOLE_SIZE }, + vk_subbuffer{ buf[11], offset[11], VK_WHOLE_SIZE }, + }, pc, elements); +} + static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { const uint32_t src0_type_size = ggml_type_size(src0->type); const uint32_t src1_type_size = ggml_type_size(src1->type); @@ -7749,7 +8842,7 @@ static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, 0, - 0.0f, 0.0f, 0, + 0.0f, 0.0f, ctx->do_add_rms_partials, }, dryrun); } @@ -7837,8 +8930,6 @@ static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx src_buf_ctxs[i] = (ggml_backend_vk_buffer_context *)dst->src[i]->buffer->context; } - ggml_vk_sync_buffers(subctx); - vk_buffer d_D = nullptr, d_srcs[7] = { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr }; size_t dst_offset = 0, src_offsets[7] = { 0, 0, 0, 0, 0, 0, 0 }; bool dst_uma = false, srcs_uma[7] = { false, false, false, false, false, false, false }; @@ -7976,8 +9067,6 @@ static void ggml_vk_op_f32_opt_step_adamw(ggml_backend_vk_context * ctx, vk_cont ggml_backend_vk_buffer_context * gv_buf_ctx = (ggml_backend_vk_buffer_context *)gv->buffer->context; ggml_backend_vk_buffer_context * p_buf_ctx = (ggml_backend_vk_buffer_context *)p->buffer->context; - ggml_vk_sync_buffers(subctx); - vk_buffer d_X = nullptr, d_G = nullptr, d_GM = nullptr, d_GV = nullptr, d_P = nullptr; size_t x_offset = 0, g_offset = 0, gm_offset = 0, gv_offset = 0, p_offset = 0; bool X_uma = false, G_uma = false, GM_uma = false, GV_uma = false, P_uma = false; @@ -8044,6 +9133,12 @@ static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& su ); } +static void ggml_vk_opt_step_sgd(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) { + const size_t n = ggml_nelements(dst->src[0]); + + ggml_vk_op_f32(ctx, subctx, src0, src1, src2, dst, GGML_OP_OPT_STEP_SGD, { (uint32_t)n, 0, 0.0f, 0.0f }, dryrun); +} + static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { int * op_params = (int *)dst->op_params; @@ -8096,6 +9191,10 @@ static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQR, vk_op_unary_push_constants_init(src0, dst), dryrun); } +static void ggml_vk_sqrt(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQRT, vk_op_unary_push_constants_init(src0, dst), dryrun); +} + static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SIN, vk_op_unary_push_constants_init(src0, dst), dryrun); } @@ -8113,7 +9212,7 @@ static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, con } static void ggml_vk_pad(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { - vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst)); + vk_op_pad_push_constants p = vk_op_pad_push_constants_init(src0, dst); ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_PAD, std::move(p), dryrun); } @@ -8201,19 +9300,39 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun); } +static uint32_t ggml_vk_rms_num_partials(ggml_backend_vk_context * ctx, const ggml_tensor *node) { + const uint32_t ne = (uint32_t)node->ne[0]; + const uint32_t denom = ctx->device->pipeline_add_rms[0][0][0]->wg_denoms[0]; + const uint32_t num_partials = CEIL_DIV(ne, denom); + return num_partials; +} + +static uint32_t ggml_vk_rms_partials_size(ggml_backend_vk_context * ctx, const ggml_tensor *node) { + const uint32_t num_partials = ggml_vk_rms_num_partials(ctx, node); + const uint32_t num_bytes = ROUNDUP_POW2(num_partials * sizeof(uint32_t), ctx->device->partials_binding_alignment); + return num_bytes; +} + static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, float * op_params, bool dryrun = false) { const uint32_t src0_type_size = ggml_type_size(src0->type); const uint32_t src1_type_size = ggml_type_size(src1->type); const uint32_t dst_type_size = ggml_type_size(dst->type); + uint32_t param3 = ctx->do_add_rms_partials ? ggml_vk_rms_num_partials(ctx, dst) : 0; + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)ggml_nelements(src0), (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, 0, - op_params[0], 0.0f, 0, + op_params[0], 0.0f, (int32_t)param3, }, dryrun); + + if (ctx->do_add_rms_partials) { + ctx->prealloc_size_add_rms_partials_offset += ggml_vk_rms_partials_size(ctx, src0); + ctx->do_add_rms_partials = false; + } } static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { @@ -8304,7 +9423,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { float * op_params = (float *)dst->op_params; - ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], op_params[1] }, dryrun); + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), op_params[0], op_params[1] }, dryrun); } static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool backprop, bool dryrun = false) { @@ -8335,7 +9454,7 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1], freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale, src2 != nullptr, (uint32_t)src0->ne[2], s1, s2, - sections[0], sections[1], sections[2], sections[3], backprop + { sections[0], sections[1], sections[2], sections[3] }, backprop }, dryrun); } @@ -8344,30 +9463,30 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c uint32_t ncols = src0->ne[0]; - uint32_t ncols_pad = 1; - while (ncols_pad < ncols) { - ncols_pad *= 2; - } - - GGML_ASSERT(ncols_pad <= 1024); - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ARGSORT, { ncols, - ncols_pad, op_params[0], }, dryrun); } static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun); + vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, ggml_nelements(src0)); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM, p, dryrun); } static void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, { (uint32_t)src0->ne[0], 0, 0.0f, 0.0f }, dryrun); + vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, p, dryrun); +} + +static void ggml_vk_mean(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]); + p.weight = 1.0f / (float)src0->ne[0]; + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_MEAN, p, dryrun); } static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], 0, 0.0f, 0.0f }, dryrun); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f }, dryrun); } static void ggml_vk_count_equal(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { @@ -8399,7 +9518,13 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co const uint32_t pelements = OW * KW * KH; + const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + const vk_buffer d_buf = d_buf_ctx->dev_buffer; + + const vk::DeviceAddress dst_addr = d_buf->bda_addr + vk_tensor_offset(dst) + dst->view_offs; + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_IM2COL, { + dst_addr, batch_offset, offset_delta, IC, IW, IH, OW, OH, KW, KH, pelements, @@ -8408,6 +9533,72 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co }, dryrun); } +static void ggml_vk_im2col_3d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + GGML_TENSOR_BINARY_OP_LOCALS + + const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; + const int32_t s2 = ((const int32_t *)(dst->op_params))[2]; + const int32_t p0 = ((const int32_t *)(dst->op_params))[3]; + const int32_t p1 = ((const int32_t *)(dst->op_params))[4]; + const int32_t p2 = ((const int32_t *)(dst->op_params))[5]; + const int32_t d0 = ((const int32_t *)(dst->op_params))[6]; + const int32_t d1 = ((const int32_t *)(dst->op_params))[7]; + const int32_t d2 = ((const int32_t *)(dst->op_params))[8]; + const int32_t IC = ((const int32_t *)(dst->op_params))[9]; + + const int64_t N = ne13 / IC; + const int64_t ID = ne12; + const int64_t IH = ne11; + const int64_t IW = ne10; + + const int64_t KD = ne02; + const int64_t KH = ne01; + const int64_t KW = ne00; + + const int64_t OD = ne3 / N; + const int64_t OH = ne2; + const int64_t OW = ne1; + + const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + const vk_buffer d_buf = d_buf_ctx->dev_buffer; + + const vk::DeviceAddress dst_addr = d_buf->bda_addr + vk_tensor_offset(dst) + dst->view_offs; + + vk_op_im2col_3d_push_constants pc {}; + + pc.dst_addr = dst_addr; + pc.nb10 = nb10 / ggml_type_size(src1->type); + pc.nb11 = nb11 / ggml_type_size(src1->type); + pc.nb12 = nb12 / ggml_type_size(src1->type); + pc.nb13 = nb13 / ggml_type_size(src1->type); + pc.s0 = s0; + pc.s1 = s1; + pc.s2 = s2; + pc.p0 = p0; + pc.p1 = p1; + pc.p2 = p2; + pc.d0 = d0; + pc.d1 = d1; + pc.d2 = d2; + pc.IW = IW; + pc.IH = IH; + pc.ID = ID; + pc.IC = IC; + pc.KW = KW; + pc.OH = OH; + pc.KD_KH_KW = KD*KH*KW; + pc.KH_KW = KH*KW; + pc.IC_KD_KH_KW = IC*KD*KH*KW; + pc.N_OD_OH = N*OD*OH; + pc.OD_OH = OD*OH; + pc.OD_OH_OW_IC_KD_KH_KW = OD*OH*OW*IC*KD*KH*KW; + pc.OH_OW_IC_KD_KH_KW = OH*OW*IC*KD*KH*KW; + pc.OW_IC_KD_KH_KW = OW*IC*KD*KH*KW; + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_IM2COL_3D, std::move(pc), dryrun); +} + static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { const uint32_t dim = dst->op_params[0]; const uint32_t max_period = dst->op_params[1]; @@ -8526,6 +9717,55 @@ static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context & subctx, ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_2D, std::move(p), dryrun); } +static void ggml_vk_conv_transpose_2d(ggml_backend_vk_context * ctx, vk_context & subctx, const ggml_tensor * src0, + const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT(nb00 == sizeof(float) || nb00 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb10 == sizeof(float)); + GGML_ASSERT(nb0 == sizeof(float)); + + vk_op_conv_transpose_2d_push_constants p{}; + p.Cout = static_cast(ne02); + p.Cin = static_cast(ne03); + p.N = static_cast(ne13); + + p.KW = static_cast(ne00); + p.KH = static_cast(ne01); + p.W = static_cast(ne10); + p.H = static_cast(ne11); + p.OW = static_cast(ne0); + p.OH = static_cast(ne1); + + p.s0 = static_cast(dst->op_params[0]); + p.s1 = static_cast(dst->op_params[0]); + p.p0 = 0; + p.p1 = 0; + p.d0 = 1; + p.d1 = 1; + + p.nb01 = static_cast(nb01 / nb00); + p.nb02 = static_cast(nb02 / nb00); + p.nb03 = static_cast(nb03 / nb00); + + p.nb11 = static_cast(nb11 / nb10); + p.nb12 = static_cast(nb12 / nb10); + p.nb13 = static_cast(nb13 / nb10); + + p.nb1 = static_cast(nb1 / nb0); + p.nb2 = static_cast(nb2 / nb0); + p.nb3 = static_cast(nb3 / nb0); + + GGML_ASSERT(ne02 == ne2); + GGML_ASSERT(ne03 == ne12); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_TRANSPOSE_2D, std::move(p), dryrun); +} + static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { vk_op_conv2d_dw_push_constants p{}; p.ne = ggml_nelements(dst); @@ -8707,7 +9947,7 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t if (ctx->prealloc_split_k != nullptr) { ggml_vk_destroy_buffer(ctx->prealloc_split_k); } - ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, vk::MemoryPropertyFlagBits::eDeviceLocal); + ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, {vk::MemoryPropertyFlagBits::eDeviceLocal}); } } @@ -8717,9 +9957,9 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t ggml_pipeline_allocate_descriptor_sets(ctx); - vk_buffer d_X = ggml_vk_create_buffer_check(ctx->device, sizeof(X_TYPE) * x_ne, vk::MemoryPropertyFlagBits::eDeviceLocal); - vk_buffer d_Y = ggml_vk_create_buffer_check(ctx->device, sizeof(Y_TYPE) * y_ne, vk::MemoryPropertyFlagBits::eDeviceLocal); - vk_buffer d_D = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne, vk::MemoryPropertyFlagBits::eDeviceLocal); + vk_buffer d_X = ggml_vk_create_buffer_check(ctx->device, sizeof(X_TYPE) * x_ne, {vk::MemoryPropertyFlagBits::eDeviceLocal}); + vk_buffer d_Y = ggml_vk_create_buffer_check(ctx->device, sizeof(Y_TYPE) * y_ne, {vk::MemoryPropertyFlagBits::eDeviceLocal}); + vk_buffer d_D = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne, {vk::MemoryPropertyFlagBits::eDeviceLocal}); X_TYPE* x = (X_TYPE *) malloc(sizeof(X_TYPE) * x_ne); Y_TYPE* y = (Y_TYPE *) malloc(sizeof(Y_TYPE) * y_ne); @@ -8945,8 +10185,8 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_ const size_t qx_sz = ne * ggml_type_size(quant)/ggml_blck_size(quant); float * x = (float *) malloc(x_sz); void * qx = malloc(qx_sz); - vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal); - vk_buffer x_buf = ggml_vk_create_buffer_check(ctx->device, x_sz_f16, vk::MemoryPropertyFlagBits::eDeviceLocal); + vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal}); + vk_buffer x_buf = ggml_vk_create_buffer_check(ctx->device, x_sz_f16, {vk::MemoryPropertyFlagBits::eDeviceLocal}); float * x_ref = (float *) malloc(x_sz); ggml_fp16_t * x_chk = (ggml_fp16_t *) malloc(x_sz_f16); @@ -9051,8 +10291,8 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_ // float * x = (float *) malloc(x_sz); // block_q8_1 * qx = (block_q8_1 *)malloc(qx_sz); // block_q8_1 * qx_res = (block_q8_1 *)malloc(qx_sz); -// vk_buffer x_buf = ggml_vk_create_buffer_check(ctx->device, x_sz, vk::MemoryPropertyFlagBits::eDeviceLocal); -// vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal); +// vk_buffer x_buf = ggml_vk_create_buffer_check(ctx->device, x_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal}); +// vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal}); // // for (size_t i = 0; i < ne; i++) { // x[i] = rand() / (float)RAND_MAX; @@ -9199,10 +10439,10 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, float * x = (float *) malloc(x_sz); float * y = (float *) malloc(y_sz); void * qx = malloc(qx_sz); - vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal); - vk_buffer y_buf = ggml_vk_create_buffer_check(ctx->device, y_sz, vk::MemoryPropertyFlagBits::eDeviceLocal); - vk_buffer qy_buf = ggml_vk_create_buffer_check(ctx->device, qy_sz, vk::MemoryPropertyFlagBits::eDeviceLocal); - vk_buffer d_buf = ggml_vk_create_buffer_check(ctx->device, d_sz, vk::MemoryPropertyFlagBits::eDeviceLocal); + vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal}); + vk_buffer y_buf = ggml_vk_create_buffer_check(ctx->device, y_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal}); + vk_buffer qy_buf = ggml_vk_create_buffer_check(ctx->device, qy_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal}); + vk_buffer d_buf = ggml_vk_create_buffer_check(ctx->device, d_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal}); float * d = (float *) malloc(d_sz); float * d_chk = (float *) malloc(d_sz); @@ -9229,7 +10469,7 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, if (ctx->prealloc_split_k != nullptr) { ggml_vk_destroy_buffer(ctx->prealloc_split_k); } - ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, vk::MemoryPropertyFlagBits::eDeviceLocal); + ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, {vk::MemoryPropertyFlagBits::eDeviceLocal}); } } if (mmq) { @@ -9491,6 +10731,14 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) { } ctx->prealloc_split_k = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_split_k); } + if (ctx->prealloc_add_rms_partials == nullptr || (ctx->prealloc_size_add_rms_partials > 0 && ctx->prealloc_add_rms_partials->size < ctx->prealloc_size_add_rms_partials)) { + VK_LOG_MEMORY("ggml_vk_preallocate_buffers(add_partials_size: " << ctx->prealloc_add_rms_partials << ")"); + // Resize buffer + if (ctx->prealloc_add_rms_partials != nullptr) { + ggml_vk_destroy_buffer(ctx->prealloc_add_rms_partials); + } + ctx->prealloc_add_rms_partials = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_add_rms_partials); + } } static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_cgraph * cgraph, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready); @@ -9506,10 +10754,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr VK_LOG_DEBUG("ggml_vk_build_graph(" << node << ", " << ggml_op_name(node->op) << ")"); ctx->semaphore_idx = 0; - const ggml_tensor * src0 = node->src[0]; - const ggml_tensor * src1 = node->src[1]; - const ggml_tensor * src2 = node->src[2]; - const ggml_tensor * src3 = node->src[3]; + ggml_tensor * src0 = node->src[0]; + ggml_tensor * src1 = node->src[1]; + ggml_tensor * src2 = node->src[2]; + ggml_tensor * src3 = node->src[3]; switch (node->op) { // Return on empty ops to avoid generating a compute_ctx and setting exit_tensor @@ -9521,6 +10769,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr return false; case GGML_OP_UNARY: switch (ggml_get_unary_op(node)) { + case GGML_UNARY_OP_EXP: case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_GELU: case GGML_UNARY_OP_GELU_ERF: @@ -9528,6 +10777,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_TANH: case GGML_UNARY_OP_SIGMOID: + case GGML_UNARY_OP_HARDSIGMOID: + case GGML_UNARY_OP_HARDSWISH: break; default: return false; @@ -9546,10 +10797,23 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr return false; } break; + case GGML_OP_ADD: + { + int next_node_idx = node_idx + 1 + ctx->num_additional_fused_ops; + if (next_node_idx < cgraph->n_nodes && + cgraph->nodes[next_node_idx]->op == GGML_OP_RMS_NORM && + cgraph->nodes[next_node_idx]->src[0] == cgraph->nodes[next_node_idx - 1] && + ggml_nrows(cgraph->nodes[next_node_idx]) == 1 && + ctx->device->add_rms_fusion) { + if (dryrun) { + ctx->prealloc_size_add_rms_partials += ggml_vk_rms_partials_size(ctx, cgraph->nodes[node_idx]); + } + ctx->do_add_rms_partials = true; + } + } break; case GGML_OP_REPEAT: case GGML_OP_REPEAT_BACK: case GGML_OP_GET_ROWS: - case GGML_OP_ADD: case GGML_OP_ADD_ID: case GGML_OP_ACC: case GGML_OP_SUB: @@ -9559,6 +10823,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_UPSCALE: case GGML_OP_SCALE: case GGML_OP_SQR: + case GGML_OP_SQRT: case GGML_OP_SIN: case GGML_OP_COS: case GGML_OP_CLAMP: @@ -9584,24 +10849,27 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_ARGSORT: case GGML_OP_SUM: case GGML_OP_SUM_ROWS: + case GGML_OP_MEAN: case GGML_OP_ARGMAX: case GGML_OP_COUNT_EQUAL: case GGML_OP_IM2COL: + case GGML_OP_IM2COL_3D: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_CONV_TRANSPOSE_1D: case GGML_OP_POOL_2D: case GGML_OP_CONV_2D: + case GGML_OP_CONV_TRANSPOSE_2D: case GGML_OP_CONV_2D_DW: case GGML_OP_RWKV_WKV6: case GGML_OP_RWKV_WKV7: case GGML_OP_LEAKY_RELU: case GGML_OP_FLASH_ATTN_EXT: case GGML_OP_OPT_STEP_ADAMW: + case GGML_OP_OPT_STEP_SGD: break; default: std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl; GGML_ABORT("fatal error"); - return false; } vk_context compute_ctx; @@ -9628,6 +10896,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_UPSCALE: case GGML_OP_SCALE: case GGML_OP_SQR: + case GGML_OP_SQRT: case GGML_OP_SIN: case GGML_OP_COS: case GGML_OP_CLAMP: @@ -9652,20 +10921,27 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_ARGSORT: case GGML_OP_SUM: case GGML_OP_SUM_ROWS: + case GGML_OP_MEAN: case GGML_OP_ARGMAX: case GGML_OP_COUNT_EQUAL: case GGML_OP_IM2COL: + case GGML_OP_IM2COL_3D: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_CONV_TRANSPOSE_1D: case GGML_OP_POOL_2D: case GGML_OP_CONV_2D: + case GGML_OP_CONV_TRANSPOSE_2D: case GGML_OP_CONV_2D_DW: case GGML_OP_LEAKY_RELU: + case GGML_OP_OPT_STEP_SGD: { // These operations all go through ggml_vk_op_f32, so short-circuit and // do the only thing needed for the dryrun. vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, node, node->op); ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + if (node->op == GGML_OP_RMS_NORM) { + ctx->do_add_rms_partials = false; + } return false; } default: @@ -9673,6 +10949,80 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr } } + if (!dryrun) { + // This logic detects dependencies between modes in the graph and calls ggml_vk_sync_buffers + // to synchronize them. This handles most "normal" synchronization when computing the graph, and when + // there is no auxiliary memory use, it shouldn't be necessary to call ggml_vk_sync_buffers + // outside of this logic. When a node uses one of the prealloc buffers for something like + // dequantization or split_k, additional synchronization is needed between those passes. + bool need_sync = false; + + // Check whether "node" requires synchronization. The node requires synchronization if it + // overlaps in memory with another unsynchronized node and at least one of them is a write. + // Destination nodes are checked against both the written/read lists. Source nodes are only + // checked against the written list. Two nodes overlap in memory if they come from the same + // buffer and the tensor or view ranges overlap. + auto const &overlaps_unsynced = [&](const ggml_tensor *node, const std::vector &unsynced_nodes) -> bool { + if (unsynced_nodes.size() == 0) { + return false; + } + auto n_base = vk_tensor_offset(node) + node->view_offs; + auto n_size = ggml_nbytes(node); + ggml_backend_vk_buffer_context * a_buf_ctx = (ggml_backend_vk_buffer_context *)node->buffer->context; + vk_buffer a_buf = a_buf_ctx->dev_buffer; + for (auto &other : unsynced_nodes) { + ggml_backend_vk_buffer_context * o_buf_ctx = (ggml_backend_vk_buffer_context *)other->buffer->context; + vk_buffer o_buf = o_buf_ctx->dev_buffer; + if (a_buf == o_buf) { + auto o_base = vk_tensor_offset(other) + other->view_offs; + auto o_size = ggml_nbytes(other); + + if ((o_base <= n_base && n_base < o_base + o_size) || + (n_base <= o_base && o_base < n_base + n_size)) { + return true; + } + } + } + return false; + }; + + // For all fused ops, check if the destination node or any of the source + // nodes require synchronization. + for (int32_t i = 0; i < ctx->num_additional_fused_ops + 1 && !need_sync; ++i) { + const ggml_tensor *cur_node = cgraph->nodes[node_idx + i]; + if (overlaps_unsynced(cur_node, ctx->unsynced_nodes_read) || overlaps_unsynced(cur_node, ctx->unsynced_nodes_written)) { + need_sync = true; + break; + } + for (uint32_t j = 0; j < GGML_MAX_SRC; ++j) { + if (!cur_node->src[j]) { + continue; + } + if (overlaps_unsynced(cur_node->src[j], ctx->unsynced_nodes_written)) { + need_sync = true; + break; + } + } + } + if (need_sync) { + ctx->unsynced_nodes_written.clear(); + ctx->unsynced_nodes_read.clear(); + ggml_vk_sync_buffers(ctx, compute_ctx); + } + // Add the last fused node and all fused source nodes to the unsynchronized list. + const ggml_tensor * last_node = cgraph->nodes[node_idx + ctx->num_additional_fused_ops]; + ctx->unsynced_nodes_written.push_back(last_node); + for (int32_t i = 0; i < ctx->num_additional_fused_ops + 1; ++i) { + const ggml_tensor *cur_node = cgraph->nodes[node_idx + i]; + for (uint32_t j = 0; j < GGML_MAX_SRC; ++j) { + if (!cur_node->src[j]) { + continue; + } + ctx->unsynced_nodes_read.push_back(cur_node->src[j]); + } + } + } + switch (node->op) { case GGML_OP_REPEAT: ggml_vk_repeat(ctx, compute_ctx, src0, node, dryrun); @@ -9691,8 +11041,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr break; case GGML_OP_ADD: - ggml_vk_add(ctx, compute_ctx, src0, src1, node, dryrun); - + if (ctx->num_additional_fused_ops) { + ggml_vk_multi_add(ctx, compute_ctx, cgraph, node_idx, dryrun); + } else { + ggml_vk_add(ctx, compute_ctx, src0, src1, node, dryrun); + } break; case GGML_OP_SUB: ggml_vk_sub(ctx, compute_ctx, src0, src1, node, dryrun); @@ -9725,6 +11078,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_SQR: ggml_vk_sqr(ctx, compute_ctx, src0, node, dryrun); + break; + case GGML_OP_SQRT: + ggml_vk_sqrt(ctx, compute_ctx, src0, node, dryrun); + break; case GGML_OP_SIN: ggml_vk_sin(ctx, compute_ctx, src0, node, dryrun); @@ -9788,6 +11145,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr break; case GGML_OP_UNARY: switch (ggml_get_unary_op(node)) { + case GGML_UNARY_OP_EXP: case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_GELU: case GGML_UNARY_OP_GELU_ERF: @@ -9795,6 +11153,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_TANH: case GGML_UNARY_OP_SIGMOID: + case GGML_UNARY_OP_HARDSIGMOID: + case GGML_UNARY_OP_HARDSWISH: ggml_vk_unary(ctx, compute_ctx, src0, node, dryrun); break; default: @@ -9846,6 +11206,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_SUM_ROWS: ggml_vk_sum_rows(ctx, compute_ctx, src0, node, dryrun); + break; + case GGML_OP_MEAN: + ggml_vk_mean(ctx, compute_ctx, src0, node, dryrun); + break; case GGML_OP_ARGMAX: ggml_vk_argmax(ctx, compute_ctx, src0, node, dryrun); @@ -9858,6 +11222,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_IM2COL: ggml_vk_im2col(ctx, compute_ctx, src0, src1, node, dryrun); + break; + case GGML_OP_IM2COL_3D: + ggml_vk_im2col_3d(ctx, compute_ctx, src0, src1, node, dryrun); + break; case GGML_OP_TIMESTEP_EMBEDDING: ggml_vk_timestep_embedding(ctx, compute_ctx, src0, node, dryrun); @@ -9874,6 +11242,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_CONV_2D: ggml_vk_conv_2d(ctx, compute_ctx, src0, src1, node, dryrun); + break; + case GGML_OP_CONV_TRANSPOSE_2D: + ggml_vk_conv_transpose_2d(ctx, compute_ctx, src0, src1, node, dryrun); + break; case GGML_OP_CONV_2D_DW: ggml_vk_conv_2d_dw(ctx, compute_ctx, src0, src1, node, dryrun); @@ -9910,6 +11282,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_OPT_STEP_ADAMW: ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun); + break; + + case GGML_OP_OPT_STEP_SGD: + ggml_vk_opt_step_sgd(ctx, compute_ctx, src0, src1, src2, node, dryrun); + break; default: return false; @@ -9971,6 +11348,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * case GGML_OP_UPSCALE: case GGML_OP_SCALE: case GGML_OP_SQR: + case GGML_OP_SQRT: case GGML_OP_SIN: case GGML_OP_COS: case GGML_OP_CLAMP: @@ -9999,13 +11377,16 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * case GGML_OP_ARGSORT: case GGML_OP_SUM: case GGML_OP_SUM_ROWS: + case GGML_OP_MEAN: case GGML_OP_ARGMAX: case GGML_OP_COUNT_EQUAL: case GGML_OP_IM2COL: + case GGML_OP_IM2COL_3D: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_CONV_TRANSPOSE_1D: case GGML_OP_POOL_2D: case GGML_OP_CONV_2D: + case GGML_OP_CONV_TRANSPOSE_2D: case GGML_OP_CONV_2D_DW: case GGML_OP_RWKV_WKV6: case GGML_OP_RWKV_WKV7: @@ -10013,11 +11394,12 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * case GGML_OP_REPEAT: case GGML_OP_REPEAT_BACK: case GGML_OP_OPT_STEP_ADAMW: + case GGML_OP_OPT_STEP_SGD: buf = tensor->buffer; - break; case GGML_OP_UNARY: switch (ggml_get_unary_op(tensor)) { + case GGML_UNARY_OP_EXP: case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_GELU: case GGML_UNARY_OP_GELU_ERF: @@ -10025,6 +11407,8 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_TANH: case GGML_UNARY_OP_SIGMOID: + case GGML_UNARY_OP_HARDSIGMOID: + case GGML_UNARY_OP_HARDSWISH: buf = tensor->buffer; break; default: @@ -10080,6 +11464,10 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * memcpy(cpy.dst, cpy.src, cpy.n); } + for (auto& mset : subctx->memsets) { + memset(mset.dst, mset.val, mset.n); + } + if (almost_ready && !ctx->almost_ready_fence_pending && !use_fence) { ggml_vk_submit(subctx, ctx->almost_ready_fence); ctx->almost_ready_fence_pending = true; @@ -10102,6 +11490,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * } subctx->in_memcpys.clear(); subctx->out_memcpys.clear(); + subctx->memsets.clear(); } return true; @@ -10114,6 +11503,11 @@ static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) { ggml_vk_pool_free(ctx, buffer); } ctx->gc.temp_buffers.clear(); + ctx->prealloc_y_last_pipeline_used = {}; + + ctx->unsynced_nodes_written.clear(); + ctx->unsynced_nodes_read.clear(); + ctx->prealloc_x_need_sync = ctx->prealloc_y_need_sync = ctx->prealloc_split_k_need_sync = false; ggml_vk_command_pool_cleanup(ctx->device, ctx->compute_cmd_pool); ggml_vk_command_pool_cleanup(ctx->device, ctx->transfer_cmd_pool); @@ -10149,6 +11543,7 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) { ggml_vk_destroy_buffer(ctx->prealloc_x); ggml_vk_destroy_buffer(ctx->prealloc_y); ggml_vk_destroy_buffer(ctx->prealloc_split_k); + ctx->prealloc_y_last_pipeline_used = nullptr; for (auto& buffer : ctx->buffer_pool) { ggml_vk_destroy_buffer(buffer); @@ -10594,6 +11989,58 @@ static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, st return true; } +static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx) { + + const ggml_tensor *first_node = cgraph->nodes[node_idx]; + if (first_node->op != GGML_OP_ADD) { + return 0; + } + + if (!ctx->device->multi_add) { + return 0; + } + + int32_t num_adds = 1; + while (node_idx + num_adds < cgraph->n_nodes && + cgraph->nodes[node_idx + num_adds]->op == GGML_OP_ADD && + num_adds < MAX_FUSED_ADDS) { + num_adds++; + } + + // The shader currently requires same shapes (but different strides are allowed), + // everything f32, and no misalignment + for (int32_t i = 0; i < num_adds; ++i) { + const ggml_tensor *next_node = cgraph->nodes[node_idx + i]; + if (!ggml_are_same_shape(first_node, next_node->src[0]) || + !ggml_are_same_shape(first_node, next_node->src[1]) || + next_node->type != GGML_TYPE_F32 || + next_node->src[0]->type != GGML_TYPE_F32 || + next_node->src[1]->type != GGML_TYPE_F32 || + get_misalign_bytes(ctx, next_node) || + get_misalign_bytes(ctx, next_node->src[0]) || + get_misalign_bytes(ctx, next_node->src[1])) { + num_adds = i; + } + } + + // Verify we can fuse these + ggml_op adds[MAX_FUSED_ADDS]; + for (int32_t i = 0; i < num_adds; ++i) { + adds[i] = GGML_OP_ADD; + } + + // decrease num_adds if they can't all be fused + while (num_adds > 1 && !ggml_can_fuse(cgraph, node_idx, adds, num_adds)) { + num_adds--; + } + + // a single add is not "fused", so just return zero + if (num_adds == 1) { + return 0; + } + return num_adds; +} + static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)"); ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; @@ -10605,18 +12052,27 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg vk_instance.pfn_vkQueueBeginDebugUtilsLabelEXT(ctx->device->compute_queue.queue, reinterpret_cast(&dul)); } + ctx->prealloc_size_add_rms_partials = 0; + ctx->prealloc_size_add_rms_partials_offset = 0; + ctx->do_add_rms_partials = false; + uint64_t total_mat_mul_bytes = 0; for (int i = 0; i < cgraph->n_nodes; i++) { - if (!ctx->device->disable_fusion && ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { - ctx->num_additional_fused_ops = 1; + if (!ctx->device->disable_fusion) { + uint32_t num_adds = ggml_vk_fuse_multi_add(ctx, cgraph, i); + if (num_adds) { + ctx->num_additional_fused_ops = num_adds - 1; + } else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { + ctx->num_additional_fused_ops = 1; + } } ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false); if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) { total_mat_mul_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]); - } else if (cgraph->nodes[i]->op == GGML_OP_CONV_2D) { + } else if (cgraph->nodes[i]->op == GGML_OP_CONV_2D || cgraph->nodes[i]->op == GGML_OP_CONV_TRANSPOSE_2D) { // Return CRSxNPQxsizeof(*) to account as many bytes as mul_mat has in im2col->mul_mat mode. auto CRS_size = - cgraph->nodes[i]->src[0]->ne[0] * cgraph->nodes[i]->src[0]->ne[1] * cgraph->nodes[i]->src[0]->ne[2]; + cgraph->nodes[i]->src[0]->ne[0] * cgraph->nodes[i]->src[0]->ne[1] * cgraph->nodes[i]->src[1]->ne[2]; auto NPQ_size = cgraph->nodes[i]->ne[0] * cgraph->nodes[i]->ne[1] * cgraph->nodes[i]->ne[3]; total_mat_mul_bytes += NPQ_size * CRS_size * ggml_type_size(cgraph->nodes[i]->type); } @@ -10665,6 +12121,22 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, 0); } + ctx->prealloc_y_last_pipeline_used = nullptr; + ctx->prealloc_y_last_tensor_used = nullptr; + + if (ctx->prealloc_size_add_rms_partials) { + if (ctx->compute_ctx.expired()) { + compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); + ctx->compute_ctx = compute_ctx; + ggml_vk_ctx_begin(ctx->device, compute_ctx); + } else { + compute_ctx = ctx->compute_ctx.lock(); + } + // initialize partial sums to zero. + ggml_vk_buffer_memset_async(compute_ctx, ctx->prealloc_add_rms_partials, 0, 0, ctx->prealloc_size_add_rms_partials); + ggml_vk_sync_buffers(ctx, compute_ctx); + } + // Submit after enough work has accumulated, to overlap CPU cmdbuffer generation with GPU execution. // Estimate the amount of matmul work by looking at the weight matrix size, and submit every 100MB // (and scaled down based on model size, so smaller models submit earlier). @@ -10683,8 +12155,13 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]); } - if (!ctx->device->disable_fusion && ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { - ctx->num_additional_fused_ops = 1; + if (!ctx->device->disable_fusion) { + uint32_t num_adds = ggml_vk_fuse_multi_add(ctx, cgraph, i); + if (num_adds) { + ctx->num_additional_fused_ops = num_adds - 1; + } else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { + ctx->num_additional_fused_ops = 1; + } } // Signal the almost_ready fence when the graph is mostly complete (< 20% remaining) @@ -10762,6 +12239,131 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg UNUSED(backend); } +// Sort the graph for improved parallelism. +static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * graph) +{ + VK_LOG_DEBUG("ggml_vk_graph_optimize(" << graph->n_nodes << " nodes)"); + ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; + + if (ctx->device->disable_graph_optimize) { + return; + } + + auto const &is_empty = [](ggml_tensor * node) -> bool { + return node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE; + }; + + auto const &is_src_of = [](const ggml_tensor *dst, const ggml_tensor *src) -> bool { + for (uint32_t s = 0; s < GGML_MAX_SRC; ++s) { + if (dst->src[s] == src) { + return true; + } + } + // implicit dependency if they view the same tensor + const ggml_tensor *dst2 = dst->view_src ? dst->view_src : dst; + const ggml_tensor *src2 = src->view_src ? src->view_src : src; + if (dst2 == src2) { + return true; + } + return false; + }; + + // This function tries to reorder the graph to allow nodes to run in parallel. + // This helps with small batches, but for large batches its a slowdown, probably + // due to cache contention. So only reorder if the majority of nodes have few rows. + int num_small_nodes = 0; + int num_counted_nodes = 0; + for (int i = 0; i < graph->n_nodes; ++i) { + if (!is_empty(graph->nodes[i]) && + graph->nodes[i]->op != GGML_OP_SET_ROWS) { + if (ggml_nrows(graph->nodes[i]) <= 8) { + num_small_nodes++; + } + num_counted_nodes++; + } + } + if (num_small_nodes < num_counted_nodes / 2) { + return; + } + + std::vector new_order; + std::vector used(graph->n_nodes, false); + int first_unused = 0; + while (first_unused < graph->n_nodes) { + std::vector current_set; + + // First, grab the next unused node. + current_set.push_back(first_unused); + + // Loop through the next N nodes. Grab any that don't depend on other nodes that + // haven't already been run. Nodes that have already been run have used[i] set + // to true. Allow nodes that depend on the previous node if it's a fusion pattern + // that we support (e.g. RMS_NORM + MUL). + // This first pass only grabs "real" (non-view nodes). Second pass grabs view nodes. + // The goal is to not interleave real and view nodes in a way that breaks fusion. + const int NUM_TO_CHECK = 20; + for (int j = first_unused+1; j < std::min(first_unused + NUM_TO_CHECK, graph->n_nodes); ++j) { + if (used[j]) { + continue; + } + if (is_empty(graph->nodes[j])) { + continue; + } + bool ok = true; + for (int c = first_unused; c < j; ++c) { + if (!used[c] && + is_src_of(graph->nodes[j], graph->nodes[c]) && + !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_RMS_NORM && graph->nodes[j]->op == GGML_OP_MUL)) { + ok = false; + break; + } + } + if (ok) { + current_set.push_back(j); + } + } + // Second pass grabs view nodes. + // Skip this if it would break a fusion optimization (don't split up add->rms_norm or add->add). + if (graph->nodes[current_set.back()]->op != GGML_OP_ADD) { + for (int j = first_unused+1; j < std::min(first_unused + NUM_TO_CHECK, graph->n_nodes); ++j) { + if (used[j]) { + continue; + } + if (!is_empty(graph->nodes[j])) { + continue; + } + bool ok = true; + for (int c = first_unused; c < j; ++c) { + bool c_in_current_set = std::find(current_set.begin(), current_set.end(), c) != current_set.end(); + // skip views whose srcs haven't been processed. + if (!used[c] && + is_src_of(graph->nodes[j], graph->nodes[c]) && + !c_in_current_set) { + ok = false; + break; + } + } + if (ok) { + current_set.push_back(j); + } + } + } + + // Push the current set into new_order + for (auto c : current_set) { + new_order.push_back(graph->nodes[c]); + used[c] = true; + } + while (first_unused < graph->n_nodes && used[first_unused]) { + first_unused++; + } + } + // Replace the graph with the new order. + for (int i = 0; i < graph->n_nodes; ++i) { + graph->nodes[i] = new_order[i]; + } +} + // TODO: enable async and synchronize static ggml_backend_i ggml_backend_vk_interface = { /* .get_name = */ ggml_backend_vk_name, @@ -10777,6 +12379,7 @@ static ggml_backend_i ggml_backend_vk_interface = { /* .graph_compute = */ ggml_backend_vk_graph_compute, /* .event_record = */ NULL, /* .event_wait = */ NULL, + /* .graph_optimize = */ ggml_vk_graph_optimize, }; static ggml_guid_t ggml_backend_vk_guid() { @@ -10820,100 +12423,96 @@ std::string ggml_backend_vk_get_device_id(int device) { return ggml_vk_get_device_id(dev_idx); } +void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total) { + GGML_ASSERT(device < (int) vk_instance.device_indices.size()); + GGML_ASSERT(device < (int) vk_instance.device_supports_membudget.size()); + + vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device]]; + vk::PhysicalDeviceMemoryBudgetPropertiesEXT budgetprops; + vk::PhysicalDeviceMemoryProperties2 memprops = {}; + bool membudget_supported = vk_instance.device_supports_membudget[device]; + + if (membudget_supported) { + memprops.pNext = &budgetprops; + } + vkdev.getMemoryProperties2(&memprops); + + for (uint32_t i = 0; i < memprops.memoryProperties.memoryHeapCount; ++i) { + const vk::MemoryHeap & heap = memprops.memoryProperties.memoryHeaps[i]; + + if (heap.flags & vk::MemoryHeapFlagBits::eDeviceLocal) { + *total = heap.size; + + if (membudget_supported && i < budgetprops.heapUsage.size()) { + *free = budgetprops.heapBudget[i] - budgetprops.heapUsage[i]; + } else { + *free = heap.size; + } + break; + } + } +} + +static vk::PhysicalDeviceType ggml_backend_vk_get_device_type(int device_idx) { + GGML_ASSERT(device_idx >= 0 && device_idx < (int) vk_instance.device_indices.size()); + + vk::PhysicalDevice device = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device_idx]]; + + vk::PhysicalDeviceProperties2 props = {}; + device.getProperties2(&props); + + return props.properties.deviceType; +} + +static std::string ggml_backend_vk_get_device_pci_id(int device_idx) { + GGML_ASSERT(device_idx >= 0 && device_idx < (int) vk_instance.device_indices.size()); + + vk::PhysicalDevice device = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device_idx]]; + + const std::vector ext_props = device.enumerateDeviceExtensionProperties(); + + bool ext_support = false; + + for (const auto& properties : ext_props) { + if (strcmp("VK_EXT_pci_bus_info", properties.extensionName) == 0) { + ext_support = true; + break; + } + } + + if (!ext_support) { + return ""; + } + + vk::PhysicalDeviceProperties2 props = {}; + vk::PhysicalDevicePCIBusInfoPropertiesEXT pci_bus_info = {}; + + props.pNext = &pci_bus_info; + + device.getProperties2(&props); + + const uint32_t pci_domain = pci_bus_info.pciDomain; + const uint32_t pci_bus = pci_bus_info.pciBus; + const uint32_t pci_device = pci_bus_info.pciDevice; + const uint8_t pci_function = (uint8_t) pci_bus_info.pciFunction; // pci function is between 0 and 7, prevent printf overflow warning + + char pci_bus_id[16] = {}; + snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.%x", pci_domain, pci_bus, pci_device, pci_function); + + return std::string(pci_bus_id); +} + ////////////////////////// struct ggml_backend_vk_device_context { size_t device; std::string name; std::string description; + bool is_integrated_gpu; + std::string pci_bus_id; std::string id; - std::string uuid; - std::string dev_idx; - int major; - int minor; - int driver_major; - int driver_minor; - int integrated; - int pci_bus_id; - int pci_device_id; - int pci_domain_id; }; -void ggml_backend_vk_get_device_memory(ggml_backend_vk_device_context *ctx, size_t * free, size_t * total) { - GGML_ASSERT(ctx->device < (int) vk_instance.device_indices.size()); - - vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[ctx->device]]; - - vk::PhysicalDeviceMemoryProperties memprops = vkdev.getMemoryProperties(); - vk::PhysicalDeviceProperties2 props2; - vkdev.getProperties2(&props2); - - // Use vendor specific management libraries for best VRAM reporting if available - switch (props2.properties.vendorID) { - case VK_VENDOR_ID_AMD: - if (ggml_hip_mgmt_init() == 0) { - int status = ggml_hip_get_device_memory(ctx->pci_bus_id, ctx->pci_device_id, free, total); - if (status == 0) { - GGML_LOG_DEBUG("%s utilizing ADLX memory reporting free: %zu total: %zu\n", __func__, *free, *total); - ggml_hip_mgmt_release(); - return; - } - ggml_hip_mgmt_release(); - } - break; - case VK_VENDOR_ID_NVIDIA: - if (ggml_nvml_init() == 0) { - int status = ggml_nvml_get_device_memory(ctx->uuid.c_str(), free, total); - if (status == 0) { - GGML_LOG_DEBUG("%s utilizing NVML memory reporting free: %zu total: %zu\n", __func__, *free, *total); - ggml_nvml_release(); - return; - } - ggml_nvml_release(); - } - break; - } - // else fallback to memory budget if supported - - *total = 0; - *free = 0; - vk::PhysicalDeviceMemoryBudgetPropertiesEXT mem_budget_props; - vk::PhysicalDeviceMemoryProperties2 memprops2; - memprops2.pNext = &mem_budget_props; - vkdev.getMemoryProperties2(&memprops2); - for (int i = 0; i < memprops2.memoryProperties.memoryHeapCount; i++) { - if (memprops2.memoryProperties.memoryHeaps[i].flags & vk::MemoryHeapFlagBits::eDeviceLocal) { - *total += memprops2.memoryProperties.memoryHeaps[i].size; - } else if (ctx->integrated) { - // Include shared memory on iGPUs - *total += memprops2.memoryProperties.memoryHeaps[i].size; - } - } - for (int i = 0; i < memprops2.memoryProperties.memoryHeapCount; i++) { - if (memprops2.memoryProperties.memoryHeaps[i].flags & vk::MemoryHeapFlagBits::eDeviceLocal) { - *free += mem_budget_props.heapBudget[i]; - } else if (ctx->integrated) { - *free += mem_budget_props.heapBudget[i]; - } - } - if (*total > 0 && *free > 0) { - return; - } else if (*total > 0) { - *free = *total; - return; - } - - // else just report the physical memory - for (const vk::MemoryHeap& heap : memprops2.memoryProperties.memoryHeaps) { - if (heap.flags & vk::MemoryHeapFlagBits::eDeviceLocal) { - *total = heap.size; - *free = heap.size; - break; - } - } -} - - static const char * ggml_backend_vk_device_get_name(ggml_backend_dev_t dev) { ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; return ctx->name.c_str(); @@ -10931,7 +12530,7 @@ static const char * ggml_backend_vk_device_get_id(ggml_backend_dev_t dev) { static void ggml_backend_vk_device_get_memory(ggml_backend_dev_t device, size_t * free, size_t * total) { ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)device->context; - ggml_backend_vk_get_device_memory(ctx, free, total); + ggml_backend_vk_get_device_memory(ctx->device, free, total); } static ggml_backend_buffer_type_t ggml_backend_vk_device_get_buffer_type(ggml_backend_dev_t dev) { @@ -10945,16 +12544,19 @@ static ggml_backend_buffer_type_t ggml_backend_vk_device_get_host_buffer_type(gg } static enum ggml_backend_dev_type ggml_backend_vk_device_get_type(ggml_backend_dev_t dev) { - UNUSED(dev); - return GGML_BACKEND_DEVICE_TYPE_GPU; + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + + return ctx->is_integrated_gpu ? GGML_BACKEND_DEVICE_TYPE_IGPU : GGML_BACKEND_DEVICE_TYPE_GPU; } -#define GGML_VULKAN_NAME "VULKAN" static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + props->name = ggml_backend_vk_device_get_name(dev); props->description = ggml_backend_vk_device_get_description(dev); - // props->id = ggml_backend_vk_device_get_id(dev); + props->id = ggml_backend_vk_device_get_id(dev); props->type = ggml_backend_vk_device_get_type(dev); + props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str(); ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total); props->caps = { /* .async = */ false, @@ -10962,19 +12564,6 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml /* .buffer_from_host_ptr = */ false, /* .events = */ false, }; - - ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; - // Use the unfiltered ID so round-trip through env var works - props->id = ctx->dev_idx.c_str(); - props->compute_major = ctx->major; - props->compute_minor = ctx->minor; - props->driver_major = ctx->driver_major; - props->driver_minor = ctx->driver_minor; - props->integrated = ctx->integrated; - props->pci_bus_id = ctx->pci_bus_id; - props->pci_device_id = ctx->pci_device_id; - props->pci_domain_id = ctx->pci_domain_id; - props->library = GGML_VULKAN_NAME; } static ggml_backend_t ggml_backend_vk_device_init(ggml_backend_dev_t dev, const char * params) { @@ -10987,6 +12576,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm switch (op->op) { case GGML_OP_UNARY: switch (ggml_get_unary_op(op)) { + case GGML_UNARY_OP_EXP: case GGML_UNARY_OP_GELU: case GGML_UNARY_OP_GELU_ERF: case GGML_UNARY_OP_GELU_QUICK: @@ -10994,6 +12584,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_TANH: case GGML_UNARY_OP_SIGMOID: + case GGML_UNARY_OP_HARDSIGMOID: + case GGML_UNARY_OP_HARDSWISH: return ggml_is_contiguous(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && @@ -11001,7 +12593,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm default: return false; } - break; case GGML_OP_GLU: switch (ggml_get_glu_op(op)) { case GGML_GLU_OP_GEGLU: @@ -11017,7 +12608,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm default: return false; } - break; case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: { @@ -11081,14 +12671,15 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm } return true; - } break; + } case GGML_OP_FLASH_ATTN_EXT: { ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; auto device = ggml_vk_get_device(ctx->device); bool coopmat2 = device->coopmat2; - FaHeadSizes head_sizes = fa_get_head_sizes(op->src[1]->ne[0], op->src[2]->ne[0]); - if (head_sizes == FA_HEAD_SIZE_UNSUPPORTED) { + uint32_t HSK = op->src[1]->ne[0]; + uint32_t HSV = op->src[2]->ne[0]; + if ((HSK % 8) != 0 || (HSV % 8) != 0) { return false; } if (op->src[4] && op->src[4]->type != GGML_TYPE_F32) { @@ -11157,6 +12748,11 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_M: case GGML_TYPE_IQ2_XXS: @@ -11171,7 +12767,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm default: return false; } - } break; + } case GGML_OP_SET_ROWS: { switch (op->type) { @@ -11188,7 +12784,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm default: return false; } - } break; + } case GGML_OP_CONT: case GGML_OP_CPY: case GGML_OP_DUP: @@ -11231,6 +12827,13 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm return true; } + if ( + (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_I32) || + (src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_F32) + ) { + return true; + } + // We can handle copying from a type to the same type if it's // contiguous (memcpy). We use f16 or f32 shaders to do the copy, // so the type/block size must be a multiple of 4. @@ -11240,7 +12843,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm return true; } return false; - } break; + } case GGML_OP_REPEAT: return ggml_type_size(op->type) == sizeof(float) && ggml_type_size(op->src[0]->type) == sizeof(float); case GGML_OP_REPEAT_BACK: @@ -11271,10 +12874,16 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_SILU_BACK: case GGML_OP_RMS_NORM_BACK: case GGML_OP_SQR: + case GGML_OP_SQRT: case GGML_OP_SIN: case GGML_OP_COS: case GGML_OP_CLAMP: + case GGML_OP_LEAKY_RELU: + case GGML_OP_OPT_STEP_ADAMW: + case GGML_OP_OPT_STEP_SGD: return op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_ARGSORT: + return op->ne[0] <= max_argsort_cols; case GGML_OP_UPSCALE: case GGML_OP_ACC: case GGML_OP_CONCAT: @@ -11284,35 +12893,40 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX_BACK: - case GGML_OP_ARGSORT: + return true; case GGML_OP_SUM: case GGML_OP_SUM_ROWS: + case GGML_OP_MEAN: + return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]); case GGML_OP_ARGMAX: case GGML_OP_COUNT_EQUAL: case GGML_OP_IM2COL: + case GGML_OP_IM2COL_3D: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_CONV_2D_DW: case GGML_OP_POOL_2D: case GGML_OP_RWKV_WKV6: case GGML_OP_RWKV_WKV7: - case GGML_OP_LEAKY_RELU: - case GGML_OP_OPT_STEP_ADAMW: return true; case GGML_OP_CONV_TRANSPOSE_1D: return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32; case GGML_OP_CONV_2D: + case GGML_OP_CONV_TRANSPOSE_2D: { // Op is disabled for Apple because it segfaults at pipeline create time on MoltenVK ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; const vk_device& device = ggml_vk_get_device(ctx->device); - bool is_Apple = ggml_vk_get_device(ctx->device)->vendor_id == VK_VENDOR_ID_APPLE; + if (op->op == GGML_OP_CONV_TRANSPOSE_2D && + device->properties.limits.maxPushConstantsSize < sizeof(vk_op_conv_transpose_2d_push_constants)) { + return false; + } // Channel-contiguous format is not supported yet. return ((op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]) && - ggml_is_contiguous(op)) && !is_Apple; + ggml_is_contiguous(op)); } default: return false; @@ -11378,8 +12992,6 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, static std::mutex mutex; std::lock_guard lock(mutex); if (!initialized) { - std::vector vk_devices = vk_instance.instance.enumeratePhysicalDevices(); - for (int i = 0; i < ggml_backend_vk_get_device_count(); i++) { ggml_backend_vk_device_context * ctx = new ggml_backend_vk_device_context; char desc[256]; @@ -11387,51 +12999,14 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, ctx->device = i; ctx->name = GGML_VK_NAME + std::to_string(i); ctx->description = desc; + ctx->is_integrated_gpu = ggml_backend_vk_get_device_type(i) == vk::PhysicalDeviceType::eIntegratedGpu; + ctx->pci_bus_id = ggml_backend_vk_get_device_pci_id(i); ctx->id = ggml_backend_vk_get_device_id(i); devices.push_back(new ggml_backend_device { /* .iface = */ ggml_backend_vk_device_i, /* .reg = */ reg, /* .context = */ ctx, }); - - // Gather additional information about the device - int dev_idx = vk_instance.device_indices[i]; - ctx->dev_idx = std::to_string(dev_idx); - vk::PhysicalDeviceProperties props1; - vk_devices[dev_idx].getProperties(&props1); - vk::PhysicalDeviceProperties2 props2; - vk::PhysicalDeviceIDProperties device_id_props; - vk::PhysicalDevicePCIBusInfoPropertiesEXT pci_bus_props; - vk::PhysicalDeviceDriverProperties driver_props; - props2.pNext = &device_id_props; - device_id_props.pNext = &pci_bus_props; - pci_bus_props.pNext = &driver_props; - vk_devices[dev_idx].getProperties2(&props2); - std::ostringstream oss; - oss << std::hex << std::setfill('0'); - oss << "GPU-"; - int byteIdx = 0; - for (int i = 0; i < 16; ++i, ++byteIdx) { - oss << std::setw(2) << static_cast(device_id_props.deviceUUID[i]); - if (byteIdx == 3 || byteIdx == 5 || byteIdx == 7 || byteIdx == 9) { - oss << '-'; - } - } - ctx->uuid = oss.str(); - ctx->pci_bus_id = pci_bus_props.pciBus; - ctx->pci_device_id = pci_bus_props.pciDevice; - ctx->pci_domain_id = pci_bus_props.pciDomain; - ctx->id = std::to_string(i); - if (props1.deviceType == vk::PhysicalDeviceType::eIntegratedGpu) { - ctx->integrated = 1; - } else { - ctx->integrated = 0; - } - ctx->major = 0; - ctx->minor = 0; - // TODO regex parse driver_props.driverInfo for a X.Y or X.Y.Z version string - ctx->driver_major = 0; - ctx->driver_minor = 0; } initialized = true; } @@ -11460,39 +13035,43 @@ ggml_backend_reg_t ggml_backend_vk_reg() { } catch (const vk::SystemError& e) { VK_LOG_DEBUG("ggml_backend_vk_reg() -> Error: System error: " << e.what()); return nullptr; + } catch (const std::exception &e) { + VK_LOG_DEBUG("ggml_backend_vk_reg() -> Error: " << e.what()); + return nullptr; + } catch (...) { + VK_LOG_DEBUG("ggml_backend_vk_reg() -> Error: unknown exception during Vulkan init"); + return nullptr; } } // Extension availability -static bool ggml_vk_instance_validation_ext_available(const std::vector& instance_extensions) { +static bool ggml_vk_instance_validation_ext_available() { #ifdef GGML_VULKAN_VALIDATE - bool portability_enumeration_ext = false; - // Check for portability enumeration extension for MoltenVK support - for (const auto& properties : instance_extensions) { - if (strcmp("VK_KHR_portability_enumeration", properties.extensionName) == 0) { - return true; + // Check if validation layer provides the extension + const std::string layer_name = "VK_LAYER_KHRONOS_validation"; + for (const auto& layer : vk::enumerateInstanceLayerProperties()) { + if (layer_name == layer.layerName.data()) { + for (const auto& ext : vk::enumerateInstanceExtensionProperties(layer_name)) { + if (strcmp("VK_EXT_validation_features", ext.extensionName.data()) == 0) { + return true; + } + } } } - if (!portability_enumeration_ext) { - std::cerr << "ggml_vulkan: WARNING: Instance extension VK_KHR_portability_enumeration not found." << std::endl; - } + + std::cerr << "ggml_vulkan: WARNING: Validation layer or layer extension VK_EXT_validation_features not found." << std::endl; #endif return false; - - UNUSED(instance_extensions); } static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector& instance_extensions) { #ifdef __APPLE__ - bool portability_enumeration_ext = false; // Check for portability enumeration extension for MoltenVK support for (const auto& properties : instance_extensions) { if (strcmp("VK_KHR_portability_enumeration", properties.extensionName) == 0) { return true; } } - if (!portability_enumeration_ext) { - std::cerr << "ggml_vulkan: WARNING: Instance extension VK_KHR_portability_enumeration not found." << std::endl; - } + std::cerr << "ggml_vulkan: WARNING: Instance extension VK_KHR_portability_enumeration not found." << std::endl; #endif return false; @@ -11515,6 +13094,20 @@ static bool ggml_vk_instance_debug_utils_ext_available( UNUSED(instance_extensions); } +static bool ggml_vk_device_is_supported(const vk::PhysicalDevice & vkdev) { + VkPhysicalDeviceFeatures2 device_features2; + device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2; + + VkPhysicalDeviceVulkan11Features vk11_features; + vk11_features.pNext = nullptr; + vk11_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_1_FEATURES; + device_features2.pNext = &vk11_features; + + vkGetPhysicalDeviceFeatures2(vkdev, &device_features2); + + return vk11_features.storageBuffer16BitAccess; +} + static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch) { switch (props.vendorID) { case VK_VENDOR_ID_INTEL: @@ -11749,12 +13342,14 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * } else if (tensor->op == GGML_OP_CONCAT) { tensor_clone = ggml_concat(ggml_ctx, src_clone[0], src_clone[1], *(int *)tensor->op_params); } else if (tensor->op == GGML_OP_UPSCALE) { - tensor_clone = ggml_upscale_ext(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], (ggml_scale_mode) tensor->op_params[0]); + tensor_clone = ggml_interpolate(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], (ggml_scale_mode) tensor->op_params[0]); } else if (tensor->op == GGML_OP_SCALE) { const float * params = (const float *)tensor->op_params; tensor_clone = ggml_scale_bias(ggml_ctx, src_clone[0], params[0], params[1]); } else if (tensor->op == GGML_OP_SQR) { tensor_clone = ggml_sqr(ggml_ctx, src_clone[0]); + } else if (tensor->op == GGML_OP_SQRT) { + tensor_clone = ggml_sqrt(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_SIN) { tensor_clone = ggml_sin(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_COS) { @@ -11763,7 +13358,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * const float * params = (const float *)tensor->op_params; tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], params[0], params[1]); } else if (tensor->op == GGML_OP_PAD) { - tensor_clone = ggml_pad(ggml_ctx, src_clone[0], tensor->ne[0] - src_clone[0]->ne[0], tensor->ne[1] - src_clone[0]->ne[1], tensor->ne[2] - src_clone[0]->ne[2], tensor->ne[3] - src_clone[0]->ne[3]); + tensor_clone = ggml_pad_ext(ggml_ctx, src_clone[0], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3], + tensor->op_params[4], tensor->op_params[5], tensor->op_params[6], tensor->op_params[7]); } else if (tensor->op == GGML_OP_REPEAT) { tensor_clone = ggml_repeat(ggml_ctx, src_clone[0], tensor); } else if (tensor->op == GGML_OP_REPEAT_BACK) { @@ -11825,6 +13421,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * } } else if (tensor->op == GGML_OP_UNARY) { switch (ggml_get_unary_op(tensor)) { + case GGML_UNARY_OP_EXP: + tensor_clone = ggml_exp(ggml_ctx, src_clone[0]); + break; case GGML_UNARY_OP_SILU: tensor_clone = ggml_silu(ggml_ctx, src_clone[0]); break; @@ -11846,6 +13445,12 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * case GGML_UNARY_OP_SIGMOID: tensor_clone = ggml_sigmoid(ggml_ctx, src_clone[0]); break; + case GGML_UNARY_OP_HARDSIGMOID: + tensor_clone = ggml_hardsigmoid(ggml_ctx, src_clone[0]); + break; + case GGML_UNARY_OP_HARDSWISH: + tensor_clone = ggml_hardswish(ggml_ctx, src_clone[0]); + break; default: std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl; GGML_ABORT("fatal error"); @@ -11856,6 +13461,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * } else { tensor_clone = ggml_glu_split(ggml_ctx, src_clone[0], src_clone[1], (ggml_glu_op) tensor->op_params[0]); } + ggml_set_op_params_i32(tensor_clone, 2, ggml_get_op_params_i32(tensor, 2)); + ggml_set_op_params_i32(tensor_clone, 3, ggml_get_op_params_i32(tensor, 3)); } else if (tensor->op == GGML_OP_CPY || tensor->op == GGML_OP_DUP) { if (src1 == nullptr) { tensor_clone = ggml_dup(ggml_ctx, src_clone[0]); @@ -11882,6 +13489,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * tensor_clone = ggml_sum(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_SUM_ROWS) { tensor_clone = ggml_sum_rows(ggml_ctx, src_clone[0]); + } else if (tensor->op == GGML_OP_MEAN) { + tensor_clone = ggml_mean(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_ARGMAX) { tensor_clone = ggml_argmax(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_COUNT_EQUAL) { @@ -11896,6 +13505,19 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * const bool is_2D = tensor->op_params[6] == 1; tensor_clone = ggml_im2col(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1, is_2D, tensor->type); + } else if (tensor->op == GGML_OP_IM2COL_3D) { + const int32_t s0 = tensor->op_params[0]; + const int32_t s1 = tensor->op_params[1]; + const int32_t s2 = tensor->op_params[2]; + const int32_t p0 = tensor->op_params[3]; + const int32_t p1 = tensor->op_params[4]; + const int32_t p2 = tensor->op_params[5]; + const int32_t d0 = tensor->op_params[6]; + const int32_t d1 = tensor->op_params[7]; + const int32_t d2 = tensor->op_params[8]; + const int32_t IC = tensor->op_params[9]; + + tensor_clone = ggml_im2col_3d(ggml_ctx, src_clone[0], src_clone[1], IC, s0, s1, s2, p0, p1, p2, d0, d1, d2, tensor->type); } else if (tensor->op == GGML_OP_TIMESTEP_EMBEDDING) { const int32_t dim = tensor->op_params[0]; const int32_t max_period = tensor->op_params[1]; @@ -11923,6 +13545,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * const int32_t d0 = tensor->op_params[4]; const int32_t d1 = tensor->op_params[5]; tensor_clone = ggml_conv_2d(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1); + } else if (tensor->op == GGML_OP_CONV_TRANSPOSE_2D) { + const int32_t s = tensor->op_params[0]; + tensor_clone = ggml_conv_transpose_2d_p0(ggml_ctx, src_clone[0], src_clone[1], s); } else if (tensor->op == GGML_OP_LEAKY_RELU) { const float * op_params = (const float *)tensor->op_params; tensor_clone = ggml_leaky_relu(ggml_ctx, src_clone[0], op_params[0], false); @@ -11936,6 +13561,12 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * src_clone[0]->flags = src0->flags; tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], src_clone[4]); + } else if (tensor->op == GGML_OP_OPT_STEP_SGD) { + src_clone[0]->flags = src0->flags; + tensor_clone = ggml_opt_step_sgd(ggml_ctx, src_clone[0], src_clone[1], + src_clone[2]); + } else if (tensor->op == GGML_OP_ADD_ID) { + tensor_clone = ggml_add_id(ggml_ctx, src_clone[0], src_clone[1], src_clone[2]); } else { std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl; @@ -11973,11 +13604,9 @@ static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * if (tensor->op == GGML_OP_TRANSPOSE || tensor->op == GGML_OP_SET_ROWS) { return; } - bool fused_rms_norm_mul = false; if (ctx->num_additional_fused_ops == 1 && tensor->op == GGML_OP_RMS_NORM && cgraph->nodes[tensor_idx + 1]->op == GGML_OP_MUL) { - fused_rms_norm_mul = true; tensor = cgraph->nodes[tensor_idx + 1]; } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/add.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/add.comp index 2b4085c4f..00cf2dd62 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/add.comp @@ -1,20 +1,34 @@ #version 450 #extension GL_EXT_shader_16bit_storage : require +#if ADD_RMS +#extension GL_KHR_shader_subgroup_arithmetic : enable +#extension GL_KHR_shader_subgroup_basic : enable +#endif #include "types.comp" #include "generic_binary_head.comp" const uint num_threads = 256; +layout (binding = 3, std430) buffer PartialBuf {float partial_sums[];}; + layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; +#if ADD_RMS +// XXX TODO this could be sized based on number of subgroups, but that't not considered a constant +shared FLOAT_TYPE sumsh[num_threads]; +#endif + void main() { uint idx = get_idx(); + uint orig_idx = idx; // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation const uint num_iter = 2; + FLOAT_TYPE sum_sq = 0; + [[unroll]] for (uint i = 0; i < num_iter; ++i) { if (idx >= p.ne) { continue; @@ -22,8 +36,34 @@ void main() { uint i00, i01, i02, i03; get_indices(idx, i00, i01, i02, i03); - data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)])); + FLOAT_TYPE sum = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]); + sum_sq += sum*sum; + + data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(sum); idx += num_threads; } + +#if ADD_RMS + if (p.param3 != 0) { + // reduce the sum within each subgroup, then across subgroups + const uint NumSubgroups = num_threads / gl_SubgroupSize; + sum_sq = subgroupAdd(sum_sq); + if (gl_SubgroupInvocationID == 0) { + sumsh[gl_SubgroupID] = sum_sq; + } + barrier(); + [[unroll]] for (uint s = NumSubgroups / 2; s > 0; s >>= 1) { + if (gl_SubgroupID < s && gl_SubgroupInvocationID == 0) { + sum_sq += sumsh[gl_SubgroupID + s]; + sumsh[gl_SubgroupID] = sum_sq; + } + barrier(); + } + + if (gl_SubgroupID == 0 && gl_SubgroupInvocationID == 0) { + partial_sums[orig_idx / (num_iter * num_threads)] = sum_sq; + } + } +#endif } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp index eaf4da341..a1d4c240d 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp @@ -5,6 +5,8 @@ #extension GL_EXT_control_flow_attributes : enable +#define FLT_MAX 3.402823466e+38F + layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; @@ -19,19 +21,26 @@ void main() { const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; const uint col = gl_LocalInvocationID.x; - if (col >= p.KX) { + if (row >= p.KY) { return; } - A_TYPE amax = data_a[row*p.KX + col]; - tmp[col] = col; + + A_TYPE amax = -FLT_MAX; + uint acol = col; + + if (col < p.KX) { + amax = data_a[row*p.KX + col]; + } for (uint i = col + BLOCK_SIZE; i < p.KX; i += BLOCK_SIZE) { A_TYPE val = data_a[row*p.KX + i]; if (val > amax) { amax = val; - tmp[col] = i; + acol = i; } } + + tmp[col] = acol; tmpmax[col] = amax; barrier(); diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp index d4fa45b1e..dc53a401e 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp @@ -1,22 +1,24 @@ #version 450 +#extension GL_EXT_control_flow_attributes : enable #include "types.comp" -#define BLOCK_SIZE 1024 +layout(constant_id = 0) const int BLOCK_SIZE = 1024; +layout(constant_id = 1) const int BLOCK_SIZE_LOG2 = 10; #define ASC 0 -layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; layout (binding = 1) buffer D {int data_d[];}; layout (push_constant) uniform parameter { uint ncols; - uint ncols_pad; uint order; } p; shared int dst_row[BLOCK_SIZE]; +shared A_TYPE a_sh[BLOCK_SIZE]; void swap(uint idx0, uint idx1) { int tmp = dst_row[idx0]; @@ -24,7 +26,7 @@ void swap(uint idx0, uint idx1) { dst_row[idx1] = tmp; } -void main() { +void argsort(bool needs_bounds_check) { // bitonic sort const int col = int(gl_LocalInvocationID.x); const uint row = gl_WorkGroupID.y; @@ -32,38 +34,46 @@ void main() { const uint row_offset = row * p.ncols; // initialize indices - if (col < p.ncols_pad) { - dst_row[col] = col; - } + dst_row[col] = col; + a_sh[col] = data_a[row_offset + col]; barrier(); - for (uint k = 2; k <= p.ncols_pad; k *= 2) { - for (uint j = k / 2; j > 0; j /= 2) { - const uint ixj = col ^ j; - if (col < p.ncols_pad && ixj > col) { - if ((col & k) == 0) { - if (dst_row[col] >= p.ncols || - (dst_row[ixj] < p.ncols && (p.order == ASC ? - data_a[row_offset + dst_row[col]] > data_a[row_offset + dst_row[ixj]] : - data_a[row_offset + dst_row[col]] < data_a[row_offset + dst_row[ixj]])) - ) { - swap(col, ixj); - } - } else { - if (dst_row[ixj] >= p.ncols || - (dst_row[col] < p.ncols && (p.order == ASC ? - data_a[row_offset + dst_row[col]] < data_a[row_offset + dst_row[ixj]] : - data_a[row_offset + dst_row[col]] > data_a[row_offset + dst_row[ixj]])) - ) { - swap(col, ixj); - } - } + uint num_outer_loop_iters = BLOCK_SIZE_LOG2; + [[unroll]] for (uint k = 2, outer_idx = 0; outer_idx < num_outer_loop_iters; k *= 2, outer_idx++) { + uint num_inner_loop_iters = outer_idx + 1; + [[unroll]] for (uint j = k / 2, inner_idx = 0; inner_idx < num_inner_loop_iters; j /= 2, inner_idx++) { + const int ixj = int(col ^ j); + + int idx_0 = (col & k) == 0 ? col : ixj; + int idx_1 = (col & k) == 0 ? ixj : col; + + int sh_idx_0 = dst_row[idx_0]; + int sh_idx_1 = dst_row[idx_1]; + bool idx_0_oob = needs_bounds_check ? sh_idx_0 >= p.ncols : false; + bool idx_1_oob = needs_bounds_check ? sh_idx_1 >= p.ncols : false; + + if ((idx_0_oob || + (!idx_1_oob && a_sh[sh_idx_0] > a_sh[sh_idx_1])) && (ixj > col)) { + swap(idx_0, idx_1); } + barrier(); } } if (col < p.ncols) { - data_d[row_offset + col] = dst_row[col]; + if (p.order == ASC) { + data_d[row_offset + col] = dst_row[col]; + } else { + data_d[row_offset + p.ncols - col - 1] = dst_row[col]; + } + } +} + +void main() { + if (p.ncols == BLOCK_SIZE) { + argsort(false); + } else { + argsort(true); } } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp index 86bafba4a..44a64ddc8 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp @@ -16,7 +16,7 @@ // shape notation: [dim(N), ..., dim(0)] -- stride(dim(j)) >= stride(dim(i)) if i > j layout(binding = 0) readonly buffer A { A_TYPE knl_data[]; -}; // src0 - kernel: [KW, KH, Cin, Cout] +}; // src0 - kernel: [KW, KH, Cin, Cout] for conv_2d, [KW, KH, Cout, Cin] for conv_transposed_2d layout(binding = 1) readonly buffer B { B_TYPE src_data[]; @@ -66,6 +66,10 @@ layout(push_constant) uniform parameter { uint32_t KWKHmp; uint32_t KWKHL; uint32_t OWmp; uint32_t OWL; uint32_t OWOHmp; uint32_t OWOHL; +#ifdef TRANSPOSE + uint32_t s0mp; uint32_t s0L; + uint32_t s1mp; uint32_t s1L; +#endif } p; @@ -225,7 +229,11 @@ void main() { uint32_t B_ly = r_offset + Ar; uint32_t B_lx = Ac; uint32_t K_idx = B_idx_K * BS_K + B_ly; /* Global K_idx (row index of A)*/ +#ifdef TRANSPOSE + uint32_t knl_idx = min(KW_idx_a + KH_idx_a * p.nb01 + K_idx * p.nb02 + Cin_idx_a * p.nb03, K * CRS - 1); +#else uint32_t knl_idx = min(KW_idx_a + KH_idx_a * p.nb01 + Cin_idx_a * p.nb02 + K_idx * p.nb03, K * CRS - 1); +#endif float val = knl_data[knl_idx]; if (K_idx >= K || CRS_idx_a >= CRS) { val = 0.0; @@ -267,12 +275,24 @@ void main() { KW_idx_b = CRS_remainder - KH_idx_b * p.KW; #endif +#ifdef TRANSPOSE + uint32_t H_idx_x_s1 = OH_idx - KH_idx_b * p.d1 + p.p1; + uint32_t W_idx_x_s0 = OW_idx - KW_idx_b * p.d0 + p.p0; + uint32_t H_idx = fastdiv(H_idx_x_s1, p.s1mp, p.s1L); + uint32_t W_idx = fastdiv(W_idx_x_s0, p.s0mp, p.s0L); +#else uint32_t H_idx = OH_idx * p.s1 + KH_idx_b * p.d1 - p.p1; uint32_t W_idx = OW_idx * p.s0 + KW_idx_b * p.d0 - p.p0; +#endif uint32_t src_idx = min(max(W_idx + H_idx * p.nb11 + Cin_idx_b * p.nb12 + N_idx * p.nb13, 0), p.Cin * p.N * p.W * p.H - 1); float val = src_data[src_idx]; - if (CRS_idx_b >= CRS || NPQ_idx >= NPQ || H_idx < 0 || H_idx >= p.H || W_idx < 0 || W_idx >= p.W) { + if (CRS_idx_b >= CRS || NPQ_idx >= NPQ + || H_idx >= p.H || W_idx >= p.W // Lower bound checks aren't necessary. (idx >= 0x80000000 for such case) +#ifdef TRANSPOSE + || (H_idx_x_s1 - H_idx * p.s1 != 0) || (W_idx_x_s0 - W_idx * p.s0 != 0) +#endif + ) { val = 0.0; } Bsh[B_ly * Bsh_stride + B_lx] = SHMEM_TYPE(val); diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp index 27d6b7464..bc2e1f2df 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp @@ -15,8 +15,15 @@ layout (binding = 0) readonly buffer S {float data_s[];}; #if defined(SET_ROWS) #include "generic_binary_head.comp" -layout (binding = 1) readonly buffer C {uvec2 data_i[];}; +layout (binding = 1) readonly buffer C {B_TYPE data_i[];}; layout (binding = 2) writeonly buffer Q {A_TYPE data_q[];}; + +#if B_SIZE == 64 +#define DATA_I_SWIZZLE .x +#else +#define DATA_I_SWIZZLE +#endif + #else #include "generic_unary_head.comp" layout (binding = 1) writeonly buffer Q {A_TYPE data_q[];}; @@ -259,7 +266,7 @@ void main() { uint i11 = fastmod(i02, p.ne11); uint i10 = i01; - uint i1 = data_i[src1_idx(i10, i11, i12, 0) + get_boffset()].x; + uint i1 = data_i[src1_idx(i10, i11, i12, 0) + get_boffset()] DATA_I_SWIZZLE; uint src0_idx = src0_idx(i00, i01, i02, i03) + get_aoffset(); uint dst_idx = dst_idx(i00 / QUANT_K, i1, i02, i03) + get_doffset(); diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp index d3127fbd9..73fef4fa6 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp @@ -478,3 +478,139 @@ vec2 get_dm(uint ib, uint a_offset) { return vec2(float(data_a[a_offset + ib].d), float(data_a[a_offset + ib].m)); } #endif + +#if defined(DATA_A_Q2_K) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + iqs /= 2; + const uint qsi = (iqs / 64) * 32 + (iqs % 16) * 2; // 0,2,4..30 + const uint scalesi = iqs / 8; // 0..15 + const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6 + + const uvec2 qs = uvec2(data_a[a_offset + ib].qs[qsi], data_a[a_offset + ib].qs[qsi + 1]); + const uint scales = data_a[a_offset + ib].scales[scalesi]; + const vec2 d = vec2(data_a[a_offset + ib].d); + + return d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4); +} +vec2 get_dm(uint ib, uint a_offset) { + return vec2(1, 0); +} +#endif + +#if defined(DATA_A_Q3_K) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + iqs /= 2; + const uint n = iqs / 64; // 0,1 + const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..62 + const uint hmi = (iqs % 16) * 2; // 0,2,4..30 + const uint j = (iqs % 64) / 4; // 0..3 + const uint is = iqs / 8; // 0..15 + const uint halfsplit = ((iqs % 64) / 16); // 0,1,2,3 + const uint qsshift = halfsplit * 2; // 0,2,4,6 + const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128 + + const int8_t us = int8_t(((data_a[a_offset + ib].scales[is % 8] >> (4 * int(is / 8))) & 0xF) + | (((data_a[a_offset + ib].scales[8 + (is % 4)] >> (2 * int(is / 4))) & 3) << 4)); + const float dl = float(data_a[a_offset + ib].d) * float(us - 32); + + return vec2(dl * float(int8_t((data_a[a_offset + ib].qs[qsi ] >> qsshift) & 3) - (((data_a[a_offset + ib].hmask[hmi ] & m) != 0) ? 0 : 4)), + dl * float(int8_t((data_a[a_offset + ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[a_offset + ib].hmask[hmi + 1] & m) != 0) ? 0 : 4))); +} +vec2 get_dm(uint ib, uint a_offset) { + return vec2(1, 0); +} +#endif + +#if defined(DATA_A_Q4_K) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + iqs /= 2; + const uint n = iqs / 32; // 0,1,2,3 + const uint b = (iqs % 32) / 16; // 0,1 + const uint is = 2 * n + b; // 0..7 + const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126 + + const vec2 loadd = vec2(data_a[a_offset + ib].d); + + const uint scidx0 = (is < 4) ? is : (is + 4); + const uint scidx1 = (is < 4) ? is : (is - 4); + const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0; + const uint scidxshift1 = (is < 4) ? 0 : 2; + const uint mbidx0 = is + 4; + const uint mbidx1 = (is < 4) ? is + 4 : is; + const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0; + const uint mbidxshift0 = (is < 4) ? 0 : 4; + const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0; + const uint mbidxshift1 = (is < 4) ? 0 : 2; + + const uint8_t sc = uint8_t((data_a[a_offset + ib].scales[scidx0] & 0xF) | ((data_a[a_offset + ib].scales[scidx1] & scidxmask1) >> scidxshift1)); + const uint8_t mbyte = uint8_t((data_a[a_offset + ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[a_offset + ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + + const float d = loadd.x * sc; + const float m = -loadd.y * mbyte; + + return vec2(fma(d, float((data_a[a_offset + ib].qs[qsi ] >> (b * 4)) & 0xF), m), + fma(d, float((data_a[a_offset + ib].qs[qsi + 1] >> (b * 4)) & 0xF), m)); +} +vec2 get_dm(uint ib, uint a_offset) { + return vec2(1, 0); +} +#endif + +#if defined(DATA_A_Q5_K) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + iqs /= 2; + const uint n = iqs / 32; // 0,1,2,3 + const uint b = (iqs % 32) / 16; // 0,1 + const uint is = 2 * n + b; // 0..7 + const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126 + const uint qhi = (iqs % 16) * 2; // 0,2,4..30 + + const uint8_t hm = uint8_t(1 << (iqs / 16)); + + const vec2 loadd = vec2(data_a[a_offset + ib].d); + + const uint scidx0 = (is < 4) ? is : (is + 4); + const uint scidx1 = (is < 4) ? is : (is - 4); + const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0; + const uint scidxshift1 = (is < 4) ? 0 : 2; + const uint mbidx0 = is + 4; + const uint mbidx1 = (is < 4) ? is + 4 : is; + const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0; + const uint mbidxshift0 = (is < 4) ? 0 : 4; + const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0; + const uint mbidxshift1 = (is < 4) ? 0 : 2; + + const uint8_t sc = uint8_t((data_a[a_offset + ib].scales[scidx0] & 0xF) | ((data_a[a_offset + ib].scales[scidx1] & scidxmask1) >> scidxshift1)); + const uint8_t mbyte = uint8_t(((data_a[a_offset + ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((data_a[a_offset + ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + + const float d = loadd.x * sc; + const float m = -loadd.y * mbyte; + + return vec2(fma(d, float((data_a[a_offset + ib].qs[qsi ] >> (b * 4)) & 0xF) + float((data_a[a_offset + ib].qh[qhi ] & hm) != 0 ? 16 : 0), m), + fma(d, float((data_a[a_offset + ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[a_offset + ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m)); +} +vec2 get_dm(uint ib, uint a_offset) { + return vec2(1, 0); +} +#endif + +#if defined(DATA_A_Q6_K) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + iqs /= 2; + const uint n = iqs / 64; // 0,1 + const uint b = (iqs % 64) / 32; // 0,1 + const uint is_b = (iqs % 16) / 8; // 0,1 + const uint qhshift = ((iqs % 64) / 16) * 2; // 0,2,4,6 + const uint is = 8 * n + qhshift + is_b; // 0..15 + const uint qsi = n * 64 + (iqs % 32) * 2; // 0,2,4..126 + const uint qhi = n * 32 + (iqs % 16) * 2; // 0,2,4..62 + + const float dscale = float(data_a[a_offset + ib].d) * float(data_a[a_offset + ib].scales[is]); + + return vec2(dscale * float(int8_t(((data_a[a_offset + ib].ql[qsi ] >> (b * 4)) & 0xF) | (((data_a[a_offset + ib].qh[qhi ] >> qhshift) & 3) << 4)) - 32), + dscale * float(int8_t(((data_a[a_offset + ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[a_offset + ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32)); +} +vec2 get_dm(uint ib, uint a_offset) { + return vec2(1, 0); +} +#endif diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp index 48f6b65bc..127c7b642 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp @@ -29,7 +29,7 @@ void main() { uint qs = data_a[ib].qs[4 * ib32 + l]; const uint8_t sign = data_a[ib].qs[QUANT_K / 8 + 4 * ib32 + l]; qs |= (qh << (8 - 2 * l)) & 0x300; - const uvec2 grid = iq2s_grid[qs & 511]; + const uvec2 grid = iq2s_grid[qs]; const u8vec4 grid0 = unpack8(grid.x); const u8vec4 grid1 = unpack8(grid.y); data_b[b_idx + 8 * l + 0] = D_TYPE(db[l/2] * grid0.x * ((sign & 1) != 0 ? -1.0 : 1.0)); diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp index e370690bc..0ae9acd02 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp @@ -33,7 +33,8 @@ void main() { [[unroll]] for (uint l = 0; l < 4; ++l) { const uint sign7 = bitfieldExtract(signscale, 7 * int(l), 7); const uint sign8 = sign7 | (bitCount(sign7) << 7); // parity bit - const uvec2 grid = iq2xxs_grid[data_a[ib].qs[8 * is + l]]; + const uint qs = data_a[ib].qs[8 * is + l]; + const uvec2 grid = iq2xxs_grid[qs]; const u8vec4 grid0 = unpack8(grid.x); const u8vec4 grid1 = unpack8(grid.y); data_b[b_idx + 8 * l + 0] = D_TYPE(db * grid0.x * ((sign8 & 1) != 0 ? -1.0 : 1.0)); diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp index c3f4bca5d..e4f42be94 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp @@ -22,15 +22,16 @@ void main() { const uint b_idx = 256 * ib + 32 * is; const float d = float(data_a[ib].d); - const float db = d * (1 + 2 * ((data_a[ib].scales[is] >> (4 * (is % 2))) & 0xf)); + const float db = d * (1 + 2 * ((data_a[ib].scales[is / 2] >> (4 * (is % 2))) & 0xf)); // We must produce 32 values using 4 sign bytes, 1 qh byte, 8 qs bytes. uint qh = data_a[ib].qh[is]; [[unroll]] for (uint l = 0; l < 8; ++l) { - uint qs = data_a[ib].qs[8 * is + l]; - uint gidx = qs | ((qh << (8 - l)) & 256); - uint8_t signs = data_a[ib].signs[8 * is + l / 2] >> (4 * (l & 1)); - u8vec4 grid = unpack8(iq3s_grid[gidx]); + const uint iqs = 8 * is + l; + const uint qs = data_a[ib].qs[iqs]; + const uint gidx = qs | ((qh << (8 - l)) & 256); + const uint8_t signs = data_a[ib].signs[iqs / 2] >> (4 * (l & 1)); + const u8vec4 grid = unpack8(iq3s_grid[gidx]); data_b[b_idx + 4 * l + 0] = D_TYPE(db * grid.x * ((signs & 1) != 0 ? -1.0 : 1.0)); data_b[b_idx + 4 * l + 1] = D_TYPE(db * grid.y * ((signs & 2) != 0 ? -1.0 : 1.0)); data_b[b_idx + 4 * l + 2] = D_TYPE(db * grid.z * ((signs & 4) != 0 ? -1.0 : 1.0)); diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp index a92b82961..19c7fdeef 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp @@ -35,8 +35,10 @@ void main() { const uint sign7 = bitfieldExtract(signscale, 7 * int(l), 7); // Restore parity bit. const uint sign8 = sign7 | (bitCount(sign7) << 7); - const u8vec4 grid0 = unpack8(iq3xxs_grid[data_a[ib].qs[8 * is + 2 * l]]); - const u8vec4 grid1 = unpack8(iq3xxs_grid[data_a[ib].qs[8 * is + 2 * l + 1]]); + const uint qs0 = data_a[ib].qs[8 * is + 2 * l]; + const uint qs1 = data_a[ib].qs[8 * is + 2 * l + 1]; + const u8vec4 grid0 = unpack8(iq3xxs_grid[qs0]); + const u8vec4 grid1 = unpack8(iq3xxs_grid[qs1]); data_b[b_idx + 8 * l + 0] = D_TYPE(db * grid0.x * ((sign8 & 1) != 0 ? -1.0 : 1.0)); data_b[b_idx + 8 * l + 1] = D_TYPE(db * grid0.y * ((sign8 & 2) != 0 ? -1.0 : 1.0)); data_b[b_idx + 8 * l + 2] = D_TYPE(db * grid0.z * ((sign8 & 4) != 0 ? -1.0 : 1.0)); diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp new file mode 100644 index 000000000..a3941372a --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp @@ -0,0 +1,21 @@ +#version 450 + +#include "rte.comp" +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + data_d[i] = D_TYPE(exp(float(data_a[i]))); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index d40848e15..43b906e5e 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -117,6 +117,9 @@ void main() { [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { + continue; + } [[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) { #if BLOCK_SIZE > 1 uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); @@ -155,7 +158,11 @@ void main() { uint32_t c = (idx + tid) % Bc; uint32_t r = (idx + tid) / Bc; if (idx + tid < Bc * Br) { - masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]); + if (!KV_bounds_check || j * Bc + c < KV) { + masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]); + } else { + masksh[c][r] = float(0); + } } } barrier(); @@ -172,8 +179,11 @@ void main() { float rowmaxf[Br], Pf[Br][cols_per_thread], rowsumf[Br], eMf[Br], Moldf[Br]; [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - rowmaxf[r] = Sf[r][0]; + rowmaxf[r] = NEG_FLT_MAX_OVER_2; [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { + continue; + } rowmaxf[r] = max(rowmaxf[r], Sf[r][c]); } Moldf[r] = Mf[r]; @@ -190,6 +200,9 @@ void main() { // Compute sum across row of P rowsumf[r] = 0.0; [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { + continue; + } rowsumf[r] += Pf[r][c]; } @@ -203,6 +216,9 @@ void main() { } [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { + continue; + } [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { #if BLOCK_SIZE > 1 uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); @@ -334,6 +350,9 @@ void main() { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { [[unroll]] for (uint32_t r = 0; r < Br; ++r) { Of[r][d] *= Lfrcp[r]; +#if defined(ACC_TYPE_MAX) + Of[r][d] = clamp(Of[r][d], -vec4(ACC_TYPE_MAX), vec4(ACC_TYPE_MAX)); +#endif } } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp index b57c9dcfc..9b1f153bf 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp @@ -9,6 +9,12 @@ layout (constant_id = 4) const uint32_t HSV = 32; layout (constant_id = 5) const uint32_t Clamp = 0; layout (constant_id = 6) const uint32_t D_split = 16; +// Round up head sizes to a multiple of 16, for coopmat1/coopmat2 paths +const uint32_t HSK_pad = (HSK + 15) & ~15; +const uint32_t HSV_pad = (HSV + 15) & ~15; + +const bool KV_bounds_check = Clamp != 0; + layout (push_constant) uniform parameter { uint32_t N; uint32_t KV; @@ -61,30 +67,48 @@ layout (binding = 5) writeonly buffer O {D_TYPE data_o[];}; #if defined(A_TYPE_PACKED16) #define BINDING_IDX_K 0 #define BINDING_IDX_V 1 -layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2]; +layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16[];} k_packed; +layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed; #endif #if defined(DATA_A_Q4_0) #define BLOCK_BYTE_SIZE 18 vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { - uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); - uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); - uint shift = (iqs & 0x10) >> 2; - vui_lo >>= shift; - vui_hi >>= shift; + if (binding_idx == BINDING_IDX_K) { + uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); + uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); + uint shift = (iqs & 0x10) >> 2; + vui_lo >>= shift; + vui_hi >>= shift; - return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f); + return float(k_packed.k_data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f); + } else { + uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); + uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); + uint shift = (iqs & 0x10) >> 2; + vui_lo >>= shift; + vui_hi >>= shift; + + return float(v_packed.v_data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f); + } } #endif #if defined(DATA_A_Q8_0) #define BLOCK_BYTE_SIZE 34 vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { - const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147 - const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy; + if (binding_idx == BINDING_IDX_K) { + const i8vec2 v0 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147 + const i8vec2 v1 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy; - return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y); + return float(k_packed.k_data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y); + } else { + const i8vec2 v0 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147 + const i8vec2 v1 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy; + + return float(v_packed.v_data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y); + } } #endif diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index 230e815f2..ddb1246e0 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -46,14 +46,14 @@ const uint32_t MatBc = 16; shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x]; shared ACC_TYPEV4 tmpshv4[gl_WorkGroupSize.x]; -const uint32_t qstride = HSK / 4 + 2; // in units of f16vec4 +const uint32_t qstride = HSK_pad / 4 + 2; // in units of f16vec4 shared f16vec4 Qf[Br * qstride]; // Avoid padding for hsk==256 to make it fit in 48KB shmem. const uint32_t sfshstride = (HSK <= 128) ? (Br + 8) : Br; shared ACC_TYPE sfsh[Bc * sfshstride]; -const uint32_t kshstride = HSK / 4 + 2; // in units of f16vec4 +const uint32_t kshstride = HSK_pad / 4 + 2; // in units of f16vec4 shared f16vec4 ksh[Bc * kshstride]; shared float slope[Br]; @@ -74,6 +74,21 @@ void main() { #define tile_row(r) (row_tid * rows_per_thread + (r)) + // Zero-initialize shared memory for Q/K when HSK is not a multiple of 16 (HSK_pad > HSK). + if ((HSK % 16) != 0) { + [[unroll]] for (uint i = 0; i < Br * qstride; i += gl_WorkGroupSize.x) { + if (i + tid < Br * qstride) { + Qf[i + tid] = f16vec4(0); + } + } + [[unroll]] for (uint i = 0; i < Bc * kshstride; i += gl_WorkGroupSize.x) { + if (i + tid < Bc * kshstride) { + ksh[i + tid] = f16vec4(0); + } + } + barrier(); + } + uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4; [[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) { @@ -137,28 +152,31 @@ void main() { uint32_t d = (idx + tid) % (HSK / 4); uint32_t c = (idx + tid) / (HSK / 4); if (c < Bc && d < HSK / 4) { + f16vec4 K_Tf = f16vec4(0); + if (!KV_bounds_check || j * Bc + c < KV) { #if BLOCK_SIZE > 1 - uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d; - uint ib = coord / BLOCK_SIZE; - uint iqs = (coord % BLOCK_SIZE); - f16vec4 K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K)); + uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d; + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K)); #else - f16vec4 K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]); + K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]); #endif + } ksh[c * kshstride + d] = K_Tf; } } barrier(); - // K * Q^T -> S^T: Bc x HSK * HSK x Br -> Bc x Br + // K * Q^T -> S^T: Bc x HSK_pad * HSK_pad x Br -> Bc x Br // Bc split across workgroup (four subgroups), loop over HSK in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16 // This is written transposed in order to allow for N being 8 if implementations need it coopmat SfMat = coopmat(0); coopmat KMat; coopmat QMat; - for (uint32_t d = 0; d < HSK / 16; ++d) { + for (uint32_t d = 0; d < HSK_pad / 16; ++d) { coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor); uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4; @@ -187,7 +205,9 @@ void main() { uint32_t c = (idx + tid) % Bc; uint32_t r = (idx + tid) / Bc; if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) { - sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)])); + if (!KV_bounds_check || j * Bc + c < KV) { + sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)])); + } } } barrier(); @@ -195,8 +215,11 @@ void main() { float eMf[rows_per_thread]; [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - float rowmaxf = sfsh[tile_row(r) + (0 * cols_per_iter + col_tid) * sfshstride]; + float rowmaxf = NEG_FLT_MAX_OVER_2; [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { + continue; + } rowmaxf = max(rowmaxf, float(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride])); } float Moldf = Mf[r]; @@ -210,7 +233,7 @@ void main() { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Of[r][d] = float16_t(eMf[r]) * Of[r][d]; + Of[r][d] = ACC_TYPE(eMf[r]) * Of[r][d]; } } [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { @@ -218,6 +241,9 @@ void main() { } [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { + continue; + } float Pf[rows_per_thread]; [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { Pf[r] = exp(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride] - Mf[r]); @@ -233,7 +259,7 @@ void main() { vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]); #endif [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Of[r][d] += float16_t(Pf[r]) * ACC_TYPEV4(Vf); + Of[r][d] += ACC_TYPE(Pf[r]) * ACC_TYPEV4(Vf); } } } @@ -288,7 +314,7 @@ void main() { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { - Of[r][d] = float16_t(eMf[r]) * Of[r][d]; + Of[r][d] = ACC_TYPE(eMf[r]) * Of[r][d]; tmpshv4[tid] = Of[r][d]; barrier(); @@ -357,7 +383,10 @@ void main() { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Of[r][d] *= float16_t(Lfrcp[r]); + Of[r][d] *= ACC_TYPE(Lfrcp[r]); +#if defined(ACC_TYPE_MAX) + Of[r][d] = clamp(Of[r][d], -ACC_TYPE_MAX, ACC_TYPE_MAX); +#endif } } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index b0564ca0b..ab647e9bc 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -104,16 +104,16 @@ void main() { tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1); tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1); - coopmat Q; - coopmat Qf16; + coopmat Q; + coopmat Qf16; uint32_t q_offset = iq2*p.nb02+iq3*p.nb03; - coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK)); + coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK_pad)); - Qf16 = coopmat(Q); + Qf16 = coopmat(Q); Qf16 *= float16_t(p.scale); - coopmat O = coopmat(0); + coopmat O = coopmat(0); coopmat L, M; @@ -140,10 +140,10 @@ void main() { coopmat S = coopmat(0); - coopmat K_T; + coopmat K_T; uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13; - coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK), tensorViewTranspose DECODEFUNC); + coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose DECODEFUNC); S = coopMatMulAdd(Qf16, K_T, S); if (p.logit_softcap != 0.0f) { @@ -208,31 +208,31 @@ void main() { rowsum = coopmat(0.0); rowsum = coopMatMulAdd(P_A, One, rowsum); - coopmat V; + coopmat V; uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23; - coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV) DECODEFUNC); + coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad) DECODEFUNC); L = eM*L + rowsum; // This is the "diagonal" matrix in the paper, but since we do componentwise // multiply rather than matrix multiply it has the diagonal element smeared // across the row - coopmat eMdiag; + coopmat eMdiag; // resize eM by using smear/reduce coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce); // multiply with fp16 accumulation, then add to O. - coopmat PV = coopmat(0); + coopmat PV = coopmat(0); PV = coopMatMulAdd(P_A, V, PV); - O = eMdiag * O + coopmat(PV); + O = eMdiag * O + coopmat(PV); } // If there is split_k, then the split_k resolve shader does the final // division by L. Store the intermediate O value and per-row m and L values. if (p.k_num > 1) { - coopmat O_D = coopmat(O); + coopmat O_D = coopmat(O); uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num); coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N); @@ -243,16 +243,16 @@ void main() { return; } - coopmat Ldiag; + coopmat Ldiag; // resize L by using smear/reduce coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce); if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) { - coopmat S; + coopmat S; coopMatPerElementNV(S, S, perElemOpGetSink, iq2); - coopmat Mr; + coopmat Mr; // resize M by using smear/reduce coopMatReduceNV(Mr, M, gl_CooperativeMatrixReduceRowNV, smearReduce); @@ -283,9 +283,13 @@ void main() { O = Ldiag*O; +#if defined(ACC_TYPE_MAX) + [[unroll]] for (uint i = 0; i < O.length(); ++i) { O[i] = clamp(O[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); } +#endif + uint32_t o_offset = iq3*p.ne2*p.ne1*HSV; - coopmat O_D = coopmat(O); + coopmat O_D = coopmat(O); if (p.gqa_ratio > 1) { coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N); } else { @@ -295,6 +299,6 @@ void main() { // permute dimensions tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2); - coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, HSV), tensorViewPermute); + coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, HSV_pad), tensorViewPermute); } } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp index 76ef4b6df..06e83822f 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp @@ -111,6 +111,10 @@ void main() { } } O *= L; + + const float FLT_MAX = uintBitsToFloat(0x7F7FFFFF); + O = clamp(O, -FLT_MAX, FLT_MAX); + data_d[iq3 * D * N + D * n + d] = O; } } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp index 4b4316cf3..750e78575 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp @@ -2,6 +2,7 @@ #extension GL_EXT_control_flow_attributes : require #include "rte.comp" +#include "utils.comp" layout (push_constant) uniform parameter { @@ -28,25 +29,9 @@ uint get_aoffset() { return p.misalign_offsets >> 16; } uint get_boffset() { return (p.misalign_offsets >> 8) & 0xFF; } uint get_doffset() { return p.misalign_offsets & 0xFF; } -// mod and div are expensive and coordinates/dimensions are often power of 2 or equal to 1 -uint fastmod(uint a, uint b) { - if ((b & (b-1)) == 0) { - return a & (b-1); - } - return a % b; -} - -uint fastdiv(uint a, uint b) { - return (a < b) ? 0 : (a / b); -} void get_indices(uint idx, out uint i00, out uint i01, out uint i02, out uint i03) { - i03 = fastdiv(idx, (p.ne02*p.ne01*p.ne00)); - const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00; - i02 = fastdiv((idx - i03_offset), (p.ne01*p.ne00)); - const uint i02_offset = i02*p.ne01*p.ne00; - i01 = (idx - i03_offset - i02_offset) / p.ne00; - i00 = idx - i03_offset - i02_offset - i01*p.ne00; + get_indices(idx, i00, i01, i02, i03, p.ne00, p.ne01, p.ne02, p.ne03); } uint src0_idx(uint i00, uint i01, uint i02, uint i03) { diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp index ee6b86a18..7ef75cd7a 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp @@ -7,27 +7,36 @@ layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; void main() { const uint i00 = gl_GlobalInvocationID.x; - const uint i10 = gl_GlobalInvocationID.y; - const uint i11 = (gl_GlobalInvocationID.z)/p.ne12; - const uint i12 = (gl_GlobalInvocationID.z)%p.ne12; if (i00 >= p.ne00) { return; } - const uint i01 = data_b[get_boffset() + i10*p.nb10 + i11*p.nb11 + i12*p.nb12]; + uint gid_z = gl_GlobalInvocationID.z; + while (gid_z < p.ne11 * p.ne12) { + uint gid_y = gl_GlobalInvocationID.y; + while (gid_y < p.ne10) { + const uint i10 = gid_y; + const uint i11 = gid_z / p.ne12; + const uint i12 = gid_z % p.ne12; - const uint a_offset = get_aoffset() + i01*p.nb01 + i11*p.nb02 + i12*p.nb03; - const uint d_offset = get_doffset() + i10*p.nb21 + i11*p.nb22 + i12*p.nb23; + const uint i01 = data_b[get_boffset() + i10*p.nb10 + i11*p.nb11 + i12*p.nb12]; + + const uint a_offset = get_aoffset() + i01*p.nb01 + i11*p.nb02 + i12*p.nb03; + const uint d_offset = get_doffset() + i10*p.nb21 + i11*p.nb22 + i12*p.nb23; #if defined(DATA_A_BF16) - FLOAT_TYPE v = FLOAT_TYPE(bf16_to_fp32(data_a[a_offset + i00])); + FLOAT_TYPE v = FLOAT_TYPE(bf16_to_fp32(data_a[a_offset + i00])); #else - FLOAT_TYPE v = FLOAT_TYPE(data_a[a_offset + i00]); + FLOAT_TYPE v = FLOAT_TYPE(data_a[a_offset + i00]); #endif #ifndef OPTIMIZATION_ERROR_WORKAROUND - data_d[d_offset + i00] = D_TYPE(v); + data_d[d_offset + i00] = D_TYPE(v); #else - data_d[d_offset + i00] = D_TYPE(v); + data_d[d_offset + i00] = D_TYPE(v); #endif + gid_y += gl_WorkGroupSize.y * gl_NumWorkGroups.y; + } + gid_z += gl_WorkGroupSize.z * gl_NumWorkGroups.z; + } } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp index cfd645a38..339f905fc 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp @@ -10,9 +10,6 @@ layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; void main() { const uint i00 = (gl_GlobalInvocationID.x)*2; - const uint i10 = gl_GlobalInvocationID.y; - const uint i11 = (gl_GlobalInvocationID.z)/p.ne12; - const uint i12 = (gl_GlobalInvocationID.z)%p.ne12; #ifdef NEEDS_INIT_IQ_SHMEM init_iq_shmem(gl_WorkGroupSize); @@ -22,20 +19,33 @@ void main() { return; } - const uint i01 = data_b[i10*p.nb10 + i11*p.nb11 + i12*p.nb12]; + uint gid_z = gl_GlobalInvocationID.z; + while (gid_z < p.ne11 * p.ne12) { + uint gid_y = gl_GlobalInvocationID.y; + while (gid_y < p.ne10) { + const uint i10 = gid_y; + const uint i11 = gid_z / p.ne12; + const uint i12 = gid_z % p.ne12; - const uint a_offset = i01*p.nb01 + i11*p.nb02 + i12*p.nb03; - const uint d_offset = i10*p.nb21 + i11*p.nb22 + i12*p.nb23; + const uint i01 = data_b[i10*p.nb10 + i11*p.nb11 + i12*p.nb12]; - const uint ib = a_offset + i00/QUANT_K; // block index - const uint iqs = (i00%QUANT_K)/QUANT_R; // quant index - const uint iybs = i00 - i00%QUANT_K; // dst block start index - const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2; + const uint a_offset = i01*p.nb01 + i11*p.nb02 + i12*p.nb03; + const uint d_offset = i10*p.nb21 + i11*p.nb22 + i12*p.nb23; - vec2 v = dequantize(ib, iqs, 0); - const vec2 dm = get_dm(ib, 0); - v = v * dm.x + dm.y; + const uint ib = a_offset + i00/QUANT_K; // block index + const uint iqs = (i00%QUANT_K)/QUANT_R; // quant index + const uint iybs = i00 - i00%QUANT_K; // dst block start index + const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2; - data_d[d_offset + iybs + iqs ] = D_TYPE(v.x); - data_d[d_offset + iybs + iqs + y_offset] = D_TYPE(v.y); + vec2 v = dequantize(ib, iqs, 0); + const vec2 dm = get_dm(ib, 0); + v = v * dm.x + dm.y; + + data_d[d_offset + iybs + iqs ] = D_TYPE(v.x); + data_d[d_offset + iybs + iqs + y_offset] = D_TYPE(v.y); + + gid_y += gl_WorkGroupSize.y * gl_NumWorkGroups.y; + } + gid_z += gl_WorkGroupSize.z * gl_NumWorkGroups.z; + } } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp new file mode 100644 index 000000000..1da252cc6 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp @@ -0,0 +1,22 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float x = float(data_a[i]); + data_d[i] = D_TYPE(min(1.0f, max(0.0f, (x + 3.0f) / 6.0f))); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp new file mode 100644 index 000000000..3afc58827 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp @@ -0,0 +1,22 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float x = float(data_a[i]); + data_d[i] = D_TYPE(x * min(1.0f, max(0.0f, (x + 3.0f) / 6.0f))); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp index fdbcf7eba..f0f19a019 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp @@ -5,8 +5,11 @@ #include "rte.comp" +#include "types.comp" + layout (push_constant) uniform parameter { + BDA_STORAGE_T dst_addr; uint batch_offset; uint offset_delta; uint IC; uint IW; uint IH; @@ -19,8 +22,6 @@ layout (push_constant) uniform parameter int d0; int d1; } p; -#include "types.comp" - layout(constant_id = 0) const uint BLOCK_SIZE = 32; const uint NUM_ITER = 512 / BLOCK_SIZE; @@ -30,6 +31,10 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; +#if BDA +layout (buffer_reference) buffer D_ptr {D_TYPE d;}; +#endif + void main() { const uint gidx = gl_GlobalInvocationID.x; @@ -38,7 +43,7 @@ void main() { const uint ic = gl_GlobalInvocationID.z % p.IC; const uint src_base = ic * p.offset_delta + batch * p.batch_offset; - const uint dst_base = ((batch * p.OH + oh) * p.OW) * p.CHW + ic * (p.KW * p.KH); + const BDA_OFFSET_T dst_base = ((BDA_OFFSET_T(batch) * p.OH + oh) * p.OW) * p.CHW + BDA_OFFSET_T(ic) * (p.KW * p.KH); const int oh_s1 = int(oh) * p.s1; const uint ksize = p.OW * p.KH; @@ -50,7 +55,7 @@ void main() { uint current_ix = rem % p.OW; A_TYPE values[NUM_ITER]; - uint offset_dst[NUM_ITER]; + BDA_OFFSET_T offset_dst[NUM_ITER]; [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { values[idx] = A_TYPE(0); } @@ -66,7 +71,7 @@ void main() { const uint iiw = current_ix * p.s0 + current_kx * p.d0 - p.p0; const uint iih = oh_s1 + current_ky * p.d1 - p.p1; - offset_dst[idx] = dst_base + current_ix * p.CHW + current_ky * p.KW + current_kx; + offset_dst[idx] = dst_base + BDA_OFFSET_T(current_ix) * p.CHW + current_ky * p.KW + current_kx; if ((iih < p.IH) && (iiw < p.IW)) { values[idx] = data_a[src_base + iih * p.IW + iiw]; @@ -89,7 +94,11 @@ void main() { continue; } +#if BDA + D_ptr dst_addr = D_ptr(p.dst_addr + D_SIZE * offset_dst[idx]); + dst_addr.d = D_TYPE(values[idx]); +#else data_d[offset_dst[idx]] = D_TYPE(values[idx]); +#endif } - } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp new file mode 100644 index 000000000..9faa636ac --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp @@ -0,0 +1,126 @@ +#version 450 + +#extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_control_flow_attributes : require +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "rte.comp" + +#include "types.comp" + +layout (push_constant) uniform parameter +{ + BDA_STORAGE_T dst_addr; + uint32_t nb10; + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + uint32_t s0; + uint32_t s1; + uint32_t s2; + uint32_t p0; + uint32_t p1; + uint32_t p2; + uint32_t d0; + uint32_t d1; + uint32_t d2; + uint32_t IW; + uint32_t IH; + uint32_t ID; + uint32_t IC; + uint32_t KW; + uint32_t OH; + uint32_t KD_KH_KW; + uint32_t KH_KW; + uint32_t IC_KD_KH_KW; + uint32_t N_OD_OH; + uint32_t OD_OH; + uint32_t OD_OH_OW_IC_KD_KH_KW; + uint32_t OH_OW_IC_KD_KH_KW; + uint32_t OW_IC_KD_KH_KW; + uint32_t misalign_offsets; +} p; + +uint get_aoffset() { return p.misalign_offsets >> 16; } +uint get_doffset() { return p.misalign_offsets & 0xFFFF; } + +layout(constant_id = 0) const uint BLOCK_SIZE = 32; + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +#if BDA +layout (buffer_reference) buffer D_ptr {D_TYPE d;}; +#endif + +void main() { + const uint32_t i = gl_GlobalInvocationID.x; + + uint32_t nb10 = p.nb10; + uint32_t nb11 = p.nb11; + uint32_t nb12 = p.nb12; + uint32_t nb13 = p.nb13; + uint32_t s0 = p.s0; + uint32_t s1 = p.s1; + uint32_t s2 = p.s2; + uint32_t p0 = p.p0; + uint32_t p1 = p.p1; + uint32_t p2 = p.p2; + uint32_t d0 = p.d0; + uint32_t d1 = p.d1; + uint32_t d2 = p.d2; + uint32_t IW = p.IW; + uint32_t IH = p.IH; + uint32_t ID = p.ID; + uint32_t IC = p.IC; + uint32_t KW = p.KW; + uint32_t OH = p.OH; + uint32_t KD_KH_KW = p.KD_KH_KW; + uint32_t KH_KW = p.KH_KW; + uint32_t IC_KD_KH_KW = p.IC_KD_KH_KW; + uint32_t N_OD_OH = p.N_OD_OH; + uint32_t OD_OH = p.OD_OH; + uint32_t OD_OH_OW_IC_KD_KH_KW = p.OD_OH_OW_IC_KD_KH_KW; + uint32_t OH_OW_IC_KD_KH_KW = p.OH_OW_IC_KD_KH_KW; + uint32_t OW_IC_KD_KH_KW = p.OW_IC_KD_KH_KW; + + if (i >= IC_KD_KH_KW) { + return; + } + + const uint32_t iic = i / KD_KH_KW; + const uint32_t ikd = (i - iic * KD_KH_KW) / KH_KW; + const uint32_t ikh = (i - iic * KD_KH_KW - ikd * KH_KW) / KW; + const uint32_t ikw = i % KW; + + const uint32_t iow = gl_GlobalInvocationID.y; + for (uint32_t iz = gl_GlobalInvocationID.z; iz < N_OD_OH; iz += gl_NumWorkGroups.z) { + const uint32_t in_ = iz / OD_OH; + const uint32_t iod = (iz - in_*OD_OH) / OH; + const uint32_t ioh = iz % OH; + + const uint32_t iiw = iow * s0 + ikw * d0 - p0; + const uint32_t iih = ioh * s1 + ikh * d1 - p1; + const uint32_t iid = iod * s2 + ikd * d2 - p2; + + const BDA_OFFSET_T offset_dst = BDA_OFFSET_T(in_)*OD_OH_OW_IC_KD_KH_KW + BDA_OFFSET_T(iod)*OH_OW_IC_KD_KH_KW + BDA_OFFSET_T(ioh)*OW_IC_KD_KH_KW + BDA_OFFSET_T(iow)*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw; + + const uint32_t offset_src = (in_*IC + iic)*nb13 + iid*nb12 + iih*nb11 + iiw*nb10; +#if BDA + D_ptr dst_addr = D_ptr(p.dst_addr + D_SIZE * offset_dst); + if (iih >= IH || iiw >= IW || iid >= ID) { + dst_addr.d = D_TYPE(0.0f); + } else { + dst_addr.d = D_TYPE(data_a[offset_src + get_aoffset()]); + } +#else + if (iih >= IH || iiw >= IW || iid >= ID) { + data_d[offset_dst + get_doffset()] = D_TYPE(0.0f); + } else { + data_d[offset_dst + get_doffset()] = D_TYPE(data_a[offset_src + get_aoffset()]); + } +#endif + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp index 903753c7e..f761391ea 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp @@ -2,16 +2,30 @@ #extension GL_EXT_shader_16bit_storage : require #extension GL_EXT_shader_8bit_storage : require +#if USE_SUBGROUP_ADD || USE_SUBGROUP_ADD_NO_SHMEM +#extension GL_KHR_shader_subgroup_basic : require +#extension GL_KHR_shader_subgroup_arithmetic : require +#endif + #ifdef MUL_MAT_ID #define EXPERT_COUNT 8 #endif #include "types.comp" +#ifndef MMQ layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +#else +layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];}; +#endif + layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; +#ifdef B_TYPE_VEC2 layout (binding = 1) readonly buffer BV2 {B_TYPE_VEC2 data_b_v2[];}; +#endif +#ifdef B_TYPE_VEC4 layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];}; +#endif layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; #ifdef MUL_MAT_ID @@ -88,9 +102,57 @@ layout (constant_id = 0) const uint BLOCK_SIZE = 32; layout (constant_id = 1) const uint NUM_ROWS = 1; layout (constant_id = 2) const uint NUM_COLS = 1; +#ifdef USE_SUBGROUP_ADD_NO_SHMEM +void reduce_result(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid) { + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + temp[j][n] = subgroupAdd(temp[j][n]); + } + } + + if (tid == 0) { + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]); + } + } + } +} +#else shared FLOAT_TYPE tmpsh[NUM_COLS][NUM_ROWS][BLOCK_SIZE]; -void reduce_result(const in FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid) { +void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid) { + // subgroupAdd is probably faster on devices that support it, + // particularly when the workgroup has more than one subgroup +#if USE_SUBGROUP_ADD + // sum up partial sums within a subgroup + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + temp[j][n] = subgroupAdd(temp[j][n]); + } + } + + // Go through shared memory to sum partials across subgroups + if (gl_SubgroupInvocationID == 0) { + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + tmpsh[j][n][gl_SubgroupID] = temp[j][n]; + } + } + } + barrier(); + if (tid == 0) { + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + temp[j][n] = FLOAT_TYPE(0); + [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) { + temp[j][n] += tmpsh[j][n][s]; + } + data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]); + } + } + } +#else // sum up partial sums and write back result [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { [[unroll]] for (uint n = 0; n < num_rows; ++n) { @@ -115,4 +177,6 @@ void reduce_result(const in FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32 } } } +#endif } +#endif diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp new file mode 100644 index 000000000..8fb314fa0 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp @@ -0,0 +1,140 @@ +#version 450 + +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require +#extension GL_EXT_integer_dot_product : require + +#define MMQ +#define B_TYPE block_q8_1_x4 + +#include "mul_mat_vec_base.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +#define K_PER_ITER 8 + +#include "mul_mmq_funcs.comp" + +uint a_offset, b_offset, d_offset; + +int32_t cache_b_qs[2]; +vec2 cache_b_ds; + +void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i) { + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + const uint col = i*BLOCK_SIZE + tid*K_PER_ITER; + + // Preload data_b block + const uint b_block_idx = (j*p.batch_stride_b + col) / QUANT_K_Q8_1 + b_offset; + const uint b_qs_idx = tid % 4; + const uint b_block_idx_outer = b_block_idx / 4; + const uint b_block_idx_inner = b_block_idx % 4; + cache_b_ds = vec2(data_b[b_block_idx_outer].ds[b_block_idx_inner]); + +#if QUANT_R == 2 + cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx]; + cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx + 4]; +#else + cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2]; + cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2 + 1]; +#endif + + uint ibi = first_row*p.ncols; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint a_block_idx = (ibi + col)/QUANT_K + a_offset; + ibi += p.ncols; + + int32_t q_sum = 0; +#if QUANT_R == 2 + const i32vec2 data_a_qs = repack(a_block_idx, b_qs_idx); + q_sum += dotPacked4x8EXT(data_a_qs.x, + cache_b_qs[0]); + q_sum += dotPacked4x8EXT(data_a_qs.y, + cache_b_qs[1]); +#else + int32_t data_a_qs = repack(a_block_idx, b_qs_idx * 2); + q_sum += dotPacked4x8EXT(data_a_qs, + cache_b_qs[0]); + data_a_qs = repack(a_block_idx, b_qs_idx * 2 + 1); + q_sum += dotPacked4x8EXT(data_a_qs, + cache_b_qs[1]); +#endif + +#if QUANT_AUXF == 1 + temp[j][n] += mul_q8_1(q_sum, get_d(a_block_idx), cache_b_ds, 4); +#else + temp[j][n] += mul_q8_1(q_sum, get_dm(a_block_idx), cache_b_ds, 4); +#endif + } + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + const uint tid = gl_LocalInvocationID.x; + + get_offsets(a_offset, b_offset, d_offset); + a_offset /= QUANT_K; + b_offset /= QUANT_K_Q8_1; + + FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + temp[j][n] = FLOAT_TYPE(0.0f); + } + } + + uint num_iters = p.ncols / (K_PER_ITER * BLOCK_SIZE); + if (num_iters * K_PER_ITER * BLOCK_SIZE + K_PER_ITER*tid < p.ncols) { + num_iters++; + } + int unroll_count = 4; + uint unrolled_iters = num_iters & ~(unroll_count - 1); + + uint i = 0; + while (i < unrolled_iters) { + // Manually partially unroll the loop + [[unroll]] for (uint k = 0; k < unroll_count; ++k) { + iter(temp, first_row, num_rows, tid, i*K_PER_ITER); + i++; + } + } + + unroll_count = 2; + unrolled_iters = num_iters & ~(unroll_count - 1); + +#if K_PER_ITER == 2 + if ((p.ncols & 1) != 0 && + unrolled_iters == num_iters && + unrolled_iters > 0) { + unrolled_iters -= unroll_count; + } +#endif + + while (i < unrolled_iters) { + // Manually partially unroll the loop + [[unroll]] for (uint k = 0; k < unroll_count; ++k) { + iter(temp, first_row, num_rows, tid, i*K_PER_ITER); + i++; + } + } + while (i < num_iters) { + iter(temp, first_row, num_rows, tid, i*K_PER_ITER); + i++; + } + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index 8c5114a79..3cb24412d 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -17,6 +17,9 @@ #ifdef COOPMAT #extension GL_KHR_cooperative_matrix : enable #extension GL_KHR_memory_scope_semantics : enable +#endif + +#if defined(COOPMAT) || defined(MUL_MAT_ID_USE_SUBGROUPS) #extension GL_KHR_shader_subgroup_basic : enable #extension GL_KHR_shader_subgroup_ballot : enable #endif @@ -34,6 +37,18 @@ #define LOAD_VEC_B 1 #endif +// Load 2 values at once without affecting index calculations through LOAD_VEC +#if (defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)) && !defined(ALIGNED) +#define LOAD_VEC_BATCH_A 2 +#else +#define LOAD_VEC_BATCH_A 1 +#endif +#if !defined(ALIGNED) +#define LOAD_VEC_BATCH_B 2 +#else +#define LOAD_VEC_BATCH_B 1 +#endif + #if !defined(TO_FLOAT_TYPE) #define TO_FLOAT_TYPE FLOAT_TYPE #endif @@ -95,28 +110,93 @@ layout (constant_id = 9) const uint TK = 1; // Only needed for coopmat layout (constant_id = 10) const uint WARP = 32; #ifdef COOPMAT -#define SHMEM_STRIDE (BK + 8) +#define SHMEM_STRIDE (BK / 2 + 4) #else -#define SHMEM_STRIDE (BK + 1) +#define SHMEM_STRIDE (BK / 2 + 1) #endif -shared FLOAT_TYPE buf_a[BM * SHMEM_STRIDE]; -shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE]; - -#ifdef MUL_MAT_ID -shared u16vec2 row_ids[4096]; -uint _ne1; -#ifdef COOPMAT -shared uint _ne1_sh; -#endif -#endif // MUL_MAT_ID +shared FLOAT_TYPE_VEC2 buf_a[BM * SHMEM_STRIDE]; +shared FLOAT_TYPE_VEC2 buf_b[BN * SHMEM_STRIDE]; #define NUM_WARPS (BLOCK_SIZE / WARP) +#ifdef MUL_MAT_ID +shared u16vec2 row_ids[BN]; +uint _ne1; + +#ifdef MUL_MAT_ID_USE_SUBGROUPS +shared uvec4 ballots_sh[NUM_WARPS]; + +void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) { + _ne1 = 0; + uint num_elements = p.nei1 * p.nei0; + uint nei0shift = findLSB(p.nei0); + + uint ids[16]; + uint iter = 0; + + for (uint j = 0; j < num_elements; j += BLOCK_SIZE) { + // prefetch up to 16 elements + if (iter == 0) { + [[unroll]] for (uint k = 0; k < 16; ++k) { + uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE; + bool in_range = i < num_elements; + uint ii1; + if (nei0_is_pow2) { + ii1 = i >> nei0shift; + } else { + ii1 = i / p.nei0; + } + uint ii0 = i - ii1 * p.nei0; + ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0; + } + } + uint i = j + gl_LocalInvocationIndex; + bool in_range = i < num_elements; + uint ii1; + if (nei0_is_pow2) { + ii1 = i >> nei0shift; + } else { + ii1 = i / p.nei0; + } + uint ii0 = i - ii1 * p.nei0; + uint id = ids[iter++]; + uvec4 ballot = subgroupBallot(in_range && id == expert_idx); + + ballots_sh[gl_SubgroupID] = ballot; + barrier(); + + uint subgroup_base = 0; + uint total = 0; + for (uint k = 0; k < gl_NumSubgroups; ++k) { + if (k == gl_SubgroupID) { + subgroup_base = total; + } + total += subgroupBallotBitCount(ballots_sh[k]); + } + barrier(); + + uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot); + if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) { + row_ids[_ne1 + idx - ic * BN] = u16vec2(ii0, ii1); + } + _ne1 += total; + iter &= 15; + if (_ne1 >= (ic + 1) * BN) { + break; + } + } + barrier(); +} +#endif // MUL_MAT_ID_USE_SUBGROUPS +#endif // MUL_MAT_ID + #ifdef COOPMAT shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS]; #endif +#include "mul_mm_funcs.comp" + void main() { #ifdef NEEDS_INIT_IQ_SHMEM init_iq_shmem(gl_WorkGroupSize); @@ -168,60 +248,29 @@ void main() { const uint warp_r = warp_i % (BM / WM); const uint warp_c = warp_i / (BM / WM); - const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A); - const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A); - const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B); - const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B); + const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A / LOAD_VEC_BATCH_A); + const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A / LOAD_VEC_BATCH_A); + const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B / LOAD_VEC_BATCH_B); + const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B / LOAD_VEC_BATCH_B); - const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A / BK; - const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B / BK; + const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A * LOAD_VEC_BATCH_A / BK; + const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B * LOAD_VEC_BATCH_B / BK; #ifdef MUL_MAT_ID -#ifdef COOPMAT - // Spread the search across all elements in the first subgroup - if (gl_SubgroupID == 0) { - _ne1 = 0; - uint num_elements = p.nei1 * p.nei0; - - uint ids[16]; - uint iter = 0; - - for (uint j = 0; j < num_elements; j += gl_SubgroupSize) { - // prefetch up to 16 elements - if (iter == 0) { - [[unroll]] for (uint k = 0; k < 16; ++k) { - uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize; - bool in_range = i < num_elements; - uint ii1 = i / p.nei0; - uint ii0 = i % p.nei0; - ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0; - } - } - uint i = j + gl_SubgroupInvocationID; - bool in_range = i < num_elements; - uint ii1 = i / p.nei0; - uint ii0 = i % p.nei0; - uint id = ids[iter++]; - uvec4 ballot = subgroupBallot(in_range && id == expert_idx); - uint idx = subgroupBallotExclusiveBitCount(ballot); - if (in_range && id == expert_idx) { - row_ids[_ne1 + idx] = u16vec2(ii0, ii1); - } - _ne1 += subgroupBallotBitCount(ballot); - iter &= 15; - } - _ne1_sh = _ne1; +#ifdef MUL_MAT_ID_USE_SUBGROUPS + if (bitCount(p.nei0) == 1) { + load_row_ids(expert_idx, true, ic); + } else { + load_row_ids(expert_idx, false, ic); } - - barrier(); - - _ne1 = _ne1_sh; #else _ne1 = 0; - for (uint ii1 = 0; ii1 < p.nei1; ii1++) { - for (uint ii0 = 0; ii0 < p.nei0; ii0++) { + for (uint ii1 = 0; ii1 < p.nei1 && _ne1 < (ic + 1) * BN; ii1++) { + for (uint ii0 = 0; ii0 < p.nei0 && _ne1 < (ic + 1) * BN; ii0++) { if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) { - row_ids[_ne1] = u16vec2(ii0, ii1); + if (_ne1 >= ic * BN) { + row_ids[_ne1 - ic * BN] = u16vec2(ii0, ii1); + } _ne1++; } } @@ -265,8 +314,8 @@ void main() { } #else ACC_TYPE sums[WMITER * TM * WNITER * TN]; - FLOAT_TYPE cache_a[WMITER * TM]; - FLOAT_TYPE cache_b[TN]; + FLOAT_TYPE_VEC2 cache_a[WMITER * TM]; + FLOAT_TYPE_VEC2 cache_b[TN]; [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) { sums[i] = ACC_TYPE(0.0f); @@ -275,538 +324,13 @@ void main() { for (uint block = start_k; block < end_k; block += BK) { [[unroll]] for (uint l = 0; l < BM; l += loadstride_a) { - -#if defined(DATA_A_F32) || defined(DATA_A_F16) -#if LOAD_VEC_A == 8 - const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx][0].x); - buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx][0].y); - buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx][0].z); - buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx][0].w); - buf_a[buf_idx + 4] = FLOAT_TYPE(data_a[idx][1].x); - buf_a[buf_idx + 5] = FLOAT_TYPE(data_a[idx][1].y); - buf_a[buf_idx + 6] = FLOAT_TYPE(data_a[idx][1].z); - buf_a[buf_idx + 7] = FLOAT_TYPE(data_a[idx][1].w); -#elif LOAD_VEC_A == 4 - const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx].x); - buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx].y); - buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx].z); - buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx].w); -#else - if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) { - buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]); - } else { - buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(0.0f); - } -#endif -#elif defined(DATA_A_BF16) -#if LOAD_VEC_A == 4 - const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - buf_a[buf_idx ] = TO_FLOAT_TYPE(data_a[idx].x); - buf_a[buf_idx + 1] = TO_FLOAT_TYPE(data_a[idx].y); - buf_a[buf_idx + 2] = TO_FLOAT_TYPE(data_a[idx].z); - buf_a[buf_idx + 3] = TO_FLOAT_TYPE(data_a[idx].w); -#else - if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) { - buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = TO_FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]); - } else { - buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = TO_FLOAT_TYPE(uint16_t(0)); - } -#endif -#elif defined(DATA_A_Q4_0) - const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 4 * loadr_a; - - const uint ib = idx / 4; - const uint iqs = idx & 0x03; - - const float d = float(data_a_packed16[ib].d); - const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16); - const vec4 v0 = (vec4(unpack8(vui & 0x0F0F0F0F)) - 8.0f) * d; - const vec4 v1 = (vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) - 8.0f) * d; - - buf_a[buf_idx ] = FLOAT_TYPE(v0.x); - buf_a[buf_idx + 1 ] = FLOAT_TYPE(v0.y); - buf_a[buf_idx + 2 ] = FLOAT_TYPE(v0.z); - buf_a[buf_idx + 3 ] = FLOAT_TYPE(v0.w); - buf_a[buf_idx + 16] = FLOAT_TYPE(v1.x); - buf_a[buf_idx + 17] = FLOAT_TYPE(v1.y); - buf_a[buf_idx + 18] = FLOAT_TYPE(v1.z); - buf_a[buf_idx + 19] = FLOAT_TYPE(v1.w); -#elif defined(DATA_A_Q4_1) - const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 4 * loadr_a; - - const uint ib = idx / 4; - const uint iqs = idx & 0x03; - - const float d = float(data_a_packed16[ib].d); - const float m = float(data_a_packed16[ib].m); - const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16); - const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * d + m; - const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * d + m; - - buf_a[buf_idx ] = FLOAT_TYPE(v0.x); - buf_a[buf_idx + 1 ] = FLOAT_TYPE(v0.y); - buf_a[buf_idx + 2 ] = FLOAT_TYPE(v0.z); - buf_a[buf_idx + 3 ] = FLOAT_TYPE(v0.w); - buf_a[buf_idx + 16] = FLOAT_TYPE(v1.x); - buf_a[buf_idx + 17] = FLOAT_TYPE(v1.y); - buf_a[buf_idx + 18] = FLOAT_TYPE(v1.z); - buf_a[buf_idx + 19] = FLOAT_TYPE(v1.w); -#elif defined(DATA_A_Q5_0) - const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a; - - const uint ib = idx / 8; - const uint iqs = idx & 0x07; - - const float d = float(data_a_packed16[ib].d); - const uint uint_qh = uint(data_a_packed16[ib].qh[1]) << 16 | uint(data_a_packed16[ib].qh[0]); - const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10); - const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10); - - const uint vui = uint(data_a_packed16[ib].qs[iqs]); - const vec4 v = (vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) - 16.0f) * d; - - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1 ] = FLOAT_TYPE(v.z); - buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); - buf_a[buf_idx + 17] = FLOAT_TYPE(v.w); -#elif defined(DATA_A_Q5_1) - const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a; - - const uint ib = idx / 8; - const uint iqs = idx & 0x07; - - const float d = float(data_a_packed16[ib].d); - const float m = float(data_a_packed16[ib].m); - const uint uint_qh = data_a_packed16[ib].qh; - const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10); - const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10); - - const uint vui = uint(data_a_packed16[ib].qs[iqs]); - const vec4 v = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) * d + m; - - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1 ] = FLOAT_TYPE(v.z); - buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); - buf_a[buf_idx + 17] = FLOAT_TYPE(v.w); -#elif defined(DATA_A_Q8_0) - const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - - const uint ib = idx / 8; - const uint iqs = idx & 0x07; - - const float d = float(data_a_packed16[ib].d); - const i8vec2 v0 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs])).xy; // vec4 used due to #12147 - const i8vec2 v1 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs + 1])).xy; - const vec4 v = vec4(v0.x, v0.y, v1.x, v1.y) * d; - - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); - buf_a[buf_idx + 2] = FLOAT_TYPE(v.z); - buf_a[buf_idx + 3] = FLOAT_TYPE(v.w); -#elif defined(DATA_A_Q2_K) - const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - - const uint ib = idx / 128; // 2 values per idx - const uint iqs = idx % 128; // 0..127 - - const uint qsi = (iqs / 64) * 32 + (iqs % 16) * 2; // 0,2,4..30 - const uint scalesi = iqs / 8; // 0..15 - const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6 - - const uvec2 qs = uvec2(data_a[ib].qs[qsi], data_a[ib].qs[qsi + 1]); - const uint scales = data_a[ib].scales[scalesi]; - const vec2 d = vec2(data_a[ib].d); - - const vec2 v = d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4); - - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); -#elif defined(DATA_A_Q3_K) - const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - - const uint ib = idx / 128; // 2 values per idx - const uint iqs = idx % 128; // 0..127 - - const uint n = iqs / 64; // 0,1 - const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..62 - const uint hmi = (iqs % 16) * 2; // 0,2,4..30 - const uint j = (iqs % 64) / 4; // 0..3 - const uint is = iqs / 8; // 0..15 - const uint halfsplit = ((iqs % 64) / 16); // 0,1,2,3 - const uint qsshift = halfsplit * 2; // 0,2,4,6 - const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128 - - const int8_t us = int8_t(((data_a[ib].scales[is % 8] >> (4 * int(is / 8))) & 0xF) - | (((data_a[ib].scales[8 + (is % 4)] >> (2 * int(is / 4))) & 3) << 4)); - const float dl = float(data_a[ib].d) * float(us - 32); - - buf_a[buf_idx ] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi ] >> qsshift) & 3) - (((data_a[ib].hmask[hmi ] & m) != 0) ? 0 : 4))); - buf_a[buf_idx + 1] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4))); -#elif defined(DATA_A_Q4_K) - const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - - const uint ib = idx / 128; // 2 values per idx - const uint iqs = idx % 128; // 0..127 - - const uint n = iqs / 32; // 0,1,2,3 - const uint b = (iqs % 32) / 16; // 0,1 - const uint is = 2 * n + b; // 0..7 - const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126 - - const vec2 loadd = vec2(data_a[ib].d); - - const uint scidx0 = (is < 4) ? is : (is + 4); - const uint scidx1 = (is < 4) ? is : (is - 4); - const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0; - const uint scidxshift1 = (is < 4) ? 0 : 2; - const uint mbidx0 = is + 4; - const uint mbidx1 = (is < 4) ? is + 4 : is; - const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0; - const uint mbidxshift0 = (is < 4) ? 0 : 4; - const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0; - const uint mbidxshift1 = (is < 4) ? 0 : 2; - - const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); - const uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); - - const float d = loadd.x * sc; - const float m = -loadd.y * mbyte; - - buf_a[buf_idx ] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF), m)); - buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF), m)); -#elif defined(DATA_A_Q5_K) - const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - - const uint ib = idx / 128; // 2 values per idx - const uint iqs = idx % 128; // 0..127 - - const uint n = iqs / 32; // 0,1,2,3 - const uint b = (iqs % 32) / 16; // 0,1 - const uint is = 2 * n + b; // 0..7 - const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126 - const uint qhi = (iqs % 16) * 2; // 0,2,4..30 - - const uint8_t hm = uint8_t(1 << (iqs / 16)); - - const vec2 loadd = vec2(data_a[ib].d); - - const uint scidx0 = (is < 4) ? is : (is + 4); - const uint scidx1 = (is < 4) ? is : (is - 4); - const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0; - const uint scidxshift1 = (is < 4) ? 0 : 2; - const uint mbidx0 = is + 4; - const uint mbidx1 = (is < 4) ? is + 4 : is; - const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0; - const uint mbidxshift0 = (is < 4) ? 0 : 4; - const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0; - const uint mbidxshift1 = (is < 4) ? 0 : 2; - - const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); - const uint8_t mbyte = uint8_t(((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); - - const float d = loadd.x * sc; - const float m = -loadd.y * mbyte; - - buf_a[buf_idx ] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi ] & hm) != 0 ? 16 : 0), m)); - buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m)); -#elif defined(DATA_A_Q6_K) - const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - - const uint ib = idx / 128; // 2 values per idx - const uint iqs = idx % 128; // 0..127 - - const uint n = iqs / 64; // 0,1 - const uint b = (iqs % 64) / 32; // 0,1 - const uint is_b = (iqs % 16) / 8; // 0,1 - const uint qhshift = ((iqs % 64) / 16) * 2; // 0,2,4,6 - const uint is = 8 * n + qhshift + is_b; // 0..15 - const uint qsi = n * 64 + (iqs % 32) * 2; // 0,2,4..126 - const uint qhi = n * 32 + (iqs % 16) * 2; // 0,2,4..62 - - const float dscale = float(data_a[ib].d) * float(data_a[ib].scales[is]); - - buf_a[buf_idx ] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi ] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi ] >> qhshift) & 3) << 4)) - 32)); - buf_a[buf_idx + 1] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32)); -#elif defined(DATA_A_IQ1_S) - const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - - const uint ib = idx / 32; // 8 values per idx - const uint ib32 = (idx % 32) / 4; // 0..7 - const uint ib8 = idx % 32; - - const float d = float(data_a[ib].d); - const uint qh = data_a[ib].qh[ib32]; - const uint qs = data_a[ib].qs[ib8]; - const float dl = d * (2 * bitfieldExtract(qh, 12, 3) + 1); - const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA; - const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]); - - [[unroll]] for (int k = 0; k < 8; ++k) { - buf_a[buf_idx + k] = FLOAT_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta)); - } -#elif defined(DATA_A_IQ1_M) - const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - - const uint ib = idx / 32; // 8 values per idx - const uint ib8 = idx % 32; - const uint ib16 = ib8 / 2; - - const uint16_t[4] scales = data_a[ib].scales; - const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12; - const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x); - const uint sc = scales[ib8 / 8]; - const uint qs = data_a[ib].qs[ib8]; - const uint qh = data_a[ib].qh[ib16] >> (4 * (ib8 & 1)); - const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1); - const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA; - const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]); - - [[unroll]] for (int k = 0; k < 8; ++k) { - buf_a[buf_idx + k] = FLOAT_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta)); - } -#elif defined(DATA_A_IQ2_XXS) - const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - - const uint ib = idx / 32; // 8 values per idx - const uint ib32 = (idx % 32) / 4; // 0..7 - const uint ib8 = idx % 4; - - const float d = float(data_a[ib].d); - const uint qs = data_a[ib].qs[8 * ib32 + ib8]; - const uint signs = pack32(u8vec4( - data_a[ib].qs[8*ib32 + 4], - data_a[ib].qs[8*ib32 + 5], - data_a[ib].qs[8*ib32 + 6], - data_a[ib].qs[8*ib32 + 7] - )); - const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + (signs >> 28))); - const uint32_t sign7 = bitfieldExtract(signs, 7 * int(ib8), 7); - const uint sign = sign7 | (bitCount(sign7) << 7); - const uvec2 grid = iq2xxs_grid[qs]; - const vec4 grid0 = vec4(unpack8(grid.x)); - const vec4 grid1 = vec4(unpack8(grid.y)); - - buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x); - buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y); - buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z); - buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w); - buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x); - buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y); - buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z); - buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w); -#elif defined(DATA_A_IQ2_XS) - const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - - const uint ib = idx / 32; // 8 values per idx - const uint ib32 = (idx % 32) / 4; // 0..7 - const uint ib8 = idx % 4; // 0..3 - - const float d = float(data_a[ib].d); - const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf; - const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale)); - const uint qs = data_a[ib].qs[4 * ib32 + ib8]; - const uint sign7 = qs >> 9; - const uint sign = sign7 | (bitCount(sign7) << 7); - const uvec2 grid = iq2xs_grid[qs & 511]; - const vec4 grid0 = vec4(unpack8(grid.x)); - const vec4 grid1 = vec4(unpack8(grid.y)); - - buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x); - buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y); - buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z); - buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w); - buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x); - buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y); - buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z); - buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w); -#elif defined(DATA_A_IQ2_S) - const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - - const uint ib = idx / 32; // 8 values per idx - const uint ib8 = idx % 32; // 0..31 - const uint ib32 = ib8 / 4; // 0..7 - - const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf; - const uint qs = data_a[ib].qs[ib8]; - const uint qh = data_a[ib].qh[ib32]; - const uint qhshift = 2 * (ib8 % 4); - const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8]; - - const float d = float(data_a[ib].d); - const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale)); - const uvec2 grid = iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)]; - const vec4 grid0 = vec4(unpack8(grid.x)); - const vec4 grid1 = vec4(unpack8(grid.y)); - - buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x); - buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y); - buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z); - buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w); - buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x); - buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y); - buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z); - buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w); -#elif defined(DATA_A_IQ3_XXS) - const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - - const uint ib = idx / 64; // 4 values per idx - const uint iqs = idx % 64; // 0..63 - const uint is = QUANT_K / 4 + 4 * (iqs / 8); // 8 values - - const float d = float(data_a[ib].d); - const uint qs = data_a[ib].qs[iqs]; - const uint signs = pack32(u8vec4( - data_a[ib].qs[is+0], - data_a[ib].qs[is+1], - data_a[ib].qs[is+2], - data_a[ib].qs[is+3] - )); - const float db = d * 0.5 * (0.5 + (signs >> 28)); - const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7); - const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (4 * (idx % 2)); - const uint grid = iq3xxs_grid[qs]; - const vec4 v = db * vec4(unpack8(grid)); - - buf_a[buf_idx ] = FLOAT_TYPE((sign & 1) != 0 ? -v.x : v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE((sign & 2) != 0 ? -v.y : v.y); - buf_a[buf_idx + 2] = FLOAT_TYPE((sign & 4) != 0 ? -v.z : v.z); - buf_a[buf_idx + 3] = FLOAT_TYPE((sign & 8) != 0 ? -v.w : v.w); -#elif defined(DATA_A_IQ3_S) - const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - - const uint ib = idx / 64; // 4 values per idx - const uint iqs = idx % 64; // 0..63 - const uint iqh = iqs / 8; - - const float d = float(data_a[ib].d); - const uint qs = data_a[ib].qs[iqs]; - const uint qh = data_a[ib].qh[iqh]; - const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (4 * (idx % 2))); - const uint scale = data_a[ib].scales[iqs / 16]; - const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign))); - const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf)); - const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)]; - const vec4 v = db * vec4(unpack8(grid)); - - buf_a[buf_idx ] = FLOAT_TYPE((sign & 1) != 0 ? -v.x : v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE((sign & 2) != 0 ? -v.y : v.y); - buf_a[buf_idx + 2] = FLOAT_TYPE((sign & 4) != 0 ? -v.z : v.z); - buf_a[buf_idx + 3] = FLOAT_TYPE((sign & 8) != 0 ? -v.w : v.w); -#elif defined(DATA_A_IQ4_XS) - const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - - const uint ib = idx / 128; // 2 values per idx - const uint ib32 = (idx % 128) / 16; // 0..7 - const uint iq = 16 * ib32 + 2 * (idx % 8); - - const uint sl = (data_a[ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF; - const uint sh = ((data_a[ib].scales_h) >> (2 * ib32)) & 3; - const uint qshift = (idx & 8) >> 1; - u8vec2 qs = u8vec2(data_a[ib].qs[iq], data_a[ib].qs[iq + 1]); - qs = (qs >> qshift) & uint8_t(0xF); - - const float d = float(data_a[ib].d); - const vec2 v = d * float(int(sl | (sh << 4)) - 32) * vec2(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y]); - - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); -#elif defined(DATA_A_IQ4_NL) - const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a; - - const uint ib = idx / 8; - const uint iqs = idx & 0x07; - - const FLOAT_TYPE d = FLOAT_TYPE(data_a_packed16[ib].d); - const uint vui = uint(data_a_packed16[ib].qs[iqs]); - - buf_a[buf_idx ] = FLOAT_TYPE(kvalues_iq4nl[vui & 0xF]) * d; - buf_a[buf_idx + 1 ] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 8, 4)]) * d; - buf_a[buf_idx + 16] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 4, 4)]) * d; - buf_a[buf_idx + 17] = FLOAT_TYPE(kvalues_iq4nl[vui >> 12]) * d; -#elif defined(DATA_A_MXFP4) - const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a; - - const uint ib = idx / 8; - const uint iqs = (idx & 0x07) * 2; - - const float d = e8m0_to_fp32(data_a[ib].e); - const uint vui = uint(data_a[ib].qs[iqs]); - const uint vui2 = uint(data_a[ib].qs[iqs+1]); - - buf_a[buf_idx ] = FLOAT_TYPE(kvalues_mxfp4[vui & 0xF] * d); - buf_a[buf_idx + 16] = FLOAT_TYPE(kvalues_mxfp4[vui >> 4] * d); - buf_a[buf_idx + 1] = FLOAT_TYPE(kvalues_mxfp4[vui2 & 0xF] * d); - buf_a[buf_idx + 17] = FLOAT_TYPE(kvalues_mxfp4[vui2 >> 4] * d); -#endif + load_a_to_shmem(pos_a, loadr_a, loadc_a + l, ir * BM + loadc_a + l, block, end_k); } [[unroll]] for (uint l = 0; l < BN; l += loadstride_b) { -#if LOAD_VEC_B == 8 -#ifdef MUL_MAT_ID - const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l]; - const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b; +#if !defined(MUL_MAT_ID) + load_b_to_shmem(pos_b, loadr_b, loadc_b + l, ic * BN + loadc_b + l, block, end_k); #else - const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b; -#endif - const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B; - buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx][0].x); - buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx][0].y); - buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx][0].z); - buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx][0].w); - buf_b[buf_idx + 4] = FLOAT_TYPE(data_b[idx][1].x); - buf_b[buf_idx + 5] = FLOAT_TYPE(data_b[idx][1].y); - buf_b[buf_idx + 6] = FLOAT_TYPE(data_b[idx][1].z); - buf_b[buf_idx + 7] = FLOAT_TYPE(data_b[idx][1].w); -#elif LOAD_VEC_B == 4 -#ifdef MUL_MAT_ID - const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l]; - const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b; -#else - const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b; -#endif - const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B; - buf_b[buf_idx + 0] = TO_FLOAT_TYPE(data_b[idx].x); - buf_b[buf_idx + 1] = TO_FLOAT_TYPE(data_b[idx].y); - buf_b[buf_idx + 2] = TO_FLOAT_TYPE(data_b[idx].z); - buf_b[buf_idx + 3] = TO_FLOAT_TYPE(data_b[idx].w); -#elif !MUL_MAT_ID - if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) { - buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]); - } else { - buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f); - } -#else - const uint row_i = ic * BN + loadc_b + l; - if (row_i < _ne1) { - const u16vec2 row_idx = row_ids[row_i]; - buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]); - } else { - buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f); - } + load_b_to_shmem(pos_b, loadr_b, loadc_b + l, ic, _ne1, block, end_k); #endif } @@ -819,17 +343,17 @@ void main() { [[unroll]] for (uint i = 0; i < BK; i += TK) { [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { // Load from shared into cache - coopMatLoad(cache_a, buf_a, (warp_r * WM + cm_row * TM) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor); + coopMatLoad(cache_a, buf_a, (warp_r * WM + cm_row * TM) * SHMEM_STRIDE + i / 2, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor); [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { - coopMatLoad(cache_b, buf_b, (warp_c * WN + cm_col * TN) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor); + coopMatLoad(cache_b, buf_b, (warp_c * WN + cm_col * TN) * SHMEM_STRIDE + i / 2, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor); sums[cm_col * cms_per_row + cm_row] = coopMatMulAdd(cache_a, cache_b, sums[cm_col * cms_per_row + cm_row]); } } } #else - [[unroll]] for (uint i = 0; i < BK; i++) { + [[unroll]] for (uint i = 0; i < BK / 2; i++) { // Load from shared into cache [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { [[unroll]] for (uint j = 0; j < TM; j++) { @@ -845,7 +369,7 @@ void main() { [[unroll]] for (uint cc = 0; cc < TN; cc++) { [[unroll]] for (uint cr = 0; cr < TM; cr++) { const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr; - sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr]), ACC_TYPE(cache_b[cc]), sums[sums_idx]); + sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr].x), ACC_TYPE(cache_b[cc].x), fma(ACC_TYPE(cache_a[wsir * TM + cr].y), ACC_TYPE(cache_b[cc].y), sums[sums_idx])); } } } @@ -856,6 +380,20 @@ void main() { barrier(); } +#if defined(ACC_TYPE_MAX) +#ifdef COOPMAT + [[unroll]] for (uint j = 0; j < cms_per_row * cms_per_col; j++) { + [[unroll]] for (uint i = 0; i < sums[j].length(); ++i) { + sums[j][i] = clamp(sums[j][i], -ACC_TYPE_MAX, ACC_TYPE_MAX); + } + } +#else + [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) { + sums[i] = clamp(sums[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); + } +#endif +#endif + const uint dr = ir * BM + warp_r * WM; const uint dc = ic * BN + warp_c * WN; @@ -873,9 +411,11 @@ void main() { const uint row_i = dc + cm_col * TN + col + store_c; if (row_i >= _ne1) break; - const u16vec2 row_idx = row_ids[row_i]; + const u16vec2 row_idx = row_ids[row_i - ic * BN]; - data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); + if (dr + cm_row * TM + store_r < p.M) { + data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); + } } } } @@ -921,11 +461,13 @@ void main() { const uint row_i = dc_warp + cc; if (row_i >= _ne1) break; - const u16vec2 row_idx = row_ids[row_i]; + const u16vec2 row_idx = row_ids[row_i - ic * BN]; #endif // MUL_MAT_ID [[unroll]] for (uint cr = 0; cr < TM; cr++) { #ifdef MUL_MAT_ID - data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); + if (dr_warp + cr < p.M) { + data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); + } #else if (dr_warp + cr < p.M && dc_warp + cc < p.N) { data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp index 29e4b5c9c..0e3065e01 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp @@ -19,6 +19,7 @@ #endif #include "types.comp" +#include "utils.comp" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; @@ -92,14 +93,15 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; #ifdef MUL_MAT_ID layout (binding = 3) readonly buffer IDS {int data_ids[];}; -shared u16vec4 row_ids[4096]; +shared u16vec4 row_ids[BN]; layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufB { B_TYPE b[]; }; uint _ne1; -shared uint _ne1_sh; +layout (constant_id = 5) const uint subgroup_size = 32; +shared uvec4 ballots_sh[BLOCK_SIZE / subgroup_size]; B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { @@ -109,7 +111,7 @@ B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const i return B_TYPE(0.0); } - const u16vec4 row_idx = row_ids[row_i]; + const u16vec4 row_idx = row_ids[row_i & (BN - 1)]; B_TYPE ret = data_b[row_idx.y * p.batch_stride_b + row_idx.x * p.stride_b + blockCoords[1]]; return ret; @@ -121,13 +123,74 @@ D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem uint dc = ic * BN + c; if (dr < p.M && dc < _ne1) { - uint row_i = dc; + uint row_i = c; const u16vec4 row_idx = row_ids[row_i]; data_d[row_idx.y * p.batch_stride_d + row_idx.z * p.stride_d + dr] = elem; } return elem; } +void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) { + _ne1 = 0; + uint num_elements = p.nei1 * p.nei0; + uint nei0shift = findLSB(p.nei0); + + uint ids[16]; + uint iter = 0; + + for (uint j = 0; j < num_elements; j += BLOCK_SIZE) { + // prefetch up to 16 elements + if (iter == 0) { + [[unroll]] for (uint k = 0; k < 16; ++k) { + uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE; + bool in_range = i < num_elements; + uint ii1; + if (nei0_is_pow2) { + ii1 = i >> nei0shift; + } else { + ii1 = i / p.nei0; + } + uint ii0 = i - ii1 * p.nei0; + ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0; + } + } + uint i = j + gl_LocalInvocationIndex; + bool in_range = i < num_elements; + uint ii1; + if (nei0_is_pow2) { + ii1 = i >> nei0shift; + } else { + ii1 = i / p.nei0; + } + uint ii0 = i - ii1 * p.nei0; + uint id = ids[iter++]; + uvec4 ballot = subgroupBallot(in_range && id == expert_idx); + + ballots_sh[gl_SubgroupID] = ballot; + barrier(); + + uint subgroup_base = 0; + uint total = 0; + for (uint k = 0; k < gl_NumSubgroups; ++k) { + if (k == gl_SubgroupID) { + subgroup_base = total; + } + total += subgroupBallotBitCount(ballots_sh[k]); + } + barrier(); + + uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot); + if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) { + row_ids[_ne1 + idx - ic * BN] = u16vec4(fastmod(ii0, p.ne11), ii1, ii0, 0); + } + _ne1 += total; + iter &= 15; + if (_ne1 >= (ic + 1) * BN) { + break; + } + } + barrier(); +} #endif void main() { @@ -157,45 +220,12 @@ void main() { const uint ic = gl_WorkGroupID.y; #ifdef MUL_MAT_ID - // Spread the search across all elements in the first subgroup - if (gl_SubgroupID == 0) { - _ne1 = 0; - uint num_elements = p.nei1 * p.nei0; - - uint ids[16]; - uint iter = 0; - - for (uint j = 0; j < num_elements; j += gl_SubgroupSize) { - // prefetch up to 16 elements - if (iter == 0) { - [[unroll]] for (uint k = 0; k < 16; ++k) { - uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize; - bool in_range = i < num_elements; - uint ii1 = i / p.nei0; - uint ii0 = i % p.nei0; - ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0; - } - } - uint i = j + gl_SubgroupInvocationID; - bool in_range = i < num_elements; - uint ii1 = i / p.nei0; - uint ii0 = i % p.nei0; - uint id = ids[iter++]; - uvec4 ballot = subgroupBallot(in_range && id == expert_idx); - uint idx = subgroupBallotExclusiveBitCount(ballot); - if (in_range && id == expert_idx) { - row_ids[_ne1 + idx] = u16vec4(ii0 % p.ne11, ii1, ii0, 0); - } - _ne1 += subgroupBallotBitCount(ballot); - iter &= 15; - } - _ne1_sh = _ne1; + if (bitCount(p.nei0) == 1) { + load_row_ids(expert_idx, true, ic); + } else { + load_row_ids(expert_idx, false, ic); } - barrier(); - - _ne1 = _ne1_sh; - // Workgroup has no work if (ic * BN >= _ne1) return; #endif @@ -235,7 +265,6 @@ void main() { tensorLayoutNV<2> tensorLayoutB = createTensorLayoutNV(2); tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutBClamp = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); - tensorLayoutD = setTensorLayoutStrideNV(tensorLayoutD, p.stride_d, 1); #if QUANT_K > 1 tensorLayoutA = setTensorLayoutBlockSizeNV(tensorLayoutA, 1, QUANT_K); @@ -251,6 +280,8 @@ void main() { tensorLayoutAClamp = setTensorLayoutDimensionNV(tensorLayoutAClamp, p.M, end_k); tensorLayoutBClamp = setTensorLayoutDimensionNV(tensorLayoutBClamp, p.padded_N, end_k); + tensorLayoutD = setTensorLayoutStrideNV(tensorLayoutD, p.stride_d, 1); + tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0); #if !defined(MUL_MAT_ID) @@ -319,6 +350,10 @@ void main() { sum = coopMatMulAdd(mat_a, mat_b, sum); block_k += BK; } +#if defined(ACC_TYPE_MAX) + [[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); } +#endif + coopmat mat_d = coopmat(sum); coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover4, ir * BM, BM), tensorViewTranspose); @@ -358,6 +393,10 @@ void main() { sum = coopMatMulAdd(mat_a, mat_b, sum); block_k += BK; } +#if defined(ACC_TYPE_MAX) + [[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); } +#endif + coopmat mat_d = coopmat(sum); coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover2, ir * BM, BM), tensorViewTranspose); @@ -398,6 +437,10 @@ void main() { sum = coopMatMulAdd(mat_a, mat_b, sum); block_k += BK; } +#if defined(ACC_TYPE_MAX) + [[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); } +#endif + coopmat mat_d = coopmat(sum); coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose); @@ -414,18 +457,111 @@ void main() { tensorLayoutBClamp = setTensorLayoutStrideNV(tensorLayoutBClamp, stride_b, 1); - coopmat sum; - sum = coopmat(0.0); - uint k_iters = (end_k - start_k + BK - 1) / BK; fetch_scales(ir * BM, pos_a, stride_a, start_k, tid, false); + store_scales(tid); + +#ifdef MUL_MAT_ID + if (enable_smaller_matrices && ic * BN + BNover4 >= _ne1) { + coopmat sum; + sum = coopmat(0.0); + + [[dont_unroll]] + for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) { + + if ((block_k % QUANT_K) == 0) { + store_scales(tid); + } + if (block_k + BK < end_k && ((block_k + BK) % QUANT_K) == 0) { + fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false); + } + + if ((ir + 1) * BM <= p.M && block_k + BK <= end_k) { + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB); + + sum = coopMatMulAdd(mat_a, mat_b, sum); + } else { + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB); + + sum = coopMatMulAdd(mat_a, mat_b, sum); + } + } +#if defined(ACC_TYPE_MAX) + [[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); } +#endif + + // Convert from ACC_TYPE to D_TYPE + coopmat mat_d; + mat_d = coopmat(sum); + + // Call callback to store each element, remapping row through shared memory + coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic); + return; + } + if (enable_smaller_matrices && ic * BN + BNover2 >= _ne1) { + coopmat sum; + sum = coopmat(0.0); + + [[dont_unroll]] + for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) { + + if ((block_k % QUANT_K) == 0) { + store_scales(tid); + } + if (block_k + BK < end_k && ((block_k + BK) % QUANT_K) == 0) { + fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false); + } + + if ((ir + 1) * BM <= p.M && block_k + BK <= end_k) { + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB); + + sum = coopMatMulAdd(mat_a, mat_b, sum); + } else { + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB); + + sum = coopMatMulAdd(mat_a, mat_b, sum); + } + } +#if defined(ACC_TYPE_MAX) + [[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); } +#endif + + // Convert from ACC_TYPE to D_TYPE + coopmat mat_d; + mat_d = coopmat(sum); + + // Call callback to store each element, remapping row through shared memory + coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic); + return; + } +#endif + coopmat sum; + sum = coopmat(0.0); [[dont_unroll]] for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) { - store_scales(tid); - if (block_k + BK < end_k) { + if ((block_k % QUANT_K) == 0) { + store_scales(tid); + } + if (block_k + BK < end_k && ((block_k + BK) % QUANT_K) == 0) { fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false); } @@ -455,6 +591,9 @@ void main() { sum = coopMatMulAdd(mat_a, mat_b, sum); } } +#if defined(ACC_TYPE_MAX) + [[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); } +#endif // Convert from ACC_TYPE to D_TYPE coopmat mat_d; diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp new file mode 100644 index 000000000..0ebfbd646 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp @@ -0,0 +1,556 @@ +void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uint idx_m, const uint block, const uint end_k) { +#if defined(DATA_A_F32) || defined(DATA_A_F16) +#if LOAD_VEC_A == 8 + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + FLOAT_TYPE_VEC8 aa = FLOAT_TYPE_VEC8(data_a[idx]); + buf_a[buf_idx ] = aa[0].xy; + buf_a[buf_idx + 1] = aa[0].zw; + buf_a[buf_idx + 2] = aa[1].xy; + buf_a[buf_idx + 3] = aa[1].zw; +#elif LOAD_VEC_A == 4 + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + FLOAT_TYPE_VEC4 aa = FLOAT_TYPE_VEC4(data_a[idx]); + buf_a[buf_idx ] = aa.xy; + buf_a[buf_idx + 1] = aa.zw; +#else // LOAD_VEC_BATCH_A == 2 + const uint idx = pos_a + col * p.stride_a + row * 2; + const uint buf_idx = col * SHMEM_STRIDE + row; + if (idx_m < p.M && block + row * 2 + 1 < end_k) { + buf_a[buf_idx] = FLOAT_TYPE_VEC2(data_a[idx], + data_a[idx + 1]); + } else if (idx_m < p.M && block + row * 2 < end_k) { + buf_a[buf_idx] = FLOAT_TYPE_VEC2(data_a[idx], 0.0f); + } else { + buf_a[buf_idx] = FLOAT_TYPE_VEC2(0.0f); + } +#endif +#elif defined(DATA_A_BF16) +#if LOAD_VEC_A == 4 + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + FLOAT_TYPE_VEC4 aa = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_a[idx])); + buf_a[buf_idx ] = aa.xy; + buf_a[buf_idx + 1] = aa.zw; +#else // LOAD_VEC_BATCH_A == 2 + const uint idx = pos_a + col * p.stride_a + row * 2; + const uint buf_idx = col * SHMEM_STRIDE + row; + if (idx_m < p.M && block + row * 2 + 1 < end_k) { + buf_a[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_a[idx]), + TO_FLOAT_TYPE(data_a[idx + 1])); + } else if (idx_m < p.M && block + row * 2 < end_k) { + buf_a[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_a[idx]), 0.0f); + } else { + buf_a[buf_idx] = FLOAT_TYPE_VEC2(0.0f); + } +#endif +#elif defined(DATA_A_Q4_0) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + 2 * row; + + const uint ib = idx / 4; + const uint iqs = idx & 0x03; + + const float d = float(data_a_packed16[ib].d); + const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16); + const vec4 v0 = (vec4(unpack8(vui & 0x0F0F0F0F)) - 8.0f) * d; + const vec4 v1 = (vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) - 8.0f) * d; + + buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xy); + buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v0.zw); + buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v1.xy); + buf_a[buf_idx + 9] = FLOAT_TYPE_VEC2(v1.zw); +#elif defined(DATA_A_Q4_1) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + 2 * row; + + const uint ib = idx / 4; + const uint iqs = idx & 0x03; + + const float d = float(data_a_packed16[ib].d); + const float m = float(data_a_packed16[ib].m); + const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16); + const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * d + m; + const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * d + m; + + buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xy); + buf_a[buf_idx + 1 ] = FLOAT_TYPE_VEC2(v0.zw); + buf_a[buf_idx + 8 ] = FLOAT_TYPE_VEC2(v1.xy); + buf_a[buf_idx + 9 ] = FLOAT_TYPE_VEC2(v1.zw); +#elif defined(DATA_A_Q5_0) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row; + + const uint ib = idx / 8; + const uint iqs = idx & 0x07; + + const float d = float(data_a_packed16[ib].d); + const uint uint_qh = uint(data_a_packed16[ib].qh[1]) << 16 | uint(data_a_packed16[ib].qh[0]); + const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10); + const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10); + + const uint vui = uint(data_a_packed16[ib].qs[iqs]); + const vec4 v = (vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) - 16.0f) * d; + + buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xz); + buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v.yw); +#elif defined(DATA_A_Q5_1) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row; + + const uint ib = idx / 8; + const uint iqs = idx & 0x07; + + const float d = float(data_a_packed16[ib].d); + const float m = float(data_a_packed16[ib].m); + const uint uint_qh = data_a_packed16[ib].qh; + const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10); + const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10); + + const uint vui = uint(data_a_packed16[ib].qs[iqs]); + const vec4 v = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) * d + m; + + buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xz); + buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v.yw); +#elif defined(DATA_A_Q8_0) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + + const uint ib = idx / 8; + const uint iqs = idx & 0x07; + + const float d = float(data_a_packed16[ib].d); + const i8vec2 v0 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs])).xy; // vec4 used due to #12147 + const i8vec2 v1 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs + 1])).xy; + const vec4 v = vec4(v0.x, v0.y, v1.x, v1.y) * d; + + buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy); + buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v.zw); +#elif defined(DATA_A_Q2_K) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + + const uint ib = idx / 128; // 2 values per idx + const uint iqs = idx % 128; // 0..127 + + const uint qsi = (iqs / 64) * 32 + (iqs % 16) * 2; // 0,2,4..30 + const uint scalesi = iqs / 8; // 0..15 + const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6 + + const uvec2 qs = uvec2(data_a[ib].qs[qsi], data_a[ib].qs[qsi + 1]); + const uint scales = data_a[ib].scales[scalesi]; + const vec2 d = vec2(data_a[ib].d); + + const vec2 v = d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4); + + buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy); +#elif defined(DATA_A_Q3_K) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + + const uint ib = idx / 128; // 2 values per idx + const uint iqs = idx % 128; // 0..127 + + const uint n = iqs / 64; // 0,1 + const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..62 + const uint hmi = (iqs % 16) * 2; // 0,2,4..30 + const uint j = (iqs % 64) / 4; // 0..3 + const uint is = iqs / 8; // 0..15 + const uint halfsplit = ((iqs % 64) / 16); // 0,1,2,3 + const uint qsshift = halfsplit * 2; // 0,2,4,6 + const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128 + + const int8_t us = int8_t(((data_a[ib].scales[is % 8] >> (4 * int(is / 8))) & 0xF) + | (((data_a[ib].scales[8 + (is % 4)] >> (2 * int(is / 4))) & 3) << 4)); + const float dl = float(data_a[ib].d) * float(us - 32); + + buf_a[buf_idx] = FLOAT_TYPE_VEC2(dl * float(int8_t((data_a[ib].qs[qsi ] >> qsshift) & 3) - (((data_a[ib].hmask[hmi ] & m) != 0) ? 0 : 4)), + dl * float(int8_t((data_a[ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4))); +#elif defined(DATA_A_Q4_K) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + + const uint ib = idx / 128; // 2 values per idx + const uint iqs = idx % 128; // 0..127 + + const uint n = iqs / 32; // 0,1,2,3 + const uint b = (iqs % 32) / 16; // 0,1 + const uint is = 2 * n + b; // 0..7 + const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126 + + const vec2 loadd = vec2(data_a[ib].d); + + const uint scidx0 = (is < 4) ? is : (is + 4); + const uint scidx1 = (is < 4) ? is : (is - 4); + const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0; + const uint scidxshift1 = (is < 4) ? 0 : 2; + const uint mbidx0 = is + 4; + const uint mbidx1 = (is < 4) ? is + 4 : is; + const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0; + const uint mbidxshift0 = (is < 4) ? 0 : 4; + const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0; + const uint mbidxshift1 = (is < 4) ? 0 : 2; + + const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); + const uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + + const float d = loadd.x * sc; + const float m = -loadd.y * mbyte; + + buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF), m), + fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF), m)); +#elif defined(DATA_A_Q5_K) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + + const uint ib = idx / 128; // 2 values per idx + const uint iqs = idx % 128; // 0..127 + + const uint n = iqs / 32; // 0,1,2,3 + const uint b = (iqs % 32) / 16; // 0,1 + const uint is = 2 * n + b; // 0..7 + const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126 + const uint qhi = (iqs % 16) * 2; // 0,2,4..30 + + const uint8_t hm = uint8_t(1 << (iqs / 16)); + + const vec2 loadd = vec2(data_a[ib].d); + + const uint scidx0 = (is < 4) ? is : (is + 4); + const uint scidx1 = (is < 4) ? is : (is - 4); + const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0; + const uint scidxshift1 = (is < 4) ? 0 : 2; + const uint mbidx0 = is + 4; + const uint mbidx1 = (is < 4) ? is + 4 : is; + const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0; + const uint mbidxshift0 = (is < 4) ? 0 : 4; + const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0; + const uint mbidxshift1 = (is < 4) ? 0 : 2; + + const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); + const uint8_t mbyte = uint8_t(((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + + const float d = loadd.x * sc; + const float m = -loadd.y * mbyte; + + buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi ] & hm) != 0 ? 16 : 0), m), + fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m)); +#elif defined(DATA_A_Q6_K) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + + const uint ib = idx / 128; // 2 values per idx + const uint iqs = idx % 128; // 0..127 + + const uint n = iqs / 64; // 0,1 + const uint b = (iqs % 64) / 32; // 0,1 + const uint is_b = (iqs % 16) / 8; // 0,1 + const uint qhshift = ((iqs % 64) / 16) * 2; // 0,2,4,6 + const uint is = 8 * n + qhshift + is_b; // 0..15 + const uint qsi = n * 64 + (iqs % 32) * 2; // 0,2,4..126 + const uint qhi = n * 32 + (iqs % 16) * 2; // 0,2,4..62 + + const float dscale = float(data_a[ib].d) * float(data_a[ib].scales[is]); + + buf_a[buf_idx] = FLOAT_TYPE_VEC2(dscale * float(int8_t(((data_a[ib].ql[qsi ] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi ] >> qhshift) & 3) << 4)) - 32), + dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32)); +#elif defined(DATA_A_IQ1_S) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + + const uint ib = idx / 32; // 8 values per idx + const uint ib32 = (idx % 32) / 4; // 0..7 + const uint ib8 = idx % 32; + + const float d = float(data_a[ib].d); + const uint qh = data_a[ib].qh[ib32]; + const uint qs = data_a[ib].qs[ib8]; + const float dl = d * (2 * bitfieldExtract(qh, 12, 3) + 1); + const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA; + const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]); + + [[unroll]] for (int k = 0; k < 4; ++k) { + buf_a[buf_idx + k] = FLOAT_TYPE_VEC2(dl * (bitfieldExtract(grid, 4 * k , 2) + delta), + dl * (bitfieldExtract(grid, 4 * k + 2, 2) + delta)); + } +#elif defined(DATA_A_IQ1_M) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + + const uint ib = idx / 32; // 8 values per idx + const uint ib8 = idx % 32; + const uint ib16 = ib8 / 2; + + const uint16_t[4] scales = data_a[ib].scales; + const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12; + const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x); + const uint sc = scales[ib8 / 8]; + const uint qs = data_a[ib].qs[ib8]; + const uint qh = data_a[ib].qh[ib16] >> (4 * (ib8 & 1)); + const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1); + const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA; + const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]); + + [[unroll]] for (int k = 0; k < 4; ++k) { + buf_a[buf_idx + k] = FLOAT_TYPE_VEC2(dl * (bitfieldExtract(grid, 4 * k , 2) + delta), + dl * (bitfieldExtract(grid, 4 * k + 2, 2) + delta)); + } +#elif defined(DATA_A_IQ2_XXS) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + + const uint ib = idx / 32; // 8 values per idx + const uint ib32 = (idx % 32) / 4; // 0..7 + const uint ib8 = idx % 4; + + const float d = float(data_a[ib].d); + const uint qs = data_a[ib].qs[8 * ib32 + ib8]; + const uint signs = pack32(u8vec4( + data_a[ib].qs[8*ib32 + 4], + data_a[ib].qs[8*ib32 + 5], + data_a[ib].qs[8*ib32 + 6], + data_a[ib].qs[8*ib32 + 7] + )); + const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + (signs >> 28))); + const uint32_t sign7 = bitfieldExtract(signs, 7 * int(ib8), 7); + const uint sign = sign7 | (bitCount(sign7) << 7); + const uvec2 grid = iq2xxs_grid[qs]; + const vec4 grid0 = vec4(unpack8(grid.x)); + const vec4 grid1 = vec4(unpack8(grid.y)); + + buf_a[buf_idx ] = db * FLOAT_TYPE_VEC2((sign & 1) != 0 ? -grid0.x : grid0.x, + (sign & 2) != 0 ? -grid0.y : grid0.y); + buf_a[buf_idx + 1] = db * FLOAT_TYPE_VEC2((sign & 4) != 0 ? -grid0.z : grid0.z, + (sign & 8) != 0 ? -grid0.w : grid0.w); + buf_a[buf_idx + 2] = db * FLOAT_TYPE_VEC2((sign & 16) != 0 ? -grid1.x : grid1.x, + (sign & 32) != 0 ? -grid1.y : grid1.y); + buf_a[buf_idx + 3] = db * FLOAT_TYPE_VEC2((sign & 64) != 0 ? -grid1.z : grid1.z, + (sign & 128) != 0 ? -grid1.w : grid1.w); +#elif defined(DATA_A_IQ2_XS) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + + const uint ib = idx / 32; // 8 values per idx + const uint ib32 = (idx % 32) / 4; // 0..7 + const uint ib8 = idx % 4; // 0..3 + + const float d = float(data_a[ib].d); + const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf; + const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale)); + const uint qs = data_a[ib].qs[4 * ib32 + ib8]; + const uint sign7 = qs >> 9; + const uint sign = sign7 | (bitCount(sign7) << 7); + const uvec2 grid = iq2xs_grid[qs & 511]; + const vec4 grid0 = vec4(unpack8(grid.x)); + const vec4 grid1 = vec4(unpack8(grid.y)); + + buf_a[buf_idx ] = db * FLOAT_TYPE_VEC2((sign & 1) != 0 ? -grid0.x : grid0.x, + (sign & 2) != 0 ? -grid0.y : grid0.y); + buf_a[buf_idx + 1] = db * FLOAT_TYPE_VEC2((sign & 4) != 0 ? -grid0.z : grid0.z, + (sign & 8) != 0 ? -grid0.w : grid0.w); + buf_a[buf_idx + 2] = db * FLOAT_TYPE_VEC2((sign & 16) != 0 ? -grid1.x : grid1.x, + (sign & 32) != 0 ? -grid1.y : grid1.y); + buf_a[buf_idx + 3] = db * FLOAT_TYPE_VEC2((sign & 64) != 0 ? -grid1.z : grid1.z, + (sign & 128) != 0 ? -grid1.w : grid1.w); +#elif defined(DATA_A_IQ2_S) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + + const uint ib = idx / 32; // 8 values per idx + const uint ib8 = idx % 32; // 0..31 + const uint ib32 = ib8 / 4; // 0..7 + + const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf; + const uint qs = data_a[ib].qs[ib8]; + const uint qh = data_a[ib].qh[ib32]; + const uint qhshift = 2 * (ib8 % 4); + const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8]; + + const float d = float(data_a[ib].d); + const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale)); + const uvec2 grid = iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)]; + const vec4 grid0 = vec4(unpack8(grid.x)); + const vec4 grid1 = vec4(unpack8(grid.y)); + + buf_a[buf_idx ] = db * FLOAT_TYPE_VEC2((sign & 1) != 0 ? -grid0.x : grid0.x, + (sign & 2) != 0 ? -grid0.y : grid0.y); + buf_a[buf_idx + 1] = db * FLOAT_TYPE_VEC2((sign & 4) != 0 ? -grid0.z : grid0.z, + (sign & 8) != 0 ? -grid0.w : grid0.w); + buf_a[buf_idx + 2] = db * FLOAT_TYPE_VEC2((sign & 16) != 0 ? -grid1.x : grid1.x, + (sign & 32) != 0 ? -grid1.y : grid1.y); + buf_a[buf_idx + 3] = db * FLOAT_TYPE_VEC2((sign & 64) != 0 ? -grid1.z : grid1.z, + (sign & 128) != 0 ? -grid1.w : grid1.w); +#elif defined(DATA_A_IQ3_XXS) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + + const uint ib = idx / 64; // 4 values per idx + const uint iqs = idx % 64; // 0..63 + const uint is = QUANT_K / 4 + 4 * (iqs / 8); // 8 values + + const float d = float(data_a[ib].d); + const uint qs = data_a[ib].qs[iqs]; + const uint signs = pack32(u8vec4( + data_a[ib].qs[is+0], + data_a[ib].qs[is+1], + data_a[ib].qs[is+2], + data_a[ib].qs[is+3] + )); + const float db = d * 0.5 * (0.5 + (signs >> 28)); + const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7); + const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (4 * (idx % 2)); + const uint grid = iq3xxs_grid[qs]; + const vec4 v = db * vec4(unpack8(grid)); + + buf_a[buf_idx ] = FLOAT_TYPE_VEC2((sign & 1) != 0 ? -v.x : v.x, + (sign & 2) != 0 ? -v.y : v.y); + buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2((sign & 4) != 0 ? -v.z : v.z, + (sign & 8) != 0 ? -v.w : v.w); +#elif defined(DATA_A_IQ3_S) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + + const uint ib = idx / 64; // 4 values per idx + const uint iqs = idx % 64; // 0..63 + const uint iqh = iqs / 8; + + const float d = float(data_a[ib].d); + const uint qs = data_a[ib].qs[iqs]; + const uint qh = data_a[ib].qh[iqh]; + const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (4 * (idx % 2))); + const uint scale = data_a[ib].scales[iqs / 16]; + const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign))); + const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf)); + const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)]; + const vec4 v = db * vec4(unpack8(grid)); + + buf_a[buf_idx ] = FLOAT_TYPE_VEC2((sign & 1) != 0 ? -v.x : v.x, + (sign & 2) != 0 ? -v.y : v.y); + buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2((sign & 4) != 0 ? -v.z : v.z, + (sign & 8) != 0 ? -v.w : v.w); +#elif defined(DATA_A_IQ4_XS) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + + const uint ib = idx / 128; // 2 values per idx + const uint ib32 = (idx % 128) / 16; // 0..7 + const uint iq = 16 * ib32 + 2 * (idx % 8); + + const uint sl = (data_a[ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF; + const uint sh = ((data_a[ib].scales_h) >> (2 * ib32)) & 3; + const uint qshift = (idx & 8) >> 1; + u8vec2 qs = u8vec2(data_a[ib].qs[iq], data_a[ib].qs[iq + 1]); + qs = (qs >> qshift) & uint8_t(0xF); + + const float d = float(data_a[ib].d); + const vec2 v = d * float(int(sl | (sh << 4)) - 32) * vec2(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y]); + + buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy); +#elif defined(DATA_A_IQ4_NL) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row; + + const uint ib = idx / 8; + const uint iqs = idx & 0x07; + + const FLOAT_TYPE d = FLOAT_TYPE(data_a_packed16[ib].d); + const uint vui = uint(data_a_packed16[ib].qs[iqs]); + + buf_a[buf_idx ] = d * FLOAT_TYPE_VEC2(kvalues_iq4nl[vui & 0xF], + kvalues_iq4nl[bitfieldExtract(vui, 8, 4)]); + buf_a[buf_idx + 8] = d * FLOAT_TYPE_VEC2(kvalues_iq4nl[bitfieldExtract(vui, 4, 4)], + kvalues_iq4nl[vui >> 12]); +#elif defined(DATA_A_MXFP4) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row; + + const uint ib = idx / 8; + const uint iqs = (idx & 0x07) * 2; + + const float d = e8m0_to_fp32(data_a[ib].e); + const uint vui = uint(data_a[ib].qs[iqs]); + const uint vui2 = uint(data_a[ib].qs[iqs+1]); + + buf_a[buf_idx ] = FLOAT_TYPE_VEC2(kvalues_mxfp4[vui & 0xF] * d, + kvalues_mxfp4[vui2 & 0xF] * d); + buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(kvalues_mxfp4[vui >> 4] * d, + kvalues_mxfp4[vui2 >> 4] * d); +#endif +} + +#if !defined(MUL_MAT_ID) +void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uint idx_n, const uint block, const uint end_k) { +#if LOAD_VEC_B == 8 + // Not supported for b_type bf16 because bf16mat2x4 does not exist + const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2; + FLOAT_TYPE_VEC8 bb = FLOAT_TYPE_VEC8(data_b[idx]); + buf_b[buf_idx + 0] = bb[0].xy; + buf_b[buf_idx + 1] = bb[0].zw; + buf_b[buf_idx + 2] = bb[1].xy; + buf_b[buf_idx + 3] = bb[1].zw; +#elif LOAD_VEC_B == 4 + const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2; +#if defined(DATA_B_BF16) + FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_b[idx])); +#else + FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(data_b[idx]); +#endif + buf_b[buf_idx + 0] = bb.xy; + buf_b[buf_idx + 1] = bb.zw; +#else // LOAD_VEC_BATCH_B == 2 + const uint idx = pos_b + col * p.stride_b + row * 2; + const uint buf_idx = col * SHMEM_STRIDE + row; + if (idx_n < p.N && block + row * 2 + 1 < end_k) { + buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]), + TO_FLOAT_TYPE(data_b[idx + 1])); + } else if (idx_n < p.N && block + row * 2 < end_k) { + buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]), 0.0f); + } else { + buf_b[buf_idx] = FLOAT_TYPE_VEC2(0.0f); + } +#endif +} +#else +void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uint ic, const uint _ne1, const uint block, const uint end_k) { +#if LOAD_VEC_B == 8 + // Not supported for b_type bf16 because bf16mat2x4 does not exist + const u16vec2 row_idx = row_ids[col]; + const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2; + FLOAT_TYPE_VEC8 bb = FLOAT_TYPE_VEC8(data_b[idx]); + buf_b[buf_idx + 0] = bb[0].xy; + buf_b[buf_idx + 1] = bb[0].zw; + buf_b[buf_idx + 2] = bb[1].xy; + buf_b[buf_idx + 3] = bb[1].zw; +#elif LOAD_VEC_B == 4 + const u16vec2 row_idx = row_ids[col]; + const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2; +#if defined(DATA_B_BF16) + FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_b[idx])); +#else + FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(data_b[idx]); +#endif + buf_b[buf_idx + 0] = bb.xy; + buf_b[buf_idx + 1] = bb.zw; +#else // LOAD_VEC_BATCH_B == 2 + const uint row_i = ic * BN + col; + const uint buf_idx = col * SHMEM_STRIDE + row; + if (row_i < _ne1 && block + row * 2 + 1 < end_k) { + const u16vec2 row_idx = row_ids[col]; + const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2; + buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]), + TO_FLOAT_TYPE(data_b[idx + 1])); + } else if (row_i < _ne1 && block + row * 2 < end_k) { + const u16vec2 row_idx = row_ids[col]; + const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2; + buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]), 0.0f); + } else { + buf_b[buf_idx] = FLOAT_TYPE_VEC2(0.0f); + } +#endif +} +#endif diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp index 83de90eb7..f36add62a 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp @@ -28,7 +28,7 @@ layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];}; #if defined(A_TYPE_PACKED32) layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];}; #endif -layout (binding = 1) readonly buffer B {block_q8_1_packed32 data_b[];}; +layout (binding = 1) readonly buffer B {block_q8_1_x4_packed128 data_b[];}; layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; #ifdef MUL_MAT_ID @@ -98,7 +98,7 @@ shared FLOAT_TYPE_VEC2 buf_b_ds[BN]; #endif #define LOAD_VEC_A (4 * QUANT_R) -#define LOAD_VEC_B 4 +#define LOAD_VEC_B 16 #ifdef MUL_MAT_ID shared u16vec2 row_ids[4096]; @@ -270,15 +270,22 @@ void main() { const uint iqs = idx & 0x7; #else const uint ib = pos_b_ib + (loadc_b + l) * p.stride_b / BK; + const uint ib_outer = ib / 4; + const uint ib_inner = ib % 4; + const uint iqs = loadr_b; #endif const uint buf_ib = loadc_b + l; if (iqs == 0) { - buf_b_ds[buf_ib] = FLOAT_TYPE_VEC2(data_b[ib].ds); + buf_b_ds[buf_ib] = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]); } - buf_b_qs[buf_ib * SHMEM_STRIDE + iqs] = data_b[ib].qs[iqs]; + const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs]; + buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 ] = values.x; + buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 1] = values.y; + buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 2] = values.z; + buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 3] = values.w; } barrier(); @@ -349,7 +356,7 @@ void main() { cache_b_qs[cc * (BK / 4) + idx_k]); } - sums[sums_idx] += mul_q8_1(q_sum, cache_a_dm[cache_a_idx], cache_b_ds[cc]); + sums[sums_idx] += mul_q8_1(q_sum, cache_a_dm[cache_a_idx], cache_b_ds[cc], 1); } } } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp index 34e8db977..cdfb230f4 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp @@ -16,8 +16,8 @@ i32vec2 repack(uint ib, uint iqs) { (vui >> 4) & 0x0F0F0F0F); } -ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) { - return ACC_TYPE(da * (float(q_sum) * dsb.x - 8.0f * dsb.y)); +ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { + return ACC_TYPE(da * (float(q_sum) * dsb.x - (8 / sum_divisor) * dsb.y)); } #endif @@ -29,8 +29,8 @@ i32vec2 repack(uint ib, uint iqs) { (vui >> 4) & 0x0F0F0F0F); } -ACC_TYPE mul_q8_1(int32_t q_sum, vec2 dma, vec2 dsb) { - return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y); +ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) { + return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor); } #endif @@ -50,8 +50,8 @@ i32vec2 repack(uint ib, uint iqs) { return i32vec2(v0, v1); } -ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) { - return ACC_TYPE(da * (float(q_sum) * dsb.x - 16.0f * dsb.y)); +ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { + return ACC_TYPE(da * (float(q_sum) * dsb.x - (16 / sum_divisor) * dsb.y)); } #endif @@ -69,8 +69,8 @@ i32vec2 repack(uint ib, uint iqs) { return i32vec2(v0, v1); } -ACC_TYPE mul_q8_1(int32_t q_sum, vec2 dma, vec2 dsb) { - return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y); +ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) { + return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor); } #endif @@ -81,7 +81,7 @@ int32_t repack(uint ib, uint iqs) { data_a[ib].qs[iqs * 2 + 1])); } -ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) { +ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { return ACC_TYPE(float(q_sum) * da * dsb.x); } #endif diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp new file mode 100644 index 000000000..854a2ad81 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp @@ -0,0 +1,111 @@ +#version 450 + +#extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_nonuniform_qualifier : enable +#extension GL_EXT_control_flow_attributes : require +#if ADD_RMS +#extension GL_KHR_shader_subgroup_arithmetic : enable +#extension GL_KHR_shader_subgroup_basic : enable +#endif + +#include "rte.comp" +#include "types.comp" +#include "utils.comp" + +layout (push_constant) uniform parameter2 +{ + // shape for dst + uint ne20; uint ne21; uint ne22; uint ne23; + + // strides for srcs+dst + uint nb[12][4]; + + uint rms_partials; +} p; + +// Workaround for MoltenVK Bug, see https://github.com/ggml-org/llama.cpp/issues/15498 +// layout (binding = 0) readonly buffer A {A_TYPE data_a[];} a[]; +// layout (binding = 0) writeonly buffer D {D_TYPE data_d[];} d[]; +layout (binding = 0) buffer A {A_TYPE data_a[];} a[]; +layout (binding = 0) buffer D {D_TYPE data_d[];} d[]; + +layout (binding = 0, std430) buffer PartialBuf {float partial_sums[];} partials[]; + +layout(constant_id = 0) const uint num_srcs = 2; + +uint src_idx(uint s, uint i00, uint i01, uint i02, uint i03) { + return i03*p.nb[s][3] + i02*p.nb[s][2] + i01*p.nb[s][1] + i00*p.nb[s][0]; +} + +uint dst_idx(uint i00, uint i01, uint i02, uint i03) { + uint nb20 = p.nb[num_srcs][0]; + uint nb21 = p.nb[num_srcs][1]; + uint nb22 = p.nb[num_srcs][2]; + uint nb23 = p.nb[num_srcs][3]; + return i03*nb23 + i02*nb22 + i01*nb21 + i00*nb20; +} + +uint get_idx() { + return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; +} + +const uint num_threads = 256; + +layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; + +#if ADD_RMS +// XXX TODO this could be sized based on number of subgroups, but that't not considered a constant +shared FLOAT_TYPE sumsh[num_threads]; +#endif + +void main() { + uint idx = get_idx(); + uint orig_idx = idx; + + uint ne = p.ne20 * p.ne21 * p.ne22 * p.ne23; + + // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation + const uint num_iter = 2; + + FLOAT_TYPE sum_sq = 0; + + [[unroll]] for (uint i = 0; i < num_iter; ++i) { + if (idx >= ne) { + continue; + } + uint i00, i01, i02, i03; + get_indices(idx, i00, i01, i02, i03, p.ne20, p.ne21, p.ne22, p.ne23); + + FLOAT_TYPE sum = FLOAT_TYPE(0); + [[unroll]] for (uint s = 0; s < num_srcs; ++s) { + sum += FLOAT_TYPE(a[s].data_a[src_idx(s, i00, i01, i02, i03)]); + } + sum_sq += sum*sum; + d[num_srcs].data_d[dst_idx(i00, i01, i02, i03)] = D_TYPE(sum); + + idx += num_threads; + } + +#if ADD_RMS + if (p.rms_partials != 0) { + // reduce the sum within each subgroup, then across subgroups + const uint NumSubgroups = num_threads / gl_SubgroupSize; + sum_sq = subgroupAdd(sum_sq); + if (gl_SubgroupInvocationID == 0) { + sumsh[gl_SubgroupID] = sum_sq; + } + barrier(); + [[unroll]] for (uint s = NumSubgroups / 2; s > 0; s >>= 1) { + if (gl_SubgroupID < s && gl_SubgroupInvocationID == 0) { + sum_sq += sumsh[gl_SubgroupID + s]; + sumsh[gl_SubgroupID] = sum_sq; + } + barrier(); + } + + if (gl_SubgroupID == 0 && gl_SubgroupInvocationID == 0) { + partials[num_srcs + 1].partial_sums[orig_idx / (num_iter * num_threads)] = sum_sq; + } + } +#endif +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp new file mode 100644 index 000000000..6426dedee --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp @@ -0,0 +1,22 @@ +#version 450 + +#include "generic_head.comp" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) buffer X {A_TYPE data_x[];}; +layout (binding = 1) readonly buffer G {A_TYPE data_grad[];}; +layout (binding = 2) readonly buffer P {float data_params[2];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float alpha = data_params[0]; + const float keep = 1.f - alpha * data_params[1]; + + data_x[i] = data_x[i] * keep - alpha * data_grad[i]; +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp index 450b67fc5..0d81220c7 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp @@ -1,7 +1,25 @@ #version 450 #include "types.comp" -#include "generic_unary_head.comp" + +layout (push_constant) uniform parameter +{ + uint ne; + uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03; + uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13; + uint misalign_offsets; + + uint lp0; uint rp0; + uint lp1; uint rp1; + uint lp2; uint rp2; + uint lp3; uint rp3; +} p; + +uint get_aoffset() { return p.misalign_offsets >> 16; } +uint get_doffset() { return p.misalign_offsets & 0xFFFF; } + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; @@ -19,10 +37,13 @@ void main() { const uint i1 = (idx - i3_offset - i2_offset) / p.ne10; const uint i0 = idx - i3_offset - i2_offset - i1*p.ne10; - const uint src0_idx = i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0*p.nb00; + const uint src0_idx = (i3 - p.lp3)*p.nb03 + (i2 - p.lp2)*p.nb02 + (i1 - p.lp1)*p.nb01 + (i0 - p.lp0)*p.nb00; const uint dst_idx = i3*p.nb13 + i2*p.nb12 + i1*p.nb11 + i0*p.nb10; - const bool is_src0 = i0 < p.ne00 && i1 < p.ne01 && i2 < p.ne02 && i3 < p.ne03; + const bool is_src0 = i0 >= p.lp0 && i0 < p.ne10 - p.rp0 && + i1 >= p.lp1 && i1 < p.ne11 - p.rp1 && + i2 >= p.lp2 && i2 < p.ne12 - p.rp2 && + i3 >= p.lp3 && i3 < p.ne13 - p.rp3; data_d[get_doffset() + dst_idx] = D_TYPE(is_src0 ? data_a[get_aoffset() + src0_idx] : 0.0f); } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp index e2e020fec..145c9fbdc 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp @@ -3,6 +3,15 @@ #extension GL_EXT_control_flow_attributes : require #extension GL_EXT_shader_16bit_storage : require +#ifdef USE_SUBGROUPS +#extension GL_KHR_shader_subgroup_basic : require +#extension GL_KHR_shader_subgroup_clustered : require + +#define INVOCATION_ID gl_SubgroupInvocationID.x +#else +#define INVOCATION_ID gl_LocalInvocationID.x +#endif + layout (push_constant) uniform parameter { uint ne; @@ -14,13 +23,19 @@ layout(constant_id = 0) const uint GROUP_SIZE = 32; layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer A {vec4 data_a[];}; +#ifndef QBLOCK_X4 layout (binding = 1) writeonly buffer D {block_q8_1_packed32 data_b[];}; +#else +layout (binding = 1) writeonly buffer D {block_q8_1_x4 data_b[];}; +#endif +#ifndef USE_SUBGROUPS shared float shmem[GROUP_SIZE]; +#endif void quantize() { const uint wgid = gl_WorkGroupID.x; - const uint tid = gl_LocalInvocationID.x; + const uint tid = INVOCATION_ID; // Each thread handles a vec4, so 8 threads handle a block const uint blocks_per_group = GROUP_SIZE / 8; @@ -30,9 +45,19 @@ void quantize() { const uint ib = wgid * blocks_per_group + block_in_wg; const uint iqs = tid % 8; +#ifndef QBLOCK_X4 if (ib >= gl_NumWorkGroups.x * blocks_per_group) { return; } +#else + const uint ibx4_outer = ib / 4; + const uint ibx4_inner = ib % 4; + + const uint required_x4_blocks = (p.ne + 127) / 128; + if (ibx4_outer >= required_x4_blocks) { + return; + } +#endif const uint a_idx = ib * 8 + iqs; @@ -40,7 +65,9 @@ void quantize() { const vec4 abs_vals = abs(vals); // Find absolute max for each block - shmem[tid] = max(max(abs_vals.x, abs_vals.y), max(abs_vals.z, abs_vals.w)); + const float thread_max = max(max(abs_vals.x, abs_vals.y), max(abs_vals.z, abs_vals.w)); +#ifndef USE_SUBGROUPS + shmem[tid] = thread_max; barrier(); [[unroll]] for (uint s = 4; s > 0; s >>= 1) { if (iqs < s) { @@ -50,14 +77,28 @@ void quantize() { } const float amax = shmem[block_in_wg * 8]; +#else + const float amax = subgroupClusteredMax(thread_max, 8); +#endif + const float d = amax / 127.0; const float d_inv = d != 0.0 ? 1.0 / d : 0.0; vals = round(vals * d_inv); + +#ifndef QBLOCK_X4 data_b[ib].qs[iqs] = pack32(i8vec4(round(vals))); +#else + data_b[ibx4_outer].qs[ibx4_inner * 8 + iqs] = pack32(i8vec4(round(vals))); +#endif + +#ifndef USE_SUBGROUPS barrier(); +#endif // Calculate the sum for each block - shmem[tid] = vals.x + vals.y + vals.z + vals.w; + const float thread_sum = vals.x + vals.y + vals.z + vals.w; +#ifndef USE_SUBGROUPS + shmem[tid] = thread_sum; barrier(); [[unroll]] for (uint s = 4; s > 0; s >>= 1) { if (iqs < s) { @@ -65,10 +106,19 @@ void quantize() { } barrier(); } +#else + const float sum = subgroupClusteredAdd(thread_sum, 8); +#endif if (iqs == 0) { +#ifndef USE_SUBGROUPS const float sum = shmem[tid]; +#endif +#ifndef QBLOCK_X4 data_b[ib].ds = f16vec2(vec2(d, sum * d)); +#else + data_b[ibx4_outer].ds[ibx4_inner] = f16vec2(vec2(d, sum * d)); +#endif } } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp index bdd7db2d6..41197e930 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp @@ -10,9 +10,9 @@ layout (constant_id = 1) const bool do_multiply = false; layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; -shared FLOAT_TYPE sum[BLOCK_SIZE]; +shared FLOAT_TYPE sumsh[BLOCK_SIZE]; -void main() { +void rms_norm(uint num_iters) { const uint ncols = p.ne00; const uint nrows = gl_NumWorkGroups.x; const uint nchannels = gl_NumWorkGroups.y; @@ -30,38 +30,76 @@ void main() { uint32_t b_offset = src1_idx(0, row, channel, samp) + get_boffset(); uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset(); - sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp + FLOAT_TYPE sum = FLOAT_TYPE(0.0f); // partial sum for thread in warp - [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) { - const FLOAT_TYPE xi = FLOAT_TYPE(data_a[a_offset + col]); - sum[tid] += xi * xi; + [[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) { + FLOAT_TYPE xi = FLOAT_TYPE(0); + if (col < ncols) { + xi = FLOAT_TYPE(data_a[a_offset + col]); + } + sum += xi * xi; } + sumsh[tid] = sum; // sum up partial sums and write back result barrier(); [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { if (tid < s) { - sum[tid] += sum[tid + s]; + sum += sumsh[tid + s]; + sumsh[tid] = sum; } barrier(); } + sum = sumsh[0]; - const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(ncols); + const FLOAT_TYPE mean = sum / FLOAT_TYPE(ncols); const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1)); if (do_multiply) { if (ncols > p.ne10) { - [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) { + [[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) { + if (col >= ncols) { + continue; + } data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + fastmod(col, p.ne10)])); } } else { - [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) { + [[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) { + if (col >= ncols) { + continue; + } data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col])); } } } else { - [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) { + [[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) { + if (col >= ncols) { + continue; + } data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col])); } } } + +void main() { + // instantiate the rms_norm function for several different + // dimensions, to allow loop unrolling + uint num_blocks = (p.ne00 + BLOCK_SIZE - 1) / BLOCK_SIZE; + if (num_blocks > 32) { + rms_norm(num_blocks); + } else if (num_blocks > 16) { + rms_norm(32); + } else if (num_blocks > 8) { + rms_norm(16); + } else if (num_blocks > 4) { + rms_norm(8); + } else if (num_blocks == 4) { + rms_norm(4); + } else if (num_blocks == 3) { + rms_norm(3); + } else if (num_blocks == 2) { + rms_norm(2); + } else if (num_blocks == 1) { + rms_norm(1); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp new file mode 100644 index 000000000..ba4677c29 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp @@ -0,0 +1,65 @@ +#version 450 + +#include "generic_binary_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_KHR_shader_subgroup_arithmetic : enable +#extension GL_KHR_shader_subgroup_basic : enable + +#define BLOCK_SIZE 128 + +layout (constant_id = 1) const bool do_multiply = false; + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 3, std430) readonly buffer PartialsBuf {float partial_sums[];}; + +shared FLOAT_TYPE sumsh[BLOCK_SIZE]; + +void main() { + const uint ncols = p.ne00; + const uint nrows = gl_NumWorkGroups.x; + const uint nchannels = gl_NumWorkGroups.y; + + const uint row = 0; + const uint channel = gl_WorkGroupID.y; + const uint samp = gl_WorkGroupID.z; + // The work is split across multiple workgroups in the x dimension. Each invocation + // processes one element + const uint tid = gl_GlobalInvocationID.x; + + const uint stride_row = p.nb01; + const uint stride_channel = p.nb02; + const uint stride_sample = p.nb03; + + uint32_t a_offset = samp*stride_sample + channel*stride_channel + row*stride_row + get_aoffset(); + uint32_t b_offset = src1_idx(0, row, channel, samp) + get_boffset(); + uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset(); + + FLOAT_TYPE sum = FLOAT_TYPE(0.0f); // partial sum for thread in warp + + uint32_t num_partials = p.param3; + for (uint32_t i = gl_SubgroupInvocationID; i < num_partials; i += gl_SubgroupSize) { + sum += partial_sums[i]; + } + sum = subgroupAdd(sum); + + uint col = tid; + if (col >= ncols) { + return; + } + + const FLOAT_TYPE mean = sum / FLOAT_TYPE(ncols); + const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1)); + + if (do_multiply) { + if (ncols > p.ne10) { + data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + fastmod(col, p.ne10)])); + } else { + data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col])); + } + } else { + data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col])); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp index 29bd77d7e..144ea58e6 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp @@ -20,6 +20,10 @@ void main() { const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; const uint tid = gl_LocalInvocationID.x; + if (row >= p.KY) { + return; + } + FLOAT_TYPE scale = p.param1; // partial sums for thread in warp diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp new file mode 100644 index 000000000..4bc697b9b --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp @@ -0,0 +1,17 @@ +#version 450 + +#include "types.comp" +#include "generic_unary_head.comp" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint idx = get_idx(); + + if (idx >= p.ne) { + return; + } + + const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]); + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(sqrt(val)); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp index 961e5ffa1..759204afa 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp @@ -1,9 +1,9 @@ #version 450 -#include "generic_head.comp" #include "types.comp" #extension GL_EXT_control_flow_attributes : enable + layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; @@ -11,16 +11,49 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; layout (constant_id = 0) const uint BLOCK_SIZE = 32; +layout (push_constant) uniform parameter +{ + uint n_cols; + uint ne01, ne02; + uint nb01, nb02, nb03; + uint nb11, nb12, nb13; + float weight; + uint misalign_offsets; + uint ne0_12mp, ne0_12L; + uint ne0_1mp, ne0_1L; +} p; + +uint get_aoffset() { return p.misalign_offsets >> 16; } +uint get_doffset() { return p.misalign_offsets & 0xFFFF; } + +// see init_fastdiv_values in ggml-vulkan.cpp +uint fastdiv(uint n, uint mp, uint L) { + uint msbs, lsbs; + // msbs = mulhi(n, mp) + umulExtended(n, mp, msbs, lsbs); + return (msbs + n) >> L; +} + + shared FLOAT_TYPE tmp[BLOCK_SIZE]; void main() { const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; const uint col = gl_LocalInvocationID.x; + const float weight = p.weight; - tmp[col] = FLOAT_TYPE(0.0f); + const uint i03 = fastdiv(row, p.ne0_12mp, p.ne0_12L); + const uint i03_offset = i03 * p.ne01*p.ne02; + const uint i02 = fastdiv(row - i03_offset, p.ne0_1mp, p.ne0_1L); + const uint i01 = row - i03_offset - i02*p.ne01; - for (uint i = col; i < p.KX; i += BLOCK_SIZE) { - tmp[col] += FLOAT_TYPE(data_a[row*p.KX + i]); + const uint src_idx = get_aoffset() + i01 * p.nb01 + i02 * p.nb02 + i03 * p.nb03; + const uint dst_idx = get_doffset() + i01 * p.nb11 + i02 * p.nb12 + i03 * p.nb13; + + tmp[col] = FLOAT_TYPE(0.0); + + for (uint i = col; i < p.n_cols; i += BLOCK_SIZE) { + tmp[col] += FLOAT_TYPE(data_a[src_idx + i]); } barrier(); @@ -32,6 +65,6 @@ void main() { } if (col == 0) { - data_d[row] = D_TYPE(tmp[0]); + data_d[dst_idx] = D_TYPE(tmp[0] * weight); } } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp index 79e065a93..ce8e09442 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp @@ -24,11 +24,12 @@ void main() { const uint j = gl_GlobalInvocationID.x; const uint d_offset = i * p.nb1; - if (p.dim % 2 != 0 && j == ((p.dim + 1) / 2)) { - data_d[d_offset + p.dim] = 0.f; + const uint half_dim = p.dim / 2; + + if (p.dim % 2 != 0 && j == half_dim) { + data_d[d_offset + 2 * half_dim] = 0.f; } - const uint half_dim = p.dim / 2; if (j >= half_dim) { return; } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/types.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/types.comp index a36c33e26..2fa54ce51 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/types.comp @@ -11,12 +11,12 @@ #define QUANT_K 1 #define QUANT_R 1 -#if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1 -#define A_TYPE float -#elif LOAD_VEC_A == 4 +#if LOAD_VEC_A == 4 #define A_TYPE vec4 #elif LOAD_VEC_A == 8 #define A_TYPE mat2x4 +#else +#define A_TYPE float #endif #endif @@ -24,12 +24,12 @@ #define QUANT_K 1 #define QUANT_R 1 -#if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1 -#define A_TYPE float16_t -#elif LOAD_VEC_A == 4 +#if LOAD_VEC_A == 4 #define A_TYPE f16vec4 #elif LOAD_VEC_A == 8 #define A_TYPE f16mat2x4 +#else +#define A_TYPE float16_t #endif #endif @@ -37,12 +37,12 @@ #define QUANT_K 1 #define QUANT_R 1 -#if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1 -#define A_TYPE uint16_t -#elif LOAD_VEC_A == 4 +#if LOAD_VEC_A == 4 #define A_TYPE u16vec4 #elif LOAD_VEC_A == 8 #error unsupported +#else +#define A_TYPE uint16_t #endif #endif @@ -207,6 +207,18 @@ struct block_q8_1_packed32 int32_t qs[8]; }; +// 4 blocks in one to allow 16-byte/128-bit alignment and loads +struct block_q8_1_x4 +{ + f16vec2 ds[4]; + int32_t qs[32]; +}; +struct block_q8_1_x4_packed128 +{ + f16vec2 ds[4]; + ivec4 qs[8]; +}; + // K-quants #define QUANT_K_Q2_K 256 @@ -233,6 +245,7 @@ struct block_q2_K_packed32 #if defined(DATA_A_Q2_K) #define QUANT_K QUANT_K_Q2_K +#define QUANT_R 1 #define A_TYPE block_q2_K #define A_TYPE_PACKED16 block_q2_K_packed16 #define A_TYPE_PACKED32 block_q2_K_packed32 @@ -258,6 +271,7 @@ struct block_q3_K_packed16 #if defined(DATA_A_Q3_K) #define QUANT_K QUANT_K_Q3_K +#define QUANT_R 1 #define A_TYPE block_q3_K #define A_TYPE_PACKED16 block_q3_K_packed16 #endif @@ -292,6 +306,7 @@ struct block_q4_K_packed128 #if defined(DATA_A_Q4_K) #define QUANT_K QUANT_K_Q4_K +#define QUANT_R 1 #define A_TYPE block_q4_K #define A_TYPE_PACKED16 block_q4_K_packed16 #define A_TYPE_PACKED32 block_q4_K_packed32 @@ -322,6 +337,7 @@ struct block_q5_K_packed128 #if defined(DATA_A_Q5_K) #define QUANT_K QUANT_K_Q5_K +#define QUANT_R 1 #define A_TYPE block_q5_K #define A_TYPE_PACKED16 block_q5_K_packed16 #endif @@ -346,6 +362,7 @@ struct block_q6_K_packed16 #if defined(DATA_A_Q6_K) #define QUANT_K QUANT_K_Q6_K +#define QUANT_R 1 #define A_TYPE block_q6_K #define A_TYPE_PACKED16 block_q6_K_packed16 #endif @@ -1412,6 +1429,11 @@ float bf16_to_fp32(uint32_t u) return uintBitsToFloat(u << 16); } +vec4 bf16_to_fp32(uvec4 u) +{ + return vec4(bf16_to_fp32(u.x), bf16_to_fp32(u.y), bf16_to_fp32(u.z), bf16_to_fp32(u.w)); +} + float e8m0_to_fp32(uint8_t x) { uint32_t bits; @@ -1425,4 +1447,19 @@ float e8m0_to_fp32(uint8_t x) { return uintBitsToFloat(bits); } +#if BDA + +#extension GL_EXT_buffer_reference : enable +#extension GL_EXT_shader_explicit_arithmetic_types_int64 : enable + +#define BDA_STORAGE_T uint64_t +#define BDA_OFFSET_T uint64_t + +#else + +#define BDA_STORAGE_T uvec2 +#define BDA_OFFSET_T uint + +#endif + #endif // !defined(GGML_TYPES_COMP) diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp new file mode 100644 index 000000000..dc4a1e6d9 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp @@ -0,0 +1,25 @@ +#ifndef UTILS_COMP +#define UTILS_COMP + +// mod and div are expensive and coordinates/dimensions are often power of 2 or equal to 1 +uint fastmod(uint a, uint b) { + if ((b & (b-1)) == 0) { + return a & (b-1); + } + return a % b; +} + +uint fastdiv(uint a, uint b) { + return (a < b) ? 0 : (a / b); +} + +void get_indices(uint idx, out uint i00, out uint i01, out uint i02, out uint i03, uint ne00, uint ne01, uint ne02, uint ne03) { + i03 = fastdiv(idx, (ne02*ne01*ne00)); + const uint i03_offset = i03 * ne02*ne01*ne00; + i02 = fastdiv((idx - i03_offset), (ne01*ne00)); + const uint i02_offset = i02*ne01*ne00; + i01 = (idx - i03_offset - i02_offset) / ne00; + i00 = idx - i03_offset - i02_offset - i01*ne00; +} + +#endif // UTILS_COMP diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 4cd94c51e..84bb9df9a 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -68,6 +68,12 @@ const std::vector type_names = { "bf16", }; +enum MatMulIdType { + NONE, + DEFAULT, + SUBGROUP, +}; + namespace { void execute_command(const std::string& command, std::string& stdout_str, std::string& stderr_str) { #ifdef _WIN32 @@ -200,6 +206,22 @@ bool string_ends_with(const std::string& str, const std::string& suffix) { return std::equal(suffix.rbegin(), suffix.rend(), str.rbegin()); } +bool is_quantized_type(const std::string& type_name) { + return type_name != "f32" && type_name != "f16" && type_name != "bf16"; +} + +bool is_legacy_quant(const std::string& type_name) { + return type_name == "q4_0" || type_name == "q4_1" || type_name == "q5_0" || type_name == "q5_1" || type_name == "q8_0"; +} + +bool is_k_quant(const std::string& type_name) { + return string_ends_with(type_name, "_k"); +} + +bool is_iq_quant(const std::string& type_name) { + return string_starts_with(type_name, "iq"); +} + static const char path_separator = '/'; std::string join_paths(const std::string& path1, const std::string& path2) { @@ -223,7 +245,8 @@ void string_to_spv_func(const std::string& _name, const std::string& in_fname, c std::string target_env = (name.find("_cm2") != std::string::npos) ? "--target-env=vulkan1.3" : "--target-env=vulkan1.2"; // disable spirv-opt for coopmat shaders for https://github.com/ggerganov/llama.cpp/issues/10734 - std::string opt_level = coopmat ? "" : "-O"; + // disable spirv-opt for bf16 shaders for https://github.com/ggml-org/llama.cpp/issues/15344 + std::string opt_level = (coopmat || name.find("bf16") != std::string::npos) ? "" : "-O"; #ifdef _WIN32 std::vector cmd = {GLSLC, "-fshader-stage=compute", target_env, opt_level, "\"" + in_path + "\"", "-o", "\"" + out_fname + "\""}; @@ -292,26 +315,32 @@ void string_to_spv(const std::string& _name, const std::string& in_fname, const compiles.push_back(std::async(string_to_spv_func, _name, in_fname, defines, fp16, coopmat, coopmat2, f16acc)); } -void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool f16acc) { +void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool coopmat2, bool f16acc) { std::string load_vec = coopmat2 ? "1" : fp16 ? "8" : "4"; std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4"; std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4"; - std::map base_dict = { - {"FLOAT_TYPE_VEC2", (coopmat2 || fp16) ? "f16vec2" : "vec2"}, - }; + std::map base_dict; std::string shader_name = "matmul"; - if (matmul_id) { + if (matmul_id_type == MatMulIdType::DEFAULT) { base_dict["MUL_MAT_ID"] = "1"; shader_name = "matmul_id"; + } else if (matmul_id_type == MatMulIdType::SUBGROUP) { + base_dict["MUL_MAT_ID"] = "1"; + base_dict["MUL_MAT_ID_USE_SUBGROUPS"] = "1"; + shader_name = "matmul_id_subgroup"; } if (fp16) { base_dict["FLOAT16"] = "1"; } - base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float"; + base_dict["ACC_TYPE" ] = f16acc ? "float16_t" : "float"; + base_dict["ACC_TYPE_VEC2"] = f16acc ? "f16vec2" : "vec2"; + if (f16acc) { + base_dict["ACC_TYPE_MAX"] = "\"float16_t(65504.0)\""; + } if (coopmat) { base_dict["COOPMAT"] = "1"; @@ -319,43 +348,96 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool const std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp"; - auto const &FLOAT_TYPE = [&](const std::string &t) -> std::string { - if (t == "bf16") { - // scalar path promotes to float - if (!coopmat && !coopmat2) { - return "float"; + auto const &FLOAT_TYPE = [&](int vec, const std::string &t) -> std::string { + switch (vec) { + case 1: + if (t == "bf16") { + // scalar path promotes to float + if (!coopmat && !coopmat2) { + return "float"; + } + return "bfloat16_t"; } - return "bfloat16_t"; + if (coopmat2 || fp16) { + return "float16_t"; + } + return "float"; + case 2: + if (t == "bf16") { + // scalar path promotes to float + if (!coopmat && !coopmat2) { + return "vec2"; + } + return "bf16vec2"; + } + if (coopmat2 || fp16) { + return "f16vec2"; + } + return "vec2"; + case 4: + if (t == "bf16") { + // scalar path promotes to float + if (!coopmat && !coopmat2) { + return "vec4"; + } + return "bf16vec4"; + } + if (coopmat2 || fp16) { + return "f16vec4"; + } + return "vec4"; + case 8: + if (t == "bf16") { + // scalar path promotes to float + if (!coopmat && !coopmat2) { + return "mat2x4"; + } + throw std::runtime_error("bf16 vec8 not supported"); + } + if (coopmat2 || fp16) { + return "f16mat2x4"; + } + return "mat2x4"; + default: + throw std::runtime_error("invalid vector size"); } - if (coopmat2 || fp16) { - return "float16_t"; - } - return "float"; + }; + + const std::map float_type_dict_f16 = { + {"FLOAT_TYPE", FLOAT_TYPE(1, "f16")}, + {"FLOAT_TYPE_VEC2", FLOAT_TYPE(2, "f16")}, + {"FLOAT_TYPE_VEC4", FLOAT_TYPE(4, "f16")}, + {"FLOAT_TYPE_VEC8", FLOAT_TYPE(8, "f16")}, }; // Shaders with f16 B_TYPE - string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); // bf16 { - std::string load_vec_a_unaligned = "1"; // For aligned matmul loads std::string load_vec_a = coopmat2 ? "1" : "4"; // scalar path promotes to float std::string to_float_type = (coopmat || coopmat2) ? "uintBitsToBFloat16EXT" : "bf16_to_fp32"; + const std::map float_type_dict_bf16 = { + {"FLOAT_TYPE", FLOAT_TYPE(1, "bf16")}, + {"FLOAT_TYPE_VEC2", FLOAT_TYPE(2, "bf16")}, + {"FLOAT_TYPE_VEC4", FLOAT_TYPE(4, "bf16")}, + }; + // If bfloat16 is not supported, then only compile the scalar (promote to fp32) shader #if !defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) if (!(coopmat || coopmat2)) #endif { - string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_bf16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_bf16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); } } @@ -376,20 +458,27 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool // For aligned matmul loads std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? load_vec : load_vec_quant; + const std::map float_type_dict = { + {"FLOAT_TYPE", FLOAT_TYPE(1, tname)}, + {"FLOAT_TYPE_VEC2", FLOAT_TYPE(2, tname)}, + {"FLOAT_TYPE_VEC4", FLOAT_TYPE(4, tname)}, + {"FLOAT_TYPE_VEC8", FLOAT_TYPE(8, tname)}, + }; + // don't generate f32 variants for coopmat2 if (!coopmat2) { - string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); } if (tname != "f16" && tname != "f32") { - string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); } #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) - if (!coopmat && !coopmat2 && !matmul_id && (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "q8_0")) { - string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc); + if (!coopmat && !coopmat2 && matmul_id_type == MatMulIdType::NONE && is_legacy_quant(tname)) { + string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc); } #endif } @@ -400,32 +489,38 @@ void process_shaders() { std::map base_dict = {{"FLOAT_TYPE", "float"}}; // matmul - for (const auto& matmul_id : {false, true}) { + for (const MatMulIdType& matmul_id_type : {MatMulIdType::NONE, MatMulIdType::DEFAULT, MatMulIdType::SUBGROUP}) { // No coopmats // fp32 - matmul_shaders(false, matmul_id, false, false, false); + matmul_shaders(false, matmul_id_type, false, false, false); // fp16, fp32acc and fp16acc - matmul_shaders(true, matmul_id, false, false, false); - matmul_shaders(true, matmul_id, false, false, true); + matmul_shaders(true, matmul_id_type, false, false, false); + matmul_shaders(true, matmul_id_type, false, false, true); + if (matmul_id_type != MatMulIdType::DEFAULT) { #if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) - // Coopmat, fp32acc and fp16acc - matmul_shaders(true, matmul_id, true, false, false); - matmul_shaders(true, matmul_id, true, false, true); + // Coopmat, fp32acc and fp16acc + matmul_shaders(true, matmul_id_type, true, false, false); + matmul_shaders(true, matmul_id_type, true, false, true); #endif #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) - // Coopmat2, fp32acc and fp16acc - matmul_shaders(true, matmul_id, false, true, false); - matmul_shaders(true, matmul_id, false, true, true); + // Coopmat2, fp32acc and fp16acc + matmul_shaders(true, matmul_id_type, false, true, false); + matmul_shaders(true, matmul_id_type, false, true, true); #endif + } } // flash attention for (const auto& f16acc : {false, true}) { - std::string acctype = f16acc ? "float16_t" : "float"; - std::string acctypev4 = f16acc ? "f16vec4" : "vec4"; + std::map fa_base_dict = base_dict; + fa_base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float"; + fa_base_dict["ACC_TYPEV4"] = f16acc ? "f16vec4" : "vec4"; + if (f16acc) { + fa_base_dict["ACC_TYPE_MAX"] = "\"float16_t(65504.0)\""; + } for (const auto& tname : type_names) { if (tname == "f32") { @@ -436,30 +531,30 @@ void process_shaders() { #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) if (tname == "f16") { string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", - merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, true, f16acc); + merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, true, f16acc); } else { std::string data_a_key = "DATA_A_" + to_uppercase(tname); string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", - merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc); + merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc); } #endif #if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) if (tname == "f16") { string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", - merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"ACC_TYPEV4", acctypev4}, {"COOPMAT", "1"}}), true, true, false, f16acc); + merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"COOPMAT", "1"}}), true, true, false, f16acc); } else if (tname == "q4_0" || tname == "q8_0") { std::string data_a_key = "DATA_A_" + to_uppercase(tname); string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", - merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"ACC_TYPEV4", acctypev4}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc); + merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc); } #endif if (tname == "f16") { string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", - merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, false, f16acc); + merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, false, f16acc); } else if (tname == "q4_0" || tname == "q8_0") { std::string data_a_key = "DATA_A_" + to_uppercase(tname); string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", - merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc); + merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc); } } } @@ -472,23 +567,36 @@ void process_shaders() { string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}})); string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}})); + string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); + string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); + + string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); + string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); + string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}})); + // mul mat vec with integer dot product +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + if (is_legacy_quant(tname)) { + string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}})); + string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); + string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); + } +#endif + // Dequant shaders if (tname != "f16" && tname != "bf16") { string_to_spv("dequant_" + tname, "dequant_" + tname + ".comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float16_t"}})); } - if (!string_ends_with(tname, "_k")) { - shader = (tname == "f32" || tname == "f16" || tname == "bf16") ? "get_rows.comp" : "get_rows_quant.comp"; + shader = (tname == "f32" || tname == "f16" || tname == "bf16") ? "get_rows.comp" : "get_rows_quant.comp"; - if (tname == "f16") { - string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}})); - } else { - string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}})); - } - string_to_spv("get_rows_" + tname + "_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}})); + if (tname == "f16") { + string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}})); + } else { + string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}})); } + string_to_spv("get_rows_" + tname + "_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}})); } string_to_spv("mul_mat_vec_p021_f16_f32_subgroup_add", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}); @@ -499,6 +607,7 @@ void process_shaders() { string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("rms_norm_partials_f32", "rms_norm_partials.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); @@ -508,10 +617,14 @@ void process_shaders() { string_to_spv("cpy_f16_f32", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); string_to_spv("cpy_f32_bf16","copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}}); string_to_spv("contig_cpy_f32_f32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("contig_cpy_f32_i32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "int"}}); + string_to_spv("contig_cpy_i32_f32", "contig_copy.comp", {{"A_TYPE", "int"}, {"D_TYPE", "float"}}); string_to_spv("contig_cpy_f32_f16", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}); string_to_spv("contig_cpy_f16_f16", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); string_to_spv("contig_cpy_f16_f32", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); string_to_spv("contig_cpy_f32_bf16","contig_copy.comp",{{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}}); + string_to_spv("cpy_f32_i32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "int"}}); + string_to_spv("cpy_i32_f32", "copy.comp", {{"A_TYPE", "int"}, {"D_TYPE", "float"}}); for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) { string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); @@ -520,8 +633,10 @@ void process_shaders() { } for (std::string t : {"f32", "f16", "bf16", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) { - string_to_spv("set_rows_" + t, "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); - string_to_spv("set_rows_" + t + "_rte", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}}); + string_to_spv("set_rows_" + t + "_i32", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uint"}, {"B_SIZE", "32"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + string_to_spv("set_rows_" + t + "_i32_rte", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uint"}, {"B_SIZE", "32"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}}); + string_to_spv("set_rows_" + t + "_i64", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"B_SIZE", "64"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + string_to_spv("set_rows_" + t + "_i64_rte", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"B_SIZE", "64"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}}); } auto get_type_str = [](bool f16) { @@ -534,13 +649,15 @@ void process_shaders() { s += std::string(dst_f16 ? "_f16" : "_f32"); return s; }; - for (std::string op : {"add", "sub", "mul", "div"}) { + for (std::string op : {"add", "sub", "mul", "div", "add_rms", }) { for (auto src0_f16 : {false, true}) { for (auto src1_f16 : {false, true}) { for (auto dst_f16 : {false, true}) { for (auto rte : {false, true}) { + auto source = op == "add_rms" ? std::string("add") : op; auto name = op + get_suffix(src0_f16, src1_f16, dst_f16) + (rte ? "_rte" : ""); - string_to_spv(name.c_str(), op + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); + auto add_rms = op == "add_rms" ? "1" : "0"; + string_to_spv(name.c_str(), source + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}, {"ADD_RMS" , add_rms}}); } } } @@ -553,7 +670,12 @@ void process_shaders() { string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {}); string_to_spv("fa_split_k_reduce", "flash_attn_split_k_reduce.comp", {}); + string_to_spv("quantize_q8_1", "quantize_q8_1.comp", {}); + string_to_spv("quantize_q8_1_subgroup", "quantize_q8_1.comp", {{"USE_SUBGROUPS", "1"}}); + + string_to_spv("quantize_q8_1_x4", "quantize_q8_1.comp", {{"QBLOCK_X4", "1"}}); + string_to_spv("quantize_q8_1_x4_subgroup", "quantize_q8_1.comp", {{"QBLOCK_X4", "1"}, {"USE_SUBGROUPS", "1"}}); string_to_spv("mul_f32", "mul.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); @@ -566,6 +688,8 @@ void process_shaders() { string_to_spv("sqr_f32", "square.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + string_to_spv("sqrt_f32", "sqrt.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + string_to_spv("sin_f32", "sin.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); string_to_spv("cos_f32", "cos.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); @@ -580,6 +704,11 @@ void process_shaders() { string_to_spv("upscale_f32", "upscale.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); + for (auto rte : {false, true}) { + std::string suffix = rte ? "_rte" : ""; + string_to_spv("exp_f16" + suffix, "exp.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); + string_to_spv("exp_f32" + suffix, "exp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"} , {"RTE16", rte ? "1" : "0"}}); + } string_to_spv("gelu_f16", "gelu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("gelu_erf_f16", "gelu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); @@ -594,6 +723,10 @@ void process_shaders() { string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("sigmoid_f16", "sigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("sigmoid_f32", "sigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("hardsigmoid_f16","hardsigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("hardsigmoid_f32","hardsigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("hardswish_f16", "hardswish.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("hardswish_f32", "hardswish.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); for (auto rte : {false, true}) { std::string suffix = rte ? "_rte" : ""; @@ -642,9 +775,15 @@ void process_shaders() { string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}})); - string_to_spv("im2col_f32", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); - string_to_spv("im2col_f32_f16", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}})); - string_to_spv("im2col_f32_f16_rte", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}})); + for (std::string dim_str : {"", "_3d"}) { + for (bool bda : {false, true}) { + std::string bda_str = bda ? "_bda" : ""; + std::string bda_def = bda ? "1" : "0"; + string_to_spv("im2col" + dim_str + "_f32" + bda_str, "im2col" + dim_str + ".comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"D_SIZE", "4"}, {"BDA", bda_def}})); + string_to_spv("im2col" + dim_str + "_f32_f16" + bda_str, "im2col" + dim_str + ".comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"D_SIZE", "2"}, {"BDA", bda_def}})); + string_to_spv("im2col" + dim_str + "_f32_f16_rte" + bda_str, "im2col" + dim_str + ".comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"D_SIZE", "2"}, {"RTE16", "1"}, {"BDA", bda_def}})); + } + } string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); @@ -657,25 +796,41 @@ void process_shaders() { string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); + string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); - string_to_spv("conv2d_f32_unroll", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}}); - string_to_spv("conv2d_f16_f32_unroll", "conv2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}}); - - string_to_spv("conv2d_f32", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", ""}}); - string_to_spv("conv2d_f16_f32", "conv2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", ""}}); - + for (auto transpose : {false, true}) { + for (auto unroll : {false, true}) { + for (auto a_f16 : {false, true}) { + std::map defines = { + {"A_TYPE", a_f16 ? "float16_t" : "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, + {"USE_COLLECTIVES", "1"}, {"UNROLL", unroll ? "[[unroll]]" : ""}, + }; + if (transpose) defines["TRANSPOSE"] = "1"; + std::string name = std::string(transpose ? "conv_transpose_2d": "conv2d") + + (a_f16 ? "_f16" : "") + "_f32"; + string_to_spv(name + (unroll ? "_unroll" : ""), "conv2d_mm.comp", defines); #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) - string_to_spv("conv2d_f32", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}, {"COOPMAT2", "1"}}, true, false, true); - string_to_spv("conv2d_f16_f32", "conv2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}, {"COOPMAT2", "1"}}, true, false, true); + if (unroll) { + defines["COOPMAT2"] = "1"; + string_to_spv(name, "conv2d_mm.comp", defines, true, false, true); + } #endif + } + } + } string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}})); string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}})); + string_to_spv("conv2d_dw_whcn_f16_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}})); + string_to_spv("conv2d_dw_cwhn_f16_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}})); string_to_spv("roll_f32", "roll.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("add_id_f32", "add_id.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "0"}}); + string_to_spv("multi_add_rms_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "1"}}); + for (auto &c : compiles) { c.wait(); } @@ -732,7 +887,7 @@ void write_output_files() { } std::string suffixes[2] = {"_f32", "_f16"}; - for (const char *op : {"add", "sub", "mul", "div"}) { + for (const char *op : {"add", "sub", "mul", "div", "add_rms"}) { fprintf(hdr, "extern unsigned char *%s_data[2][2][2][2];\n", op); fprintf(hdr, "extern uint64_t %s_len[2][2][2][2];\n", op); std::string data = "unsigned char *" + std::string(op) + "_data[2][2][2][2] = "; @@ -784,6 +939,27 @@ void write_output_files() { fputs(data.c_str(), src); fputs(len.c_str(), src); } + + std::vector btypes = {"f16", "f32"}; + +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + btypes.push_back("q8_1"); +#endif + + for (const std::string& btype : btypes) { + for (const auto& tname : type_names) { + if (btype == "q8_1" && !is_legacy_quant(tname)) { + continue; + } + fprintf(hdr, "extern unsigned char *arr_dmmv_%s_%s_f32_data[3];\n", tname.c_str(), btype.c_str()); + fprintf(hdr, "extern uint64_t arr_dmmv_%s_%s_f32_len[3];\n", tname.c_str(), btype.c_str()); + std::string data = "unsigned char *arr_dmmv_" + tname + "_" + btype + "_f32_data[3] = {mul_mat_vec_" + tname + "_" + btype + "_f32_data, mul_mat_vec_" + tname + "_" + btype + "_f32_subgroup_data, mul_mat_vec_" + tname + "_" + btype + "_f32_subgroup_no_shmem_data};\n"; + std::string len = "uint64_t arr_dmmv_" + tname + "_" + btype + "_f32_len[3] = {mul_mat_vec_" + tname + "_" + btype + "_f32_len, mul_mat_vec_" + tname + "_" + btype + "_f32_subgroup_len, mul_mat_vec_" + tname + "_" + btype + "_f32_subgroup_no_shmem_len};\n"; + fputs(data.c_str(), src); + fputs(len.c_str(), src); + } + } + fclose(hdr); fclose(src); } From 8ad169403b6115e5052ee0129d349b54e781da42 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sat, 4 Oct 2025 19:25:34 +0200 Subject: [PATCH 140/172] update build windows script --- scripts/build_windows.ps1 | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/scripts/build_windows.ps1 b/scripts/build_windows.ps1 index 171acace2..36c95c0a0 100644 --- a/scripts/build_windows.ps1 +++ b/scripts/build_windows.ps1 @@ -169,7 +169,8 @@ function buildROCm() { -DCMAKE_C_COMPILER=clang ` -DCMAKE_CXX_COMPILER=clang++ ` -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" ` - -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" + -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" ` + --install-prefix $script:DIST_DIR if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} $env:HIPCXX="" $env:HIP_PLATFORM="" @@ -327,4 +328,4 @@ try { } finally { set-location $script:SRC_DIR $env:PKG_VERSION="" -} +} \ No newline at end of file From f8551bc631835b9d95b3b5ec9bf88fa557d01032 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sat, 4 Oct 2025 21:28:15 +0200 Subject: [PATCH 141/172] merge fixes --- discover/runner.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/discover/runner.go b/discover/runner.go index 7a659367c..3842392ec 100644 --- a/discover/runner.go +++ b/discover/runner.go @@ -167,6 +167,7 @@ func GPUDevices(ctx context.Context, runners []FilteredRunnerDiscovery) []ml.Dev devices = append(devices[:i], devices[i+1:]...) needsDelete = append(needsDelete[:i], needsDelete[i+1:]...) i-- + } else if devices[i].Library == "ROCm" { if _, err := strconv.Atoi(devices[i].ID); err == nil { // Replace the numeric ID with the post-filtered IDs devices[i].FilteredID = devices[i].ID @@ -176,6 +177,7 @@ func GPUDevices(ctx context.Context, runners []FilteredRunnerDiscovery) []ml.Dev } } + // Now filter out any overlap with different libraries (favor CUDA/HIP over others) for i := 0; i < len(devices); i++ { for j := i + 1; j < len(devices); j++ { // For this pass, we only drop exact duplicates @@ -268,7 +270,7 @@ func GPUDevices(ctx context.Context, runners []FilteredRunnerDiscovery) []ml.Dev devCheck: for _, dev := range deviceIDs { for i := range devices { - if dev == devices[i].DeviceID { + if dev.ID == devices[i].ID && dev.Library == devices[i].Library { if !updated[i] { skip = false break devCheck @@ -289,7 +291,7 @@ func GPUDevices(ctx context.Context, runners []FilteredRunnerDiscovery) []ml.Dev slog.Debug("existing runner discovery took", "duration", time.Since(start)) for _, u := range updatedDevices { for i := range devices { - if u.DeviceID == devices[i].DeviceID { + if u.Library == devices[i].Library && u.ID == devices[i].ID { updated[i] = true devices[i].FreeMemory = u.FreeMemory break @@ -313,7 +315,7 @@ func GPUDevices(ctx context.Context, runners []FilteredRunnerDiscovery) []ml.Dev updatedDevices := bootstrapDevices(ctx, []string{LibOllamaPath, dir}, nil) for _, u := range updatedDevices { for i := range devices { - if u.DeviceID == devices[i].DeviceID { + if u.Library == devices[i].Library && u.ID == devices[i].ID { updated[i] = true devices[i].FreeMemory = u.FreeMemory break From 6bef63b0f982cf9e3755d41072d56ffa50f22c4a Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sat, 4 Oct 2025 21:45:06 +0200 Subject: [PATCH 142/172] fix format --- discover/runner.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/discover/runner.go b/discover/runner.go index 3842392ec..405d72964 100644 --- a/discover/runner.go +++ b/discover/runner.go @@ -167,7 +167,7 @@ func GPUDevices(ctx context.Context, runners []FilteredRunnerDiscovery) []ml.Dev devices = append(devices[:i], devices[i+1:]...) needsDelete = append(needsDelete[:i], needsDelete[i+1:]...) i-- - } else if devices[i].Library == "ROCm" { + } else if devices[i].Library == "ROCm" { if _, err := strconv.Atoi(devices[i].ID); err == nil { // Replace the numeric ID with the post-filtered IDs devices[i].FilteredID = devices[i].ID @@ -177,7 +177,7 @@ func GPUDevices(ctx context.Context, runners []FilteredRunnerDiscovery) []ml.Dev } } - // Now filter out any overlap with different libraries (favor CUDA/HIP over others) + // Now filter out any overlap with different libraries (favor CUDA/HIP over others) for i := 0; i < len(devices); i++ { for j := i + 1; j < len(devices); j++ { // For this pass, we only drop exact duplicates From 908b31814d48a8e2b3fba3bbc3f4da9813dc3836 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sun, 5 Oct 2025 11:01:26 +0200 Subject: [PATCH 143/172] fixed vulkan casing --- discover/gpu.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/discover/gpu.go b/discover/gpu.go index 9102bd65b..df47337c9 100644 --- a/discover/gpu.go +++ b/discover/gpu.go @@ -142,7 +142,7 @@ func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) string { func vkGetVisibleDevicesEnv(gpuInfo []GpuInfo) string { ids := []string{} for _, info := range gpuInfo { - if info.Library != "VULKAN" { + if info.Library != "Vulkan" { continue } ids = append(ids, info.ID) From d5a2462c8e27dfcebfb395524b77da14c04a046c Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sun, 5 Oct 2025 16:20:10 +0200 Subject: [PATCH 144/172] handle igpu as gpu --- llama/llama.go | 4 +++- ml/backend/ggml/ggml.go | 6 +++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/llama/llama.go b/llama/llama.go index e0804ebdd..6bbfa7e37 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -69,7 +69,9 @@ func EnumerateGPUs() []ml.DeviceID { for i := range C.ggml_backend_dev_count() { device := C.ggml_backend_dev_get(i) - if C.ggml_backend_dev_type(device) == C.GGML_BACKEND_DEVICE_TYPE_GPU { + switch C.ggml_backend_dev_type(device) { + case C.GGML_BACKEND_DEVICE_TYPE_GPU: + case C.GGML_BACKEND_DEVICE_TYPE_IGPU: var props C.struct_ggml_backend_dev_props C.ggml_backend_dev_get_props(device, &props) ids = append(ids, ml.DeviceID{ diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index dc71c8de4..315bacd20 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -58,6 +58,7 @@ var initDevices = sync.OnceFunc(func() { case C.GGML_BACKEND_DEVICE_TYPE_ACCEL: accels = append(accels, d) case C.GGML_BACKEND_DEVICE_TYPE_GPU: + case C.GGML_BACKEND_DEVICE_TYPE_IGPU: gpus = append(gpus, d) } @@ -470,7 +471,9 @@ func (b *Backend) Load(ctx context.Context, progress func(float32)) error { // Mimic llama runner logs summarizing layers and memory gpuLayers := 0 for layer := range maps.Values(b.layers) { - if C.ggml_backend_dev_type(layer.d) == C.GGML_BACKEND_DEVICE_TYPE_GPU { + switch C.ggml_backend_dev_type(layer.d) { + case C.GGML_BACKEND_DEVICE_TYPE_GPU: + case C.GGML_BACKEND_DEVICE_TYPE_IGPU: gpuLayers++ } } @@ -480,6 +483,7 @@ func (b *Backend) Load(ctx context.Context, progress func(float32)) error { case C.GGML_BACKEND_DEVICE_TYPE_CPU: slog.Info("offloading output layer to CPU") case C.GGML_BACKEND_DEVICE_TYPE_GPU: + case C.GGML_BACKEND_DEVICE_TYPE_IGPU: slog.Info("offloading output layer to GPU") gpuLayers++ case C.GGML_BACKEND_DEVICE_TYPE_ACCEL: From cafdb5c0d6d4735f1704fc57a9b14b84ca17aeeb Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sun, 5 Oct 2025 16:46:55 +0200 Subject: [PATCH 145/172] improve case --- llama/llama.go | 4 ++-- ml/backend/ggml/ggml.go | 12 ++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/llama/llama.go b/llama/llama.go index 6bbfa7e37..4af0fd117 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -70,8 +70,8 @@ func EnumerateGPUs() []ml.DeviceID { device := C.ggml_backend_dev_get(i) switch C.ggml_backend_dev_type(device) { - case C.GGML_BACKEND_DEVICE_TYPE_GPU: - case C.GGML_BACKEND_DEVICE_TYPE_IGPU: + case C.GGML_BACKEND_DEVICE_TYPE_GPU, + C.GGML_BACKEND_DEVICE_TYPE_IGPU: var props C.struct_ggml_backend_dev_props C.ggml_backend_dev_get_props(device, &props) ids = append(ids, ml.DeviceID{ diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 315bacd20..07e55dd3c 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -57,8 +57,8 @@ var initDevices = sync.OnceFunc(func() { } case C.GGML_BACKEND_DEVICE_TYPE_ACCEL: accels = append(accels, d) - case C.GGML_BACKEND_DEVICE_TYPE_GPU: - case C.GGML_BACKEND_DEVICE_TYPE_IGPU: + case C.GGML_BACKEND_DEVICE_TYPE_GPU, + C.GGML_BACKEND_DEVICE_TYPE_IGPU: gpus = append(gpus, d) } @@ -472,8 +472,8 @@ func (b *Backend) Load(ctx context.Context, progress func(float32)) error { gpuLayers := 0 for layer := range maps.Values(b.layers) { switch C.ggml_backend_dev_type(layer.d) { - case C.GGML_BACKEND_DEVICE_TYPE_GPU: - case C.GGML_BACKEND_DEVICE_TYPE_IGPU: + case C.GGML_BACKEND_DEVICE_TYPE_GPU, + C.GGML_BACKEND_DEVICE_TYPE_IGPU: gpuLayers++ } } @@ -482,8 +482,8 @@ func (b *Backend) Load(ctx context.Context, progress func(float32)) error { switch C.ggml_backend_dev_type(b.output) { case C.GGML_BACKEND_DEVICE_TYPE_CPU: slog.Info("offloading output layer to CPU") - case C.GGML_BACKEND_DEVICE_TYPE_GPU: - case C.GGML_BACKEND_DEVICE_TYPE_IGPU: + case C.GGML_BACKEND_DEVICE_TYPE_GPU, + C.GGML_BACKEND_DEVICE_TYPE_IGPU: slog.Info("offloading output layer to GPU") gpuLayers++ case C.GGML_BACKEND_DEVICE_TYPE_ACCEL: From 218e57974f44740a21a903a24c8ef0313d36ea1f Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sun, 5 Oct 2025 17:04:12 +0200 Subject: [PATCH 146/172] print out unknown library --- discover/runner.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/discover/runner.go b/discover/runner.go index 405d72964..aea618fec 100644 --- a/discover/runner.go +++ b/discover/runner.go @@ -132,6 +132,8 @@ func GPUDevices(ctx context.Context, runners []FilteredRunnerDiscovery) []ml.Dev envVar = "CUDA_VISIBLE_DEVICES" } else if devices[i].Library == "Vulkan" { envVar = "GGML_VK_VISIBLE_DEVICES" + } else { + slog.Error("Unknown Library:" + devices[i].Library) } extraEnvs := []string{ @@ -444,6 +446,9 @@ func bootstrapDevices(ctx context.Context, ollamaLibDirs []string, extraEnvs []s cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr } + cmd.Stderr = os.Stderr + cmd.Stdout = os.Stdout + // cmd.SysProcAttr = llm.LlamaServerSysProcAttr // circular dependency - bring back once refactored cmd.Env = append(cmd.Env, "OLLAMA_LIBRARY_PATH="+strings.Join(ollamaLibDirs, string(filepath.ListSeparator))) pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator)) From 690461a12fd5e93295d174c97edefb2bc33285b1 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sun, 5 Oct 2025 20:29:38 +0200 Subject: [PATCH 147/172] rturn Vulkan for vulkan library --- llama/patches/0026-vulkan-get-GPU-ID-ollama-v0.11.5.patch | 1 + ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp | 1 + 2 files changed, 2 insertions(+) diff --git a/llama/patches/0026-vulkan-get-GPU-ID-ollama-v0.11.5.patch b/llama/patches/0026-vulkan-get-GPU-ID-ollama-v0.11.5.patch index be83b6371..d52dff291 100644 --- a/llama/patches/0026-vulkan-get-GPU-ID-ollama-v0.11.5.patch +++ b/llama/patches/0026-vulkan-get-GPU-ID-ollama-v0.11.5.patch @@ -88,6 +88,7 @@ index 061cd078..adea7783 100644 ctx->is_integrated_gpu = ggml_backend_vk_get_device_type(i) == vk::PhysicalDeviceType::eIntegratedGpu; ctx->pci_bus_id = ggml_backend_vk_get_device_pci_id(i); + ctx->id = ggml_backend_vk_get_device_id(i); ++ ctx->library = GGML_VK_NAME; devices.push_back(new ggml_backend_device { /* .iface = */ ggml_backend_vk_device_i, /* .reg = */ reg, diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp index adea7783d..e9478c841 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -13002,6 +13002,7 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, ctx->is_integrated_gpu = ggml_backend_vk_get_device_type(i) == vk::PhysicalDeviceType::eIntegratedGpu; ctx->pci_bus_id = ggml_backend_vk_get_device_pci_id(i); ctx->id = ggml_backend_vk_get_device_id(i); + ctx->library = GGML_VK_NAME; devices.push_back(new ggml_backend_device { /* .iface = */ ggml_backend_vk_device_i, /* .reg = */ reg, From 3f38cdb5901589f646ec95355bbe422ccf47f3e9 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sun, 5 Oct 2025 20:38:07 +0200 Subject: [PATCH 148/172] Revert "rturn Vulkan for vulkan library" This reverts commit 690461a12fd5e93295d174c97edefb2bc33285b1. --- llama/patches/0026-vulkan-get-GPU-ID-ollama-v0.11.5.patch | 1 - ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp | 1 - 2 files changed, 2 deletions(-) diff --git a/llama/patches/0026-vulkan-get-GPU-ID-ollama-v0.11.5.patch b/llama/patches/0026-vulkan-get-GPU-ID-ollama-v0.11.5.patch index d52dff291..be83b6371 100644 --- a/llama/patches/0026-vulkan-get-GPU-ID-ollama-v0.11.5.patch +++ b/llama/patches/0026-vulkan-get-GPU-ID-ollama-v0.11.5.patch @@ -88,7 +88,6 @@ index 061cd078..adea7783 100644 ctx->is_integrated_gpu = ggml_backend_vk_get_device_type(i) == vk::PhysicalDeviceType::eIntegratedGpu; ctx->pci_bus_id = ggml_backend_vk_get_device_pci_id(i); + ctx->id = ggml_backend_vk_get_device_id(i); -+ ctx->library = GGML_VK_NAME; devices.push_back(new ggml_backend_device { /* .iface = */ ggml_backend_vk_device_i, /* .reg = */ reg, diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp index e9478c841..adea7783d 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -13002,7 +13002,6 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, ctx->is_integrated_gpu = ggml_backend_vk_get_device_type(i) == vk::PhysicalDeviceType::eIntegratedGpu; ctx->pci_bus_id = ggml_backend_vk_get_device_pci_id(i); ctx->id = ggml_backend_vk_get_device_id(i); - ctx->library = GGML_VK_NAME; devices.push_back(new ggml_backend_device { /* .iface = */ ggml_backend_vk_device_i, /* .reg = */ reg, From 66d1033610eb0b22bd0bca29a1a2f93c02f2e2e5 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sun, 5 Oct 2025 20:41:05 +0200 Subject: [PATCH 149/172] fixed patch number --- ...-v0.11.5.patch => 0027-vulkan-get-GPU-ID-ollama-v0.11.5.patch} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename llama/patches/{0026-vulkan-get-GPU-ID-ollama-v0.11.5.patch => 0027-vulkan-get-GPU-ID-ollama-v0.11.5.patch} (100%) diff --git a/llama/patches/0026-vulkan-get-GPU-ID-ollama-v0.11.5.patch b/llama/patches/0027-vulkan-get-GPU-ID-ollama-v0.11.5.patch similarity index 100% rename from llama/patches/0026-vulkan-get-GPU-ID-ollama-v0.11.5.patch rename to llama/patches/0027-vulkan-get-GPU-ID-ollama-v0.11.5.patch From d02a08aa7caab64153b74966b7153ff277f6ca97 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sun, 5 Oct 2025 20:55:28 +0200 Subject: [PATCH 150/172] return Library Name --- llama/patches/0027-vulkan-get-GPU-ID-ollama-v0.11.5.patch | 1 + ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp | 1 + 2 files changed, 2 insertions(+) diff --git a/llama/patches/0027-vulkan-get-GPU-ID-ollama-v0.11.5.patch b/llama/patches/0027-vulkan-get-GPU-ID-ollama-v0.11.5.patch index be83b6371..29a8cd8e6 100644 --- a/llama/patches/0027-vulkan-get-GPU-ID-ollama-v0.11.5.patch +++ b/llama/patches/0027-vulkan-get-GPU-ID-ollama-v0.11.5.patch @@ -80,6 +80,7 @@ index 061cd078..adea7783 100644 props->name = ggml_backend_vk_device_get_name(dev); props->description = ggml_backend_vk_device_get_description(dev); + props->id = ggml_backend_vk_device_get_id(dev); ++ props->library = GGML_VK_NAME; props->type = ggml_backend_vk_device_get_type(dev); props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str(); ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total); diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp index adea7783d..128e8e0a0 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -12555,6 +12555,7 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml props->name = ggml_backend_vk_device_get_name(dev); props->description = ggml_backend_vk_device_get_description(dev); props->id = ggml_backend_vk_device_get_id(dev); + props->library = GGML_VK_NAME; props->type = ggml_backend_vk_device_get_type(dev); props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str(); ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total); From 37206cdf3241a0a85face18f9a3a2fbadcec11d2 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sun, 5 Oct 2025 20:56:21 +0200 Subject: [PATCH 151/172] remvoe debug code --- discover/runner.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/discover/runner.go b/discover/runner.go index aea618fec..b920a08b5 100644 --- a/discover/runner.go +++ b/discover/runner.go @@ -446,8 +446,6 @@ func bootstrapDevices(ctx context.Context, ollamaLibDirs []string, extraEnvs []s cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr } - cmd.Stderr = os.Stderr - cmd.Stdout = os.Stdout // cmd.SysProcAttr = llm.LlamaServerSysProcAttr // circular dependency - bring back once refactored cmd.Env = append(cmd.Env, "OLLAMA_LIBRARY_PATH="+strings.Join(ollamaLibDirs, string(filepath.ListSeparator))) From fd648506c1aa514d1b7d664e1a88db37f56cfeef Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sun, 5 Oct 2025 21:13:21 +0200 Subject: [PATCH 152/172] return integrated in vulkan backend --- llama/patches/0027-vulkan-get-GPU-ID-ollama-v0.11.5.patch | 1 + ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp | 1 + 2 files changed, 2 insertions(+) diff --git a/llama/patches/0027-vulkan-get-GPU-ID-ollama-v0.11.5.patch b/llama/patches/0027-vulkan-get-GPU-ID-ollama-v0.11.5.patch index 29a8cd8e6..d24129211 100644 --- a/llama/patches/0027-vulkan-get-GPU-ID-ollama-v0.11.5.patch +++ b/llama/patches/0027-vulkan-get-GPU-ID-ollama-v0.11.5.patch @@ -81,6 +81,7 @@ index 061cd078..adea7783 100644 props->description = ggml_backend_vk_device_get_description(dev); + props->id = ggml_backend_vk_device_get_id(dev); + props->library = GGML_VK_NAME; ++ props->integrated = ctx->is_integrated_gpu; props->type = ggml_backend_vk_device_get_type(dev); props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str(); ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total); diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 128e8e0a0..7891482ce 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -12556,6 +12556,7 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml props->description = ggml_backend_vk_device_get_description(dev); props->id = ggml_backend_vk_device_get_id(dev); props->library = GGML_VK_NAME; + props->integrated = ctx->is_integrated_gpu; props->type = ggml_backend_vk_device_get_type(dev); props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str(); ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total); From e9828e6b11b3a5935c15ce15202fd9271edc2d29 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Mon, 6 Oct 2025 09:53:51 +0200 Subject: [PATCH 153/172] Return pci Properties --- .../ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 7891482ce..3f7d5342a 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -12502,6 +12502,18 @@ static std::string ggml_backend_vk_get_device_pci_id(int device_idx) { return std::string(pci_bus_id); } +static bool ggml_backend_vk_parse_pci_bus_id(const std::string & id, int *domain, int *bus, int *device) { + if (id.empty()) return false; + unsigned int d = 0, b = 0, dev = 0, func = 0; + // Expected format: dddd:bb:dd.f (all hex) + int n = sscanf(id.c_str(), "%4x:%2x:%2x.%1x", &d, &b, &dev, &func); + if (n < 4) return false; + if (domain) *domain = (int) d; + if (bus) *bus = (int) b; + if (device) *device = (int) dev; + return true; +} + ////////////////////////// struct ggml_backend_vk_device_context { @@ -12509,6 +12521,12 @@ struct ggml_backend_vk_device_context { std::string name; std::string description; bool is_integrated_gpu; + // PCI information (if available via VK_EXT_pci_bus_info) + // Numeric components for convenience/interop with higher layers + int pciBusID = 0; + int pciDeviceID = 0; + int pciDomainID = 0; + // Combined string id in the form "dddd:bb:dd.f" (domain:bus:device.function) std::string pci_bus_id; std::string id; }; @@ -12559,6 +12577,9 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml props->integrated = ctx->is_integrated_gpu; props->type = ggml_backend_vk_device_get_type(dev); props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str(); + props->pci_bus_id = ctx->pciBusID; + props->pci_device_id = ctx->pciDeviceID; + props->pci_domain_id = ctx->pciDomainID; ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total); props->caps = { /* .async = */ false, @@ -13003,6 +13024,17 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, ctx->description = desc; ctx->is_integrated_gpu = ggml_backend_vk_get_device_type(i) == vk::PhysicalDeviceType::eIntegratedGpu; ctx->pci_bus_id = ggml_backend_vk_get_device_pci_id(i); + // Parse numeric PCI components if available + int d = 0, b = 0, devn = 0; + if (ggml_backend_vk_parse_pci_bus_id(ctx->pci_bus_id, &d, &b, &devn)) { + ctx->pciDomainID = d; + ctx->pciBusID = b; + ctx->pciDeviceID = devn; + } else { + ctx->pciDomainID = 0; + ctx->pciBusID = 0; + ctx->pciDeviceID = 0; + } ctx->id = ggml_backend_vk_get_device_id(i); devices.push_back(new ggml_backend_device { /* .iface = */ ggml_backend_vk_device_i, From 2acedf1756d8095e3aa03d7191a72d57449c5451 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Mon, 6 Oct 2025 10:01:28 +0200 Subject: [PATCH 154/172] update patch --- ...027-vulkan-get-GPU-ID-ollama-v0.11.5.patch | 56 ++++++++++++++++--- 1 file changed, 49 insertions(+), 7 deletions(-) diff --git a/llama/patches/0027-vulkan-get-GPU-ID-ollama-v0.11.5.patch b/llama/patches/0027-vulkan-get-GPU-ID-ollama-v0.11.5.patch index d24129211..94f32b822 100644 --- a/llama/patches/0027-vulkan-get-GPU-ID-ollama-v0.11.5.patch +++ b/llama/patches/0027-vulkan-get-GPU-ID-ollama-v0.11.5.patch @@ -5,11 +5,11 @@ Subject: [PATCH] vulkan: get GPU ID (ollama v0.11.5) Signed-off-by: Xiaodong Ye --- - ggml/src/ggml-vulkan/ggml-vulkan.cpp | 37 ++++++++++++++++++++++++++++ - 1 file changed, 37 insertions(+) + ggml/src/ggml-vulkan/ggml-vulkan.cpp | 71 ++++++++++++++++++++++++++++ + 1 file changed, 71 insertions(+) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp -index 061cd078..adea7783 100644 +index 061cd0788..3f7d5342a 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -11588,6 +11588,29 @@ static void ggml_vk_get_device_description(int device, char * description, size_ @@ -55,15 +55,41 @@ index 061cd078..adea7783 100644 void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total) { GGML_ASSERT(device < (int) vk_instance.device_indices.size()); GGML_ASSERT(device < (int) vk_instance.device_supports_membudget.size()); -@@ -12481,6 +12510,7 @@ struct ggml_backend_vk_device_context { +@@ -12473,6 +12502,18 @@ static std::string ggml_backend_vk_get_device_pci_id(int device_idx) { + return std::string(pci_bus_id); + } + ++static bool ggml_backend_vk_parse_pci_bus_id(const std::string & id, int *domain, int *bus, int *device) { ++ if (id.empty()) return false; ++ unsigned int d = 0, b = 0, dev = 0, func = 0; ++ // Expected format: dddd:bb:dd.f (all hex) ++ int n = sscanf(id.c_str(), "%4x:%2x:%2x.%1x", &d, &b, &dev, &func); ++ if (n < 4) return false; ++ if (domain) *domain = (int) d; ++ if (bus) *bus = (int) b; ++ if (device) *device = (int) dev; ++ return true; ++} ++ + ////////////////////////// + + struct ggml_backend_vk_device_context { +@@ -12480,7 +12521,14 @@ struct ggml_backend_vk_device_context { + std::string name; std::string description; bool is_integrated_gpu; ++ // PCI information (if available via VK_EXT_pci_bus_info) ++ // Numeric components for convenience/interop with higher layers ++ int pciBusID = 0; ++ int pciDeviceID = 0; ++ int pciDomainID = 0; ++ // Combined string id in the form "dddd:bb:dd.f" (domain:bus:device.function) std::string pci_bus_id; + std::string id; }; static const char * ggml_backend_vk_device_get_name(ggml_backend_dev_t dev) { -@@ -12493,6 +12523,11 @@ static const char * ggml_backend_vk_device_get_description(ggml_backend_dev_t de +@@ -12493,6 +12541,11 @@ static const char * ggml_backend_vk_device_get_description(ggml_backend_dev_t de return ctx->description.c_str(); } @@ -75,7 +101,7 @@ index 061cd078..adea7783 100644 static void ggml_backend_vk_device_get_memory(ggml_backend_dev_t device, size_t * free, size_t * total) { ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)device->context; ggml_backend_vk_get_device_memory(ctx->device, free, total); -@@ -12519,6 +12554,7 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml +@@ -12519,8 +12572,14 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml props->name = ggml_backend_vk_device_get_name(dev); props->description = ggml_backend_vk_device_get_description(dev); @@ -84,11 +110,27 @@ index 061cd078..adea7783 100644 + props->integrated = ctx->is_integrated_gpu; props->type = ggml_backend_vk_device_get_type(dev); props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str(); ++ props->pci_bus_id = ctx->pciBusID; ++ props->pci_device_id = ctx->pciDeviceID; ++ props->pci_domain_id = ctx->pciDomainID; ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total); -@@ -12965,6 +13001,7 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, + props->caps = { + /* .async = */ false, +@@ -12965,6 +13024,18 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, ctx->description = desc; ctx->is_integrated_gpu = ggml_backend_vk_get_device_type(i) == vk::PhysicalDeviceType::eIntegratedGpu; ctx->pci_bus_id = ggml_backend_vk_get_device_pci_id(i); ++ // Parse numeric PCI components if available ++ int d = 0, b = 0, devn = 0; ++ if (ggml_backend_vk_parse_pci_bus_id(ctx->pci_bus_id, &d, &b, &devn)) { ++ ctx->pciDomainID = d; ++ ctx->pciBusID = b; ++ ctx->pciDeviceID = devn; ++ } else { ++ ctx->pciDomainID = 0; ++ ctx->pciBusID = 0; ++ ctx->pciDeviceID = 0; ++ } + ctx->id = ggml_backend_vk_get_device_id(i); devices.push_back(new ggml_backend_device { /* .iface = */ ggml_backend_vk_device_i, From 8e0624851f5ed7d9f74518f574dfb422e4dd4dc2 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Mon, 6 Oct 2025 14:54:50 +0200 Subject: [PATCH 155/172] directly get pci proeprties without parsing --- .../ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp | 36 ++++++------------- 1 file changed, 11 insertions(+), 25 deletions(-) diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 3f7d5342a..e47cbad91 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -12464,7 +12464,7 @@ static vk::PhysicalDeviceType ggml_backend_vk_get_device_type(int device_idx) { return props.properties.deviceType; } -static std::string ggml_backend_vk_get_device_pci_id(int device_idx) { +static std::string ggml_backend_vk_get_device_pci_id(int device_idx, int *domain_pci, int *bus_pci, int *device_pci) { GGML_ASSERT(device_idx >= 0 && device_idx < (int) vk_instance.device_indices.size()); vk::PhysicalDevice device = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device_idx]]; @@ -12496,24 +12496,17 @@ static std::string ggml_backend_vk_get_device_pci_id(int device_idx) { const uint32_t pci_device = pci_bus_info.pciDevice; const uint8_t pci_function = (uint8_t) pci_bus_info.pciFunction; // pci function is between 0 and 7, prevent printf overflow warning + // Safely convert uint32_t to int (PCI IDs are small, so this is safe) + if (domain_pci) *domain_pci = static_cast(pci_bus_info.pciDomain); + if (bus_pci) *bus_pci = static_cast(pci_bus_info.pciBus); + if (device_pci) *device_pci = static_cast(pci_bus_info.pciDevice); + char pci_bus_id[16] = {}; snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.%x", pci_domain, pci_bus, pci_device, pci_function); return std::string(pci_bus_id); } -static bool ggml_backend_vk_parse_pci_bus_id(const std::string & id, int *domain, int *bus, int *device) { - if (id.empty()) return false; - unsigned int d = 0, b = 0, dev = 0, func = 0; - // Expected format: dddd:bb:dd.f (all hex) - int n = sscanf(id.c_str(), "%4x:%2x:%2x.%1x", &d, &b, &dev, &func); - if (n < 4) return false; - if (domain) *domain = (int) d; - if (bus) *bus = (int) b; - if (device) *device = (int) dev; - return true; -} - ////////////////////////// struct ggml_backend_vk_device_context { @@ -13023,18 +13016,11 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, ctx->name = GGML_VK_NAME + std::to_string(i); ctx->description = desc; ctx->is_integrated_gpu = ggml_backend_vk_get_device_type(i) == vk::PhysicalDeviceType::eIntegratedGpu; - ctx->pci_bus_id = ggml_backend_vk_get_device_pci_id(i); - // Parse numeric PCI components if available - int d = 0, b = 0, devn = 0; - if (ggml_backend_vk_parse_pci_bus_id(ctx->pci_bus_id, &d, &b, &devn)) { - ctx->pciDomainID = d; - ctx->pciBusID = b; - ctx->pciDeviceID = devn; - } else { - ctx->pciDomainID = 0; - ctx->pciBusID = 0; - ctx->pciDeviceID = 0; - } + int domain = 0, bus = 0, device = 0; + ctx->pci_bus_id = ggml_backend_vk_get_device_pci_id(i,&domain, &bus, &device); + ctx->pciDomainID = domain; + ctx->pciBusID = bus; + ctx->pciDeviceID = device; ctx->id = ggml_backend_vk_get_device_id(i); devices.push_back(new ggml_backend_device { /* .iface = */ ggml_backend_vk_device_i, From a97b0923c9d69c308be2c44413b0abf9c12055dd Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Mon, 6 Oct 2025 14:55:56 +0200 Subject: [PATCH 156/172] workaround for filtering devices. Correct way is to have a LibraryPosition Parameter in the deviceInfo --- discover/runner.go | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/discover/runner.go b/discover/runner.go index b920a08b5..0e460152b 100644 --- a/discover/runner.go +++ b/discover/runner.go @@ -112,6 +112,9 @@ func GPUDevices(ctx context.Context, runners []FilteredRunnerDiscovery) []ml.Dev needsDelete := make([]bool, len(devices)) supportedMu := sync.Mutex{} supported := make(map[string]map[string]map[string]int) // [Library][libDir][ID] = pre-deletion devices index + rocmPosition := 0 + cudaPosition := 0 + vulkanPosition := 0 for i := range devices { libDir := devices[i].LibraryPath[len(devices[i].LibraryPath)-1] if devices[i].Library == "Metal" { @@ -122,23 +125,30 @@ func GPUDevices(ctx context.Context, runners []FilteredRunnerDiscovery) []ml.Dev go func(i int) { defer wg.Done() var envVar string + libraryPosition := 0 if devices[i].Library == "ROCm" { + libraryPosition = rocmPosition + rocmPosition++ if runtime.GOOS != "linux" { envVar = "HIP_VISIBLE_DEVICES" } else { envVar = "ROCR_VISIBLE_DEVICES" } } else if devices[i].Library == "CUDA" { + libraryPosition = cudaPosition + cudaPosition++ envVar = "CUDA_VISIBLE_DEVICES" } else if devices[i].Library == "Vulkan" { + libraryPosition = vulkanPosition + vulkanPosition++ envVar = "GGML_VK_VISIBLE_DEVICES" } else { slog.Error("Unknown Library:" + devices[i].Library) } extraEnvs := []string{ - "GGML_CUDA_INIT=1", // force deep initialization to trigger crash on unsupported GPUs - envVar + "=" + devices[i].ID, // Filter to just this one GPU + "GGML_CUDA_INIT=1", // force deep initialization to trigger crash on unsupported GPUs + envVar + "=" + strconv.Itoa(libraryPosition), // Filter to just this one GPU } if len(bootstrapDevices(ctx2ndPass, devices[i].LibraryPath, extraEnvs)) == 0 { needsDelete[i] = true From f2454a33ed1f7c0b99355a25f460c9ba50e0afa9 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Mon, 6 Oct 2025 15:00:24 +0200 Subject: [PATCH 157/172] Revert "directly get pci proeprties without parsing" This reverts commit 8e0624851f5ed7d9f74518f574dfb422e4dd4dc2. --- .../ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp | 36 +++++++++++++------ 1 file changed, 25 insertions(+), 11 deletions(-) diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp index e47cbad91..3f7d5342a 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -12464,7 +12464,7 @@ static vk::PhysicalDeviceType ggml_backend_vk_get_device_type(int device_idx) { return props.properties.deviceType; } -static std::string ggml_backend_vk_get_device_pci_id(int device_idx, int *domain_pci, int *bus_pci, int *device_pci) { +static std::string ggml_backend_vk_get_device_pci_id(int device_idx) { GGML_ASSERT(device_idx >= 0 && device_idx < (int) vk_instance.device_indices.size()); vk::PhysicalDevice device = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device_idx]]; @@ -12496,17 +12496,24 @@ static std::string ggml_backend_vk_get_device_pci_id(int device_idx, int *domain const uint32_t pci_device = pci_bus_info.pciDevice; const uint8_t pci_function = (uint8_t) pci_bus_info.pciFunction; // pci function is between 0 and 7, prevent printf overflow warning - // Safely convert uint32_t to int (PCI IDs are small, so this is safe) - if (domain_pci) *domain_pci = static_cast(pci_bus_info.pciDomain); - if (bus_pci) *bus_pci = static_cast(pci_bus_info.pciBus); - if (device_pci) *device_pci = static_cast(pci_bus_info.pciDevice); - char pci_bus_id[16] = {}; snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.%x", pci_domain, pci_bus, pci_device, pci_function); return std::string(pci_bus_id); } +static bool ggml_backend_vk_parse_pci_bus_id(const std::string & id, int *domain, int *bus, int *device) { + if (id.empty()) return false; + unsigned int d = 0, b = 0, dev = 0, func = 0; + // Expected format: dddd:bb:dd.f (all hex) + int n = sscanf(id.c_str(), "%4x:%2x:%2x.%1x", &d, &b, &dev, &func); + if (n < 4) return false; + if (domain) *domain = (int) d; + if (bus) *bus = (int) b; + if (device) *device = (int) dev; + return true; +} + ////////////////////////// struct ggml_backend_vk_device_context { @@ -13016,11 +13023,18 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, ctx->name = GGML_VK_NAME + std::to_string(i); ctx->description = desc; ctx->is_integrated_gpu = ggml_backend_vk_get_device_type(i) == vk::PhysicalDeviceType::eIntegratedGpu; - int domain = 0, bus = 0, device = 0; - ctx->pci_bus_id = ggml_backend_vk_get_device_pci_id(i,&domain, &bus, &device); - ctx->pciDomainID = domain; - ctx->pciBusID = bus; - ctx->pciDeviceID = device; + ctx->pci_bus_id = ggml_backend_vk_get_device_pci_id(i); + // Parse numeric PCI components if available + int d = 0, b = 0, devn = 0; + if (ggml_backend_vk_parse_pci_bus_id(ctx->pci_bus_id, &d, &b, &devn)) { + ctx->pciDomainID = d; + ctx->pciBusID = b; + ctx->pciDeviceID = devn; + } else { + ctx->pciDomainID = 0; + ctx->pciBusID = 0; + ctx->pciDeviceID = 0; + } ctx->id = ggml_backend_vk_get_device_id(i); devices.push_back(new ggml_backend_device { /* .iface = */ ggml_backend_vk_device_i, From 0b6b03f6718699dcef1e5a4f66ac9dc25aeadfe3 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Mon, 6 Oct 2025 17:12:24 +0200 Subject: [PATCH 158/172] Set FilteredID for Environment Filtering --- discover/gpu.go | 6 +++++- discover/runner.go | 18 ++++++------------ discover/types.go | 2 +- 3 files changed, 12 insertions(+), 14 deletions(-) diff --git a/discover/gpu.go b/discover/gpu.go index df47337c9..6239fd54b 100644 --- a/discover/gpu.go +++ b/discover/gpu.go @@ -145,7 +145,11 @@ func vkGetVisibleDevicesEnv(gpuInfo []GpuInfo) string { if info.Library != "Vulkan" { continue } - ids = append(ids, info.ID) + if info.filterID != "" { + ids = append(ids, info.filterID) + } else { + ids = append(ids, info.ID) + } } if len(ids) == 0 { return "" diff --git a/discover/runner.go b/discover/runner.go index 0e460152b..f860292c3 100644 --- a/discover/runner.go +++ b/discover/runner.go @@ -112,9 +112,6 @@ func GPUDevices(ctx context.Context, runners []FilteredRunnerDiscovery) []ml.Dev needsDelete := make([]bool, len(devices)) supportedMu := sync.Mutex{} supported := make(map[string]map[string]map[string]int) // [Library][libDir][ID] = pre-deletion devices index - rocmPosition := 0 - cudaPosition := 0 - vulkanPosition := 0 for i := range devices { libDir := devices[i].LibraryPath[len(devices[i].LibraryPath)-1] if devices[i].Library == "Metal" { @@ -125,30 +122,23 @@ func GPUDevices(ctx context.Context, runners []FilteredRunnerDiscovery) []ml.Dev go func(i int) { defer wg.Done() var envVar string - libraryPosition := 0 if devices[i].Library == "ROCm" { - libraryPosition = rocmPosition - rocmPosition++ if runtime.GOOS != "linux" { envVar = "HIP_VISIBLE_DEVICES" } else { envVar = "ROCR_VISIBLE_DEVICES" } } else if devices[i].Library == "CUDA" { - libraryPosition = cudaPosition - cudaPosition++ envVar = "CUDA_VISIBLE_DEVICES" } else if devices[i].Library == "Vulkan" { - libraryPosition = vulkanPosition - vulkanPosition++ envVar = "GGML_VK_VISIBLE_DEVICES" } else { slog.Error("Unknown Library:" + devices[i].Library) } extraEnvs := []string{ - "GGML_CUDA_INIT=1", // force deep initialization to trigger crash on unsupported GPUs - envVar + "=" + strconv.Itoa(libraryPosition), // Filter to just this one GPU + "GGML_CUDA_INIT=1", // force deep initialization to trigger crash on unsupported GPUs + envVar + "=" + devices[i].FilteredID, // Filter to just this one GPU } if len(bootstrapDevices(ctx2ndPass, devices[i].LibraryPath, extraEnvs)) == 0 { needsDelete[i] = true @@ -508,6 +498,10 @@ func bootstrapDevices(ctx context.Context, ollamaLibDirs []string, extraEnvs []s } } logutil.Trace("runner enumerated devices", "OLLAMA_LIBRARY_PATH", ollamaLibDirs, "devices", devices) + // Enumerate returned devices starting at 0 and assign the index as FilteredID + for i := range devices { + devices[i].FilteredID = strconv.Itoa(i) + } return devices } diff --git a/discover/types.go b/discover/types.go index a294f26b7..79fb9e0b0 100644 --- a/discover/types.go +++ b/discover/types.go @@ -37,7 +37,7 @@ type GpuInfo struct { // TODO better name maybe "InferenceProcessor"? UnreliableFreeMemory bool // GPU information - filterID string // AMD Workaround: The numeric ID of the device used to filter out other devices + filterID string // AMD/Vulkan Workaround: The numeric ID of the device used to filter out other devices Name string `json:"name"` // user friendly name if available Compute string `json:"compute"` // Compute Capability or gfx From da65fb27225202714e366ae5f3c0a28656a94e97 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Mon, 6 Oct 2025 19:26:16 +0200 Subject: [PATCH 159/172] ROCm Library is named ROCm --- discover/gpu.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/discover/gpu.go b/discover/gpu.go index 6239fd54b..55419f236 100644 --- a/discover/gpu.go +++ b/discover/gpu.go @@ -115,7 +115,7 @@ func (l GpuInfoList) GetVisibleDevicesEnv() []string { func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) string { ids := []string{} for _, info := range gpuInfo { - if info.Library != "HIP" { + if info.Library != "ROCm" { continue } // If the devices requires a numeric ID, for filtering purposes, we use the unfiltered ID number From fc17120fc6c16d4ba0b3ea7064fb0eddda54ebbb Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Mon, 6 Oct 2025 19:52:30 +0200 Subject: [PATCH 160/172] revert changes in patch --- ...027-vulkan-get-GPU-ID-ollama-v0.11.5.patch | 70 ++++--------------- 1 file changed, 13 insertions(+), 57 deletions(-) diff --git a/llama/patches/0027-vulkan-get-GPU-ID-ollama-v0.11.5.patch b/llama/patches/0027-vulkan-get-GPU-ID-ollama-v0.11.5.patch index 94f32b822..997dd3860 100644 --- a/llama/patches/0027-vulkan-get-GPU-ID-ollama-v0.11.5.patch +++ b/llama/patches/0027-vulkan-get-GPU-ID-ollama-v0.11.5.patch @@ -5,17 +5,17 @@ Subject: [PATCH] vulkan: get GPU ID (ollama v0.11.5) Signed-off-by: Xiaodong Ye --- - ggml/src/ggml-vulkan/ggml-vulkan.cpp | 71 ++++++++++++++++++++++++++++ - 1 file changed, 71 insertions(+) + ggml/src/ggml-vulkan/ggml-vulkan.cpp | 37 ++++++++++++++++++++++++++++ + 1 file changed, 37 insertions(+) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp -index 061cd0788..3f7d5342a 100644 +index 061cd078..adea7783 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -11588,6 +11588,29 @@ static void ggml_vk_get_device_description(int device, char * description, size_ snprintf(description, description_size, "%s", props.deviceName.data()); } - + +static std::string ggml_vk_get_device_id(int device) { + ggml_vk_instance_init(); + @@ -40,12 +40,12 @@ index 061cd0788..3f7d5342a 100644 +} + // backend interface - + #define UNUSED GGML_UNUSED @@ -12394,6 +12417,12 @@ void ggml_backend_vk_get_device_description(int device, char * description, size ggml_vk_get_device_description(dev_idx, description, description_size); } - + +std::string ggml_backend_vk_get_device_id(int device) { + GGML_ASSERT(device < (int) vk_instance.device_indices.size()); + int dev_idx = vk_instance.device_indices[device]; @@ -55,44 +55,18 @@ index 061cd0788..3f7d5342a 100644 void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total) { GGML_ASSERT(device < (int) vk_instance.device_indices.size()); GGML_ASSERT(device < (int) vk_instance.device_supports_membudget.size()); -@@ -12473,6 +12502,18 @@ static std::string ggml_backend_vk_get_device_pci_id(int device_idx) { - return std::string(pci_bus_id); - } - -+static bool ggml_backend_vk_parse_pci_bus_id(const std::string & id, int *domain, int *bus, int *device) { -+ if (id.empty()) return false; -+ unsigned int d = 0, b = 0, dev = 0, func = 0; -+ // Expected format: dddd:bb:dd.f (all hex) -+ int n = sscanf(id.c_str(), "%4x:%2x:%2x.%1x", &d, &b, &dev, &func); -+ if (n < 4) return false; -+ if (domain) *domain = (int) d; -+ if (bus) *bus = (int) b; -+ if (device) *device = (int) dev; -+ return true; -+} -+ - ////////////////////////// - - struct ggml_backend_vk_device_context { -@@ -12480,7 +12521,14 @@ struct ggml_backend_vk_device_context { - std::string name; +@@ -12481,6 +12510,7 @@ struct ggml_backend_vk_device_context { std::string description; bool is_integrated_gpu; -+ // PCI information (if available via VK_EXT_pci_bus_info) -+ // Numeric components for convenience/interop with higher layers -+ int pciBusID = 0; -+ int pciDeviceID = 0; -+ int pciDomainID = 0; -+ // Combined string id in the form "dddd:bb:dd.f" (domain:bus:device.function) std::string pci_bus_id; + std::string id; }; - + static const char * ggml_backend_vk_device_get_name(ggml_backend_dev_t dev) { -@@ -12493,6 +12541,11 @@ static const char * ggml_backend_vk_device_get_description(ggml_backend_dev_t de +@@ -12493,6 +12523,11 @@ static const char * ggml_backend_vk_device_get_description(ggml_backend_dev_t de return ctx->description.c_str(); } - + +static const char * ggml_backend_vk_device_get_id(ggml_backend_dev_t dev) { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + return ctx->id.c_str(); @@ -101,36 +75,18 @@ index 061cd0788..3f7d5342a 100644 static void ggml_backend_vk_device_get_memory(ggml_backend_dev_t device, size_t * free, size_t * total) { ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)device->context; ggml_backend_vk_get_device_memory(ctx->device, free, total); -@@ -12519,8 +12572,14 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml - +@@ -12519,6 +12554,7 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml + props->name = ggml_backend_vk_device_get_name(dev); props->description = ggml_backend_vk_device_get_description(dev); + props->id = ggml_backend_vk_device_get_id(dev); -+ props->library = GGML_VK_NAME; -+ props->integrated = ctx->is_integrated_gpu; props->type = ggml_backend_vk_device_get_type(dev); props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str(); -+ props->pci_bus_id = ctx->pciBusID; -+ props->pci_device_id = ctx->pciDeviceID; -+ props->pci_domain_id = ctx->pciDomainID; ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total); - props->caps = { - /* .async = */ false, -@@ -12965,6 +13024,18 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, +@@ -12965,6 +13001,7 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, ctx->description = desc; ctx->is_integrated_gpu = ggml_backend_vk_get_device_type(i) == vk::PhysicalDeviceType::eIntegratedGpu; ctx->pci_bus_id = ggml_backend_vk_get_device_pci_id(i); -+ // Parse numeric PCI components if available -+ int d = 0, b = 0, devn = 0; -+ if (ggml_backend_vk_parse_pci_bus_id(ctx->pci_bus_id, &d, &b, &devn)) { -+ ctx->pciDomainID = d; -+ ctx->pciBusID = b; -+ ctx->pciDeviceID = devn; -+ } else { -+ ctx->pciDomainID = 0; -+ ctx->pciBusID = 0; -+ ctx->pciDeviceID = 0; -+ } + ctx->id = ggml_backend_vk_get_device_id(i); devices.push_back(new ggml_backend_device { /* .iface = */ ggml_backend_vk_device_i, From 235613903713b9d4934ba3ebafca0e2ebc15aaff Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Mon, 6 Oct 2025 20:06:42 +0200 Subject: [PATCH 161/172] Create 0028-vulkan-pci-and-memory.patch --- .../patches/0028-vulkan-pci-and-memory.patch | 221 ++++++++++++++++++ 1 file changed, 221 insertions(+) create mode 100644 llama/patches/0028-vulkan-pci-and-memory.patch diff --git a/llama/patches/0028-vulkan-pci-and-memory.patch b/llama/patches/0028-vulkan-pci-and-memory.patch new file mode 100644 index 000000000..56743006a --- /dev/null +++ b/llama/patches/0028-vulkan-pci-and-memory.patch @@ -0,0 +1,221 @@ +commit 0000000000000000000000000000000000000000 +Author: Daniel Hiltgen +Date: Fri Sep 5 08:25:03 2025 -0700 + + WIP - wire up Vulkan with the new engine based discovery + + Not a complete implementation - free VRAM is better, but not accurate on + windows + +diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp +index d73cdf17..3b0a0891 100644 +--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp ++++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp +@@ -4123,7 +4123,6 @@ static void ggml_vk_instance_init() { + } + } else { + std::vector devices = vk_instance.instance.enumeratePhysicalDevices(); +- + // If no vulkan devices are found, return early + if (devices.empty()) { + GGML_LOG_INFO("ggml_vulkan: No devices found.\n"); +@@ -10821,14 +10820,90 @@ std::string ggml_backend_vk_get_device_id(int device) { + return ggml_vk_get_device_id(dev_idx); + } + +-void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total) { +- GGML_ASSERT(device < (int) vk_instance.device_indices.size()); ++////////////////////////// ++ ++struct ggml_backend_vk_device_context { ++ size_t device; ++ std::string name; ++ std::string description; ++ std::string id; ++ std::string uuid; ++ int major; ++ int minor; ++ int driver_major; ++ int driver_minor; ++ int integrated; ++ int pci_bus_id; ++ int pci_device_id; ++ int pci_domain_id; ++}; ++ ++void ggml_backend_vk_get_device_memory(ggml_backend_vk_device_context *ctx, size_t * free, size_t * total) { ++ GGML_ASSERT(ctx->device < (int) vk_instance.device_indices.size()); + +- vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device]]; ++ vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[ctx->device]]; + + vk::PhysicalDeviceMemoryProperties memprops = vkdev.getMemoryProperties(); ++ vk::PhysicalDeviceProperties2 props2; ++ vkdev.getProperties2(&props2); + +- for (const vk::MemoryHeap& heap : memprops.memoryHeaps) { ++ // Use vendor specific management libraries for best VRAM reporting if available ++ switch (props2.properties.vendorID) { ++ case VK_VENDOR_ID_AMD: ++ if (ggml_hip_mgmt_init() == 0) { ++ int status = ggml_hip_get_device_memory(ctx->pci_bus_id, ctx->pci_device_id, free, total); ++ if (status == 0) { ++ GGML_LOG_DEBUG("%s utilizing ADLX memory reporting free: %zu total: %zu\n", __func__, *free, *total); ++ ggml_hip_mgmt_release(); ++ return; ++ } ++ ggml_hip_mgmt_release(); ++ } ++ break; ++ case VK_VENDOR_ID_NVIDIA: ++ if (ggml_nvml_init() == 0) { ++ int status = ggml_nvml_get_device_memory(ctx->uuid.c_str(), free, total); ++ if (status == 0) { ++ GGML_LOG_DEBUG("%s utilizing NVML memory reporting free: %zu total: %zu\n", __func__, *free, *total); ++ ggml_nvml_release(); ++ return; ++ } ++ ggml_nvml_release(); ++ } ++ break; ++ } ++ // else fallback to memory budget if supported ++ ++ *total = 0; ++ *free = 0; ++ vk::PhysicalDeviceMemoryBudgetPropertiesEXT mem_budget_props; ++ vk::PhysicalDeviceMemoryProperties2 memprops2; ++ memprops2.pNext = &mem_budget_props; ++ vkdev.getMemoryProperties2(&memprops2); ++ for (int i = 0; i < memprops2.memoryProperties.memoryHeapCount; i++) { ++ if (memprops2.memoryProperties.memoryHeaps[i].flags & vk::MemoryHeapFlagBits::eDeviceLocal) { ++ *total += memprops2.memoryProperties.memoryHeaps[i].size; ++ } else if (ctx->integrated) { ++ // Include shared memory on iGPUs ++ *total += memprops2.memoryProperties.memoryHeaps[i].size; ++ } ++ } ++ for (int i = 0; i < memprops2.memoryProperties.memoryHeapCount; i++) { ++ if (memprops2.memoryProperties.memoryHeaps[i].flags & vk::MemoryHeapFlagBits::eDeviceLocal) { ++ *free += mem_budget_props.heapBudget[i]; ++ } else if (ctx->integrated) { ++ *free += mem_budget_props.heapBudget[i]; ++ } ++ } ++ if (*total > 0 && *free > 0) { ++ return; ++ } else if (*total > 0) { ++ *free = *total; ++ return; ++ } ++ ++ // else just report the physical memory ++ for (const vk::MemoryHeap& heap : memprops2.memoryProperties.memoryHeaps) { + if (heap.flags & vk::MemoryHeapFlagBits::eDeviceLocal) { + *total = heap.size; + *free = heap.size; +@@ -10837,14 +10912,6 @@ void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total + } + } + +-////////////////////////// +- +-struct ggml_backend_vk_device_context { +- size_t device; +- std::string name; +- std::string description; +- std::string id; +-}; + + static const char * ggml_backend_vk_device_get_name(ggml_backend_dev_t dev) { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; +@@ -10863,7 +10930,7 @@ static const char * ggml_backend_vk_device_get_id(ggml_backend_dev_t dev) { + + static void ggml_backend_vk_device_get_memory(ggml_backend_dev_t device, size_t * free, size_t * total) { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)device->context; +- ggml_backend_vk_get_device_memory(ctx->device, free, total); ++ ggml_backend_vk_get_device_memory(ctx, free, total); + } + + static ggml_backend_buffer_type_t ggml_backend_vk_device_get_buffer_type(ggml_backend_dev_t dev) { +@@ -10881,6 +10948,7 @@ static enum ggml_backend_dev_type ggml_backend_vk_device_get_type(ggml_backend_d + return GGML_BACKEND_DEVICE_TYPE_GPU; + } + ++#define GGML_VULKAN_NAME "VULKAN" + static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) { + props->name = ggml_backend_vk_device_get_name(dev); + props->description = ggml_backend_vk_device_get_description(dev); +@@ -10893,6 +10961,18 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml + /* .buffer_from_host_ptr = */ false, + /* .events = */ false, + }; ++ ++ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; ++ props->id = ctx->id.c_str(); ++ props->compute_major = ctx->major; ++ props->compute_minor = ctx->minor; ++ props->driver_major = ctx->driver_major; ++ props->driver_minor = ctx->driver_minor; ++ props->integrated = ctx->integrated; ++ props->pci_bus_id = ctx->pci_bus_id; ++ props->pci_device_id = ctx->pci_device_id; ++ props->pci_domain_id = ctx->pci_domain_id; ++ props->library = GGML_VULKAN_NAME; + } + + static ggml_backend_t ggml_backend_vk_device_init(ggml_backend_dev_t dev, const char * params) { +@@ -11296,6 +11376,8 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, + static std::mutex mutex; + std::lock_guard lock(mutex); + if (!initialized) { ++ std::vector vk_devices = vk_instance.instance.enumeratePhysicalDevices(); ++ + for (int i = 0; i < ggml_backend_vk_get_device_count(); i++) { + ggml_backend_vk_device_context * ctx = new ggml_backend_vk_device_context; + char desc[256]; +@@ -11309,6 +11391,44 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, + /* .reg = */ reg, + /* .context = */ ctx, + }); ++ ++ // Gather additional information about the device ++ int dev_idx = vk_instance.device_indices[i]; ++ vk::PhysicalDeviceProperties props1; ++ vk_devices[dev_idx].getProperties(&props1); ++ vk::PhysicalDeviceProperties2 props2; ++ vk::PhysicalDeviceIDProperties device_id_props; ++ vk::PhysicalDevicePCIBusInfoPropertiesEXT pci_bus_props; ++ vk::PhysicalDeviceDriverProperties driver_props; ++ props2.pNext = &device_id_props; ++ device_id_props.pNext = &pci_bus_props; ++ pci_bus_props.pNext = &driver_props; ++ vk_devices[dev_idx].getProperties2(&props2); ++ std::ostringstream oss; ++ oss << std::hex << std::setfill('0'); ++ oss << "GPU-"; ++ int byteIdx = 0; ++ for (int i = 0; i < 16; ++i, ++byteIdx) { ++ oss << std::setw(2) << static_cast(device_id_props.deviceUUID[i]); ++ if (byteIdx == 3 || byteIdx == 5 || byteIdx == 7 || byteIdx == 9) { ++ oss << '-'; ++ } ++ } ++ ctx->uuid = oss.str(); ++ ctx->pci_bus_id = pci_bus_props.pciBus; ++ ctx->pci_device_id = pci_bus_props.pciDevice; ++ ctx->pci_domain_id = pci_bus_props.pciDomain; ++ ctx->id = std::to_string(i); ++ if (props1.deviceType == vk::PhysicalDeviceType::eIntegratedGpu) { ++ ctx->integrated = 1; ++ } else { ++ ctx->integrated = 0; ++ } ++ ctx->major = 0; ++ ctx->minor = 0; ++ // TODO regex parse driver_props.driverInfo for a X.Y or X.Y.Z version string ++ ctx->driver_major = 0; ++ ctx->driver_minor = 0; + } + initialized = true; + } From 06bb04873733c087dad8e7699aa06a25ed9d9023 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Mon, 6 Oct 2025 20:16:27 +0200 Subject: [PATCH 162/172] vulkan memory patch --- llama/patches/0028-vulkan-pci-and-memory.patch | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/llama/patches/0028-vulkan-pci-and-memory.patch b/llama/patches/0028-vulkan-pci-and-memory.patch index 56743006a..a0b808e03 100644 --- a/llama/patches/0028-vulkan-pci-and-memory.patch +++ b/llama/patches/0028-vulkan-pci-and-memory.patch @@ -1,11 +1,14 @@ -commit 0000000000000000000000000000000000000000 +commit 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 Author: Daniel Hiltgen Date: Fri Sep 5 08:25:03 2025 -0700 +Subject: [PATCH] vulkan PCI and Memory - WIP - wire up Vulkan with the new engine based discovery - - Not a complete implementation - free VRAM is better, but not accurate on - windows +WIP - wire up Vulkan with the new engine based discovery +Not a complete implementation - free VRAM is better, but not accurate on +windows +--- + ggml/src/ggml-vulkan/ggml-vulkan.cpp | 37 ++++++++++++++++++++++++++++ + 1 file changed, 37 insertions(+) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index d73cdf17..3b0a0891 100644 From 23137c1d41204faa20f1358da6850f937cbd17a9 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Mon, 6 Oct 2025 20:17:20 +0200 Subject: [PATCH 163/172] casing fix --- llama/patches/0028-vulkan-pci-and-memory.patch | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama/patches/0028-vulkan-pci-and-memory.patch b/llama/patches/0028-vulkan-pci-and-memory.patch index a0b808e03..0665c9a8c 100644 --- a/llama/patches/0028-vulkan-pci-and-memory.patch +++ b/llama/patches/0028-vulkan-pci-and-memory.patch @@ -1,7 +1,7 @@ commit 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 Author: Daniel Hiltgen Date: Fri Sep 5 08:25:03 2025 -0700 -Subject: [PATCH] vulkan PCI and Memory +Subject: [PATCH] Vulkan PCI and Memory WIP - wire up Vulkan with the new engine based discovery Not a complete implementation - free VRAM is better, but not accurate on From 511bbad6cda59736b5e3af931cd1bd2cbb450ba1 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Mon, 6 Oct 2025 23:48:14 +0200 Subject: [PATCH 164/172] Add more pci properties --- .../ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp | 80 +++++++++++++------ 1 file changed, 56 insertions(+), 24 deletions(-) diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 3f7d5342a..7f2bdea80 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -12521,14 +12521,17 @@ struct ggml_backend_vk_device_context { std::string name; std::string description; bool is_integrated_gpu; - // PCI information (if available via VK_EXT_pci_bus_info) - // Numeric components for convenience/interop with higher layers - int pciBusID = 0; - int pciDeviceID = 0; - int pciDomainID = 0; // Combined string id in the form "dddd:bb:dd.f" (domain:bus:device.function) - std::string pci_bus_id; + std::string pci_id; std::string id; + std::string uuid; + int major; + int minor; + int driver_major; + int driver_minor; + int pci_bus_id; + int pci_device_id; + int pci_domain_id; }; static const char * ggml_backend_vk_device_get_name(ggml_backend_dev_t dev) { @@ -12573,13 +12576,8 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml props->name = ggml_backend_vk_device_get_name(dev); props->description = ggml_backend_vk_device_get_description(dev); props->id = ggml_backend_vk_device_get_id(dev); - props->library = GGML_VK_NAME; - props->integrated = ctx->is_integrated_gpu; props->type = ggml_backend_vk_device_get_type(dev); - props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str(); - props->pci_bus_id = ctx->pciBusID; - props->pci_device_id = ctx->pciDeviceID; - props->pci_domain_id = ctx->pciDomainID; + props->device_id = ctx->pci_id.empty() ? nullptr : ctx->pci_id.c_str(); ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total); props->caps = { /* .async = */ false, @@ -12587,6 +12585,16 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml /* .buffer_from_host_ptr = */ false, /* .events = */ false, }; + + props->compute_major = ctx->major; + props->compute_minor = ctx->minor; + props->driver_major = ctx->driver_major; + props->driver_minor = ctx->driver_minor; + props->integrated = ctx->is_integrated_gpu; + props->pci_bus_id = ctx->pci_bus_id; + props->pci_device_id = ctx->pci_device_id; + props->pci_domain_id = ctx->pci_domain_id; + props->library = GGML_VK_NAME; } static ggml_backend_t ggml_backend_vk_device_init(ggml_backend_dev_t dev, const char * params) { @@ -13015,6 +13023,8 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, static std::mutex mutex; std::lock_guard lock(mutex); if (!initialized) { + std::vector vk_devices = vk_instance.instance.enumeratePhysicalDevices(); + for (int i = 0; i < ggml_backend_vk_get_device_count(); i++) { ggml_backend_vk_device_context * ctx = new ggml_backend_vk_device_context; char desc[256]; @@ -13023,24 +13033,46 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, ctx->name = GGML_VK_NAME + std::to_string(i); ctx->description = desc; ctx->is_integrated_gpu = ggml_backend_vk_get_device_type(i) == vk::PhysicalDeviceType::eIntegratedGpu; - ctx->pci_bus_id = ggml_backend_vk_get_device_pci_id(i); - // Parse numeric PCI components if available - int d = 0, b = 0, devn = 0; - if (ggml_backend_vk_parse_pci_bus_id(ctx->pci_bus_id, &d, &b, &devn)) { - ctx->pciDomainID = d; - ctx->pciBusID = b; - ctx->pciDeviceID = devn; - } else { - ctx->pciDomainID = 0; - ctx->pciBusID = 0; - ctx->pciDeviceID = 0; - } + ctx->pci_id = ggml_backend_vk_get_device_pci_id(i); ctx->id = ggml_backend_vk_get_device_id(i); devices.push_back(new ggml_backend_device { /* .iface = */ ggml_backend_vk_device_i, /* .reg = */ reg, /* .context = */ ctx, }); + + // Gather additional information about the device + int dev_idx = vk_instance.device_indices[i]; + vk::PhysicalDeviceProperties props1; + vk_devices[dev_idx].getProperties(&props1); + vk::PhysicalDeviceProperties2 props2; + vk::PhysicalDeviceIDProperties device_id_props; + vk::PhysicalDevicePCIBusInfoPropertiesEXT pci_bus_props; + vk::PhysicalDeviceDriverProperties driver_props; + props2.pNext = &device_id_props; + device_id_props.pNext = &pci_bus_props; + pci_bus_props.pNext = &driver_props; + vk_devices[dev_idx].getProperties2(&props2); + std::ostringstream oss; + oss << std::hex << std::setfill('0'); + oss << "GPU-"; + int byteIdx = 0; + for (int i = 0; i < 16; ++i, ++byteIdx) { + oss << std::setw(2) << static_cast(device_id_props.deviceUUID[i]); + if (byteIdx == 3 || byteIdx == 5 || byteIdx == 7 || byteIdx == 9) { + oss << '-'; + } + } + ctx->uuid = oss.str(); + ctx->pci_bus_id = pci_bus_props.pciBus; + ctx->pci_device_id = pci_bus_props.pciDevice; + ctx->pci_domain_id = pci_bus_props.pciDomain; + ctx->id = std::to_string(i); + ctx->major = 0; + ctx->minor = 0; + // TODO regex parse driver_props.driverInfo for a X.Y or X.Y.Z version string + ctx->driver_major = 0; + ctx->driver_minor = 0; } initialized = true; } From 5a417dfe925b88ec7219206a7c4e03b3913e426b Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Tue, 7 Oct 2025 00:04:11 +0200 Subject: [PATCH 165/172] Added better memory management --- .../ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp | 123 ++++++++++++------ 1 file changed, 84 insertions(+), 39 deletions(-) diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 7f2bdea80..57dd19ce0 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -12423,31 +12423,96 @@ std::string ggml_backend_vk_get_device_id(int device) { return ggml_vk_get_device_id(dev_idx); } -void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total) { - GGML_ASSERT(device < (int) vk_instance.device_indices.size()); - GGML_ASSERT(device < (int) vk_instance.device_supports_membudget.size()); +////////////////////////// - vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device]]; - vk::PhysicalDeviceMemoryBudgetPropertiesEXT budgetprops; - vk::PhysicalDeviceMemoryProperties2 memprops = {}; - bool membudget_supported = vk_instance.device_supports_membudget[device]; +struct ggml_backend_vk_device_context { + size_t device; + std::string name; + std::string description; + bool is_integrated_gpu; + // Combined string id in the form "dddd:bb:dd.f" (domain:bus:device.function) + std::string pci_id; + std::string id; + std::string uuid; + int major; + int minor; + int driver_major; + int driver_minor; + int pci_bus_id; + int pci_device_id; + int pci_domain_id; +}; - if (membudget_supported) { - memprops.pNext = &budgetprops; +void ggml_backend_vk_get_device_memory(ggml_backend_vk_device_context *ctx, size_t * free, size_t * total) { + GGML_ASSERT(ctx->device < (int) vk_instance.device_indices.size()); + GGML_ASSERT(ctx->device < (int) vk_instance.device_supports_membudget.size()); + + vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[ctx->device]]; + + vk::PhysicalDeviceMemoryProperties memprops = vkdev.getMemoryProperties(); + vk::PhysicalDeviceProperties2 props2; + vkdev.getProperties2(&props2); + + // Use vendor specific management libraries for best VRAM reporting if available + switch (props2.properties.vendorID) { + case VK_VENDOR_ID_AMD: + if (ggml_hip_mgmt_init() == 0) { + int status = ggml_hip_get_device_memory(ctx->pci_bus_id, ctx->pci_device_id, free, total); + if (status == 0) { + GGML_LOG_DEBUG("%s utilizing ADLX memory reporting free: %zu total: %zu\n", __func__, *free, *total); + ggml_hip_mgmt_release(); + return; + } + ggml_hip_mgmt_release(); + } + break; + case VK_VENDOR_ID_NVIDIA: + if (ggml_nvml_init() == 0) { + int status = ggml_nvml_get_device_memory(ctx->uuid.c_str(), free, total); + if (status == 0) { + GGML_LOG_DEBUG("%s utilizing NVML memory reporting free: %zu total: %zu\n", __func__, *free, *total); + ggml_nvml_release(); + return; + } + ggml_nvml_release(); + } + break; } - vkdev.getMemoryProperties2(&memprops); + // else fallback to memory budget if supported - for (uint32_t i = 0; i < memprops.memoryProperties.memoryHeapCount; ++i) { - const vk::MemoryHeap & heap = memprops.memoryProperties.memoryHeaps[i]; + *total = 0; + *free = 0; + vk::PhysicalDeviceMemoryBudgetPropertiesEXT mem_budget_props; + vk::PhysicalDeviceMemoryProperties2 memprops2; + memprops2.pNext = &mem_budget_props; + vkdev.getMemoryProperties2(&memprops2); + for (int i = 0; i < memprops2.memoryProperties.memoryHeapCount; i++) { + if (memprops2.memoryProperties.memoryHeaps[i].flags & vk::MemoryHeapFlagBits::eDeviceLocal) { + *total += memprops2.memoryProperties.memoryHeaps[i].size; + } else if (ctx->is_integrated_gpu) { + // Include shared memory on iGPUs + *total += memprops2.memoryProperties.memoryHeaps[i].size; + } + } + for (int i = 0; i < memprops2.memoryProperties.memoryHeapCount; i++) { + if (memprops2.memoryProperties.memoryHeaps[i].flags & vk::MemoryHeapFlagBits::eDeviceLocal) { + *free += mem_budget_props.heapBudget[i]; + } else if (ctx->is_integrated_gpu) { + *free += mem_budget_props.heapBudget[i]; + } + } + if (*total > 0 && *free > 0) { + return; + } else if (*total > 0) { + *free = *total; + return; + } + // else just report the physical memory + for (const vk::MemoryHeap& heap : memprops2.memoryProperties.memoryHeaps) { if (heap.flags & vk::MemoryHeapFlagBits::eDeviceLocal) { *total = heap.size; - - if (membudget_supported && i < budgetprops.heapUsage.size()) { - *free = budgetprops.heapBudget[i] - budgetprops.heapUsage[i]; - } else { - *free = heap.size; - } + *free = heap.size; break; } } @@ -12514,26 +12579,6 @@ static bool ggml_backend_vk_parse_pci_bus_id(const std::string & id, int *domain return true; } -////////////////////////// - -struct ggml_backend_vk_device_context { - size_t device; - std::string name; - std::string description; - bool is_integrated_gpu; - // Combined string id in the form "dddd:bb:dd.f" (domain:bus:device.function) - std::string pci_id; - std::string id; - std::string uuid; - int major; - int minor; - int driver_major; - int driver_minor; - int pci_bus_id; - int pci_device_id; - int pci_domain_id; -}; - static const char * ggml_backend_vk_device_get_name(ggml_backend_dev_t dev) { ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; return ctx->name.c_str(); @@ -12551,7 +12596,7 @@ static const char * ggml_backend_vk_device_get_id(ggml_backend_dev_t dev) { static void ggml_backend_vk_device_get_memory(ggml_backend_dev_t device, size_t * free, size_t * total) { ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)device->context; - ggml_backend_vk_get_device_memory(ctx->device, free, total); + ggml_backend_vk_get_device_memory(ctx, free, total); } static ggml_backend_buffer_type_t ggml_backend_vk_device_get_buffer_type(ggml_backend_dev_t dev) { From 8575a66178d33ed5f17c06986726abb24393f603 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Tue, 7 Oct 2025 00:21:04 +0200 Subject: [PATCH 166/172] Added better memory managament --- .../ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp | 41 ++++++++++--------- 1 file changed, 22 insertions(+), 19 deletions(-) diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 57dd19ce0..fb7204ce9 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -12453,30 +12453,33 @@ void ggml_backend_vk_get_device_memory(ggml_backend_vk_device_context *ctx, size vk::PhysicalDeviceProperties2 props2; vkdev.getProperties2(&props2); - // Use vendor specific management libraries for best VRAM reporting if available - switch (props2.properties.vendorID) { - case VK_VENDOR_ID_AMD: - if (ggml_hip_mgmt_init() == 0) { - int status = ggml_hip_get_device_memory(ctx->pci_bus_id, ctx->pci_device_id, free, total); - if (status == 0) { - GGML_LOG_DEBUG("%s utilizing ADLX memory reporting free: %zu total: %zu\n", __func__, *free, *total); + if (!ctx->is_integrated_gpu) + { + // Use vendor specific management libraries for best VRAM reporting if available + switch (props2.properties.vendorID) { + case VK_VENDOR_ID_AMD: + if (ggml_hip_mgmt_init() == 0) { + int status = ggml_hip_get_device_memory(ctx->pci_bus_id, ctx->pci_device_id, free, total); + if (status == 0) { + GGML_LOG_DEBUG("%s utilizing ADLX memory reporting free: %zu total: %zu\n", __func__, *free, *total); + ggml_hip_mgmt_release(); + return; + } ggml_hip_mgmt_release(); - return; } - ggml_hip_mgmt_release(); - } - break; - case VK_VENDOR_ID_NVIDIA: - if (ggml_nvml_init() == 0) { - int status = ggml_nvml_get_device_memory(ctx->uuid.c_str(), free, total); - if (status == 0) { - GGML_LOG_DEBUG("%s utilizing NVML memory reporting free: %zu total: %zu\n", __func__, *free, *total); + break; + case VK_VENDOR_ID_NVIDIA: + if (ggml_nvml_init() == 0) { + int status = ggml_nvml_get_device_memory(ctx->uuid.c_str(), free, total); + if (status == 0) { + GGML_LOG_DEBUG("%s utilizing NVML memory reporting free: %zu total: %zu\n", __func__, *free, *total); + ggml_nvml_release(); + return; + } ggml_nvml_release(); - return; } - ggml_nvml_release(); + break; } - break; } // else fallback to memory budget if supported From 1e2a5188bcefc4f0f33290572fbd8c84e33d42d7 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Tue, 7 Oct 2025 00:34:52 +0200 Subject: [PATCH 167/172] fixed patch --- .../patches/0028-vulkan-pci-and-memory.patch | 178 ++++++++++-------- 1 file changed, 104 insertions(+), 74 deletions(-) diff --git a/llama/patches/0028-vulkan-pci-and-memory.patch b/llama/patches/0028-vulkan-pci-and-memory.patch index 0665c9a8c..a55594fca 100644 --- a/llama/patches/0028-vulkan-pci-and-memory.patch +++ b/llama/patches/0028-vulkan-pci-and-memory.patch @@ -3,44 +3,37 @@ Author: Daniel Hiltgen Date: Fri Sep 5 08:25:03 2025 -0700 Subject: [PATCH] Vulkan PCI and Memory -WIP - wire up Vulkan with the new engine based discovery -Not a complete implementation - free VRAM is better, but not accurate on -windows + --- - ggml/src/ggml-vulkan/ggml-vulkan.cpp | 37 ++++++++++++++++++++++++++++ - 1 file changed, 37 insertions(+) + ggml/src/ggml-vulkan/ggml-vulkan.cpp | 176 ++++++++++++++++++++++----- + 1 file changed, 145 insertions(+), 31 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp -index d73cdf17..3b0a0891 100644 +index adea7783..fb7204ce 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp -@@ -4123,7 +4123,6 @@ static void ggml_vk_instance_init() { - } - } else { - std::vector devices = vk_instance.instance.enumeratePhysicalDevices(); -- - // If no vulkan devices are found, return early - if (devices.empty()) { - GGML_LOG_INFO("ggml_vulkan: No devices found.\n"); -@@ -10821,14 +10820,90 @@ std::string ggml_backend_vk_get_device_id(int device) { +@@ -12423,31 +12423,99 @@ std::string ggml_backend_vk_get_device_id(int device) { return ggml_vk_get_device_id(dev_idx); } - + -void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total) { - GGML_ASSERT(device < (int) vk_instance.device_indices.size()); +- GGML_ASSERT(device < (int) vk_instance.device_supports_membudget.size()); +////////////////////////// + +struct ggml_backend_vk_device_context { + size_t device; + std::string name; + std::string description; ++ bool is_integrated_gpu; ++ // Combined string id in the form "dddd:bb:dd.f" (domain:bus:device.function) ++ std::string pci_id; + std::string id; + std::string uuid; + int major; + int minor; + int driver_major; + int driver_minor; -+ int integrated; + int pci_bus_id; + int pci_device_id; + int pci_domain_id; @@ -48,42 +41,53 @@ index d73cdf17..3b0a0891 100644 + +void ggml_backend_vk_get_device_memory(ggml_backend_vk_device_context *ctx, size_t * free, size_t * total) { + GGML_ASSERT(ctx->device < (int) vk_instance.device_indices.size()); - -- vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device]]; ++ GGML_ASSERT(ctx->device < (int) vk_instance.device_supports_membudget.size()); ++ + vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[ctx->device]]; - - vk::PhysicalDeviceMemoryProperties memprops = vkdev.getMemoryProperties(); + +- vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device]]; +- vk::PhysicalDeviceMemoryBudgetPropertiesEXT budgetprops; +- vk::PhysicalDeviceMemoryProperties2 memprops = {}; +- bool membudget_supported = vk_instance.device_supports_membudget[device]; ++ vk::PhysicalDeviceMemoryProperties memprops = vkdev.getMemoryProperties(); + vk::PhysicalDeviceProperties2 props2; + vkdev.getProperties2(&props2); - -- for (const vk::MemoryHeap& heap : memprops.memoryHeaps) { -+ // Use vendor specific management libraries for best VRAM reporting if available -+ switch (props2.properties.vendorID) { -+ case VK_VENDOR_ID_AMD: -+ if (ggml_hip_mgmt_init() == 0) { -+ int status = ggml_hip_get_device_memory(ctx->pci_bus_id, ctx->pci_device_id, free, total); -+ if (status == 0) { -+ GGML_LOG_DEBUG("%s utilizing ADLX memory reporting free: %zu total: %zu\n", __func__, *free, *total); + +- if (membudget_supported) { +- memprops.pNext = &budgetprops; ++ if (!ctx->is_integrated_gpu) ++ { ++ // Use vendor specific management libraries for best VRAM reporting if available ++ switch (props2.properties.vendorID) { ++ case VK_VENDOR_ID_AMD: ++ if (ggml_hip_mgmt_init() == 0) { ++ int status = ggml_hip_get_device_memory(ctx->pci_bus_id, ctx->pci_device_id, free, total); ++ if (status == 0) { ++ GGML_LOG_DEBUG("%s utilizing ADLX memory reporting free: %zu total: %zu\n", __func__, *free, *total); ++ ggml_hip_mgmt_release(); ++ return; ++ } + ggml_hip_mgmt_release(); -+ return; + } -+ ggml_hip_mgmt_release(); -+ } -+ break; -+ case VK_VENDOR_ID_NVIDIA: -+ if (ggml_nvml_init() == 0) { -+ int status = ggml_nvml_get_device_memory(ctx->uuid.c_str(), free, total); -+ if (status == 0) { -+ GGML_LOG_DEBUG("%s utilizing NVML memory reporting free: %zu total: %zu\n", __func__, *free, *total); ++ break; ++ case VK_VENDOR_ID_NVIDIA: ++ if (ggml_nvml_init() == 0) { ++ int status = ggml_nvml_get_device_memory(ctx->uuid.c_str(), free, total); ++ if (status == 0) { ++ GGML_LOG_DEBUG("%s utilizing NVML memory reporting free: %zu total: %zu\n", __func__, *free, *total); ++ ggml_nvml_release(); ++ return; ++ } + ggml_nvml_release(); -+ return; + } -+ ggml_nvml_release(); ++ break; + } -+ break; -+ } + } +- vkdev.getMemoryProperties2(&memprops); + // else fallback to memory budget if supported -+ + +- for (uint32_t i = 0; i < memprops.memoryProperties.memoryHeapCount; ++i) { +- const vk::MemoryHeap & heap = memprops.memoryProperties.memoryHeaps[i]; + *total = 0; + *free = 0; + vk::PhysicalDeviceMemoryBudgetPropertiesEXT mem_budget_props; @@ -93,7 +97,7 @@ index d73cdf17..3b0a0891 100644 + for (int i = 0; i < memprops2.memoryProperties.memoryHeapCount; i++) { + if (memprops2.memoryProperties.memoryHeaps[i].flags & vk::MemoryHeapFlagBits::eDeviceLocal) { + *total += memprops2.memoryProperties.memoryHeaps[i].size; -+ } else if (ctx->integrated) { ++ } else if (ctx->is_integrated_gpu) { + // Include shared memory on iGPUs + *total += memprops2.memoryProperties.memoryHeaps[i].size; + } @@ -101,7 +105,7 @@ index d73cdf17..3b0a0891 100644 + for (int i = 0; i < memprops2.memoryProperties.memoryHeapCount; i++) { + if (memprops2.memoryProperties.memoryHeaps[i].flags & vk::MemoryHeapFlagBits::eDeviceLocal) { + *free += mem_budget_props.heapBudget[i]; -+ } else if (ctx->integrated) { ++ } else if (ctx->is_integrated_gpu) { + *free += mem_budget_props.heapBudget[i]; + } + } @@ -111,64 +115,85 @@ index d73cdf17..3b0a0891 100644 + *free = *total; + return; + } -+ + + // else just report the physical memory + for (const vk::MemoryHeap& heap : memprops2.memoryProperties.memoryHeaps) { if (heap.flags & vk::MemoryHeapFlagBits::eDeviceLocal) { *total = heap.size; - *free = heap.size; -@@ -10837,14 +10912,6 @@ void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total +- +- if (membudget_supported && i < budgetprops.heapUsage.size()) { +- *free = budgetprops.heapBudget[i] - budgetprops.heapUsage[i]; +- } else { +- *free = heap.size; +- } ++ *free = heap.size; + break; + } } +@@ -12502,16 +12570,17 @@ static std::string ggml_backend_vk_get_device_pci_id(int device_idx) { + return std::string(pci_bus_id); } - + -////////////////////////// - -struct ggml_backend_vk_device_context { - size_t device; - std::string name; - std::string description; +- bool is_integrated_gpu; +- std::string pci_bus_id; - std::string id; -}; - ++static bool ggml_backend_vk_parse_pci_bus_id(const std::string & id, int *domain, int *bus, int *device) { ++ if (id.empty()) return false; ++ unsigned int d = 0, b = 0, dev = 0, func = 0; ++ // Expected format: dddd:bb:dd.f (all hex) ++ int n = sscanf(id.c_str(), "%4x:%2x:%2x.%1x", &d, &b, &dev, &func); ++ if (n < 4) return false; ++ if (domain) *domain = (int) d; ++ if (bus) *bus = (int) b; ++ if (device) *device = (int) dev; ++ return true; ++} + static const char * ggml_backend_vk_device_get_name(ggml_backend_dev_t dev) { ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; -@@ -10863,7 +10930,7 @@ static const char * ggml_backend_vk_device_get_id(ggml_backend_dev_t dev) { - +@@ -12530,7 +12599,7 @@ static const char * ggml_backend_vk_device_get_id(ggml_backend_dev_t dev) { + static void ggml_backend_vk_device_get_memory(ggml_backend_dev_t device, size_t * free, size_t * total) { ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)device->context; - ggml_backend_vk_get_device_memory(ctx->device, free, total); + ggml_backend_vk_get_device_memory(ctx, free, total); } - + static ggml_backend_buffer_type_t ggml_backend_vk_device_get_buffer_type(ggml_backend_dev_t dev) { -@@ -10881,6 +10948,7 @@ static enum ggml_backend_dev_type ggml_backend_vk_device_get_type(ggml_backend_d - return GGML_BACKEND_DEVICE_TYPE_GPU; - } - -+#define GGML_VULKAN_NAME "VULKAN" - static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) { - props->name = ggml_backend_vk_device_get_name(dev); +@@ -12556,7 +12625,7 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml props->description = ggml_backend_vk_device_get_description(dev); -@@ -10893,6 +10961,18 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml + props->id = ggml_backend_vk_device_get_id(dev); + props->type = ggml_backend_vk_device_get_type(dev); +- props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str(); ++ props->device_id = ctx->pci_id.empty() ? nullptr : ctx->pci_id.c_str(); + ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total); + props->caps = { + /* .async = */ false, +@@ -12564,6 +12633,16 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml /* .buffer_from_host_ptr = */ false, /* .events = */ false, }; + -+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; -+ props->id = ctx->id.c_str(); + props->compute_major = ctx->major; + props->compute_minor = ctx->minor; + props->driver_major = ctx->driver_major; + props->driver_minor = ctx->driver_minor; -+ props->integrated = ctx->integrated; ++ props->integrated = ctx->is_integrated_gpu; + props->pci_bus_id = ctx->pci_bus_id; + props->pci_device_id = ctx->pci_device_id; + props->pci_domain_id = ctx->pci_domain_id; -+ props->library = GGML_VULKAN_NAME; ++ props->library = GGML_VK_NAME; } - + static ggml_backend_t ggml_backend_vk_device_init(ggml_backend_dev_t dev, const char * params) { -@@ -11296,6 +11376,8 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, +@@ -12992,6 +13071,8 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, static std::mutex mutex; std::lock_guard lock(mutex); if (!initialized) { @@ -177,7 +202,15 @@ index d73cdf17..3b0a0891 100644 for (int i = 0; i < ggml_backend_vk_get_device_count(); i++) { ggml_backend_vk_device_context * ctx = new ggml_backend_vk_device_context; char desc[256]; -@@ -11309,6 +11391,44 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, +@@ -13000,13 +13081,46 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, + ctx->name = GGML_VK_NAME + std::to_string(i); + ctx->description = desc; + ctx->is_integrated_gpu = ggml_backend_vk_get_device_type(i) == vk::PhysicalDeviceType::eIntegratedGpu; +- ctx->pci_bus_id = ggml_backend_vk_get_device_pci_id(i); ++ ctx->pci_id = ggml_backend_vk_get_device_pci_id(i); + ctx->id = ggml_backend_vk_get_device_id(i); + devices.push_back(new ggml_backend_device { + /* .iface = */ ggml_backend_vk_device_i, /* .reg = */ reg, /* .context = */ ctx, }); @@ -209,11 +242,6 @@ index d73cdf17..3b0a0891 100644 + ctx->pci_device_id = pci_bus_props.pciDevice; + ctx->pci_domain_id = pci_bus_props.pciDomain; + ctx->id = std::to_string(i); -+ if (props1.deviceType == vk::PhysicalDeviceType::eIntegratedGpu) { -+ ctx->integrated = 1; -+ } else { -+ ctx->integrated = 0; -+ } + ctx->major = 0; + ctx->minor = 0; + // TODO regex parse driver_props.driverInfo for a X.Y or X.Y.Z version string @@ -222,3 +250,5 @@ index d73cdf17..3b0a0891 100644 } initialized = true; } +-- +2.51.0 \ No newline at end of file From 85ce59ae7cfb71aa6571eb2455424344f7363951 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Tue, 7 Oct 2025 00:41:02 +0200 Subject: [PATCH 168/172] Fixed patch --- .../patches/0028-vulkan-pci-and-memory.patch | 27 +++++++++---------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/llama/patches/0028-vulkan-pci-and-memory.patch b/llama/patches/0028-vulkan-pci-and-memory.patch index a55594fca..337eb847b 100644 --- a/llama/patches/0028-vulkan-pci-and-memory.patch +++ b/llama/patches/0028-vulkan-pci-and-memory.patch @@ -1,9 +1,8 @@ -commit 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 -Author: Daniel Hiltgen +From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 +From: Daniel Hiltgen Date: Fri Sep 5 08:25:03 2025 -0700 Subject: [PATCH] Vulkan PCI and Memory - --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 176 ++++++++++++++++++++++----- 1 file changed, 145 insertions(+), 31 deletions(-) @@ -15,7 +14,7 @@ index adea7783..fb7204ce 100644 @@ -12423,31 +12423,99 @@ std::string ggml_backend_vk_get_device_id(int device) { return ggml_vk_get_device_id(dev_idx); } - + -void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total) { - GGML_ASSERT(device < (int) vk_instance.device_indices.size()); - GGML_ASSERT(device < (int) vk_instance.device_supports_membudget.size()); @@ -44,7 +43,7 @@ index adea7783..fb7204ce 100644 + GGML_ASSERT(ctx->device < (int) vk_instance.device_supports_membudget.size()); + + vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[ctx->device]]; - + - vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device]]; - vk::PhysicalDeviceMemoryBudgetPropertiesEXT budgetprops; - vk::PhysicalDeviceMemoryProperties2 memprops = {}; @@ -52,7 +51,7 @@ index adea7783..fb7204ce 100644 + vk::PhysicalDeviceMemoryProperties memprops = vkdev.getMemoryProperties(); + vk::PhysicalDeviceProperties2 props2; + vkdev.getProperties2(&props2); - + - if (membudget_supported) { - memprops.pNext = &budgetprops; + if (!ctx->is_integrated_gpu) @@ -85,7 +84,7 @@ index adea7783..fb7204ce 100644 } - vkdev.getMemoryProperties2(&memprops); + // else fallback to memory budget if supported - + - for (uint32_t i = 0; i < memprops.memoryProperties.memoryHeapCount; ++i) { - const vk::MemoryHeap & heap = memprops.memoryProperties.memoryHeaps[i]; + *total = 0; @@ -115,7 +114,7 @@ index adea7783..fb7204ce 100644 + *free = *total; + return; + } - + + // else just report the physical memory + for (const vk::MemoryHeap& heap : memprops2.memoryProperties.memoryHeaps) { if (heap.flags & vk::MemoryHeapFlagBits::eDeviceLocal) { @@ -133,7 +132,7 @@ index adea7783..fb7204ce 100644 @@ -12502,16 +12570,17 @@ static std::string ggml_backend_vk_get_device_pci_id(int device_idx) { return std::string(pci_bus_id); } - + -////////////////////////// - -struct ggml_backend_vk_device_context { @@ -155,17 +154,17 @@ index adea7783..fb7204ce 100644 + if (device) *device = (int) dev; + return true; +} - + static const char * ggml_backend_vk_device_get_name(ggml_backend_dev_t dev) { ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; @@ -12530,7 +12599,7 @@ static const char * ggml_backend_vk_device_get_id(ggml_backend_dev_t dev) { - + static void ggml_backend_vk_device_get_memory(ggml_backend_dev_t device, size_t * free, size_t * total) { ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)device->context; - ggml_backend_vk_get_device_memory(ctx->device, free, total); + ggml_backend_vk_get_device_memory(ctx, free, total); } - + static ggml_backend_buffer_type_t ggml_backend_vk_device_get_buffer_type(ggml_backend_dev_t dev) { @@ -12556,7 +12625,7 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml props->description = ggml_backend_vk_device_get_description(dev); @@ -191,7 +190,7 @@ index adea7783..fb7204ce 100644 + props->pci_domain_id = ctx->pci_domain_id; + props->library = GGML_VK_NAME; } - + static ggml_backend_t ggml_backend_vk_device_init(ggml_backend_dev_t dev, const char * params) { @@ -12992,6 +13071,8 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, static std::mutex mutex; @@ -250,5 +249,5 @@ index adea7783..fb7204ce 100644 } initialized = true; } --- +-- 2.51.0 \ No newline at end of file From ad09cd8fbcf63483ef97f31a9077641a1f86d045 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Tue, 7 Oct 2025 01:57:30 +0200 Subject: [PATCH 169/172] FilterID creation group by library --- discover/runner.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/discover/runner.go b/discover/runner.go index f860292c3..02fbff1e0 100644 --- a/discover/runner.go +++ b/discover/runner.go @@ -498,9 +498,13 @@ func bootstrapDevices(ctx context.Context, ollamaLibDirs []string, extraEnvs []s } } logutil.Trace("runner enumerated devices", "OLLAMA_LIBRARY_PATH", ollamaLibDirs, "devices", devices) - // Enumerate returned devices starting at 0 and assign the index as FilteredID + + // Enumerate returned devices starting at 0 per library and assign the per-library index as FilteredID + libCounts := make(map[string]int) for i := range devices { - devices[i].FilteredID = strconv.Itoa(i) + lib := devices[i].Library + devices[i].FilteredID = strconv.Itoa(libCounts[lib]) + libCounts[lib]++ } return devices } From bb7c4329707ce72dc12f5c332008654b7756b070 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Tue, 7 Oct 2025 02:19:27 +0200 Subject: [PATCH 170/172] filter out vulkan supported by other gpu --- discover/runner.go | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/discover/runner.go b/discover/runner.go index 02fbff1e0..64ac39a3e 100644 --- a/discover/runner.go +++ b/discover/runner.go @@ -158,6 +158,8 @@ func GPUDevices(ctx context.Context, runners []FilteredRunnerDiscovery) []ml.Dev wg.Wait() logutil.Trace("supported GPU library combinations", "supported", supported) + filterOutVulkanThatAreSupportedByOtherGPU(needsDelete) + // Mark for deletion any overlaps - favoring the library version that can cover all GPUs if possible filterOverlapByLibrary(supported, needsDelete) @@ -341,6 +343,37 @@ func GPUDevices(ctx context.Context, runners []FilteredRunnerDiscovery) []ml.Dev return devices } +func filterOutVulkanThatAreSupportedByOtherGPU(needsDelete []bool) { + // Filter out Vulkan devices that share a PCI ID with a non-Vulkan device that is not marked for deletion + for i := range devices { + if devices[i].Library != "Vulkan" || needsDelete[i] { + continue + } + if devices[i].PCIID == "" { + continue + } + for j := range devices { + if i == j { + continue + } + if devices[j].PCIID == "" { + continue + } + if devices[j].PCIID == devices[i].PCIID && devices[j].Library != "Vulkan" && !needsDelete[j] { + needsDelete[i] = true + slog.Debug("dropping Vulkan duplicate by PCI ID", + "vulkan_id", devices[i].ID, + "vulkan_libdir", devices[i].LibraryPath[len(devices[i].LibraryPath)-1], + "pci_id", devices[i].PCIID, + "kept_library", devices[j].Library, + "kept_id", devices[j].ID, + ) + break + } + } + } +} + func filterOverlapByLibrary(supported map[string]map[string]map[string]int, needsDelete []bool) { // For multi-GPU systems, use the newest version that supports all the GPUs for _, byLibDirs := range supported { From 8bb01725554bc3c212efe03a07e55ef410e23fe5 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Tue, 7 Oct 2025 02:36:02 +0200 Subject: [PATCH 171/172] fixing deviceid compare --- discover/runner.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/discover/runner.go b/discover/runner.go index 64ac39a3e..1379446a4 100644 --- a/discover/runner.go +++ b/discover/runner.go @@ -274,7 +274,7 @@ func GPUDevices(ctx context.Context, runners []FilteredRunnerDiscovery) []ml.Dev devCheck: for _, dev := range deviceIDs { for i := range devices { - if dev.ID == devices[i].ID && dev.Library == devices[i].Library { + if dev == devices[i].DeviceID { if !updated[i] { skip = false break devCheck @@ -295,7 +295,7 @@ func GPUDevices(ctx context.Context, runners []FilteredRunnerDiscovery) []ml.Dev slog.Debug("existing runner discovery took", "duration", time.Since(start)) for _, u := range updatedDevices { for i := range devices { - if u.Library == devices[i].Library && u.ID == devices[i].ID { + if u.DeviceID == devices[i].DeviceID { updated[i] = true devices[i].FreeMemory = u.FreeMemory break @@ -319,7 +319,7 @@ func GPUDevices(ctx context.Context, runners []FilteredRunnerDiscovery) []ml.Dev updatedDevices := bootstrapDevices(ctx, []string{LibOllamaPath, dir}, nil) for _, u := range updatedDevices { for i := range devices { - if u.Library == devices[i].Library && u.ID == devices[i].ID { + if u.DeviceID == devices[i].DeviceID { updated[i] = true devices[i].FreeMemory = u.FreeMemory break From c14680095cee596bb514dc797b31c7c56c4f71bf Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Tue, 7 Oct 2025 12:01:38 +0200 Subject: [PATCH 172/172] Vulkan Fix FA coopmat1 invalid array indexing --- ...x-FA-coopmat1-invalid-array-indexing.patch | 30 +++++++++++++++++++ .../vulkan-shaders/flash_attn_cm1.comp | 4 +-- 2 files changed, 32 insertions(+), 2 deletions(-) create mode 100644 llama/patches/0029-vulkan-Fix-FA-coopmat1-invalid-array-indexing.patch diff --git a/llama/patches/0029-vulkan-Fix-FA-coopmat1-invalid-array-indexing.patch b/llama/patches/0029-vulkan-Fix-FA-coopmat1-invalid-array-indexing.patch new file mode 100644 index 000000000..af4085440 --- /dev/null +++ b/llama/patches/0029-vulkan-Fix-FA-coopmat1-invalid-array-indexing.patch @@ -0,0 +1,30 @@ +From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 +From: Jeff Bolz +Date: Fri, 3 Oct 2025 04:52:46 -0500 +Subject: [PATCH] vulkan: Fix FA coopmat1 invalid array indexing (#16365) + +When computing sinks, the cm1 shader was looping r from 0 to Br rather than +to rows_per_thread. I must have copied this from the scalar path (where it is +correct), and somehow it wasn't causing failures on current drivers. +--- + ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp | 4 ++-- + 1 file changed, 2 insertions(+), 2 deletions(-) + +diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +index e76dbb4de..0507df2d8 100644 +--- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp ++++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +@@ -358,8 +358,8 @@ void main() { + } + + if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) { +- [[unroll]] for (uint32_t r = 0; r < Br; ++r) { +- float sink = perElemOpGetSink(r, 0u, ACC_TYPE(0), iq2); ++ [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { ++ float sink = perElemOpGetSink(tile_row(r), 0u, ACC_TYPE(0), iq2); + + float ms = 1.0f; + float vs = 1.0f; +-- +2.51.0.windows.1 + diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index ddb1246e0..ef1ce0503 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -356,8 +356,8 @@ void main() { } if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) { - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - float sink = perElemOpGetSink(r, 0u, ACC_TYPE(0), iq2); + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + float sink = perElemOpGetSink(tile_row(r), 0u, ACC_TYPE(0), iq2); float ms = 1.0f; float vs = 1.0f;