2023-11-30 03:00:37 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								//go:build linux || windows
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								package gpu
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								/*
							 | 
						
					
						
							
								
									
										
										
										
											2023-12-14 09:26:47 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								#cgo linux LDFLAGS: -lrt -lpthread -ldl -lstdc++ -lm
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								#cgo windows LDFLAGS: -lpthread
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2023-11-30 03:00:37 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								#include "gpu_info.h"
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								*/
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								import "C"
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								import (
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									"fmt"
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									"log"
							 | 
						
					
						
							
								
									
										
										
										
											2023-12-24 03:35:44 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
									"runtime"
							 | 
						
					
						
							
								
									
										
										
										
											2023-11-30 03:00:37 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
									"sync"
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									"unsafe"
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									"github.com/jmorganca/ollama/api"
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								type handles struct {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									cuda *C.cuda_handle_t
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									rocm *C.rocm_handle_t
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								}
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								var gpuMutex sync.Mutex
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								var gpuHandles *handles = nil
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								// Note: gpuMutex must already be held
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								func initGPUHandles() {
							 | 
						
					
						
							
								
									
										
										
										
											2023-12-14 09:26:47 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
									// TODO - if the ollama build is CPU only, don't do these checks as they're irrelevant and confusing
							 | 
						
					
						
							
								
									
										
										
										
											2023-11-30 03:00:37 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
									log.Printf("Detecting GPU type")
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									gpuHandles = &handles{nil, nil}
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									var resp C.cuda_init_resp_t
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									C.cuda_init(&resp)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									if resp.err != nil {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
										log.Printf("CUDA not detected: %s", C.GoString(resp.err))
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
										C.free(unsafe.Pointer(resp.err))
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
										var resp C.rocm_init_resp_t
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
										C.rocm_init(&resp)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
										if resp.err != nil {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
											log.Printf("ROCm not detected: %s", C.GoString(resp.err))
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
											C.free(unsafe.Pointer(resp.err))
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
										} else {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
											log.Printf("Radeon GPU detected")
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
											rocm := resp.rh
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
											gpuHandles.rocm = &rocm
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
										}
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									} else {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
										log.Printf("Nvidia GPU detected")
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
										cuda := resp.ch
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
										gpuHandles.cuda = &cuda
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									}
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								}
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								func GetGPUInfo() GpuInfo {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									// TODO - consider exploring lspci (and equivalent on windows) to check for
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									// GPUs so we can report warnings if we see Nvidia/AMD but fail to load the libraries
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									gpuMutex.Lock()
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									defer gpuMutex.Unlock()
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									if gpuHandles == nil {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
										initGPUHandles()
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									}
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									var memInfo C.mem_info_t
							 | 
						
					
						
							
								
									
										
										
										
											2023-12-23 07:43:31 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
									resp := GpuInfo{}
							 | 
						
					
						
							
								
									
										
										
										
											2023-11-30 03:00:37 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
									if gpuHandles.cuda != nil {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
										C.cuda_check_vram(*gpuHandles.cuda, &memInfo)
							 | 
						
					
						
							
								
									
										
										
										
											2023-12-14 09:26:47 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
										if memInfo.err != nil {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
											log.Printf("error looking up CUDA GPU memory: %s", C.GoString(memInfo.err))
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
											C.free(unsafe.Pointer(memInfo.err))
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
										} else {
							 | 
						
					
						
							
								
									
										
										
										
											2023-12-24 03:35:44 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
											resp.Library = "cuda"
							 | 
						
					
						
							
								
									
										
										
										
											2023-12-14 09:26:47 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
										}
							 | 
						
					
						
							
								
									
										
										
										
											2023-11-30 03:00:37 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
									} else if gpuHandles.rocm != nil {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
										C.rocm_check_vram(*gpuHandles.rocm, &memInfo)
							 | 
						
					
						
							
								
									
										
										
										
											2023-12-14 09:26:47 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
										if memInfo.err != nil {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
											log.Printf("error looking up ROCm GPU memory: %s", C.GoString(memInfo.err))
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
											C.free(unsafe.Pointer(memInfo.err))
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
										} else {
							 | 
						
					
						
							
								
									
										
										
										
											2023-12-24 03:35:44 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
											resp.Library = "rocm"
							 | 
						
					
						
							
								
									
										
										
										
											2023-12-14 09:26:47 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
										}
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									}
							 | 
						
					
						
							
								
									
										
										
										
											2023-12-24 03:35:44 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
									if resp.Library == "" {
							 | 
						
					
						
							
								
									
										
										
										
											2023-11-30 03:00:37 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
										C.cpu_check_ram(&memInfo)
							 | 
						
					
						
							
								
									
										
										
										
											2023-12-21 02:36:01 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
										// In the future we may offer multiple CPU variants to tune CPU features
							 | 
						
					
						
							
								
									
										
										
										
											2023-12-24 03:35:44 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
										if runtime.GOOS == "windows" {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
											resp.Library = "cpu"
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
										} else {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
											resp.Library = "default"
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
										}
							 | 
						
					
						
							
								
									
										
										
										
											2023-11-30 03:00:37 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
									}
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									if memInfo.err != nil {
							 | 
						
					
						
							
								
									
										
										
										
											2023-12-14 09:26:47 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
										log.Printf("error looking up CPU memory: %s", C.GoString(memInfo.err))
							 | 
						
					
						
							
								
									
										
										
										
											2023-11-30 03:00:37 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
										C.free(unsafe.Pointer(memInfo.err))
							 | 
						
					
						
							
								
									
										
										
										
											2023-12-14 09:26:47 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
										return resp
							 | 
						
					
						
							
								
									
										
										
										
											2023-11-30 03:00:37 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
									}
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									resp.FreeMemory = uint64(memInfo.free)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									resp.TotalMemory = uint64(memInfo.total)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									return resp
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								}
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2023-12-23 07:43:31 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								func getCPUMem() (memInfo, error) {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									var ret memInfo
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									var info C.mem_info_t
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									C.cpu_check_ram(&info)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									if info.err != nil {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
										defer C.free(unsafe.Pointer(info.err))
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
										return ret, fmt.Errorf(C.GoString(info.err))
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									}
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									ret.FreeMemory = uint64(info.free)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									ret.TotalMemory = uint64(info.total)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									return ret, nil
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								}
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2023-11-30 03:00:37 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								func CheckVRAM() (int64, error) {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									gpuInfo := GetGPUInfo()
							 | 
						
					
						
							
								
									
										
										
										
											2023-12-24 03:35:44 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
									if gpuInfo.FreeMemory > 0 && (gpuInfo.Library == "cuda" || gpuInfo.Library == "rocm") {
							 | 
						
					
						
							
								
									
										
										
										
											2023-11-30 03:00:37 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
										return int64(gpuInfo.FreeMemory), nil
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									}
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									return 0, fmt.Errorf("no GPU detected") // TODO - better handling of CPU based memory determiniation
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								}
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								func NumGPU(numLayer, fileSizeBytes int64, opts api.Options) int {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									if opts.NumGPU != -1 {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
										return opts.NumGPU
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									}
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									info := GetGPUInfo()
							 | 
						
					
						
							
								
									
										
										
										
											2023-12-24 03:35:44 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
									if info.Library == "cpu" || info.Library == "default" {
							 | 
						
					
						
							
								
									
										
										
										
											2023-11-30 03:00:37 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
										return 0
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									}
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									/*
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
										Calculate bytes per layer, this will roughly be the size of the model file divided by the number of layers.
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
										We can store the model weights and the kv cache in vram,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
										to enable kv chache vram storage add two additional layers to the number of layers retrieved from the model file.
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									*/
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									bytesPerLayer := uint64(fileSizeBytes / numLayer)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									// 75% of the absolute max number of layers we can fit in available VRAM, off-loading too many layers to the GPU can cause OOM errors
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									layers := int(info.FreeMemory/bytesPerLayer) * 3 / 4
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2023-12-24 03:35:44 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
									log.Printf("%d MB VRAM available, loading up to %d %s GPU layers out of %d", info.FreeMemory/(1024*1024), layers, info.Library, numLayer)
							 | 
						
					
						
							
								
									
										
										
										
											2023-11-30 03:00:37 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									return layers
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								}
							 |