mirror of https://github.com/ollama/ollama.git
kvcache: Don't shift empty batches
When we context shift, we delete half the context and apply RoPE with an offset to the other half. We used to RoPE across the entire context in a single pass with a zero offset for the deleted section. With the change to shifting in batches, we can skip any batches where all of the offsets would be zero. This typically reduces the number of operations by half.
This commit is contained in:
parent
3515cc377c
commit
c116a7523d
|
@ -646,18 +646,31 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error {
|
|||
seqRange := c.cellRanges[seq]
|
||||
|
||||
for start := seqRange.min; start <= seqRange.max; start += c.maxBatch {
|
||||
ctx := c.backend.NewContext()
|
||||
|
||||
size := min(seqRange.max-start+1, c.maxBatch)
|
||||
offsets := make([]int32, size)
|
||||
|
||||
var batchFirst, batchLast int
|
||||
|
||||
batchFirst = -1
|
||||
for i := range offsets {
|
||||
cell := c.cells[start+i]
|
||||
|
||||
if slices.Contains(cell.sequences, seq) && cell.pos >= beginIndex {
|
||||
offsets[i] = offset
|
||||
if batchFirst < 0 {
|
||||
batchFirst = i
|
||||
}
|
||||
batchLast = i
|
||||
}
|
||||
}
|
||||
|
||||
if batchFirst < 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
offsets = offsets[batchFirst : batchLast+1]
|
||||
|
||||
ctx := c.backend.NewContext()
|
||||
kShift := ctx.Input().FromIntSlice(offsets, len(offsets))
|
||||
|
||||
for i, key := range c.keys {
|
||||
|
@ -669,10 +682,10 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error {
|
|||
numKVHeads := key.Dim(1)
|
||||
rowSize := key.Stride(2)
|
||||
|
||||
key = key.View(ctx, rowSize*start,
|
||||
key = key.View(ctx, rowSize*(start+batchFirst),
|
||||
kHeadDim, key.Stride(1),
|
||||
numKVHeads, key.Stride(2),
|
||||
size,
|
||||
len(offsets),
|
||||
)
|
||||
|
||||
roped, err := c.shiftFn(ctx, i, key, kShift)
|
||||
|
|
Loading…
Reference in New Issue