diff --git a/envconfig/config.go b/envconfig/config.go index 09243ab95..dd25ed80d 100644 --- a/envconfig/config.go +++ b/envconfig/config.go @@ -145,8 +145,8 @@ func Remotes() []string { return r } -func Bool(k string) func() bool { - return func() bool { +func BoolWithDefault(k string) func(defaultValue bool) bool { + return func(defaultValue bool) bool { if s := Var(k); s != "" { b, err := strconv.ParseBool(s) if err != nil { @@ -156,7 +156,14 @@ func Bool(k string) func() bool { return b } - return false + return defaultValue + } +} + +func Bool(k string) func() bool { + withDefault := BoolWithDefault(k) + return func() bool { + return withDefault(false) } } @@ -177,7 +184,7 @@ func LogLevel() slog.Level { var ( // FlashAttention enables the experimental flash attention feature. - FlashAttention = Bool("OLLAMA_FLASH_ATTENTION") + FlashAttention = BoolWithDefault("OLLAMA_FLASH_ATTENTION") // KvCacheType is the quantization type for the K/V cache. KvCacheType = String("OLLAMA_KV_CACHE_TYPE") // NoHistory disables readline history. @@ -263,7 +270,7 @@ type EnvVar struct { func AsMap() map[string]EnvVar { ret := map[string]EnvVar{ "OLLAMA_DEBUG": {"OLLAMA_DEBUG", LogLevel(), "Show additional debug information (e.g. OLLAMA_DEBUG=1)"}, - "OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention(), "Enabled flash attention"}, + "OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention(false), "Enabled flash attention"}, "OLLAMA_KV_CACHE_TYPE": {"OLLAMA_KV_CACHE_TYPE", KvCacheType(), "Quantization type for the K/V cache (default: f16)"}, "OLLAMA_GPU_OVERHEAD": {"OLLAMA_GPU_OVERHEAD", GpuOverhead(), "Reserve a portion of VRAM per GPU (bytes)"}, "OLLAMA_HOST": {"OLLAMA_HOST", Host(), "IP Address for the ollama server (default 127.0.0.1:11434)"}, diff --git a/llm/memory.go b/llm/memory.go index 6f192b35d..4a54b3312 100644 --- a/llm/memory.go +++ b/llm/memory.go @@ -195,7 +195,7 @@ func estimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin slog.Warn("model missing blk.0 layer size") } - useFlashAttention := (envconfig.FlashAttention() || f.FlashAttention()) && + useFlashAttention := envconfig.FlashAttention(f.FlashAttention()) && (discover.GpuInfoList)(gpus).FlashAttentionSupported() && f.SupportsFlashAttention() diff --git a/llm/server.go b/llm/server.go index 63ad6085c..a7e8049b7 100644 --- a/llm/server.go +++ b/llm/server.go @@ -196,14 +196,10 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a loadRequest.ProjectorPath = projectors[0] } + fa := envconfig.FlashAttention(f.FlashAttention()) + // This will disable flash attention unless all GPUs on the system support it, even if we end up selecting a subset // that can handle it. - fa := envconfig.FlashAttention() - if f.FlashAttention() { - slog.Info("model wants flash attention") - fa = true - } - if fa && !gpus.FlashAttentionSupported() { slog.Warn("flash attention enabled but not supported by gpu") fa = false