works for 3.1, but regression in 3???

This commit is contained in:
gr4ceG 2025-09-26 14:35:06 -07:00
parent cd17efc9eb
commit c5cd7fbead
2 changed files with 1 additions and 6 deletions

View File

@ -1509,8 +1509,6 @@ func (t *Tensor) Set(ctx ml.Context, t2 ml.Tensor, offset int, strides ...int) m
return &Tensor{b: t.b, t: tt}
}
// func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sinks ml.Tensor, scale float64) ml.Tensor {
func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sinks ml.Tensor, vmla ml.Tensor, scale float64) ml.Tensor { // add vmla
var kqMask *C.struct_ggml_tensor
if mask != nil {

View File

@ -194,10 +194,7 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor
// attention := nn.Attention(ctx, query, key, value, opts.kqScale, cache)
attention := nn.AttentionWithVMLA(ctx, query, key, value, nil, attn.VB.Weight, opts.kqScale, cache) // is there a better way to write this?
fmt.Printf("attention shape: %v\n", attention.Shape())
// func AttentionWithVMLA(ctx ml.Context, query, key, value, sinks ml.Tensor, vmla ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
// the attention is where there is difference
// attention := nn.Attention(ctx, query, key, value, opts.kqScale, cache)
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), seqLength)
fmt.Printf("attention shape: %v\n", attention.Shape())
return attn.Output.Forward(ctx, attention)