kvcache: Don't shift empty batches

When we context shift, we delete half the context and apply RoPE
with an offset to the other half. We used to RoPE across the entire
context in a single pass with a zero offset for the deleted
section. With the change to shifting in batches, we can skip any
batches where all of the offsets would be zero. This typically
reduces the number of operations by half.
This commit is contained in:
Jesse Gross 2025-07-28 11:29:25 -07:00 committed by Jesse Gross
parent 3515cc377c
commit c116a7523d
1 changed files with 17 additions and 4 deletions

View File

@ -646,18 +646,31 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error {
seqRange := c.cellRanges[seq]
for start := seqRange.min; start <= seqRange.max; start += c.maxBatch {
ctx := c.backend.NewContext()
size := min(seqRange.max-start+1, c.maxBatch)
offsets := make([]int32, size)
var batchFirst, batchLast int
batchFirst = -1
for i := range offsets {
cell := c.cells[start+i]
if slices.Contains(cell.sequences, seq) && cell.pos >= beginIndex {
offsets[i] = offset
if batchFirst < 0 {
batchFirst = i
}
batchLast = i
}
}
if batchFirst < 0 {
continue
}
offsets = offsets[batchFirst : batchLast+1]
ctx := c.backend.NewContext()
kShift := ctx.Input().FromIntSlice(offsets, len(offsets))
for i, key := range c.keys {
@ -669,10 +682,10 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error {
numKVHeads := key.Dim(1)
rowSize := key.Stride(2)
key = key.View(ctx, rowSize*start,
key = key.View(ctx, rowSize*(start+batchFirst),
kHeadDim, key.Stride(1),
numKVHeads, key.Stride(2),
size,
len(offsets),
)
roped, err := c.shiftFn(ctx, i, key, kShift)