ollama/ml/nn/pooling/pooling.go

37 lines
713 B
Go

package pooling
import (
"github.com/ollama/ollama/ml"
)
type Type uint32
const (
TypeNone Type = iota
TypeMean
TypeCLS
TypeLast
TypeRank
TypeUnknown = 0xFFFFFFFE
TypeUnspecified = 0xFFFFFFFF
)
func Pooling(ctx ml.Context, hiddenStates ml.Tensor, poolingType Type) ml.Tensor {
switch poolingType {
case TypeNone:
return hiddenStates
case TypeMean:
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mean(ctx)
return hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
case TypeCLS:
return hiddenStates.View(ctx, 0, hiddenStates.Dim(0))
case TypeLast:
panic("not implemented")
case TypeRank:
panic("not implemented")
default:
panic("not implemented")
}
}