| 
									
										
										
										
											2024-12-18 11:59:41 +08:00
										 |  |  | package kvcache | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import ( | 
					
						
							| 
									
										
										
										
											2025-02-23 13:34:10 +08:00
										 |  |  | 	"fmt" | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-18 11:59:41 +08:00
										 |  |  | 	"github.com/ollama/ollama/ml" | 
					
						
							|  |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // Encoder cache stores K and V tensors that are position independent
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | // The tensors can be of any shape and will be returned as they were stored
 | 
					
						
							|  |  |  | // The mask is currently always nil
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | // Not currently safe for multiple sequences
 | 
					
						
							|  |  |  | type EncoderCache struct { | 
					
						
							| 
									
										
										
										
											2025-02-23 13:34:10 +08:00
										 |  |  | 	// config controls mostly backend-specific optimizations
 | 
					
						
							|  |  |  | 	config *ml.CacheConfig | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-18 11:59:41 +08:00
										 |  |  | 	// ** current forward pass **
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// the active layer for Get and Put
 | 
					
						
							|  |  |  | 	curLayer int | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// if something is stored during this pass, this
 | 
					
						
							|  |  |  | 	// will be the position (but there is no guarantee
 | 
					
						
							|  |  |  | 	// anything will be stored)
 | 
					
						
							|  |  |  | 	curPos int32 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// ** cache metadata **
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// was something stored in the cache?
 | 
					
						
							|  |  |  | 	encoderCached bool | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// position of the cached data
 | 
					
						
							|  |  |  | 	encoderPos int32 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// ** cache data storage **
 | 
					
						
							| 
									
										
										
										
											2025-02-26 04:57:49 +08:00
										 |  |  | 	backend      ml.Backend | 
					
						
							|  |  |  | 	ctxs         map[int]ml.Context | 
					
						
							|  |  |  | 	keys, values map[int]ml.Tensor | 
					
						
							| 
									
										
										
										
											2024-12-18 11:59:41 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func NewEncoderCache() *EncoderCache { | 
					
						
							| 
									
										
										
										
											2025-02-26 04:57:49 +08:00
										 |  |  | 	return &EncoderCache{ | 
					
						
							|  |  |  | 		ctxs:   make(map[int]ml.Context), | 
					
						
							|  |  |  | 		keys:   make(map[int]ml.Tensor), | 
					
						
							|  |  |  | 		values: make(map[int]ml.Tensor), | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2024-12-18 11:59:41 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) { | 
					
						
							| 
									
										
										
										
											2025-02-23 13:34:10 +08:00
										 |  |  | 	if c.config == nil { | 
					
						
							|  |  |  | 		var config ml.CacheConfig | 
					
						
							|  |  |  | 		if cc, ok := backend.(ml.BackendCacheConfig); ok { | 
					
						
							|  |  |  | 			config = cc.CacheConfig() | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 		c.config = &config | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if c.config.CachePadding != 0 && c.config.CachePadding != 1 { | 
					
						
							|  |  |  | 		panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding)) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-26 04:57:49 +08:00
										 |  |  | 	c.backend = backend | 
					
						
							| 
									
										
										
										
											2024-12-18 11:59:41 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-23 13:34:10 +08:00
										 |  |  | func (c *EncoderCache) SetConfig(config ml.CacheConfig) { | 
					
						
							|  |  |  | 	if c.config != nil { | 
					
						
							|  |  |  | 		panic("config cannot be changed after being previously set, either by the model or backend") | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	c.config = &config | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-18 11:59:41 +08:00
										 |  |  | func (c *EncoderCache) Close() { | 
					
						
							| 
									
										
										
										
											2025-02-26 04:57:49 +08:00
										 |  |  | 	for _, ctx := range c.ctxs { | 
					
						
							|  |  |  | 		ctx.Close() | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2024-12-18 11:59:41 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (c *EncoderCache) StartForward(ctx ml.Context, positions []int32, seqs []int) error { | 
					
						
							|  |  |  | 	// The image is always in the first position
 | 
					
						
							|  |  |  | 	c.curPos = positions[0] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (c *EncoderCache) SetLayer(layer int) { | 
					
						
							|  |  |  | 	c.curLayer = layer | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (c *EncoderCache) EncoderCached() bool { | 
					
						
							|  |  |  | 	return c.encoderCached | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (c *EncoderCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) { | 
					
						
							|  |  |  | 	return c.keys[c.curLayer], c.values[c.curLayer], nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) { | 
					
						
							|  |  |  | 	c.encoderPos = c.curPos | 
					
						
							|  |  |  | 	c.encoderCached = true | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-23 13:34:10 +08:00
										 |  |  | 	if c.config.PermutedV { | 
					
						
							|  |  |  | 		value = value.Permute(ctx, 1, 2, 0, 3) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-26 04:57:49 +08:00
										 |  |  | 	if _, ok := c.ctxs[c.curLayer]; !ok { | 
					
						
							| 
									
										
										
										
											2025-02-26 08:06:32 +08:00
										 |  |  | 		c.ctxs[c.curLayer] = c.backend.NewContextSize(2).Layer(c.curLayer) | 
					
						
							| 
									
										
										
										
											2025-02-26 04:57:49 +08:00
										 |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if _, ok := c.keys[c.curLayer]; !ok { | 
					
						
							|  |  |  | 		c.keys[c.curLayer] = c.ctxs[c.curLayer].Empty(key.DType(), key.Shape()...) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if _, ok := c.values[c.curLayer]; !ok { | 
					
						
							|  |  |  | 		c.values[c.curLayer] = c.ctxs[c.curLayer].Empty(value.DType(), value.Shape()...) | 
					
						
							| 
									
										
										
										
											2024-12-18 11:59:41 +08:00
										 |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-22 03:57:08 +08:00
										 |  |  | 	ctx.Forward( | 
					
						
							|  |  |  | 		key.Copy(ctx, c.keys[c.curLayer]), | 
					
						
							|  |  |  | 		value.Copy(ctx, c.values[c.curLayer]), | 
					
						
							|  |  |  | 	) | 
					
						
							| 
									
										
										
										
											2024-12-18 11:59:41 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (c *EncoderCache) CopyPrefix(srcSeq, dstSeq int, len int32) { | 
					
						
							|  |  |  | 	panic("encoder cache does not support multiple sequences") | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (c *EncoderCache) Remove(seq int, beginIndex, endIndex int32) error { | 
					
						
							|  |  |  | 	if c.encoderPos >= beginIndex && c.encoderPos < endIndex { | 
					
						
							|  |  |  | 		c.encoderCached = false | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return nil | 
					
						
							|  |  |  | } |