mirror of https://github.com/ollama/ollama.git
deepseek tests
This commit is contained in:
parent
a40d427bce
commit
909232168d
|
|
@ -0,0 +1,284 @@
|
|||
package deepseek3
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
typemodel "github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
var args struct {
|
||||
model,
|
||||
prompt string
|
||||
layers int
|
||||
}
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
flag.StringVar(&args.model, "model", "", "path to model")
|
||||
flag.StringVar(&args.prompt, "prompt", "The capital of France is", "model prompt")
|
||||
flag.IntVar(&args.layers, "layers", math.MaxInt, "num of gpu layers")
|
||||
flag.Parse()
|
||||
|
||||
slog.SetDefault(logutil.NewLogger(os.Stderr, envconfig.LogLevel()))
|
||||
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
func blob(tb testing.TB, model string) string {
|
||||
tb.Helper()
|
||||
|
||||
models := envconfig.Models()
|
||||
manifest, err := os.Open(filepath.Join(models, "manifests", typemodel.ParseName(model).Filepath()))
|
||||
if err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
defer manifest.Close()
|
||||
|
||||
var m struct {
|
||||
Layers []struct {
|
||||
MediaType string `json:"mediaType"`
|
||||
Digest string `json:"digest"`
|
||||
} `json:"layers"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(manifest).Decode(&m); err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
|
||||
for _, layer := range m.Layers {
|
||||
if layer.MediaType == "application/vnd.ollama.image.model" {
|
||||
tb.Log("using model blob", layer.Digest)
|
||||
return filepath.Join(models, "blobs", strings.ReplaceAll(layer.Digest, ":", "-"))
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func loadFloatsFromBinary(filename string) ([]float32, error) {
|
||||
f, err := os.Open(filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
fi, err := f.Stat()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if fi.Size()%4 != 0 {
|
||||
return nil, fmt.Errorf("file size %d not multiple of 4", fi.Size())
|
||||
}
|
||||
|
||||
n := int(fi.Size() / 4)
|
||||
floats := make([]float32, n)
|
||||
if err := binary.Read(f, binary.LittleEndian, floats); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return floats, nil
|
||||
}
|
||||
|
||||
func TestForward(t *testing.T) {
|
||||
m, err := model.New(blob(t, args.model), ml.BackendParams{AllocMemory: true})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := m.Backend().Load(t.Context(), func(float32) {}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Logf("%+v", m.(*Transformer).TransformerBlocks[0].Attention.QA)
|
||||
|
||||
attentionBlock := m.(*Transformer).TransformerBlocks[0].Attention
|
||||
ctx := m.Backend().NewContext()
|
||||
filePath := "/Users/graceguo/Downloads/hidden_states.bin"
|
||||
|
||||
hsFloats, err := loadFloatsFromBinary(filePath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("hs len=%d, expected=%d", len(hsFloats), 7168*4*1)
|
||||
t.Logf("DEBUG: hsFloats: %v", hsFloats[:10])
|
||||
hiddenStates := ctx.Input().FromFloatSlice(hsFloats, 7168, 4, 1)
|
||||
t.Logf("DEBUG: hiddenStates.shape: %v", hiddenStates.Shape())
|
||||
positionIndices := []int32{0, 1, 2, 3}
|
||||
positions := ctx.Input().FromIntSlice(positionIndices, 4)
|
||||
|
||||
qLoraRankVal := 1536
|
||||
options := &Options{
|
||||
kvLoraRank: 512,
|
||||
qkNopeHeadDim: 128,
|
||||
qkRopeHeadDim: 64,
|
||||
kqNopeHeadDim: 128, // key part dimension (256 - 128 = 128)
|
||||
qkHeadDim: 128 + 64, // qk_nope_head_dim + qk_rope_head_dim
|
||||
qLoraRank: &qLoraRankVal,
|
||||
attnImplementation: "sdpa",
|
||||
vHeadDim: 128,
|
||||
hiddenSize: 7168,
|
||||
numHeads: 128,
|
||||
numKVHeads: 128,
|
||||
keyLength: 128,
|
||||
valueLength: 128,
|
||||
eps: 1e-06,
|
||||
ropeBase: 10000,
|
||||
ropeScale: 40,
|
||||
|
||||
yarn_log_multiplier: 0.1,
|
||||
originalContextLength: 4096,
|
||||
}
|
||||
result := attentionBlock.Forward(ctx, hiddenStates, positions, nil, options)
|
||||
result = result.Contiguous(ctx)
|
||||
ctx.Forward(result).Compute(result)
|
||||
|
||||
t.Logf("shape=%v dtype=%v", result.Shape(), result.DType())
|
||||
|
||||
// filePath = "/Users/graceguo/workspace/ollama/model/models/deepseek3/hello5.bin"
|
||||
// print("DEBUG: filePath: %v\n", filePath)
|
||||
// err = os.WriteFile(filePath, result.Bytes(), 0644)
|
||||
// if err != nil {
|
||||
// t.Fatal(err)
|
||||
// }
|
||||
|
||||
t.Logf("Forward pass completed, result shape: %v", result.Shape())
|
||||
}
|
||||
|
||||
func TestTopKIndicesComplex(t *testing.T) {
|
||||
m, err := model.New(blob(t, args.model), ml.BackendParams{AllocMemory: true})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := m.Backend().Load(t.Context(), func(float32) {}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
mlp := m.(*Transformer).TransformerBlocks[3].MLP
|
||||
ctx := m.Backend().NewContext()
|
||||
|
||||
filePath := "/Users/graceguo/Downloads/hidden_states.bin"
|
||||
|
||||
hsFloats, err := loadFloatsFromBinary(filePath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("hs len=%d, expected=%d", len(hsFloats), 7168*4*1)
|
||||
t.Logf("DEBUG: hsFloats: %v", hsFloats[:10])
|
||||
hiddenStates := ctx.Input().FromFloatSlice(hsFloats, 7168, 4, 1)
|
||||
t.Logf("DEBUG: hiddenStates.shape: %v", hiddenStates.Shape())
|
||||
|
||||
options := &Options{
|
||||
numExperts: 256,
|
||||
numExpertsUsed: 8,
|
||||
normTopKProb: true,
|
||||
routedScalingFactor: 2.5,
|
||||
}
|
||||
|
||||
result := mlp.Forward(ctx, hiddenStates, options)
|
||||
result = result.Contiguous(ctx)
|
||||
ctx.Forward(result).Compute(result)
|
||||
|
||||
t.Logf("shape=%v dtype=%v", result.Shape(), result.DType())
|
||||
|
||||
// filePath = "/Users/graceguo/workspace/ollama/model/models/deepseek3/post_moe.bin"
|
||||
// print("DEBUG: filePath: %v\n", filePath)
|
||||
// err = os.WriteFile(filePath, result.Bytes(), 0644)
|
||||
// if err != nil {
|
||||
// t.Fatal(err)
|
||||
// }
|
||||
|
||||
t.Logf("Forward pass completed, result shape: %v", result.Shape())
|
||||
t.Logf("Result shape: %v", result.Shape())
|
||||
}
|
||||
|
||||
func TestFullForward(t *testing.T) {
|
||||
m, err := model.New(blob(t, args.model), ml.BackendParams{AllocMemory: true})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := m.Backend().Load(t.Context(), func(float32) {}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ctx := m.Backend().NewContext()
|
||||
|
||||
prompt := args.prompt
|
||||
if prompt == "" {
|
||||
prompt = "Hello world! How's it going? 123 一二三"
|
||||
}
|
||||
|
||||
tp := m.(model.TextProcessor)
|
||||
tokens, err := tp.Encode(prompt, true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Logf("tokens: %q", tokens)
|
||||
|
||||
decoded, err := tp.Decode(tokens)
|
||||
if err != nil { t.Fatal(err) }
|
||||
t.Logf("decoded: %q", decoded)
|
||||
|
||||
inputsTensor := ctx.Input().FromIntSlice(tokens, len(tokens))
|
||||
positions := make([]int32, len(tokens))
|
||||
sequences := make([]int, len(tokens))
|
||||
for i := range tokens {
|
||||
positions[i] = int32(i)
|
||||
sequences[i] = 0
|
||||
}
|
||||
outputs := ctx.Input().FromIntSlice([]int32{int32(len(tokens) - 1)}, 1)
|
||||
|
||||
batch := input.Batch{
|
||||
Inputs: inputsTensor,
|
||||
Positions: positions,
|
||||
Sequences: sequences,
|
||||
Outputs: outputs,
|
||||
}
|
||||
if cache := m.Config().Cache; cache != nil {
|
||||
cache.Init(m.Backend(), ml.DTypeF16, 1, 4096, len(tokens))
|
||||
}
|
||||
|
||||
result, err := model.Forward(ctx, m, batch)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
result = result.Contiguous(ctx)
|
||||
ctx.Forward(result).Compute(result)
|
||||
|
||||
t.Logf("Forward pass completed, result shape: %v", result.Shape())
|
||||
}
|
||||
|
||||
func TestTokenization(t *testing.T) {
|
||||
m, err := model.New(blob(t, args.model), ml.BackendParams{AllocMemory: true})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := m.Backend().Load(t.Context(), func(float32) {}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
prompt := args.prompt
|
||||
if prompt == "" {
|
||||
prompt = "hello"
|
||||
}
|
||||
|
||||
tp := m.(model.TextProcessor)
|
||||
tokens, err := tp.Encode(prompt, true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Logf("tokens: %v", tokens)
|
||||
}
|
||||
Loading…
Reference in New Issue