mirror of https://github.com/ollama/ollama.git
				
				
				
			
		
			
	
	
		
			107 lines
		
	
	
		
			2.5 KiB
		
	
	
	
		
			Go
		
	
	
	
		
		
			
		
	
	
			107 lines
		
	
	
		
			2.5 KiB
		
	
	
	
		
			Go
		
	
	
	
| 
								 | 
							
								package server
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								import (
							 | 
						||
| 
								 | 
							
									"bytes"
							 | 
						||
| 
								 | 
							
									"encoding/binary"
							 | 
						||
| 
								 | 
							
									"errors"
							 | 
						||
| 
								 | 
							
									"os"
							 | 
						||
| 
								 | 
							
									"path/filepath"
							 | 
						||
| 
								 | 
							
									"strings"
							 | 
						||
| 
								 | 
							
									"testing"
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									"github.com/ollama/ollama/api"
							 | 
						||
| 
								 | 
							
								)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								func TestConvertFromSafetensors(t *testing.T) {
							 | 
						||
| 
								 | 
							
									t.Setenv("OLLAMA_MODELS", t.TempDir())
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									// Helper function to create a new layer and return its digest
							 | 
						||
| 
								 | 
							
									makeTemp := func(content string) string {
							 | 
						||
| 
								 | 
							
										l, err := NewLayer(strings.NewReader(content), "application/octet-stream")
							 | 
						||
| 
								 | 
							
										if err != nil {
							 | 
						||
| 
								 | 
							
											t.Fatalf("Failed to create layer: %v", err)
							 | 
						||
| 
								 | 
							
										}
							 | 
						||
| 
								 | 
							
										return l.Digest
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									// Create a safetensors compatible file with empty JSON content
							 | 
						||
| 
								 | 
							
									var buf bytes.Buffer
							 | 
						||
| 
								 | 
							
									headerSize := int64(len("{}"))
							 | 
						||
| 
								 | 
							
									binary.Write(&buf, binary.LittleEndian, headerSize)
							 | 
						||
| 
								 | 
							
									buf.WriteString("{}")
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									model := makeTemp(buf.String())
							 | 
						||
| 
								 | 
							
									config := makeTemp(`{
							 | 
						||
| 
								 | 
							
										"architectures": ["LlamaForCausalLM"], 
							 | 
						||
| 
								 | 
							
										"vocab_size": 32000
							 | 
						||
| 
								 | 
							
									}`)
							 | 
						||
| 
								 | 
							
									tokenizer := makeTemp(`{
							 | 
						||
| 
								 | 
							
										"version": "1.0",
							 | 
						||
| 
								 | 
							
										"truncation": null,
							 | 
						||
| 
								 | 
							
										"padding": null,
							 | 
						||
| 
								 | 
							
										"added_tokens": [
							 | 
						||
| 
								 | 
							
											{
							 | 
						||
| 
								 | 
							
												"id": 0,
							 | 
						||
| 
								 | 
							
												"content": "<|endoftext|>",
							 | 
						||
| 
								 | 
							
												"single_word": false,
							 | 
						||
| 
								 | 
							
												"lstrip": false,
							 | 
						||
| 
								 | 
							
												"rstrip": false,
							 | 
						||
| 
								 | 
							
												"normalized": false,
							 | 
						||
| 
								 | 
							
												"special": true
							 | 
						||
| 
								 | 
							
											}
							 | 
						||
| 
								 | 
							
										]
							 | 
						||
| 
								 | 
							
									}`)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									tests := []struct {
							 | 
						||
| 
								 | 
							
										name     string
							 | 
						||
| 
								 | 
							
										filePath string
							 | 
						||
| 
								 | 
							
										wantErr  error
							 | 
						||
| 
								 | 
							
									}{
							 | 
						||
| 
								 | 
							
										// Invalid
							 | 
						||
| 
								 | 
							
										{
							 | 
						||
| 
								 | 
							
											name:     "InvalidRelativePathShallow",
							 | 
						||
| 
								 | 
							
											filePath: filepath.Join("..", "file.safetensors"),
							 | 
						||
| 
								 | 
							
											wantErr:  errFilePath,
							 | 
						||
| 
								 | 
							
										},
							 | 
						||
| 
								 | 
							
										{
							 | 
						||
| 
								 | 
							
											name:     "InvalidRelativePathDeep",
							 | 
						||
| 
								 | 
							
											filePath: filepath.Join("..", "..", "..", "..", "..", "..", "data", "file.txt"),
							 | 
						||
| 
								 | 
							
											wantErr:  errFilePath,
							 | 
						||
| 
								 | 
							
										},
							 | 
						||
| 
								 | 
							
										{
							 | 
						||
| 
								 | 
							
											name:     "InvalidNestedPath",
							 | 
						||
| 
								 | 
							
											filePath: filepath.Join("dir", "..", "..", "..", "..", "..", "other.safetensors"),
							 | 
						||
| 
								 | 
							
											wantErr:  errFilePath,
							 | 
						||
| 
								 | 
							
										},
							 | 
						||
| 
								 | 
							
										{
							 | 
						||
| 
								 | 
							
											name:     "AbsolutePathOutsideRoot",
							 | 
						||
| 
								 | 
							
											filePath: filepath.Join(os.TempDir(), "model.safetensors"),
							 | 
						||
| 
								 | 
							
											wantErr:  errFilePath, // Should fail since it's outside tmpDir
							 | 
						||
| 
								 | 
							
										},
							 | 
						||
| 
								 | 
							
										{
							 | 
						||
| 
								 | 
							
											name:     "ValidRelativePath",
							 | 
						||
| 
								 | 
							
											filePath: "model.safetensors",
							 | 
						||
| 
								 | 
							
											wantErr:  nil,
							 | 
						||
| 
								 | 
							
										},
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									for _, tt := range tests {
							 | 
						||
| 
								 | 
							
										t.Run(tt.name, func(t *testing.T) {
							 | 
						||
| 
								 | 
							
											// Create the minimum required file map for convertFromSafetensors
							 | 
						||
| 
								 | 
							
											files := map[string]string{
							 | 
						||
| 
								 | 
							
												tt.filePath:      model,
							 | 
						||
| 
								 | 
							
												"config.json":    config,
							 | 
						||
| 
								 | 
							
												"tokenizer.json": tokenizer,
							 | 
						||
| 
								 | 
							
											}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
											_, err := convertFromSafetensors(files, nil, false, func(resp api.ProgressResponse) {})
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
											if (tt.wantErr == nil && err != nil) ||
							 | 
						||
| 
								 | 
							
												(tt.wantErr != nil && err == nil) ||
							 | 
						||
| 
								 | 
							
												(tt.wantErr != nil && !errors.Is(err, tt.wantErr)) {
							 | 
						||
| 
								 | 
							
												t.Errorf("convertFromSafetensors() error = %v, wantErr %v", err, tt.wantErr)
							 | 
						||
| 
								 | 
							
											}
							 | 
						||
| 
								 | 
							
										})
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								}
							 |