2024-06-11 05:54:42 +08:00
package template
import (
"bufio"
"bytes"
"encoding/json"
"io"
"os"
"path/filepath"
"slices"
2024-06-28 05:15:17 +08:00
"strings"
2024-06-11 05:54:42 +08:00
"testing"
2024-06-28 05:15:17 +08:00
"github.com/google/go-cmp/cmp"
2024-08-02 05:52:15 +08:00
2024-06-18 01:38:55 +08:00
"github.com/ollama/ollama/api"
2025-02-14 08:31:21 +08:00
"github.com/ollama/ollama/fs/ggml"
2024-06-11 05:54:42 +08:00
)
func TestNamed ( t * testing . T ) {
f , err := os . Open ( filepath . Join ( "testdata" , "templates.jsonl" ) )
if err != nil {
t . Fatal ( err )
}
defer f . Close ( )
scanner := bufio . NewScanner ( f )
for scanner . Scan ( ) {
var ss map [ string ] string
if err := json . Unmarshal ( scanner . Bytes ( ) , & ss ) ; err != nil {
t . Fatal ( err )
}
for k , v := range ss {
t . Run ( k , func ( t * testing . T ) {
2025-02-14 08:31:21 +08:00
kv := ggml . KV { "tokenizer.chat_template" : v }
2024-06-11 05:54:42 +08:00
s := kv . ChatTemplate ( )
r , err := Named ( s )
if err != nil {
t . Fatal ( err )
}
if r . Name != k {
t . Errorf ( "expected %q, got %q" , k , r . Name )
}
var b bytes . Buffer
if _ , err := io . Copy ( & b , r . Reader ( ) ) ; err != nil {
t . Fatal ( err )
}
2024-06-28 05:15:17 +08:00
tmpl , err := Parse ( b . String ( ) )
2024-06-11 05:54:42 +08:00
if err != nil {
t . Fatal ( err )
}
if tmpl . Tree . Root . String ( ) == "" {
t . Errorf ( "empty %s template" , k )
}
} )
}
}
}
2024-06-28 05:15:17 +08:00
func TestTemplate ( t * testing . T ) {
cases := make ( map [ string ] [ ] api . Message )
for _ , mm := range [ ] [ ] api . Message {
{
{ Role : "user" , Content : "Hello, how are you?" } ,
} ,
{
{ Role : "user" , Content : "Hello, how are you?" } ,
{ Role : "assistant" , Content : "I'm doing great. How can I help you today?" } ,
{ Role : "user" , Content : "I'd like to show off how chat templating works!" } ,
} ,
{
{ Role : "system" , Content : "You are a helpful assistant." } ,
{ Role : "user" , Content : "Hello, how are you?" } ,
{ Role : "assistant" , Content : "I'm doing great. How can I help you today?" } ,
{ Role : "user" , Content : "I'd like to show off how chat templating works!" } ,
} ,
} {
var roles [ ] string
for _ , m := range mm {
roles = append ( roles , m . Role )
}
cases [ strings . Join ( roles , "-" ) ] = mm
}
matches , err := filepath . Glob ( "*.gotmpl" )
if err != nil {
t . Fatal ( err )
}
for _ , match := range matches {
t . Run ( match , func ( t * testing . T ) {
bts , err := os . ReadFile ( match )
if err != nil {
t . Fatal ( err )
}
tmpl , err := Parse ( string ( bts ) )
if err != nil {
t . Fatal ( err )
}
for n , tt := range cases {
2024-07-11 02:00:07 +08:00
var actual bytes . Buffer
2024-06-28 05:15:17 +08:00
t . Run ( n , func ( t * testing . T ) {
if err := tmpl . Execute ( & actual , Values { Messages : tt } ) ; err != nil {
t . Fatal ( err )
}
expect , err := os . ReadFile ( filepath . Join ( "testdata" , match , n ) )
if err != nil {
t . Fatal ( err )
}
2024-07-12 04:11:40 +08:00
bts := actual . Bytes ( )
if slices . Contains ( [ ] string { "chatqa.gotmpl" , "llama2-chat.gotmpl" , "mistral-instruct.gotmpl" , "openchat.gotmpl" , "vicuna.gotmpl" } , match ) && bts [ len ( bts ) - 1 ] == ' ' {
t . Log ( "removing trailing space from output" )
bts = bts [ : len ( bts ) - 1 ]
}
if diff := cmp . Diff ( bts , expect ) ; diff != "" {
2024-06-28 05:15:17 +08:00
t . Errorf ( "mismatch (-got +want):\n%s" , diff )
}
} )
2024-07-11 02:00:07 +08:00
t . Run ( "legacy" , func ( t * testing . T ) {
2024-07-12 04:10:13 +08:00
t . Skip ( "legacy outputs are currently default outputs" )
2024-07-11 02:00:07 +08:00
var legacy bytes . Buffer
if err := tmpl . Execute ( & legacy , Values { Messages : tt , forceLegacy : true } ) ; err != nil {
t . Fatal ( err )
}
legacyBytes := legacy . Bytes ( )
if slices . Contains ( [ ] string { "chatqa.gotmpl" , "openchat.gotmpl" , "vicuna.gotmpl" } , match ) && legacyBytes [ len ( legacyBytes ) - 1 ] == ' ' {
t . Log ( "removing trailing space from legacy output" )
legacyBytes = legacyBytes [ : len ( legacyBytes ) - 1 ]
} else if slices . Contains ( [ ] string { "codellama-70b-instruct.gotmpl" , "llama2-chat.gotmpl" , "mistral-instruct.gotmpl" } , match ) {
t . Skip ( "legacy outputs cannot be compared to messages outputs" )
}
if diff := cmp . Diff ( legacyBytes , actual . Bytes ( ) ) ; diff != "" {
t . Errorf ( "mismatch (-got +want):\n%s" , diff )
}
} )
2024-06-28 05:15:17 +08:00
}
} )
}
}
2024-06-11 05:54:42 +08:00
func TestParse ( t * testing . T ) {
cases := [ ] struct {
2024-06-12 05:03:42 +08:00
template string
vars [ ] string
2024-06-11 05:54:42 +08:00
} {
2024-06-18 01:38:55 +08:00
{ "{{ .Prompt }}" , [ ] string { "prompt" , "response" } } ,
{ "{{ .System }} {{ .Prompt }}" , [ ] string { "prompt" , "response" , "system" } } ,
2024-06-11 05:54:42 +08:00
{ "{{ .System }} {{ .Prompt }} {{ .Response }}" , [ ] string { "prompt" , "response" , "system" } } ,
2024-06-18 01:38:55 +08:00
{ "{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}" , [ ] string { "prompt" , "response" , "system" , "tools" } } ,
2024-06-11 05:54:42 +08:00
{ "{{ range .Messages }}{{ .Role }} {{ .Content }}{{ end }}" , [ ] string { "content" , "messages" , "role" } } ,
2025-07-08 06:53:42 +08:00
{ "{{ range .Messages }}{{ if eq .Role \"tool\" }}Tool Result: {{ .ToolName }} {{ .Content }}{{ end }}{{ end }}" , [ ] string { "content" , "messages" , "role" , "toolname" } } ,
2024-07-12 04:10:13 +08:00
{ ` { { - range . Messages } }
{ { - if eq . Role "system" } } SYSTEM :
{ { - else if eq . Role "user" } } USER :
{ { - else if eq . Role "assistant" } } ASSISTANT :
2025-07-08 06:53:42 +08:00
{ { - else if eq . Role "tool" } } TOOL :
2024-07-12 04:10:13 +08:00
{ { - end } } { { . Content } }
{ { - end } } ` , [ ] string { "content" , "messages" , "role" } } ,
2024-07-11 02:00:07 +08:00
{ ` { { - if . Messages } }
{ { - range . Messages } } < | im_start | > { { . Role } }
{ { . Content } } < | im_end | >
{ { end } } < | im_start | > assistant
{ { else - } }
{ { if . System } } < | im_start | > system
{ { . System } } < | im_end | >
{ { end } } { { if . Prompt } } < | im_start | > user
{ { . Prompt } } < | im_end | >
{ { end } } < | im_start | > assistant
{ { . Response } } < | im_end | >
{ { - end - } } ` , [ ] string { "content" , "messages" , "prompt" , "response" , "role" , "system" } } ,
2024-06-11 05:54:42 +08:00
}
for _ , tt := range cases {
t . Run ( "" , func ( t * testing . T ) {
tmpl , err := Parse ( tt . template )
if err != nil {
t . Fatal ( err )
}
2024-07-11 02:00:07 +08:00
if diff := cmp . Diff ( tmpl . Vars ( ) , tt . vars ) ; diff != "" {
t . Errorf ( "mismatch (-got +want):\n%s" , diff )
2024-06-11 05:54:42 +08:00
}
} )
}
}
2024-06-18 01:38:55 +08:00
func TestExecuteWithMessages ( t * testing . T ) {
2024-06-21 02:00:08 +08:00
type template struct {
name string
template string
}
2024-06-18 01:38:55 +08:00
cases := [ ] struct {
2024-06-21 02:00:08 +08:00
name string
templates [ ] template
2024-06-18 01:38:55 +08:00
values Values
expected string
} {
{
2024-06-21 02:00:08 +08:00
"mistral" ,
[ ] template {
2024-07-12 04:11:40 +08:00
{ "no response" , ` [ INST ] { { if . System } } { { . System } }
{ { end } } { { . Prompt } } [ / INST ] ` } ,
{ "response" , ` [ INST ] { { if . System } } { { . System } }
{ { end } } { { . Prompt } } [ / INST ] { { . Response } } ` } ,
2024-07-13 02:48:06 +08:00
{ "messages" , ` [ INST ] { { if . System } } { { . System } }
2024-07-12 04:11:40 +08:00
2024-07-13 02:48:06 +08:00
{ { end } }
{ { - range . Messages } }
{ { - if eq . Role "user" } } { { . Content } } [ / INST ] { { else if eq . Role "assistant" } } { { . Content } } [ INST ] { { end } }
2024-06-21 02:00:08 +08:00
{ { - end } } ` } ,
2024-06-18 01:38:55 +08:00
} ,
Values {
Messages : [ ] api . Message {
{ Role : "user" , Content : "Hello friend!" } ,
{ Role : "assistant" , Content : "Hello human!" } ,
2024-06-21 02:00:08 +08:00
{ Role : "user" , Content : "What is your name?" } ,
2024-06-18 01:38:55 +08:00
} ,
} ,
2024-06-21 02:00:08 +08:00
` [INST] Hello friend![/INST] Hello human![INST] What is your name?[/INST] ` ,
2024-06-18 01:38:55 +08:00
} ,
{
2024-06-21 02:00:08 +08:00
"mistral system" ,
[ ] template {
2024-07-12 04:11:40 +08:00
{ "no response" , ` [ INST ] { { if . System } } { { . System } }
{ { end } } { { . Prompt } } [ / INST ] ` } ,
{ "response" , ` [ INST ] { { if . System } } { { . System } }
{ { end } } { { . Prompt } } [ / INST ] { { . Response } } ` } ,
2024-07-13 02:48:06 +08:00
{ "messages" , ` [ INST ] { { if . System } } { { . System } }
2024-07-12 04:11:40 +08:00
2024-07-13 02:48:06 +08:00
{ { end } }
{ { - range . Messages } }
{ { - if eq . Role "user" } } { { . Content } } [ / INST ] { { else if eq . Role "assistant" } } { { . Content } } [ INST ] { { end } }
2024-06-21 02:00:08 +08:00
{ { - end } } ` } ,
2024-06-18 01:38:55 +08:00
} ,
Values {
Messages : [ ] api . Message {
{ Role : "system" , Content : "You are a helpful assistant!" } ,
{ Role : "user" , Content : "Hello friend!" } ,
{ Role : "assistant" , Content : "Hello human!" } ,
2024-06-21 02:00:08 +08:00
{ Role : "user" , Content : "What is your name?" } ,
2024-06-18 01:38:55 +08:00
} ,
} ,
2024-07-11 02:00:07 +08:00
` [ INST ] You are a helpful assistant !
2024-06-18 01:38:55 +08:00
2024-07-11 02:00:07 +08:00
Hello friend ! [ / INST ] Hello human ! [ INST ] What is your name ? [ / INST ] ` ,
2024-06-18 01:38:55 +08:00
} ,
2024-07-20 11:19:26 +08:00
{
"mistral assistant" ,
[ ] template {
{ "no response" , ` [INST] {{ .Prompt }} [/INST] ` } ,
{ "response" , ` [INST] {{ .Prompt }} [/INST] {{ .Response }} ` } ,
{ "messages" , `
{ { - range $ i , $ m := . Messages } }
{ { - if eq . Role "user" } } [ INST ] { { . Content } } [ / INST ] { { else if eq . Role "assistant" } } { { . Content } } { { end } }
{ { - end } } ` } ,
} ,
Values {
Messages : [ ] api . Message {
{ Role : "user" , Content : "Hello friend!" } ,
{ Role : "assistant" , Content : "Hello human!" } ,
{ Role : "user" , Content : "What is your name?" } ,
{ Role : "assistant" , Content : "My name is Ollama and I" } ,
} ,
} ,
` [INST] Hello friend![/INST] Hello human![INST] What is your name?[/INST] My name is Ollama and I ` ,
} ,
2024-06-18 01:38:55 +08:00
{
2024-06-21 02:00:08 +08:00
"chatml" ,
[ ] template {
// this does not have a "no response" test because it's impossible to render the same output
{ "response" , ` { { if . System } } < | im_start | > system
2024-06-18 01:38:55 +08:00
{ { . System } } < | im_end | >
{ { end } } { { if . Prompt } } < | im_start | > user
{ { . Prompt } } < | im_end | >
{ { end } } < | im_start | > assistant
{ { . Response } } < | im_end | >
2024-06-21 02:00:08 +08:00
` } ,
{ "messages" , `
2024-07-12 04:11:40 +08:00
{ { - range $ index , $ _ := . Messages } } < | im_start | > { { . Role } }
{ { . Content } } < | im_end | >
{ { end } } < | im_start | > assistant
2024-06-21 02:00:08 +08:00
` } ,
2024-06-18 01:38:55 +08:00
} ,
Values {
Messages : [ ] api . Message {
{ Role : "system" , Content : "You are a helpful assistant!" } ,
{ Role : "user" , Content : "Hello friend!" } ,
{ Role : "assistant" , Content : "Hello human!" } ,
2024-06-21 02:00:08 +08:00
{ Role : "user" , Content : "What is your name?" } ,
2024-06-18 01:38:55 +08:00
} ,
} ,
2024-07-11 02:00:07 +08:00
` < | im_start | > system
You are a helpful assistant ! < | im_end | >
< | im_start | > user
2024-06-18 01:38:55 +08:00
Hello friend ! < | im_end | >
< | im_start | > assistant
Hello human ! < | im_end | >
< | im_start | > user
2024-06-21 02:00:08 +08:00
What is your name ? < | im_end | >
2024-06-18 01:38:55 +08:00
< | im_start | > assistant
` ,
} ,
}
for _ , tt := range cases {
2024-06-21 02:00:08 +08:00
t . Run ( tt . name , func ( t * testing . T ) {
for _ , ttt := range tt . templates {
t . Run ( ttt . name , func ( t * testing . T ) {
tmpl , err := Parse ( ttt . template )
2024-06-18 01:38:55 +08:00
if err != nil {
t . Fatal ( err )
}
var b bytes . Buffer
if err := tmpl . Execute ( & b , tt . values ) ; err != nil {
t . Fatal ( err )
}
2024-07-11 02:00:07 +08:00
if diff := cmp . Diff ( b . String ( ) , tt . expected ) ; diff != "" {
t . Errorf ( "mismatch (-got +want):\n%s" , diff )
2024-06-18 01:38:55 +08:00
}
} )
}
} )
}
}
2024-06-21 10:13:36 +08:00
func TestExecuteWithSuffix ( t * testing . T ) {
tmpl , err := Parse ( ` { { - if . Suffix } } < PRE > { { . Prompt } } < SUF > { { . Suffix } } < MID >
{ { - else } } { { . Prompt } }
{ { - end } } ` )
if err != nil {
t . Fatal ( err )
}
cases := [ ] struct {
name string
values Values
expect string
} {
{
"message" , Values { Messages : [ ] api . Message { { Role : "user" , Content : "hello" } } } , "hello" ,
} ,
{
"prompt suffix" , Values { Prompt : "def add(" , Suffix : "return x" } , "<PRE> def add( <SUF>return x <MID>" ,
} ,
}
for _ , tt := range cases {
t . Run ( tt . name , func ( t * testing . T ) {
var b bytes . Buffer
if err := tmpl . Execute ( & b , tt . values ) ; err != nil {
t . Fatal ( err )
}
if diff := cmp . Diff ( b . String ( ) , tt . expect ) ; diff != "" {
t . Errorf ( "mismatch (-got +want):\n%s" , diff )
}
} )
}
}
2025-07-08 06:53:42 +08:00
func TestCollate ( t * testing . T ) {
cases := [ ] struct {
name string
msgs [ ] api . Message
expected [ ] * api . Message
system string
} {
{
name : "consecutive user messages are merged" ,
msgs : [ ] api . Message {
{ Role : "user" , Content : "Hello" } ,
{ Role : "user" , Content : "How are you?" } ,
} ,
expected : [ ] * api . Message {
{ Role : "user" , Content : "Hello\n\nHow are you?" } ,
} ,
system : "" ,
} ,
{
name : "consecutive tool messages are NOT merged" ,
msgs : [ ] api . Message {
{ Role : "tool" , Content : "sunny" , ToolName : "get_weather" } ,
{ Role : "tool" , Content : "72F" , ToolName : "get_temperature" } ,
} ,
expected : [ ] * api . Message {
{ Role : "tool" , Content : "sunny" , ToolName : "get_weather" } ,
{ Role : "tool" , Content : "72F" , ToolName : "get_temperature" } ,
} ,
system : "" ,
} ,
{
name : "tool messages preserve all fields" ,
msgs : [ ] api . Message {
{ Role : "user" , Content : "What's the weather?" } ,
{ Role : "tool" , Content : "sunny" , ToolName : "get_conditions" } ,
{ Role : "tool" , Content : "72F" , ToolName : "get_temperature" } ,
} ,
expected : [ ] * api . Message {
{ Role : "user" , Content : "What's the weather?" } ,
{ Role : "tool" , Content : "sunny" , ToolName : "get_conditions" } ,
{ Role : "tool" , Content : "72F" , ToolName : "get_temperature" } ,
} ,
system : "" ,
} ,
{
name : "mixed messages with system" ,
msgs : [ ] api . Message {
{ Role : "system" , Content : "You are helpful" } ,
{ Role : "user" , Content : "Hello" } ,
{ Role : "assistant" , Content : "Hi there!" } ,
{ Role : "user" , Content : "What's the weather?" } ,
{ Role : "tool" , Content : "sunny" , ToolName : "get_weather" } ,
{ Role : "tool" , Content : "72F" , ToolName : "get_temperature" } ,
{ Role : "user" , Content : "Thanks" } ,
} ,
expected : [ ] * api . Message {
{ Role : "system" , Content : "You are helpful" } ,
{ Role : "user" , Content : "Hello" } ,
{ Role : "assistant" , Content : "Hi there!" } ,
{ Role : "user" , Content : "What's the weather?" } ,
{ Role : "tool" , Content : "sunny" , ToolName : "get_weather" } ,
{ Role : "tool" , Content : "72F" , ToolName : "get_temperature" } ,
{ Role : "user" , Content : "Thanks" } ,
} ,
system : "You are helpful" ,
} ,
}
for _ , tt := range cases {
t . Run ( tt . name , func ( t * testing . T ) {
system , collated := collate ( tt . msgs )
if diff := cmp . Diff ( system , tt . system ) ; diff != "" {
t . Errorf ( "system mismatch (-got +want):\n%s" , diff )
}
// Compare the messages
if len ( collated ) != len ( tt . expected ) {
t . Errorf ( "expected %d messages, got %d" , len ( tt . expected ) , len ( collated ) )
return
}
for i := range collated {
if collated [ i ] . Role != tt . expected [ i ] . Role {
t . Errorf ( "message %d role mismatch: got %q, want %q" , i , collated [ i ] . Role , tt . expected [ i ] . Role )
}
if collated [ i ] . Content != tt . expected [ i ] . Content {
t . Errorf ( "message %d content mismatch: got %q, want %q" , i , collated [ i ] . Content , tt . expected [ i ] . Content )
}
if collated [ i ] . ToolName != tt . expected [ i ] . ToolName {
t . Errorf ( "message %d tool name mismatch: got %q, want %q" , i , collated [ i ] . ToolName , tt . expected [ i ] . ToolName )
}
}
} )
}
}