ml: Panic rather than return error on tensor allocation failure

FromFloatSlice and FromIntSlice return an error if the shape doesn't
match the passed data or if memory can't be allocated. Since these
are inputs, the memory being allocated is system memory rather than VRAM.

In many cases, the caller can't really handle the error and panics.

Empty and Zeros directly panic if they can't allocate memory.

This makes things consistent by panicing for the first two cases,
removing a fair amount of error handling code. This is also consistent
with how Go typically handles these situations.
This commit is contained in:
Jesse Gross 2025-05-19 10:43:56 -07:00 committed by Jesse Gross
parent 73d6a82cce
commit 1f371ea92f
20 changed files with 68 additions and 209 deletions

View File

@ -211,10 +211,9 @@ func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) e
c.curCellRange.max = len(c.cells) - 1 c.curCellRange.max = len(c.cells) - 1
} }
var err error c.curMask = c.buildMask(ctx)
c.curMask, err = c.buildMask(ctx)
return err return nil
} }
func newRange() cellRange { func newRange() cellRange {
@ -297,7 +296,7 @@ func roundUp(length, pad int) int {
// Builds a mask of history x batch indicating whether for each token in the batch the // Builds a mask of history x batch indicating whether for each token in the batch the
// token in the history should apply. This is based on both the sequence and causality (the // token in the history should apply. This is based on both the sequence and causality (the
// position of the history is not ahead of the token in the batch). // position of the history is not ahead of the token in the batch).
func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) { func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
// Align and pad the two dimensions as required by the backend // Align and pad the two dimensions as required by the backend
batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding) batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding)
@ -325,10 +324,7 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
mask[i] = float32(math.Inf(-1)) mask[i] = float32(math.Inf(-1))
} }
maskTensor, err := ctx.Input().FromFloatSlice(mask, length, batchSize) maskTensor := ctx.Input().FromFloatSlice(mask, length, batchSize)
if err != nil {
return nil, err
}
if c.config.MaskDType != ml.DTypeF32 { if c.config.MaskDType != ml.DTypeF32 {
out := ctx.Input().Empty(c.config.MaskDType, maskTensor.Shape()...) out := ctx.Input().Empty(c.config.MaskDType, maskTensor.Shape()...)
@ -336,7 +332,7 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
maskTensor = out maskTensor = out
} }
return maskTensor, nil return maskTensor
} }
func (c *Causal) moveCells(ctx ml.Context, src, dst, length int) { func (c *Causal) moveCells(ctx ml.Context, src, dst, length int) {
@ -491,12 +487,7 @@ func (c *Causal) SetCausal(ctx ml.Context, opts CausalOptions) {
if !slices.Equal(c.opts.Except, opts.Except) { if !slices.Equal(c.opts.Except, opts.Except) {
c.opts = opts c.opts = opts
if ctx != nil { if ctx != nil {
var err error c.curMask = c.buildMask(ctx)
c.curMask, err = c.buildMask(ctx)
if err != nil {
// This error should never occur because we have previously built a mask with the same shape
panic(fmt.Errorf("SetCausal: %w", err))
}
} }
} }
} }
@ -652,10 +643,7 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error {
} }
} }
kShift, err := ctx.Input().FromIntSlice(offsets, len(offsets)) kShift := ctx.Input().FromIntSlice(offsets, len(offsets))
if err != nil {
return err
}
for i, key := range c.keys { for i, key := range c.keys {
if key == nil { if key == nil {

View File

@ -344,7 +344,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
} }
cache.SetLayer(0) cache.SetLayer(0)
tensor, _ := context.FromFloatSlice(test.in, test.inShape...) tensor := context.FromFloatSlice(test.in, test.inShape...)
cache.Put(context, tensor, tensor) cache.Put(context, tensor, tensor)
out, _, mask := cache.Get(context) out, _, mask := cache.Get(context)
@ -386,7 +386,7 @@ func TestCanResume(t *testing.T) {
} }
cache.SetLayer(0) cache.SetLayer(0)
tensor, _ := context.FromFloatSlice([]float32{1, 2, 3, 4}, 1, 1, 4) tensor := context.FromFloatSlice([]float32{1, 2, 3, 4}, 1, 1, 4)
cache.Put(context, tensor, tensor) cache.Put(context, tensor, tensor)
// with window size 4, nothing has slid out of the window yet // with window size 4, nothing has slid out of the window yet
@ -413,7 +413,7 @@ func TestCanResume(t *testing.T) {
} }
cache.SetLayer(0) cache.SetLayer(0)
tensor, _ = context.FromFloatSlice([]float32{5, 6}, 1, 1, 2) tensor = context.FromFloatSlice([]float32{5, 6}, 1, 1, 2)
cache.Put(context, tensor, tensor) cache.Put(context, tensor, tensor)
// only the latest position has overlapping windows // only the latest position has overlapping windows
@ -470,24 +470,24 @@ func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
return c.Empty(dtype, shape...) return c.Empty(dtype, shape...)
} }
func (c *testContext) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) { func (c *testContext) FromFloatSlice(s []float32, shape ...int) ml.Tensor {
t := c.Empty(ml.DTypeF32, shape...).(*testTensor) t := c.Empty(ml.DTypeF32, shape...).(*testTensor)
copy(t.data, s) copy(t.data, s)
return t, nil return t
} }
func (c *testContext) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) { func (c *testContext) FromIntSlice(s []int32, shape ...int) ml.Tensor {
f := make([]float32, len(s)) f := make([]float32, len(s))
for i := range f { for i := range f {
f[i] = float32(s[i]) f[i] = float32(s[i])
} }
out, _ := c.FromFloatSlice(f, shape...) out := c.FromFloatSlice(f, shape...)
out.(*testTensor).dtype = ml.DTypeI32 out.(*testTensor).dtype = ml.DTypeI32
return out, nil return out
} }
func (c *testContext) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor { func (c *testContext) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor {
@ -496,7 +496,7 @@ func (c *testContext) Arange(start, stop, step float32, dtype ml.DType) ml.Tenso
s = append(s, i) s = append(s, i)
} }
out, _ := c.FromFloatSlice(s, len(s)) out := c.FromFloatSlice(s, len(s))
out.(*testTensor).dtype = dtype out.(*testTensor).dtype = dtype
return out return out
} }

View File

@ -171,8 +171,8 @@ func NewBackend(modelPath string, params BackendParams) (Backend, error) {
type Context interface { type Context interface {
Empty(dtype DType, shape ...int) Tensor Empty(dtype DType, shape ...int) Tensor
Zeros(dtype DType, shape ...int) Tensor Zeros(dtype DType, shape ...int) Tensor
FromFloatSlice(s []float32, shape ...int) (Tensor, error) FromFloatSlice(s []float32, shape ...int) Tensor
FromIntSlice(s []int32, shape ...int) (Tensor, error) FromIntSlice(s []int32, shape ...int) Tensor
// Arange creates a 1D tensor with values within an interval (start, stop] increased by step. // Arange creates a 1D tensor with values within an interval (start, stop] increased by step.
Arange(start, stop, step float32, dtype DType) Tensor Arange(start, stop, step float32, dtype DType) Tensor

View File

@ -729,11 +729,11 @@ func (c *Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
return t return t
} }
func checkShape[S ~[]E, E any](s S, shape ...int) error { func checkShape[S ~[]E, E any](s S, shape ...int) {
n := len(s) n := len(s)
if n == 0 { if n == 0 {
return nil return
} }
for _, v := range shape { for _, v := range shape {
@ -741,16 +741,12 @@ func checkShape[S ~[]E, E any](s S, shape ...int) error {
} }
if n != 1 { if n != 1 {
return fmt.Errorf("invalid shape: %v", shape) panic(fmt.Errorf("invalid shape: %v", shape))
} }
return nil
} }
func (c *Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) { func (c *Context) FromFloatSlice(s []float32, shape ...int) ml.Tensor {
if err := checkShape(s, shape...); err != nil { checkShape(s, shape...)
return nil, err
}
t := c.newTensor(ml.DTypeF32, shape) t := c.newTensor(ml.DTypeF32, shape)
@ -758,13 +754,11 @@ func (c *Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t)) C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
} }
return t, nil return t
} }
func (c *Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) { func (c *Context) FromIntSlice(s []int32, shape ...int) ml.Tensor {
if err := checkShape(s, shape...); err != nil { checkShape(s, shape...)
return nil, err
}
t := c.newTensor(ml.DTypeI32, shape) t := c.newTensor(ml.DTypeI32, shape)
@ -772,7 +766,7 @@ func (c *Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t)) C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
} }
return t, nil return t
} }
func (c Context) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor { func (c Context) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor {
@ -790,12 +784,7 @@ func (c Context) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor {
arange = append(arange, int32(i)) arange = append(arange, int32(i))
} }
t, err := c.Input().FromIntSlice(arange, len(arange)) return c.Input().FromIntSlice(arange, len(arange))
if err != nil {
panic(err)
}
return t
default: default:
panic("unsupported dtype for arange") panic("unsupported dtype for arange")
} }

View File

@ -287,11 +287,7 @@ func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Ten
return nil, errors.New("batch size cannot be less than 1") return nil, errors.New("batch size cannot be less than 1")
} }
var err error batch.Inputs = ctx.Input().FromIntSlice(inputs, len(inputs))
batch.Inputs, err = ctx.Input().FromIntSlice(inputs, len(inputs))
if err != nil {
return nil, err
}
cache := m.Config().Cache cache := m.Config().Cache
if cache != nil { if cache != nil {

View File

@ -175,15 +175,8 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
} }
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
if err != nil { outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
return nil, err
}
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
if err != nil {
return nil, err
}
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs) hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize))) hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize)))

View File

@ -101,14 +101,11 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
return nil, err return nil, err
} }
pixelValues, err := ctx.Input().FromFloatSlice(f32s, pixelValues := ctx.Input().FromFloatSlice(f32s,
m.ImageProcessor.imageSize, m.ImageProcessor.imageSize,
m.ImageProcessor.imageSize, m.ImageProcessor.imageSize,
m.ImageProcessor.numChannels, m.ImageProcessor.numChannels,
) )
if err != nil {
return nil, err
}
visionOutputs := m.VisionModel.Forward(ctx, pixelValues) visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
visionOutputs = m.MultiModalProjector.Forward(ctx, visionOutputs, m.imageSize, m.patchSize, m.VisionModel.eps) visionOutputs = m.MultiModalProjector.Forward(ctx, visionOutputs, m.imageSize, m.patchSize, m.VisionModel.eps)
@ -144,15 +141,8 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
} }
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
if err != nil { outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
return nil, err
}
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
if err != nil {
return nil, err
}
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
} }

View File

@ -142,10 +142,7 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positions, outputs ml.Tenso
} }
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
if err != nil {
return nil, err
}
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs) hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
@ -154,10 +151,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
var outputs ml.Tensor var outputs ml.Tensor
if i == len(m.Layers)-1 { if i == len(m.Layers)-1 {
outputs, err = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
if err != nil {
return nil, err
}
} }
hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, m.Options) hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, m.Options)

View File

@ -77,10 +77,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
return nil, err return nil, err
} }
tilesLocal, err := ctx.Input().FromFloatSlice(pixelsLocal, size.X, size.Y, m.numChannels) tilesLocal := ctx.Input().FromFloatSlice(pixelsLocal, size.X, size.Y, m.numChannels)
if err != nil {
return nil, err
}
ratioW, ratioH := size.X/m.imageSize, size.Y/m.imageSize ratioW, ratioH := size.X/m.imageSize, size.Y/m.imageSize
@ -91,11 +88,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
pixelValues := tilesLocal pixelValues := tilesLocal
if len(pixelsGlobal) > 0 { if len(pixelsGlobal) > 0 {
tilesGlobal, err := ctx.Input().FromFloatSlice(pixelsGlobal, m.imageSize, m.imageSize, m.numChannels) tilesGlobal := ctx.Input().FromFloatSlice(pixelsGlobal, m.imageSize, m.imageSize, m.numChannels)
if err != nil {
return nil, err
}
pixelValues = pixelValues.Concat(ctx, tilesGlobal, 3) pixelValues = pixelValues.Concat(ctx, tilesGlobal, 3)
} }
@ -182,15 +175,8 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
} }
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
if err != nil { outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
return nil, err
}
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
if err != nil {
return nil, err
}
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
} }

View File

@ -223,11 +223,7 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
scales[i] = float32(math.Log(math.Floor(((float64(p)+1.0)/float64(m.attentionFloorScale))+1.0))*m.attentionScale + 1.0) scales[i] = float32(math.Log(math.Floor(((float64(p)+1.0)/float64(m.attentionFloorScale))+1.0))*m.attentionScale + 1.0)
} }
var err error attentionScales = ctx.Input().FromFloatSlice(scales, 1, 1, len(scales))
attentionScales, err = ctx.Input().FromFloatSlice(scales, 1, 1, len(scales))
if err != nil {
panic(err)
}
} }
for i, layer := range m.Layers { for i, layer := range m.Layers {

View File

@ -245,10 +245,7 @@ func (m *VisionModel) rotaryEmbedding(ctx ml.Context) (ml.Tensor, ml.Tensor) {
} }
} }
ropeFreqs, err := ctx.Input().FromFloatSlice(freqs, freqDim/2, numPatches, 2) ropeFreqs := ctx.Input().FromFloatSlice(freqs, freqDim/2, numPatches, 2)
if err != nil {
panic(err)
}
ropeFreqs = ropeFreqs.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) ropeFreqs = ropeFreqs.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
ropeFreqs = ropeFreqs.Reshape(ctx, freqDim, 1, numPatches) ropeFreqs = ropeFreqs.Reshape(ctx, freqDim, 1, numPatches)

View File

@ -114,10 +114,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
return nil, err return nil, err
} }
pixelValues, err := ctx.Input().FromFloatSlice(f32s, size.X, size.Y, m.ImageProcessor.numChannels) pixelValues := ctx.Input().FromFloatSlice(f32s, size.X, size.Y, m.ImageProcessor.numChannels)
if err != nil {
return nil, err
}
visionOutputs := m.VisionModel.Forward(ctx, pixelValues) visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
features, size := m.MultiModalProjector.Forward(ctx, visionOutputs, size) features, size := m.MultiModalProjector.Forward(ctx, visionOutputs, size)
@ -161,15 +158,8 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
} }
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
if err != nil { outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
return nil, err
}
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
if err != nil {
return nil, err
}
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
} }

View File

@ -110,15 +110,8 @@ func (m *VisionModel) positionalEmbedding(ctx ml.Context, positionIDs ml.Tensor)
} }
} }
h, err := ctx.Input().FromFloatSlice(frequenciesHeight, maxPatchesPerSide, frequencies/2) h := ctx.Input().FromFloatSlice(frequenciesHeight, maxPatchesPerSide, frequencies/2)
if err != nil { w := ctx.Input().FromFloatSlice(frequenciesWidth, maxPatchesPerSide, frequencies/2)
panic(err)
}
w, err := ctx.Input().FromFloatSlice(frequenciesWidth, maxPatchesPerSide, frequencies/2)
if err != nil {
panic(err)
}
h = h.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) h = h.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
w = w.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) w = w.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
@ -151,10 +144,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
} }
} }
positionIDs, err := ctx.Input().FromIntSlice(positions, len(positions)) positionIDs := ctx.Input().FromIntSlice(positions, len(positions))
if err != nil {
panic(err)
}
positionEmbedding := m.positionalEmbedding(ctx, positionIDs) positionEmbedding := m.positionalEmbedding(ctx, positionIDs)
cos, sin := positionEmbedding.Cos(ctx), positionEmbedding.Sin(ctx) cos, sin := positionEmbedding.Cos(ctx), positionEmbedding.Sin(ctx)

View File

@ -80,15 +80,8 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
f32s = f32s[:m.imageSize*m.imageSize*m.numChannels*m.maxNumTiles] f32s = f32s[:m.imageSize*m.imageSize*m.numChannels*m.maxNumTiles]
} }
pixelValues, err := ctx.Input().FromFloatSlice(f32s, m.imageSize, m.imageSize, m.numChannels, m.maxNumTiles) pixelValues := ctx.Input().FromFloatSlice(f32s, m.imageSize, m.imageSize, m.numChannels, m.maxNumTiles)
if err != nil { aspectRatio := ctx.Input().FromIntSlice([]int32{int32(ratio.rank)}, 1)
return nil, err
}
aspectRatio, err := ctx.Input().FromIntSlice([]int32{int32(ratio.rank)}, 1)
if err != nil {
return nil, err
}
positionIDs := ctx.Arange(0, 1601, 1, ml.DTypeI32) positionIDs := ctx.Arange(0, 1601, 1, ml.DTypeI32)
crossAttentionStates := m.VisionModel.Forward(ctx, pixelValues, positionIDs, aspectRatio) crossAttentionStates := m.VisionModel.Forward(ctx, pixelValues, positionIDs, aspectRatio)
@ -113,15 +106,8 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
crossAttentionStates = batch.Multimodal[len(batch.Multimodal)-1].Multimodal[0].Tensor crossAttentionStates = batch.Multimodal[len(batch.Multimodal)-1].Multimodal[0].Tensor
} }
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
if err != nil { outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
return nil, err
}
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
if err != nil {
return nil, err
}
// TODO: attention mask, cross attention mask // TODO: attention mask, cross attention mask
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil

View File

@ -100,10 +100,7 @@ type Model struct {
// Forward implements model.Model. // Forward implements model.Model.
func (m Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { func (m Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
if err != nil {
return nil, err
}
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs) hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
@ -112,10 +109,7 @@ func (m Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
var outputs ml.Tensor var outputs ml.Tensor
if i == len(m.Layers)-1 { if i == len(m.Layers)-1 {
outputs, err = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
if err != nil {
return nil, err
}
} }
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, &m.Options) hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, &m.Options)

View File

@ -69,10 +69,7 @@ func (m *Model) PixelValues(ctx ml.Context, multimodalData []byte) (ml.Tensor, *
m.ImageProcessor.patchSize * m.ImageProcessor.patchSize m.ImageProcessor.patchSize * m.ImageProcessor.patchSize
numPatches := grid.Temporal * grid.Height * grid.Width numPatches := grid.Temporal * grid.Height * grid.Width
pixelValues, err := ctx.Input().FromFloatSlice(f32s, patchDim, numPatches) pixelValues := ctx.Input().FromFloatSlice(f32s, patchDim, numPatches)
if err != nil {
return nil, nil, fmt.Errorf("failed to create tensor from image: %w", err)
}
return pixelValues, grid, nil return pixelValues, grid, nil
} }
@ -142,15 +139,8 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
} }
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
if err != nil { outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
return nil, err
}
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
if err != nil {
return nil, err
}
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache) return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache)
} }

View File

@ -1,7 +1,6 @@
package qwen25vl package qwen25vl
import ( import (
"fmt"
"math" "math"
"slices" "slices"
@ -44,10 +43,8 @@ func blockDiagonalMask(ctx ml.Context, seqLength int, bounds []int, numHeads int
} }
} }
mask, err := ctx.Input().FromFloatSlice(flat, seqLength, seqLength) mask := ctx.Input().FromFloatSlice(flat, seqLength, seqLength)
if err != nil {
panic(err)
}
// Reshape to match [seqLength, seqLength, 1] for broadcasting // Reshape to match [seqLength, seqLength, 1] for broadcasting
mask = mask.Reshape(ctx, seqLength, seqLength, 1) mask = mask.Reshape(ctx, seqLength, seqLength, 1)
@ -303,10 +300,7 @@ func (m *VisionModel) WindowIndex(ctx ml.Context, grid *Grid) (ml.Tensor, []int)
} }
} }
t, err := ctx.Input().FromIntSlice(index, len(index)) t := ctx.Input().FromIntSlice(index, len(index))
if err != nil {
panic(err)
}
return t, bounds return t, bounds
} }
@ -326,10 +320,7 @@ func (m *VisionModel) PositionalEmbedding(ctx ml.Context, grid *Grid) ml.Tensor
freqVals[i*freq+j] = float32(i) / float32(math.Pow(theta, float64(j*2)/float64(dim))) freqVals[i*freq+j] = float32(i) / float32(math.Pow(theta, float64(j*2)/float64(dim)))
} }
} }
freqs, err := ctx.Input().FromFloatSlice(freqVals, freq, maxGridSize) freqs := ctx.Input().FromFloatSlice(freqVals, freq, maxGridSize)
if err != nil {
panic(fmt.Errorf("failed to create tensor from frequencies: %w", err))
}
// Create position coordinates (y,x pairs) for the grid // Create position coordinates (y,x pairs) for the grid
// In PyTorch: Equivalent to generating position ids with torch.arange() // In PyTorch: Equivalent to generating position ids with torch.arange()
@ -339,10 +330,7 @@ func (m *VisionModel) PositionalEmbedding(ctx ml.Context, grid *Grid) ml.Tensor
coords = append(coords, int32(y), int32(x)) coords = append(coords, int32(y), int32(x))
} }
} }
pos, err := ctx.Input().FromIntSlice(coords, 2, grid.Width, grid.Height) pos := ctx.Input().FromIntSlice(coords, 2, grid.Width, grid.Height)
if err != nil {
panic(fmt.Errorf("failed to create tensor from positions: %w", err))
}
// Reshape and permute positions to match spatial merging pattern // Reshape and permute positions to match spatial merging pattern
pos = pos.Reshape(ctx, 2, grid.Width, merge, grid.Height/merge) pos = pos.Reshape(ctx, 2, grid.Width, merge, grid.Height/merge)

View File

@ -156,10 +156,7 @@ type Model struct {
// Forward implements model.Model. // Forward implements model.Model.
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
if err != nil {
return nil, err
}
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs) hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
@ -168,10 +165,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
var outputs ml.Tensor var outputs ml.Tensor
if i == len(m.Layers)-1 { if i == len(m.Layers)-1 {
outputs, err = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
if err != nil {
return nil, err
}
} }
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options) hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options)

View File

@ -102,7 +102,7 @@ func (m multimodalStore) getTensor(backend ml.Backend, ctx ml.Context, in ml.Ten
for i, t := range entry.mm { for i, t := range entry.mm {
if in == t.Tensor { if in == t.Tensor {
if !reserve { if !reserve {
return ctx.Input().FromFloatSlice(entry.data[i], t.Tensor.Shape()...) return ctx.Input().FromFloatSlice(entry.data[i], t.Tensor.Shape()...), nil
} else { } else {
return ctx.Input().Empty(t.Tensor.DType(), t.Tensor.Shape()...), nil return ctx.Input().Empty(t.Tensor.DType(), t.Tensor.Shape()...), nil
} }

View File

@ -808,10 +808,7 @@ func (s *Server) reserveWorstCaseGraph() error {
batch.Outputs[i] = int32(i) batch.Outputs[i] = int32(i)
} }
batch.Inputs, err = ctx.Input().FromIntSlice(batchInputs, len(batchInputs)) batch.Inputs = ctx.Input().FromIntSlice(batchInputs, len(batchInputs))
if err != nil {
return err
}
cache := s.model.Config().Cache cache := s.model.Config().Cache
if cache != nil { if cache != nil {
@ -876,7 +873,8 @@ func (s *Server) load(
parallel int, parallel int,
kvCacheType string, kvCacheType string,
kvSize int, kvSize int,
multiUserCache bool) { multiUserCache bool,
) {
err := s.initModel(mpath, params, lpath, parallel, kvCacheType, kvSize, multiUserCache) err := s.initModel(mpath, params, lpath, parallel, kvCacheType, kvSize, multiUserCache)
if err != nil { if err != nil {
panic(err) panic(err)