mirror of https://github.com/ollama/ollama.git
				
				
				
			
		
			
				
	
	
		
			477 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			477 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Go
		
	
	
	
package template
 | 
						|
 | 
						|
import (
 | 
						|
	"bufio"
 | 
						|
	"bytes"
 | 
						|
	"encoding/json"
 | 
						|
	"io"
 | 
						|
	"os"
 | 
						|
	"path/filepath"
 | 
						|
	"slices"
 | 
						|
	"strings"
 | 
						|
	"testing"
 | 
						|
 | 
						|
	"github.com/google/go-cmp/cmp"
 | 
						|
 | 
						|
	"github.com/ollama/ollama/api"
 | 
						|
	"github.com/ollama/ollama/fs/ggml"
 | 
						|
)
 | 
						|
 | 
						|
func TestNamed(t *testing.T) {
 | 
						|
	f, err := os.Open(filepath.Join("testdata", "templates.jsonl"))
 | 
						|
	if err != nil {
 | 
						|
		t.Fatal(err)
 | 
						|
	}
 | 
						|
	defer f.Close()
 | 
						|
 | 
						|
	scanner := bufio.NewScanner(f)
 | 
						|
	for scanner.Scan() {
 | 
						|
		var ss map[string]string
 | 
						|
		if err := json.Unmarshal(scanner.Bytes(), &ss); err != nil {
 | 
						|
			t.Fatal(err)
 | 
						|
		}
 | 
						|
 | 
						|
		for k, v := range ss {
 | 
						|
			t.Run(k, func(t *testing.T) {
 | 
						|
				kv := ggml.KV{"tokenizer.chat_template": v}
 | 
						|
				s := kv.ChatTemplate()
 | 
						|
				r, err := Named(s)
 | 
						|
				if err != nil {
 | 
						|
					t.Fatal(err)
 | 
						|
				}
 | 
						|
 | 
						|
				if r.Name != k {
 | 
						|
					t.Errorf("expected %q, got %q", k, r.Name)
 | 
						|
				}
 | 
						|
 | 
						|
				var b bytes.Buffer
 | 
						|
				if _, err := io.Copy(&b, r.Reader()); err != nil {
 | 
						|
					t.Fatal(err)
 | 
						|
				}
 | 
						|
 | 
						|
				tmpl, err := Parse(b.String())
 | 
						|
				if err != nil {
 | 
						|
					t.Fatal(err)
 | 
						|
				}
 | 
						|
 | 
						|
				if tmpl.Tree.Root.String() == "" {
 | 
						|
					t.Errorf("empty %s template", k)
 | 
						|
				}
 | 
						|
			})
 | 
						|
		}
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestTemplate(t *testing.T) {
 | 
						|
	cases := make(map[string][]api.Message)
 | 
						|
	for _, mm := range [][]api.Message{
 | 
						|
		{
 | 
						|
			{Role: "user", Content: "Hello, how are you?"},
 | 
						|
		},
 | 
						|
		{
 | 
						|
			{Role: "user", Content: "Hello, how are you?"},
 | 
						|
			{Role: "assistant", Content: "I'm doing great. How can I help you today?"},
 | 
						|
			{Role: "user", Content: "I'd like to show off how chat templating works!"},
 | 
						|
		},
 | 
						|
		{
 | 
						|
			{Role: "system", Content: "You are a helpful assistant."},
 | 
						|
			{Role: "user", Content: "Hello, how are you?"},
 | 
						|
			{Role: "assistant", Content: "I'm doing great. How can I help you today?"},
 | 
						|
			{Role: "user", Content: "I'd like to show off how chat templating works!"},
 | 
						|
		},
 | 
						|
	} {
 | 
						|
		var roles []string
 | 
						|
		for _, m := range mm {
 | 
						|
			roles = append(roles, m.Role)
 | 
						|
		}
 | 
						|
 | 
						|
		cases[strings.Join(roles, "-")] = mm
 | 
						|
	}
 | 
						|
 | 
						|
	matches, err := filepath.Glob("*.gotmpl")
 | 
						|
	if err != nil {
 | 
						|
		t.Fatal(err)
 | 
						|
	}
 | 
						|
 | 
						|
	for _, match := range matches {
 | 
						|
		t.Run(match, func(t *testing.T) {
 | 
						|
			bts, err := os.ReadFile(match)
 | 
						|
			if err != nil {
 | 
						|
				t.Fatal(err)
 | 
						|
			}
 | 
						|
 | 
						|
			tmpl, err := Parse(string(bts))
 | 
						|
			if err != nil {
 | 
						|
				t.Fatal(err)
 | 
						|
			}
 | 
						|
 | 
						|
			for n, tt := range cases {
 | 
						|
				var actual bytes.Buffer
 | 
						|
				t.Run(n, func(t *testing.T) {
 | 
						|
					if err := tmpl.Execute(&actual, Values{Messages: tt}); err != nil {
 | 
						|
						t.Fatal(err)
 | 
						|
					}
 | 
						|
 | 
						|
					expect, err := os.ReadFile(filepath.Join("testdata", match, n))
 | 
						|
					if err != nil {
 | 
						|
						t.Fatal(err)
 | 
						|
					}
 | 
						|
 | 
						|
					bts := actual.Bytes()
 | 
						|
 | 
						|
					if slices.Contains([]string{"chatqa.gotmpl", "llama2-chat.gotmpl", "mistral-instruct.gotmpl", "openchat.gotmpl", "vicuna.gotmpl"}, match) && bts[len(bts)-1] == ' ' {
 | 
						|
						t.Log("removing trailing space from output")
 | 
						|
						bts = bts[:len(bts)-1]
 | 
						|
					}
 | 
						|
 | 
						|
					if diff := cmp.Diff(bts, expect); diff != "" {
 | 
						|
						t.Errorf("mismatch (-got +want):\n%s", diff)
 | 
						|
					}
 | 
						|
				})
 | 
						|
 | 
						|
				t.Run("legacy", func(t *testing.T) {
 | 
						|
					t.Skip("legacy outputs are currently default outputs")
 | 
						|
					var legacy bytes.Buffer
 | 
						|
					if err := tmpl.Execute(&legacy, Values{Messages: tt, forceLegacy: true}); err != nil {
 | 
						|
						t.Fatal(err)
 | 
						|
					}
 | 
						|
 | 
						|
					legacyBytes := legacy.Bytes()
 | 
						|
					if slices.Contains([]string{"chatqa.gotmpl", "openchat.gotmpl", "vicuna.gotmpl"}, match) && legacyBytes[len(legacyBytes)-1] == ' ' {
 | 
						|
						t.Log("removing trailing space from legacy output")
 | 
						|
						legacyBytes = legacyBytes[:len(legacyBytes)-1]
 | 
						|
					} else if slices.Contains([]string{"codellama-70b-instruct.gotmpl", "llama2-chat.gotmpl", "mistral-instruct.gotmpl"}, match) {
 | 
						|
						t.Skip("legacy outputs cannot be compared to messages outputs")
 | 
						|
					}
 | 
						|
 | 
						|
					if diff := cmp.Diff(legacyBytes, actual.Bytes()); diff != "" {
 | 
						|
						t.Errorf("mismatch (-got +want):\n%s", diff)
 | 
						|
					}
 | 
						|
				})
 | 
						|
			}
 | 
						|
		})
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestParse(t *testing.T) {
 | 
						|
	cases := []struct {
 | 
						|
		template string
 | 
						|
		vars     []string
 | 
						|
	}{
 | 
						|
		{"{{ .Prompt }}", []string{"prompt", "response"}},
 | 
						|
		{"{{ .System }} {{ .Prompt }}", []string{"prompt", "response", "system"}},
 | 
						|
		{"{{ .System }} {{ .Prompt }} {{ .Response }}", []string{"prompt", "response", "system"}},
 | 
						|
		{"{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}", []string{"prompt", "response", "system", "tools"}},
 | 
						|
		{"{{ range .Messages }}{{ .Role }} {{ .Content }}{{ end }}", []string{"content", "messages", "role"}},
 | 
						|
		{"{{ range .Messages }}{{ if eq .Role \"tool\" }}Tool Result: {{ .ToolName }} {{ .Content }}{{ end }}{{ end }}", []string{"content", "messages", "role", "toolname"}},
 | 
						|
		{`{{- range .Messages }}
 | 
						|
{{- if eq .Role "system" }}SYSTEM:
 | 
						|
{{- else if eq .Role "user" }}USER:
 | 
						|
{{- else if eq .Role "assistant" }}ASSISTANT:
 | 
						|
{{- else if eq .Role "tool" }}TOOL: 
 | 
						|
{{- end }} {{ .Content }}
 | 
						|
{{- end }}`, []string{"content", "messages", "role"}},
 | 
						|
		{`{{- if .Messages }}
 | 
						|
{{- range .Messages }}<|im_start|>{{ .Role }}
 | 
						|
{{ .Content }}<|im_end|>
 | 
						|
{{ end }}<|im_start|>assistant
 | 
						|
{{ else -}}
 | 
						|
{{ if .System }}<|im_start|>system
 | 
						|
{{ .System }}<|im_end|>
 | 
						|
{{ end }}{{ if .Prompt }}<|im_start|>user
 | 
						|
{{ .Prompt }}<|im_end|>
 | 
						|
{{ end }}<|im_start|>assistant
 | 
						|
{{ .Response }}<|im_end|>
 | 
						|
{{- end -}}`, []string{"content", "messages", "prompt", "response", "role", "system"}},
 | 
						|
	}
 | 
						|
 | 
						|
	for _, tt := range cases {
 | 
						|
		t.Run("", func(t *testing.T) {
 | 
						|
			tmpl, err := Parse(tt.template)
 | 
						|
			if err != nil {
 | 
						|
				t.Fatal(err)
 | 
						|
			}
 | 
						|
 | 
						|
			if diff := cmp.Diff(tmpl.Vars(), tt.vars); diff != "" {
 | 
						|
				t.Errorf("mismatch (-got +want):\n%s", diff)
 | 
						|
			}
 | 
						|
		})
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestExecuteWithMessages(t *testing.T) {
 | 
						|
	type template struct {
 | 
						|
		name     string
 | 
						|
		template string
 | 
						|
	}
 | 
						|
	cases := []struct {
 | 
						|
		name      string
 | 
						|
		templates []template
 | 
						|
		values    Values
 | 
						|
		expected  string
 | 
						|
	}{
 | 
						|
		{
 | 
						|
			"mistral",
 | 
						|
			[]template{
 | 
						|
				{"no response", `[INST] {{ if .System }}{{ .System }}
 | 
						|
 | 
						|
{{ end }}{{ .Prompt }}[/INST] `},
 | 
						|
				{"response", `[INST] {{ if .System }}{{ .System }}
 | 
						|
 | 
						|
{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
 | 
						|
				{"messages", `[INST] {{ if .System }}{{ .System }}
 | 
						|
 | 
						|
{{ end }}
 | 
						|
{{- range .Messages }}
 | 
						|
{{- if eq .Role "user" }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}[INST] {{ end }}
 | 
						|
{{- end }}`},
 | 
						|
			},
 | 
						|
			Values{
 | 
						|
				Messages: []api.Message{
 | 
						|
					{Role: "user", Content: "Hello friend!"},
 | 
						|
					{Role: "assistant", Content: "Hello human!"},
 | 
						|
					{Role: "user", Content: "What is your name?"},
 | 
						|
				},
 | 
						|
			},
 | 
						|
			`[INST] Hello friend![/INST] Hello human![INST] What is your name?[/INST] `,
 | 
						|
		},
 | 
						|
		{
 | 
						|
			"mistral system",
 | 
						|
			[]template{
 | 
						|
				{"no response", `[INST] {{ if .System }}{{ .System }}
 | 
						|
 | 
						|
{{ end }}{{ .Prompt }}[/INST] `},
 | 
						|
				{"response", `[INST] {{ if .System }}{{ .System }}
 | 
						|
 | 
						|
{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
 | 
						|
				{"messages", `[INST] {{ if .System }}{{ .System }}
 | 
						|
 | 
						|
{{ end }}
 | 
						|
{{- range .Messages }}
 | 
						|
{{- if eq .Role "user" }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}[INST] {{ end }}
 | 
						|
{{- end }}`},
 | 
						|
			},
 | 
						|
			Values{
 | 
						|
				Messages: []api.Message{
 | 
						|
					{Role: "system", Content: "You are a helpful assistant!"},
 | 
						|
					{Role: "user", Content: "Hello friend!"},
 | 
						|
					{Role: "assistant", Content: "Hello human!"},
 | 
						|
					{Role: "user", Content: "What is your name?"},
 | 
						|
				},
 | 
						|
			},
 | 
						|
			`[INST] You are a helpful assistant!
 | 
						|
 | 
						|
Hello friend![/INST] Hello human![INST] What is your name?[/INST] `,
 | 
						|
		},
 | 
						|
		{
 | 
						|
			"mistral assistant",
 | 
						|
			[]template{
 | 
						|
				{"no response", `[INST] {{ .Prompt }}[/INST] `},
 | 
						|
				{"response", `[INST] {{ .Prompt }}[/INST] {{ .Response }}`},
 | 
						|
				{"messages", `
 | 
						|
{{- range $i, $m := .Messages }}
 | 
						|
{{- if eq .Role "user" }}[INST] {{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}{{ end }}
 | 
						|
{{- end }}`},
 | 
						|
			},
 | 
						|
			Values{
 | 
						|
				Messages: []api.Message{
 | 
						|
					{Role: "user", Content: "Hello friend!"},
 | 
						|
					{Role: "assistant", Content: "Hello human!"},
 | 
						|
					{Role: "user", Content: "What is your name?"},
 | 
						|
					{Role: "assistant", Content: "My name is Ollama and I"},
 | 
						|
				},
 | 
						|
			},
 | 
						|
			`[INST] Hello friend![/INST] Hello human![INST] What is your name?[/INST] My name is Ollama and I`,
 | 
						|
		},
 | 
						|
		{
 | 
						|
			"chatml",
 | 
						|
			[]template{
 | 
						|
				// this does not have a "no response" test because it's impossible to render the same output
 | 
						|
				{"response", `{{ if .System }}<|im_start|>system
 | 
						|
{{ .System }}<|im_end|>
 | 
						|
{{ end }}{{ if .Prompt }}<|im_start|>user
 | 
						|
{{ .Prompt }}<|im_end|>
 | 
						|
{{ end }}<|im_start|>assistant
 | 
						|
{{ .Response }}<|im_end|>
 | 
						|
`},
 | 
						|
				{"messages", `
 | 
						|
{{- range $index, $_ := .Messages }}<|im_start|>{{ .Role }}
 | 
						|
{{ .Content }}<|im_end|>
 | 
						|
{{ end }}<|im_start|>assistant
 | 
						|
`},
 | 
						|
			},
 | 
						|
			Values{
 | 
						|
				Messages: []api.Message{
 | 
						|
					{Role: "system", Content: "You are a helpful assistant!"},
 | 
						|
					{Role: "user", Content: "Hello friend!"},
 | 
						|
					{Role: "assistant", Content: "Hello human!"},
 | 
						|
					{Role: "user", Content: "What is your name?"},
 | 
						|
				},
 | 
						|
			},
 | 
						|
			`<|im_start|>system
 | 
						|
You are a helpful assistant!<|im_end|>
 | 
						|
<|im_start|>user
 | 
						|
Hello friend!<|im_end|>
 | 
						|
<|im_start|>assistant
 | 
						|
Hello human!<|im_end|>
 | 
						|
<|im_start|>user
 | 
						|
What is your name?<|im_end|>
 | 
						|
<|im_start|>assistant
 | 
						|
`,
 | 
						|
		},
 | 
						|
	}
 | 
						|
 | 
						|
	for _, tt := range cases {
 | 
						|
		t.Run(tt.name, func(t *testing.T) {
 | 
						|
			for _, ttt := range tt.templates {
 | 
						|
				t.Run(ttt.name, func(t *testing.T) {
 | 
						|
					tmpl, err := Parse(ttt.template)
 | 
						|
					if err != nil {
 | 
						|
						t.Fatal(err)
 | 
						|
					}
 | 
						|
 | 
						|
					var b bytes.Buffer
 | 
						|
					if err := tmpl.Execute(&b, tt.values); err != nil {
 | 
						|
						t.Fatal(err)
 | 
						|
					}
 | 
						|
 | 
						|
					if diff := cmp.Diff(b.String(), tt.expected); diff != "" {
 | 
						|
						t.Errorf("mismatch (-got +want):\n%s", diff)
 | 
						|
					}
 | 
						|
				})
 | 
						|
			}
 | 
						|
		})
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestExecuteWithSuffix(t *testing.T) {
 | 
						|
	tmpl, err := Parse(`{{- if .Suffix }}<PRE> {{ .Prompt }} <SUF>{{ .Suffix }} <MID>
 | 
						|
{{- else }}{{ .Prompt }}
 | 
						|
{{- end }}`)
 | 
						|
	if err != nil {
 | 
						|
		t.Fatal(err)
 | 
						|
	}
 | 
						|
 | 
						|
	cases := []struct {
 | 
						|
		name   string
 | 
						|
		values Values
 | 
						|
		expect string
 | 
						|
	}{
 | 
						|
		{
 | 
						|
			"message", Values{Messages: []api.Message{{Role: "user", Content: "hello"}}}, "hello",
 | 
						|
		},
 | 
						|
		{
 | 
						|
			"prompt suffix", Values{Prompt: "def add(", Suffix: "return x"}, "<PRE> def add( <SUF>return x <MID>",
 | 
						|
		},
 | 
						|
	}
 | 
						|
 | 
						|
	for _, tt := range cases {
 | 
						|
		t.Run(tt.name, func(t *testing.T) {
 | 
						|
			var b bytes.Buffer
 | 
						|
			if err := tmpl.Execute(&b, tt.values); err != nil {
 | 
						|
				t.Fatal(err)
 | 
						|
			}
 | 
						|
 | 
						|
			if diff := cmp.Diff(b.String(), tt.expect); diff != "" {
 | 
						|
				t.Errorf("mismatch (-got +want):\n%s", diff)
 | 
						|
			}
 | 
						|
		})
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestCollate(t *testing.T) {
 | 
						|
	cases := []struct {
 | 
						|
		name     string
 | 
						|
		msgs     []api.Message
 | 
						|
		expected []*api.Message
 | 
						|
		system   string
 | 
						|
	}{
 | 
						|
		{
 | 
						|
			name: "consecutive user messages are merged",
 | 
						|
			msgs: []api.Message{
 | 
						|
				{Role: "user", Content: "Hello"},
 | 
						|
				{Role: "user", Content: "How are you?"},
 | 
						|
			},
 | 
						|
			expected: []*api.Message{
 | 
						|
				{Role: "user", Content: "Hello\n\nHow are you?"},
 | 
						|
			},
 | 
						|
			system: "",
 | 
						|
		},
 | 
						|
		{
 | 
						|
			name: "consecutive tool messages are NOT merged",
 | 
						|
			msgs: []api.Message{
 | 
						|
				{Role: "tool", Content: "sunny", ToolName: "get_weather"},
 | 
						|
				{Role: "tool", Content: "72F", ToolName: "get_temperature"},
 | 
						|
			},
 | 
						|
			expected: []*api.Message{
 | 
						|
				{Role: "tool", Content: "sunny", ToolName: "get_weather"},
 | 
						|
				{Role: "tool", Content: "72F", ToolName: "get_temperature"},
 | 
						|
			},
 | 
						|
			system: "",
 | 
						|
		},
 | 
						|
		{
 | 
						|
			name: "tool messages preserve all fields",
 | 
						|
			msgs: []api.Message{
 | 
						|
				{Role: "user", Content: "What's the weather?"},
 | 
						|
				{Role: "tool", Content: "sunny", ToolName: "get_conditions"},
 | 
						|
				{Role: "tool", Content: "72F", ToolName: "get_temperature"},
 | 
						|
			},
 | 
						|
			expected: []*api.Message{
 | 
						|
				{Role: "user", Content: "What's the weather?"},
 | 
						|
				{Role: "tool", Content: "sunny", ToolName: "get_conditions"},
 | 
						|
				{Role: "tool", Content: "72F", ToolName: "get_temperature"},
 | 
						|
			},
 | 
						|
			system: "",
 | 
						|
		},
 | 
						|
		{
 | 
						|
			name: "mixed messages with system",
 | 
						|
			msgs: []api.Message{
 | 
						|
				{Role: "system", Content: "You are helpful"},
 | 
						|
				{Role: "user", Content: "Hello"},
 | 
						|
				{Role: "assistant", Content: "Hi there!"},
 | 
						|
				{Role: "user", Content: "What's the weather?"},
 | 
						|
				{Role: "tool", Content: "sunny", ToolName: "get_weather"},
 | 
						|
				{Role: "tool", Content: "72F", ToolName: "get_temperature"},
 | 
						|
				{Role: "user", Content: "Thanks"},
 | 
						|
			},
 | 
						|
			expected: []*api.Message{
 | 
						|
				{Role: "system", Content: "You are helpful"},
 | 
						|
				{Role: "user", Content: "Hello"},
 | 
						|
				{Role: "assistant", Content: "Hi there!"},
 | 
						|
				{Role: "user", Content: "What's the weather?"},
 | 
						|
				{Role: "tool", Content: "sunny", ToolName: "get_weather"},
 | 
						|
				{Role: "tool", Content: "72F", ToolName: "get_temperature"},
 | 
						|
				{Role: "user", Content: "Thanks"},
 | 
						|
			},
 | 
						|
			system: "You are helpful",
 | 
						|
		},
 | 
						|
	}
 | 
						|
 | 
						|
	for _, tt := range cases {
 | 
						|
		t.Run(tt.name, func(t *testing.T) {
 | 
						|
			system, collated := collate(tt.msgs)
 | 
						|
			if diff := cmp.Diff(system, tt.system); diff != "" {
 | 
						|
				t.Errorf("system mismatch (-got +want):\n%s", diff)
 | 
						|
			}
 | 
						|
 | 
						|
			// Compare the messages
 | 
						|
			if len(collated) != len(tt.expected) {
 | 
						|
				t.Errorf("expected %d messages, got %d", len(tt.expected), len(collated))
 | 
						|
				return
 | 
						|
			}
 | 
						|
 | 
						|
			for i := range collated {
 | 
						|
				if collated[i].Role != tt.expected[i].Role {
 | 
						|
					t.Errorf("message %d role mismatch: got %q, want %q", i, collated[i].Role, tt.expected[i].Role)
 | 
						|
				}
 | 
						|
				if collated[i].Content != tt.expected[i].Content {
 | 
						|
					t.Errorf("message %d content mismatch: got %q, want %q", i, collated[i].Content, tt.expected[i].Content)
 | 
						|
				}
 | 
						|
				if collated[i].ToolName != tt.expected[i].ToolName {
 | 
						|
					t.Errorf("message %d tool name mismatch: got %q, want %q", i, collated[i].ToolName, tt.expected[i].ToolName)
 | 
						|
				}
 | 
						|
			}
 | 
						|
		})
 | 
						|
	}
 | 
						|
}
 |