2025-02-14 08:31:21 +08:00
|
|
|
package ml
|
|
|
|
|
|
|
|
import (
|
|
|
|
"bytes"
|
2025-03-20 04:03:16 +08:00
|
|
|
"context"
|
2025-02-14 08:31:21 +08:00
|
|
|
"encoding/binary"
|
|
|
|
"fmt"
|
2025-05-24 06:37:32 +08:00
|
|
|
"log/slog"
|
2025-05-11 02:27:15 +08:00
|
|
|
"math"
|
2025-03-08 10:04:16 +08:00
|
|
|
"slices"
|
2025-02-14 08:31:21 +08:00
|
|
|
"strconv"
|
|
|
|
"strings"
|
|
|
|
|
2025-03-19 05:38:44 +08:00
|
|
|
"github.com/ollama/ollama/fs"
|
|
|
|
)
|
2025-02-14 08:31:21 +08:00
|
|
|
|
|
|
|
type Backend interface {
|
2025-04-18 04:42:40 +08:00
|
|
|
Load(ctx context.Context, progress func(float32)) error
|
2025-04-18 02:00:25 +08:00
|
|
|
|
|
|
|
// BackendMemory returns the memory allocations that were made for this model
|
|
|
|
BackendMemory() BackendMemory
|
|
|
|
|
2025-03-19 05:38:44 +08:00
|
|
|
Config() fs.Config
|
2025-02-14 08:31:21 +08:00
|
|
|
Get(name string) Tensor
|
|
|
|
NewContext() Context
|
2025-02-26 08:06:32 +08:00
|
|
|
NewContextSize(size int) Context
|
2025-02-14 08:31:21 +08:00
|
|
|
}
|
|
|
|
|
2025-02-23 13:34:10 +08:00
|
|
|
// BackendCacheConfig should be implemented by backends that need special output
|
|
|
|
// from the cache to meet specific requirements. It is frequently implemented in
|
|
|
|
// conjunction with ScaledDotProductAttention.
|
|
|
|
type BackendCacheConfig interface {
|
|
|
|
CacheConfig() CacheConfig
|
|
|
|
}
|
|
|
|
|
|
|
|
// CacheConfig controls optimizations (mostly backend-specific) that may transform
|
|
|
|
// the output the cache to work better with specific kernels.
|
|
|
|
type CacheConfig struct {
|
|
|
|
// CachePadding specifies the multiple for the number of tokens of cache history
|
|
|
|
// that will be returned from cache Get for k, v and mask. The capacity of the
|
|
|
|
// cache itself will also be increased to a multiple of this size if needed.
|
|
|
|
CachePadding int
|
|
|
|
|
|
|
|
// PermutedV performs Permute(ctx, 1, 2, 0, 3) on v tensors stored via Put
|
|
|
|
// and return the permuted version via Get. This uses the cache copy operation
|
|
|
|
// to avoid a Contiguous call on the permuted tensor.
|
|
|
|
PermutedV bool
|
2025-02-26 09:24:36 +08:00
|
|
|
|
|
|
|
// MaskDType specifies the data type for generating the mask. If unset it will
|
|
|
|
// default to DTypeF32.
|
|
|
|
MaskDType DType
|
|
|
|
|
|
|
|
// MaskBatchPadding specifies the multiple for the batch size dimension in the mask.
|
|
|
|
// Any position that does not correspond to an actual token will be filled with -Inf.
|
|
|
|
MaskBatchPadding int
|
2025-02-23 13:34:10 +08:00
|
|
|
}
|
|
|
|
|
2025-02-21 03:18:01 +08:00
|
|
|
// BackendParams controls how the backend loads and executes models
|
|
|
|
type BackendParams struct {
|
|
|
|
// NumThreads sets the number of threads to use if running on the CPU
|
|
|
|
NumThreads int
|
2025-02-14 08:31:21 +08:00
|
|
|
|
2025-02-21 03:18:01 +08:00
|
|
|
// MainGPU is the index of the primary GPU to use
|
|
|
|
MainGPU int
|
|
|
|
|
|
|
|
// NumGPULayers is the number of layers to offload to GPUs
|
|
|
|
NumGPULayers int
|
|
|
|
|
|
|
|
// TensorSplit is the fraction of the model to offload to each GPU
|
|
|
|
TensorSplit []float32
|
2025-02-26 09:24:36 +08:00
|
|
|
|
|
|
|
// FlashAttention indicates that we should use a fused flash attention kernel
|
|
|
|
FlashAttention bool
|
2025-02-21 03:18:01 +08:00
|
|
|
}
|
|
|
|
|
2025-04-18 02:00:25 +08:00
|
|
|
// ErrNoMem is returned when panicing due to insufficient memory. It includes
|
|
|
|
// the attempted memory allocation.
|
|
|
|
type ErrNoMem struct {
|
|
|
|
BackendMemory
|
|
|
|
}
|
|
|
|
|
|
|
|
func (e ErrNoMem) Error() string {
|
|
|
|
return fmt.Sprintf("insufficient memory - required allocations: %+v", e.BackendMemory)
|
|
|
|
}
|
|
|
|
|
|
|
|
type AllocationStatus int
|
|
|
|
|
|
|
|
const (
|
|
|
|
// Unallocated memory - have not yet attempted to allocate
|
|
|
|
Unallocated AllocationStatus = iota
|
|
|
|
|
|
|
|
// Failed memory - tried to allocate the memory and did not succeed
|
|
|
|
Failed
|
|
|
|
|
|
|
|
// Allocated memory = tried and succeeded to allocate memory
|
|
|
|
Allocated
|
|
|
|
)
|
|
|
|
|
|
|
|
// Memory is the size of an allocation and whether it was successful.
|
|
|
|
type Memory struct {
|
|
|
|
Size uint64
|
|
|
|
Status AllocationStatus
|
|
|
|
}
|
|
|
|
|
|
|
|
func (m Memory) String() string {
|
|
|
|
s := fmt.Sprint(m.Size)
|
|
|
|
|
|
|
|
switch m.Status {
|
|
|
|
case Unallocated:
|
|
|
|
s += "U"
|
|
|
|
case Failed:
|
|
|
|
s += "F"
|
|
|
|
case Allocated:
|
|
|
|
s += "A"
|
|
|
|
}
|
|
|
|
|
|
|
|
return s
|
|
|
|
}
|
|
|
|
|
|
|
|
// DeviceMemory provides a breakdown of the memory needed
|
|
|
|
// per device, such as a CPU or GPU.
|
|
|
|
type DeviceMemory struct {
|
|
|
|
// Name is the name of the device as labeled by the backend. It
|
|
|
|
// may not be persistent across instances of the runner.
|
|
|
|
Name string
|
|
|
|
|
2025-06-26 08:13:32 +08:00
|
|
|
// ID is an identifier for the device for matching with system
|
|
|
|
// management libraries.
|
|
|
|
ID string
|
2025-06-18 22:30:49 +08:00
|
|
|
|
2025-04-18 02:00:25 +08:00
|
|
|
// Weights is the per-layer memory needed for the model weights.
|
|
|
|
Weights []Memory
|
|
|
|
|
|
|
|
// Cache is the per-layer memory needed for the KV cache.
|
|
|
|
Cache []Memory
|
|
|
|
|
|
|
|
// Graph is the size of the compute graph. It is not per-layer.
|
|
|
|
Graph Memory
|
|
|
|
}
|
|
|
|
|
2025-05-24 06:37:32 +08:00
|
|
|
func memoryPresent(mem []Memory) bool {
|
|
|
|
return slices.ContainsFunc(mem, func(m Memory) bool { return m.Size != 0 })
|
|
|
|
}
|
|
|
|
|
|
|
|
func (m DeviceMemory) LogValue() slog.Value {
|
|
|
|
var attrs []slog.Attr
|
|
|
|
if memoryPresent(m.Weights) {
|
|
|
|
attrs = append(attrs, slog.Any("Weights", m.Weights))
|
|
|
|
}
|
|
|
|
|
|
|
|
if memoryPresent(m.Cache) {
|
|
|
|
attrs = append(attrs, slog.Any("Cache", m.Cache))
|
|
|
|
}
|
|
|
|
|
|
|
|
if m.Graph.Size != 0 {
|
|
|
|
attrs = append(attrs, slog.Any("Graph", m.Graph))
|
|
|
|
}
|
|
|
|
|
2025-06-26 08:13:32 +08:00
|
|
|
if len(attrs) > 0 && m.ID != "" {
|
|
|
|
attrs = append([]slog.Attr{slog.String("ID", m.ID)}, attrs...)
|
2025-06-18 22:30:49 +08:00
|
|
|
}
|
|
|
|
|
2025-05-24 06:37:32 +08:00
|
|
|
return slog.GroupValue(attrs...)
|
|
|
|
}
|
|
|
|
|
2025-04-18 02:00:25 +08:00
|
|
|
// BackendMemory provides the amount of memory required to load the model
|
|
|
|
// per device based on the BackendParams. In some cases, not all required
|
|
|
|
// allocations will be known at this point. However, the size of the most recent
|
|
|
|
// allocation is guaranteed to be provided so that if it failed, the caller can
|
|
|
|
// accommodate that to make forward progress.
|
|
|
|
type BackendMemory struct {
|
|
|
|
// InputsWeights are always located on the CPU and cannot be moved
|
|
|
|
InputWeights Memory
|
|
|
|
|
|
|
|
// CPU model components are located in system memory. This does not
|
|
|
|
// include unified memory allocated through the GPU.
|
|
|
|
CPU DeviceMemory
|
|
|
|
|
|
|
|
// GPU model components are located on one or more GPUs.
|
|
|
|
GPUs []DeviceMemory
|
|
|
|
}
|
|
|
|
|
2025-05-24 06:37:32 +08:00
|
|
|
func (m BackendMemory) LogValue() slog.Value {
|
|
|
|
var attrs []slog.Attr
|
|
|
|
if m.InputWeights.Size != 0 {
|
|
|
|
attrs = append(attrs, slog.Any("InputWeights", m.InputWeights))
|
|
|
|
}
|
|
|
|
|
|
|
|
attrs = append(attrs, slog.Any(m.CPU.Name, m.CPU))
|
|
|
|
for _, g := range m.GPUs {
|
|
|
|
attrs = append(attrs, slog.Any(g.Name, g))
|
|
|
|
}
|
|
|
|
|
|
|
|
return slog.GroupValue(attrs...)
|
|
|
|
}
|
|
|
|
|
2025-04-18 04:42:40 +08:00
|
|
|
var backends = make(map[string]func(string, BackendParams) (Backend, error))
|
2025-02-21 03:18:01 +08:00
|
|
|
|
2025-04-18 04:42:40 +08:00
|
|
|
func RegisterBackend(name string, f func(string, BackendParams) (Backend, error)) {
|
2025-02-14 08:31:21 +08:00
|
|
|
if _, ok := backends[name]; ok {
|
|
|
|
panic("backend: backend already registered")
|
|
|
|
}
|
|
|
|
|
|
|
|
backends[name] = f
|
|
|
|
}
|
|
|
|
|
2025-04-18 04:42:40 +08:00
|
|
|
func NewBackend(modelPath string, params BackendParams) (Backend, error) {
|
2025-02-14 08:31:21 +08:00
|
|
|
if backend, ok := backends["ggml"]; ok {
|
2025-04-18 04:42:40 +08:00
|
|
|
return backend(modelPath, params)
|
2025-02-14 08:31:21 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
return nil, fmt.Errorf("unsupported backend")
|
|
|
|
}
|
|
|
|
|
|
|
|
type Context interface {
|
2025-03-01 09:48:07 +08:00
|
|
|
Empty(dtype DType, shape ...int) Tensor
|
2025-02-14 08:31:21 +08:00
|
|
|
Zeros(dtype DType, shape ...int) Tensor
|
2025-05-20 01:43:56 +08:00
|
|
|
FromFloatSlice(s []float32, shape ...int) Tensor
|
|
|
|
FromIntSlice(s []int32, shape ...int) Tensor
|
2025-02-14 08:31:21 +08:00
|
|
|
|
2025-04-04 01:25:23 +08:00
|
|
|
// Arange creates a 1D tensor with values within an interval (start, stop] increased by step.
|
|
|
|
Arange(start, stop, step float32, dtype DType) Tensor
|
|
|
|
|
2025-02-22 03:57:08 +08:00
|
|
|
Forward(...Tensor) Context
|
2025-02-04 11:35:12 +08:00
|
|
|
Compute(...Tensor)
|
2025-04-04 03:50:20 +08:00
|
|
|
|
|
|
|
// Reserve is analogous to Compute but rather than executing a
|
|
|
|
// graph, simply preallocates memory. Typically called with a
|
|
|
|
// worst case graph to ensure all resources are available for
|
|
|
|
// for future inference.
|
2025-04-18 02:00:25 +08:00
|
|
|
Reserve()
|
2025-04-04 03:50:20 +08:00
|
|
|
|
2025-02-26 04:57:49 +08:00
|
|
|
MaxGraphNodes() int
|
2025-02-05 11:49:34 +08:00
|
|
|
Close()
|
2025-02-26 08:06:32 +08:00
|
|
|
|
2025-03-28 02:52:09 +08:00
|
|
|
// Input returns a context appropriate for creating tensors that are
|
|
|
|
// inputs to the model (which includes things like output locations)
|
2025-02-26 08:06:32 +08:00
|
|
|
Input() Context
|
|
|
|
|
|
|
|
// Layer returns a context appropriate for creating intermediate tensors
|
|
|
|
Layer(int) Context
|
2025-02-14 08:31:21 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
type Tensor interface {
|
2025-02-04 09:21:57 +08:00
|
|
|
Dim(n int) int
|
|
|
|
Stride(n int) int
|
2025-02-14 08:31:21 +08:00
|
|
|
|
2025-02-04 09:21:57 +08:00
|
|
|
Shape() []int
|
2025-02-14 08:31:21 +08:00
|
|
|
DType() DType
|
|
|
|
|
|
|
|
Bytes() []byte
|
|
|
|
Floats() []float32
|
|
|
|
|
2025-03-15 07:56:32 +08:00
|
|
|
Neg(ctx Context) Tensor
|
2025-02-14 08:31:21 +08:00
|
|
|
Add(ctx Context, t2 Tensor) Tensor
|
2025-06-26 12:47:09 +08:00
|
|
|
Sub(ctx Context, t2 Tensor) Tensor
|
2025-02-14 08:31:21 +08:00
|
|
|
Mul(ctx Context, t2 Tensor) Tensor
|
2025-05-22 01:21:07 +08:00
|
|
|
Div(ctx Context, t2 Tensor) Tensor
|
|
|
|
|
2025-02-14 08:31:21 +08:00
|
|
|
Mulmat(ctx Context, t2 Tensor) Tensor
|
2025-02-14 02:01:14 +08:00
|
|
|
MulmatFullPrec(ctx Context, t2 Tensor) Tensor
|
2025-04-04 06:18:29 +08:00
|
|
|
MulmatID(ctx Context, t2, ids Tensor) Tensor
|
2025-02-14 08:31:21 +08:00
|
|
|
|
|
|
|
Softmax(ctx Context) Tensor
|
|
|
|
LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
|
|
|
|
RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
|
|
|
|
Scale(ctx Context, s float64) Tensor
|
2025-05-22 01:21:07 +08:00
|
|
|
SumRows(ctx Context) Tensor
|
2025-02-14 08:31:21 +08:00
|
|
|
|
2025-03-12 00:00:10 +08:00
|
|
|
AvgPool2D(ctx Context, k, s int, p float32) Tensor
|
2025-02-14 08:31:21 +08:00
|
|
|
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
2025-03-07 04:16:54 +08:00
|
|
|
|
2025-03-15 07:56:32 +08:00
|
|
|
IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
2025-02-14 08:31:21 +08:00
|
|
|
|
2025-03-15 07:56:32 +08:00
|
|
|
Sin(ctx Context) Tensor
|
|
|
|
Cos(ctx Context) Tensor
|
2025-02-14 08:31:21 +08:00
|
|
|
Tanh(ctx Context) Tensor
|
|
|
|
GELU(ctx Context) Tensor
|
2025-08-06 03:21:16 +08:00
|
|
|
QuickGELU(ctx Context) Tensor
|
2025-02-14 08:31:21 +08:00
|
|
|
SILU(ctx Context) Tensor
|
2025-06-26 12:47:09 +08:00
|
|
|
RELU(ctx Context) Tensor
|
2025-04-04 06:18:29 +08:00
|
|
|
Sigmoid(ctx Context) Tensor
|
2025-02-14 08:31:21 +08:00
|
|
|
|
2025-02-04 09:21:57 +08:00
|
|
|
Reshape(ctx Context, shape ...int) Tensor
|
2025-02-14 08:31:21 +08:00
|
|
|
View(ctx Context, offset int, shape ...int) Tensor
|
|
|
|
Permute(ctx Context, shape ...int) Tensor
|
2025-08-06 03:21:16 +08:00
|
|
|
Contiguous(ctx Context, shape ...int) Tensor
|
2025-03-08 05:52:45 +08:00
|
|
|
Set(ctx Context, t2 Tensor, offset int, strides ...int) Tensor
|
2025-02-14 08:31:21 +08:00
|
|
|
|
2025-02-04 09:21:57 +08:00
|
|
|
Pad(ctx Context, shape ...int) Tensor
|
2025-02-14 08:31:21 +08:00
|
|
|
|
|
|
|
Stack(ctx Context, dim int, s ...Tensor) Tensor
|
2025-03-15 07:56:32 +08:00
|
|
|
|
|
|
|
// Repeat repeats the tensor n times along dimension dim
|
|
|
|
Repeat(ctx Context, dim, n int) Tensor
|
2025-02-14 08:31:21 +08:00
|
|
|
Concat(ctx Context, t2 Tensor, dim int) Tensor
|
|
|
|
Rows(ctx Context, t2 Tensor) Tensor
|
|
|
|
Copy(ctx Context, t2 Tensor) Tensor
|
2025-03-15 07:56:32 +08:00
|
|
|
Duplicate(ctx Context) Tensor
|
2025-04-04 06:18:29 +08:00
|
|
|
|
|
|
|
TopK(ctx Context, k int) Tensor
|
2025-05-14 11:58:02 +08:00
|
|
|
Argsort(ctx Context) Tensor
|
2025-06-26 12:47:09 +08:00
|
|
|
Mean(ctx Context) Tensor
|
|
|
|
Variance(ctx Context) Tensor
|
|
|
|
Stddev(ctx Context) Tensor
|
|
|
|
Sqr(ctx Context) Tensor
|
|
|
|
Sqrt(ctx Context) Tensor
|
|
|
|
Clamp(ctx Context, min, max float32) Tensor
|
2025-02-14 08:31:21 +08:00
|
|
|
}
|
|
|
|
|
2025-02-15 12:51:44 +08:00
|
|
|
// ScaledDotProductAttention implements a fused attention
|
|
|
|
// operation equivalent to following code on a tensor named
|
|
|
|
// query:
|
|
|
|
//
|
2025-02-23 13:34:10 +08:00
|
|
|
// query = query.Permute(ctx, 0, 2, 1, 3)
|
|
|
|
// key = key.Permute(ctx, 0, 2, 1, 3)
|
|
|
|
// value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
|
|
|
//
|
2025-02-15 12:51:44 +08:00
|
|
|
// kq := key.MulmatFullPrec(ctx, query)
|
|
|
|
//
|
|
|
|
// kq = kq.Scale(ctx, scale)
|
|
|
|
//
|
|
|
|
// if mask != nil {
|
|
|
|
// kq = kq.Add(ctx, mask)
|
|
|
|
// }
|
|
|
|
//
|
|
|
|
// kq = kq.Softmax(ctx)
|
|
|
|
//
|
|
|
|
// kqv := value.Mulmat(ctx, kq)
|
|
|
|
// return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
|
|
|
type ScaledDotProductAttention interface {
|
|
|
|
ScaledDotProductAttention(ctx Context, key, value, mask Tensor, scale float64) Tensor
|
|
|
|
}
|
|
|
|
|
2025-02-14 08:31:21 +08:00
|
|
|
type number interface {
|
|
|
|
~int | ~int8 | ~int16 | ~int32 | ~int64 |
|
|
|
|
~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
|
|
|
|
~float32 | ~float64 |
|
|
|
|
~complex64 | ~complex128
|
|
|
|
}
|
|
|
|
|
|
|
|
func mul[T number](s ...T) T {
|
|
|
|
p := T(1)
|
|
|
|
for _, v := range s {
|
|
|
|
p *= v
|
|
|
|
}
|
|
|
|
|
|
|
|
return p
|
|
|
|
}
|
|
|
|
|
2025-05-11 02:27:15 +08:00
|
|
|
type DumpOptions func(*dumpOptions)
|
2025-02-14 08:31:21 +08:00
|
|
|
|
2025-05-11 02:27:15 +08:00
|
|
|
// DumpWithPrecision sets the number of decimal places to print. Applies to float32 and float64.
|
|
|
|
func DumpWithPrecision(n int) DumpOptions {
|
|
|
|
return func(opts *dumpOptions) {
|
|
|
|
opts.Precision = n
|
|
|
|
}
|
2025-02-14 08:31:21 +08:00
|
|
|
}
|
|
|
|
|
2025-05-11 02:27:15 +08:00
|
|
|
// DumpWithThreshold sets the threshold for printing the entire tensor. If the number of elements
|
|
|
|
// is less than or equal to this value, the entire tensor will be printed. Otherwise, only the
|
|
|
|
// beginning and end of each dimension will be printed.
|
|
|
|
func DumpWithThreshold(n int) DumpOptions {
|
|
|
|
return func(opts *dumpOptions) {
|
|
|
|
opts.Threshold = n
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// DumpWithEdgeItems sets the number of elements to print at the beginning and end of each dimension.
|
|
|
|
func DumpWithEdgeItems(n int) DumpOptions {
|
|
|
|
return func(opts *dumpOptions) {
|
|
|
|
opts.EdgeItems = n
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
type dumpOptions struct {
|
|
|
|
Precision, Threshold, EdgeItems int
|
|
|
|
}
|
|
|
|
|
|
|
|
func Dump(ctx Context, t Tensor, optsFuncs ...DumpOptions) string {
|
|
|
|
opts := dumpOptions{Precision: 4, Threshold: 1000, EdgeItems: 3}
|
|
|
|
for _, optsFunc := range optsFuncs {
|
|
|
|
optsFunc(&opts)
|
|
|
|
}
|
|
|
|
|
|
|
|
if mul(t.Shape()...) <= opts.Threshold {
|
|
|
|
opts.EdgeItems = math.MaxInt
|
2025-02-14 08:31:21 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
switch t.DType() {
|
|
|
|
case DTypeF32:
|
2025-05-11 02:27:15 +08:00
|
|
|
return dump[[]float32](ctx, t, opts.EdgeItems, func(f float32) string {
|
|
|
|
return strconv.FormatFloat(float64(f), 'f', opts.Precision, 32)
|
2024-12-18 11:59:41 +08:00
|
|
|
})
|
2025-02-22 12:54:14 +08:00
|
|
|
case DTypeF16, DTypeQ80, DTypeQ40:
|
2025-03-15 07:56:32 +08:00
|
|
|
f32 := ctx.Input().Empty(DTypeF32, t.Shape()...)
|
2024-12-18 11:59:41 +08:00
|
|
|
f32 = t.Copy(ctx, f32)
|
2025-05-11 02:27:15 +08:00
|
|
|
return dump[[]float32](ctx, f32, opts.EdgeItems, func(f float32) string {
|
|
|
|
return strconv.FormatFloat(float64(f), 'f', opts.Precision, 32)
|
2025-02-14 08:31:21 +08:00
|
|
|
})
|
|
|
|
case DTypeI32:
|
2025-05-11 02:27:15 +08:00
|
|
|
return dump[[]int32](ctx, t, opts.EdgeItems, func(i int32) string {
|
2025-02-14 08:31:21 +08:00
|
|
|
return strconv.FormatInt(int64(i), 10)
|
|
|
|
})
|
|
|
|
default:
|
|
|
|
return "<unsupported>"
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-12-18 11:59:41 +08:00
|
|
|
func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string) string {
|
|
|
|
if t.Bytes() == nil {
|
2025-02-22 03:57:08 +08:00
|
|
|
ctx.Forward(t).Compute(t)
|
2025-02-14 08:31:21 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
s := make(S, mul(t.Shape()...))
|
|
|
|
if err := binary.Read(bytes.NewBuffer(t.Bytes()), binary.LittleEndian, &s); err != nil {
|
|
|
|
panic(err)
|
|
|
|
}
|
|
|
|
|
|
|
|
shape := t.Shape()
|
2025-03-08 10:04:16 +08:00
|
|
|
slices.Reverse(shape)
|
2025-02-14 08:31:21 +08:00
|
|
|
|
|
|
|
var sb strings.Builder
|
2025-02-04 09:21:57 +08:00
|
|
|
var f func([]int, int)
|
|
|
|
f = func(dims []int, stride int) {
|
2025-02-14 08:31:21 +08:00
|
|
|
prefix := strings.Repeat(" ", len(shape)-len(dims)+1)
|
2025-03-08 10:04:16 +08:00
|
|
|
sb.WriteString("[")
|
|
|
|
defer func() { sb.WriteString("]") }()
|
2025-02-04 09:21:57 +08:00
|
|
|
for i := 0; i < dims[0]; i++ {
|
2025-02-14 08:31:21 +08:00
|
|
|
if i >= items && i < dims[0]-items {
|
2025-03-08 10:04:16 +08:00
|
|
|
sb.WriteString("..., ")
|
2025-02-14 08:31:21 +08:00
|
|
|
// skip to next printable element
|
|
|
|
skip := dims[0] - 2*items
|
|
|
|
if len(dims) > 1 {
|
|
|
|
stride += mul(append(dims[1:], skip)...)
|
|
|
|
fmt.Fprint(&sb, strings.Repeat("\n", len(dims)-1), prefix)
|
|
|
|
}
|
|
|
|
i += skip - 1
|
|
|
|
} else if len(dims) > 1 {
|
|
|
|
f(dims[1:], stride)
|
|
|
|
stride += mul(dims[1:]...)
|
|
|
|
if i < dims[0]-1 {
|
|
|
|
fmt.Fprint(&sb, ",", strings.Repeat("\n", len(dims)-1), prefix)
|
|
|
|
}
|
|
|
|
} else {
|
2025-03-08 10:04:16 +08:00
|
|
|
text := fn(s[stride+i])
|
|
|
|
if len(text) > 0 && text[0] != '-' {
|
|
|
|
sb.WriteString(" ")
|
|
|
|
}
|
|
|
|
|
|
|
|
sb.WriteString(text)
|
2025-02-14 08:31:21 +08:00
|
|
|
if i < dims[0]-1 {
|
2025-03-08 10:04:16 +08:00
|
|
|
sb.WriteString(", ")
|
2025-02-14 08:31:21 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
f(shape, 0)
|
|
|
|
|
|
|
|
return sb.String()
|
|
|
|
}
|
|
|
|
|
|
|
|
type DType int
|
|
|
|
|
|
|
|
const (
|
2024-12-18 11:59:41 +08:00
|
|
|
DTypeOther DType = iota
|
|
|
|
DTypeF32
|
|
|
|
DTypeF16
|
2025-02-22 12:54:14 +08:00
|
|
|
DTypeQ80
|
|
|
|
DTypeQ40
|
2025-02-14 08:31:21 +08:00
|
|
|
DTypeI32
|
2025-08-06 03:21:16 +08:00
|
|
|
DTypeMXFP4
|
2025-02-14 08:31:21 +08:00
|
|
|
)
|