mirror of https://github.com/ollama/ollama.git
works for 3.1, but regression in 3???
This commit is contained in:
parent
cd17efc9eb
commit
c5cd7fbead
|
@ -1509,8 +1509,6 @@ func (t *Tensor) Set(ctx ml.Context, t2 ml.Tensor, offset int, strides ...int) m
|
||||||
return &Tensor{b: t.b, t: tt}
|
return &Tensor{b: t.b, t: tt}
|
||||||
}
|
}
|
||||||
|
|
||||||
// func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sinks ml.Tensor, scale float64) ml.Tensor {
|
|
||||||
|
|
||||||
func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sinks ml.Tensor, vmla ml.Tensor, scale float64) ml.Tensor { // add vmla
|
func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sinks ml.Tensor, vmla ml.Tensor, scale float64) ml.Tensor { // add vmla
|
||||||
var kqMask *C.struct_ggml_tensor
|
var kqMask *C.struct_ggml_tensor
|
||||||
if mask != nil {
|
if mask != nil {
|
||||||
|
|
|
@ -194,10 +194,7 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor
|
||||||
// attention := nn.Attention(ctx, query, key, value, opts.kqScale, cache)
|
// attention := nn.Attention(ctx, query, key, value, opts.kqScale, cache)
|
||||||
attention := nn.AttentionWithVMLA(ctx, query, key, value, nil, attn.VB.Weight, opts.kqScale, cache) // is there a better way to write this?
|
attention := nn.AttentionWithVMLA(ctx, query, key, value, nil, attn.VB.Weight, opts.kqScale, cache) // is there a better way to write this?
|
||||||
fmt.Printf("attention shape: %v\n", attention.Shape())
|
fmt.Printf("attention shape: %v\n", attention.Shape())
|
||||||
// func AttentionWithVMLA(ctx ml.Context, query, key, value, sinks ml.Tensor, vmla ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
|
|
||||||
// the attention is where there is difference
|
|
||||||
|
|
||||||
// attention := nn.Attention(ctx, query, key, value, opts.kqScale, cache)
|
|
||||||
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), seqLength)
|
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), seqLength)
|
||||||
fmt.Printf("attention shape: %v\n", attention.Shape())
|
fmt.Printf("attention shape: %v\n", attention.Shape())
|
||||||
return attn.Output.Forward(ctx, attention)
|
return attn.Output.Forward(ctx, attention)
|
||||||
|
|
Loading…
Reference in New Issue