2025-10-06 05:18:56 +08:00
package middleware
import (
"bytes"
"encoding/base64"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"reflect"
"strings"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/google/go-cmp/cmp"
2025-10-08 06:38:58 +08:00
"github.com/google/go-cmp/cmp/cmpopts"
2025-10-06 05:18:56 +08:00
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/openai"
)
const (
prefix = ` data:image/jpeg;base64, `
image = ` iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII= `
)
var (
False = false
True = true
)
2025-10-08 06:38:58 +08:00
func makeArgs ( pairs ... any ) api . ToolCallFunctionArguments {
args := api . NewToolCallFunctionArguments ( )
for i := 0 ; i < len ( pairs ) ; i += 2 {
key := pairs [ i ] . ( string )
value := pairs [ i + 1 ]
args . Set ( key , value )
}
return args
}
2025-10-06 05:18:56 +08:00
func captureRequestMiddleware ( capturedRequest any ) gin . HandlerFunc {
return func ( c * gin . Context ) {
bodyBytes , _ := io . ReadAll ( c . Request . Body )
c . Request . Body = io . NopCloser ( bytes . NewReader ( bodyBytes ) )
err := json . Unmarshal ( bodyBytes , capturedRequest )
if err != nil {
c . AbortWithStatusJSON ( http . StatusInternalServerError , "failed to unmarshal request" )
}
c . Next ( )
}
}
func TestChatMiddleware ( t * testing . T ) {
type testCase struct {
name string
body string
req api . ChatRequest
err openai . ErrorResponse
}
var capturedRequest * api . ChatRequest
testCases := [ ] testCase {
{
name : "chat handler" ,
body : ` {
"model" : "test-model" ,
"messages" : [
{ "role" : "user" , "content" : "Hello" }
]
} ` ,
req : api . ChatRequest {
Model : "test-model" ,
Messages : [ ] api . Message {
{
Role : "user" ,
Content : "Hello" ,
} ,
} ,
Options : map [ string ] any {
"temperature" : 1.0 ,
"top_p" : 1.0 ,
} ,
Stream : & False ,
} ,
} ,
{
name : "chat handler with options" ,
body : ` {
"model" : "test-model" ,
"messages" : [
{ "role" : "user" , "content" : "Hello" }
] ,
"stream" : true ,
"max_tokens" : 999 ,
"seed" : 123 ,
"stop" : [ "\n" , "stop" ] ,
"temperature" : 3.0 ,
"frequency_penalty" : 4.0 ,
"presence_penalty" : 5.0 ,
"top_p" : 6.0 ,
"response_format" : { "type" : "json_object" }
} ` ,
req : api . ChatRequest {
Model : "test-model" ,
Messages : [ ] api . Message {
{
Role : "user" ,
Content : "Hello" ,
} ,
} ,
Options : map [ string ] any {
"num_predict" : 999.0 , // float because JSON doesn't distinguish between float and int
"seed" : 123.0 ,
"stop" : [ ] any { "\n" , "stop" } ,
"temperature" : 3.0 ,
"frequency_penalty" : 4.0 ,
"presence_penalty" : 5.0 ,
"top_p" : 6.0 ,
} ,
Format : json . RawMessage ( ` "json" ` ) ,
Stream : & True ,
} ,
} ,
{
name : "chat handler with streaming usage" ,
body : ` {
"model" : "test-model" ,
"messages" : [
{ "role" : "user" , "content" : "Hello" }
] ,
"stream" : true ,
"stream_options" : { "include_usage" : true } ,
"max_tokens" : 999 ,
"seed" : 123 ,
"stop" : [ "\n" , "stop" ] ,
"temperature" : 3.0 ,
"frequency_penalty" : 4.0 ,
"presence_penalty" : 5.0 ,
"top_p" : 6.0 ,
"response_format" : { "type" : "json_object" }
} ` ,
req : api . ChatRequest {
Model : "test-model" ,
Messages : [ ] api . Message {
{
Role : "user" ,
Content : "Hello" ,
} ,
} ,
Options : map [ string ] any {
"num_predict" : 999.0 , // float because JSON doesn't distinguish between float and int
"seed" : 123.0 ,
"stop" : [ ] any { "\n" , "stop" } ,
"temperature" : 3.0 ,
"frequency_penalty" : 4.0 ,
"presence_penalty" : 5.0 ,
"top_p" : 6.0 ,
} ,
Format : json . RawMessage ( ` "json" ` ) ,
Stream : & True ,
} ,
} ,
{
name : "chat handler with image content" ,
body : ` {
"model" : "test-model" ,
"messages" : [
{
"role" : "user" ,
"content" : [
{
"type" : "text" ,
"text" : "Hello"
} ,
{
"type" : "image_url" ,
"image_url" : {
"url" : "` + prefix + image + `"
}
}
]
}
]
} ` ,
req : api . ChatRequest {
Model : "test-model" ,
Messages : [ ] api . Message {
{
Role : "user" ,
Content : "Hello" ,
} ,
{
Role : "user" ,
Images : [ ] api . ImageData {
func ( ) [ ] byte {
img , _ := base64 . StdEncoding . DecodeString ( image )
return img
} ( ) ,
} ,
} ,
} ,
Options : map [ string ] any {
"temperature" : 1.0 ,
"top_p" : 1.0 ,
} ,
Stream : & False ,
} ,
} ,
{
name : "chat handler with tools" ,
body : ` {
"model" : "test-model" ,
"messages" : [
{ "role" : "user" , "content" : "What's the weather like in Paris Today?" } ,
{ "role" : "assistant" , "tool_calls" : [ { "id" : "id" , "type" : "function" , "function" : { "name" : "get_current_weather" , "arguments" : "{\"location\": \"Paris, France\", \"format\": \"celsius\"}" } } ] }
]
} ` ,
req : api . ChatRequest {
Model : "test-model" ,
Messages : [ ] api . Message {
{
Role : "user" ,
Content : "What's the weather like in Paris Today?" ,
} ,
{
Role : "assistant" ,
ToolCalls : [ ] api . ToolCall {
{
Function : api . ToolCallFunction {
Name : "get_current_weather" ,
2025-10-08 06:38:58 +08:00
Arguments : makeArgs ( "location" , "Paris, France" , "format" , "celsius" ) ,
2025-10-06 05:18:56 +08:00
} ,
} ,
} ,
} ,
} ,
Options : map [ string ] any {
"temperature" : 1.0 ,
"top_p" : 1.0 ,
} ,
Stream : & False ,
} ,
} ,
{
name : "chat handler with tools and content" ,
body : ` {
"model" : "test-model" ,
"messages" : [
{ "role" : "user" , "content" : "What's the weather like in Paris Today?" } ,
{ "role" : "assistant" , "content" : "Let's see what the weather is like in Paris" , "tool_calls" : [ { "id" : "id" , "type" : "function" , "function" : { "name" : "get_current_weather" , "arguments" : "{\"location\": \"Paris, France\", \"format\": \"celsius\"}" } } ] }
]
} ` ,
req : api . ChatRequest {
Model : "test-model" ,
Messages : [ ] api . Message {
{
Role : "user" ,
Content : "What's the weather like in Paris Today?" ,
} ,
{
Role : "assistant" ,
Content : "Let's see what the weather is like in Paris" ,
ToolCalls : [ ] api . ToolCall {
{
Function : api . ToolCallFunction {
Name : "get_current_weather" ,
2025-10-08 06:38:58 +08:00
Arguments : makeArgs ( "location" , "Paris, France" , "format" , "celsius" ) ,
2025-10-06 05:18:56 +08:00
} ,
} ,
} ,
} ,
} ,
Options : map [ string ] any {
"temperature" : 1.0 ,
"top_p" : 1.0 ,
} ,
Stream : & False ,
} ,
} ,
{
name : "chat handler with tools and empty content" ,
body : ` {
"model" : "test-model" ,
"messages" : [
{ "role" : "user" , "content" : "What's the weather like in Paris Today?" } ,
{ "role" : "assistant" , "content" : "" , "tool_calls" : [ { "id" : "id" , "type" : "function" , "function" : { "name" : "get_current_weather" , "arguments" : "{\"location\": \"Paris, France\", \"format\": \"celsius\"}" } } ] }
]
} ` ,
req : api . ChatRequest {
Model : "test-model" ,
Messages : [ ] api . Message {
{
Role : "user" ,
Content : "What's the weather like in Paris Today?" ,
} ,
{
Role : "assistant" ,
ToolCalls : [ ] api . ToolCall {
{
Function : api . ToolCallFunction {
Name : "get_current_weather" ,
2025-10-08 06:38:58 +08:00
Arguments : makeArgs ( "location" , "Paris, France" , "format" , "celsius" ) ,
2025-10-06 05:18:56 +08:00
} ,
} ,
} ,
} ,
} ,
Options : map [ string ] any {
"temperature" : 1.0 ,
"top_p" : 1.0 ,
} ,
Stream : & False ,
} ,
} ,
{
name : "chat handler with tools and thinking content" ,
body : ` {
"model" : "test-model" ,
"messages" : [
{ "role" : "user" , "content" : "What's the weather like in Paris Today?" } ,
{ "role" : "assistant" , "reasoning" : "Let's see what the weather is like in Paris" , "tool_calls" : [ { "id" : "id" , "type" : "function" , "function" : { "name" : "get_current_weather" , "arguments" : "{\"location\": \"Paris, France\", \"format\": \"celsius\"}" } } ] }
]
} ` ,
req : api . ChatRequest {
Model : "test-model" ,
Messages : [ ] api . Message {
{
Role : "user" ,
Content : "What's the weather like in Paris Today?" ,
} ,
{
Role : "assistant" ,
Thinking : "Let's see what the weather is like in Paris" ,
ToolCalls : [ ] api . ToolCall {
{
Function : api . ToolCallFunction {
Name : "get_current_weather" ,
2025-10-08 06:38:58 +08:00
Arguments : makeArgs ( "location" , "Paris, France" , "format" , "celsius" ) ,
2025-10-06 05:18:56 +08:00
} ,
} ,
} ,
} ,
} ,
Options : map [ string ] any {
"temperature" : 1.0 ,
"top_p" : 1.0 ,
} ,
Stream : & False ,
} ,
} ,
{
name : "tool response with call ID" ,
body : ` {
"model" : "test-model" ,
"messages" : [
{ "role" : "user" , "content" : "What's the weather like in Paris Today?" } ,
{ "role" : "assistant" , "tool_calls" : [ { "id" : "id_abc" , "type" : "function" , "function" : { "name" : "get_current_weather" , "arguments" : "{\"location\": \"Paris, France\", \"format\": \"celsius\"}" } } ] } ,
{ "role" : "tool" , "tool_call_id" : "id_abc" , "content" : "The weather in Paris is 20 degrees Celsius" }
]
} ` ,
req : api . ChatRequest {
Model : "test-model" ,
Messages : [ ] api . Message {
{
Role : "user" ,
Content : "What's the weather like in Paris Today?" ,
} ,
{
Role : "assistant" ,
ToolCalls : [ ] api . ToolCall {
{
Function : api . ToolCallFunction {
Name : "get_current_weather" ,
2025-10-08 06:38:58 +08:00
Arguments : makeArgs ( "location" , "Paris, France" , "format" , "celsius" ) ,
2025-10-06 05:18:56 +08:00
} ,
} ,
} ,
} ,
{
Role : "tool" ,
Content : "The weather in Paris is 20 degrees Celsius" ,
ToolName : "get_current_weather" ,
} ,
} ,
Options : map [ string ] any {
"temperature" : 1.0 ,
"top_p" : 1.0 ,
} ,
Stream : & False ,
} ,
} ,
{
name : "tool response with name" ,
body : ` {
"model" : "test-model" ,
"messages" : [
{ "role" : "user" , "content" : "What's the weather like in Paris Today?" } ,
{ "role" : "assistant" , "tool_calls" : [ { "id" : "id" , "type" : "function" , "function" : { "name" : "get_current_weather" , "arguments" : "{\"location\": \"Paris, France\", \"format\": \"celsius\"}" } } ] } ,
{ "role" : "tool" , "name" : "get_current_weather" , "content" : "The weather in Paris is 20 degrees Celsius" }
]
} ` ,
req : api . ChatRequest {
Model : "test-model" ,
Messages : [ ] api . Message {
{
Role : "user" ,
Content : "What's the weather like in Paris Today?" ,
} ,
{
Role : "assistant" ,
ToolCalls : [ ] api . ToolCall {
{
Function : api . ToolCallFunction {
Name : "get_current_weather" ,
2025-10-08 06:38:58 +08:00
Arguments : makeArgs ( "location" , "Paris, France" , "format" , "celsius" ) ,
2025-10-06 05:18:56 +08:00
} ,
} ,
} ,
} ,
{
Role : "tool" ,
Content : "The weather in Paris is 20 degrees Celsius" ,
ToolName : "get_current_weather" ,
} ,
} ,
Options : map [ string ] any {
"temperature" : 1.0 ,
"top_p" : 1.0 ,
} ,
Stream : & False ,
} ,
} ,
{
name : "chat handler with streaming tools" ,
body : ` {
"model" : "test-model" ,
"messages" : [
{ "role" : "user" , "content" : "What's the weather like in Paris?" }
] ,
"stream" : true ,
"tools" : [ {
"type" : "function" ,
"function" : {
"name" : "get_weather" ,
"description" : "Get the current weather" ,
"parameters" : {
"type" : "object" ,
"required" : [ "location" ] ,
"properties" : {
"location" : {
"type" : "string" ,
"description" : "The city and state"
} ,
"unit" : {
"type" : "string" ,
"enum" : [ "celsius" , "fahrenheit" ]
}
}
}
}
} ]
} ` ,
req : api . ChatRequest {
Model : "test-model" ,
Messages : [ ] api . Message {
{
Role : "user" ,
Content : "What's the weather like in Paris?" ,
} ,
} ,
Tools : [ ] api . Tool {
{
Type : "function" ,
Function : api . ToolFunction {
Name : "get_weather" ,
Description : "Get the current weather" ,
2025-10-08 06:38:58 +08:00
Parameters : api . NewToolFunctionParametersWithProps (
"object" ,
[ ] string { "location" } ,
func ( ) * api . ToolProperties {
props := api . NewToolProperties ( )
props . Set ( "location" , api . ToolProperty {
2025-10-06 05:18:56 +08:00
Type : api . PropertyType { "string" } ,
Description : "The city and state" ,
2025-10-08 06:38:58 +08:00
} )
props . Set ( "unit" , api . ToolProperty {
2025-10-06 05:18:56 +08:00
Type : api . PropertyType { "string" } ,
Enum : [ ] any { "celsius" , "fahrenheit" } ,
2025-10-08 06:38:58 +08:00
} )
return props
} ( ) ,
) ,
2025-10-06 05:18:56 +08:00
} ,
} ,
} ,
Options : map [ string ] any {
"temperature" : 1.0 ,
"top_p" : 1.0 ,
} ,
Stream : & True ,
} ,
} ,
{
name : "chat handler error forwarding" ,
body : ` {
"model" : "test-model" ,
"messages" : [
{ "role" : "user" , "content" : 2 }
]
} ` ,
err : openai . ErrorResponse {
Error : openai . Error {
Message : "invalid message content type: float64" ,
Type : "invalid_request_error" ,
} ,
} ,
} ,
}
endpoint := func ( c * gin . Context ) {
c . Status ( http . StatusOK )
}
gin . SetMode ( gin . TestMode )
router := gin . New ( )
router . Use ( ChatMiddleware ( ) , captureRequestMiddleware ( & capturedRequest ) )
router . Handle ( http . MethodPost , "/api/chat" , endpoint )
for _ , tc := range testCases {
t . Run ( tc . name , func ( t * testing . T ) {
req , _ := http . NewRequest ( http . MethodPost , "/api/chat" , strings . NewReader ( tc . body ) )
req . Header . Set ( "Content-Type" , "application/json" )
defer func ( ) { capturedRequest = nil } ( )
resp := httptest . NewRecorder ( )
router . ServeHTTP ( resp , req )
var errResp openai . ErrorResponse
if resp . Code != http . StatusOK {
if err := json . Unmarshal ( resp . Body . Bytes ( ) , & errResp ) ; err != nil {
t . Fatal ( err )
}
return
}
2025-10-08 06:38:58 +08:00
if diff := cmp . Diff ( & tc . req , capturedRequest , cmpopts . IgnoreUnexported ( api . ToolCallFunctionArguments { } , api . ToolProperties { } , api . ToolFunctionParameters { } ) ) ; diff != "" {
2025-10-06 05:18:56 +08:00
t . Fatalf ( "requests did not match: %+v" , diff )
}
if diff := cmp . Diff ( tc . err , errResp ) ; diff != "" {
t . Fatalf ( "errors did not match for %s:\n%s" , tc . name , diff )
}
} )
}
}
func TestCompletionsMiddleware ( t * testing . T ) {
type testCase struct {
name string
body string
req api . GenerateRequest
err openai . ErrorResponse
}
var capturedRequest * api . GenerateRequest
testCases := [ ] testCase {
{
name : "completions handler" ,
body : ` {
"model" : "test-model" ,
"prompt" : "Hello" ,
"temperature" : 0.8 ,
"stop" : [ "\n" , "stop" ] ,
"suffix" : "suffix"
} ` ,
req : api . GenerateRequest {
Model : "test-model" ,
Prompt : "Hello" ,
Options : map [ string ] any {
"frequency_penalty" : 0.0 ,
"presence_penalty" : 0.0 ,
"temperature" : 0.8 ,
"top_p" : 1.0 ,
"stop" : [ ] any { "\n" , "stop" } ,
} ,
Suffix : "suffix" ,
Stream : & False ,
} ,
} ,
{
name : "completions handler stream" ,
body : ` {
"model" : "test-model" ,
"prompt" : "Hello" ,
"stream" : true ,
"temperature" : 0.8 ,
"stop" : [ "\n" , "stop" ] ,
"suffix" : "suffix"
} ` ,
req : api . GenerateRequest {
Model : "test-model" ,
Prompt : "Hello" ,
Options : map [ string ] any {
"frequency_penalty" : 0.0 ,
"presence_penalty" : 0.0 ,
"temperature" : 0.8 ,
"top_p" : 1.0 ,
"stop" : [ ] any { "\n" , "stop" } ,
} ,
Suffix : "suffix" ,
Stream : & True ,
} ,
} ,
{
name : "completions handler stream with usage" ,
body : ` {
"model" : "test-model" ,
"prompt" : "Hello" ,
"stream" : true ,
"stream_options" : { "include_usage" : true } ,
"temperature" : 0.8 ,
"stop" : [ "\n" , "stop" ] ,
"suffix" : "suffix"
} ` ,
req : api . GenerateRequest {
Model : "test-model" ,
Prompt : "Hello" ,
Options : map [ string ] any {
"frequency_penalty" : 0.0 ,
"presence_penalty" : 0.0 ,
"temperature" : 0.8 ,
"top_p" : 1.0 ,
"stop" : [ ] any { "\n" , "stop" } ,
} ,
Suffix : "suffix" ,
Stream : & True ,
} ,
} ,
{
name : "completions handler error forwarding" ,
body : ` {
"model" : "test-model" ,
"prompt" : "Hello" ,
"temperature" : null ,
"stop" : [ 1 , 2 ] ,
"suffix" : "suffix"
} ` ,
err : openai . ErrorResponse {
Error : openai . Error {
Message : "invalid type for 'stop' field: float64" ,
Type : "invalid_request_error" ,
} ,
} ,
} ,
}
endpoint := func ( c * gin . Context ) {
c . Status ( http . StatusOK )
}
gin . SetMode ( gin . TestMode )
router := gin . New ( )
router . Use ( CompletionsMiddleware ( ) , captureRequestMiddleware ( & capturedRequest ) )
router . Handle ( http . MethodPost , "/api/generate" , endpoint )
for _ , tc := range testCases {
t . Run ( tc . name , func ( t * testing . T ) {
req , _ := http . NewRequest ( http . MethodPost , "/api/generate" , strings . NewReader ( tc . body ) )
req . Header . Set ( "Content-Type" , "application/json" )
resp := httptest . NewRecorder ( )
router . ServeHTTP ( resp , req )
var errResp openai . ErrorResponse
if resp . Code != http . StatusOK {
if err := json . Unmarshal ( resp . Body . Bytes ( ) , & errResp ) ; err != nil {
t . Fatal ( err )
}
}
if capturedRequest != nil && ! reflect . DeepEqual ( tc . req , * capturedRequest ) {
t . Fatal ( "requests did not match" )
}
if ! reflect . DeepEqual ( tc . err , errResp ) {
t . Fatal ( "errors did not match" )
}
capturedRequest = nil
} )
}
}
func TestEmbeddingsMiddleware ( t * testing . T ) {
type testCase struct {
name string
body string
req api . EmbedRequest
err openai . ErrorResponse
}
var capturedRequest * api . EmbedRequest
testCases := [ ] testCase {
{
name : "embed handler single input" ,
body : ` {
"input" : "Hello" ,
"model" : "test-model"
} ` ,
req : api . EmbedRequest {
Input : "Hello" ,
Model : "test-model" ,
} ,
} ,
{
name : "embed handler batch input" ,
body : ` {
"input" : [ "Hello" , "World" ] ,
"model" : "test-model"
} ` ,
req : api . EmbedRequest {
Input : [ ] any { "Hello" , "World" } ,
Model : "test-model" ,
} ,
} ,
{
name : "embed handler error forwarding" ,
body : ` {
"model" : "test-model"
} ` ,
err : openai . ErrorResponse {
Error : openai . Error {
Message : "invalid input" ,
Type : "invalid_request_error" ,
} ,
} ,
} ,
}
endpoint := func ( c * gin . Context ) {
c . Status ( http . StatusOK )
}
gin . SetMode ( gin . TestMode )
router := gin . New ( )
router . Use ( EmbeddingsMiddleware ( ) , captureRequestMiddleware ( & capturedRequest ) )
router . Handle ( http . MethodPost , "/api/embed" , endpoint )
for _ , tc := range testCases {
t . Run ( tc . name , func ( t * testing . T ) {
req , _ := http . NewRequest ( http . MethodPost , "/api/embed" , strings . NewReader ( tc . body ) )
req . Header . Set ( "Content-Type" , "application/json" )
resp := httptest . NewRecorder ( )
router . ServeHTTP ( resp , req )
var errResp openai . ErrorResponse
if resp . Code != http . StatusOK {
if err := json . Unmarshal ( resp . Body . Bytes ( ) , & errResp ) ; err != nil {
t . Fatal ( err )
}
}
if capturedRequest != nil && ! reflect . DeepEqual ( tc . req , * capturedRequest ) {
t . Fatal ( "requests did not match" )
}
if ! reflect . DeepEqual ( tc . err , errResp ) {
t . Fatal ( "errors did not match" )
}
capturedRequest = nil
} )
}
}
func TestListMiddleware ( t * testing . T ) {
type testCase struct {
name string
endpoint func ( c * gin . Context )
resp string
}
testCases := [ ] testCase {
{
name : "list handler" ,
endpoint : func ( c * gin . Context ) {
c . JSON ( http . StatusOK , api . ListResponse {
Models : [ ] api . ListModelResponse {
{
Name : "test-model" ,
ModifiedAt : time . Unix ( int64 ( 1686935002 ) , 0 ) . UTC ( ) ,
} ,
} ,
} )
} ,
resp : ` {
"object" : "list" ,
"data" : [
{
"id" : "test-model" ,
"object" : "model" ,
"created" : 1686935002 ,
"owned_by" : "library"
}
]
} ` ,
} ,
{
name : "list handler empty output" ,
endpoint : func ( c * gin . Context ) {
c . JSON ( http . StatusOK , api . ListResponse { } )
} ,
resp : ` {
"object" : "list" ,
"data" : null
} ` ,
} ,
}
gin . SetMode ( gin . TestMode )
for _ , tc := range testCases {
router := gin . New ( )
router . Use ( ListMiddleware ( ) )
router . Handle ( http . MethodGet , "/api/tags" , tc . endpoint )
req , _ := http . NewRequest ( http . MethodGet , "/api/tags" , nil )
resp := httptest . NewRecorder ( )
router . ServeHTTP ( resp , req )
var expected , actual map [ string ] any
err := json . Unmarshal ( [ ] byte ( tc . resp ) , & expected )
if err != nil {
t . Fatalf ( "failed to unmarshal expected response: %v" , err )
}
err = json . Unmarshal ( resp . Body . Bytes ( ) , & actual )
if err != nil {
t . Fatalf ( "failed to unmarshal actual response: %v" , err )
}
if ! reflect . DeepEqual ( expected , actual ) {
t . Errorf ( "responses did not match\nExpected: %+v\nActual: %+v" , expected , actual )
}
}
}
func TestRetrieveMiddleware ( t * testing . T ) {
type testCase struct {
name string
endpoint func ( c * gin . Context )
resp string
}
testCases := [ ] testCase {
{
name : "retrieve handler" ,
endpoint : func ( c * gin . Context ) {
c . JSON ( http . StatusOK , api . ShowResponse {
ModifiedAt : time . Unix ( int64 ( 1686935002 ) , 0 ) . UTC ( ) ,
} )
} ,
resp : ` {
"id" : "test-model" ,
"object" : "model" ,
"created" : 1686935002 ,
"owned_by" : "library" }
` ,
} ,
{
name : "retrieve handler error forwarding" ,
endpoint : func ( c * gin . Context ) {
c . JSON ( http . StatusBadRequest , gin . H { "error" : "model not found" } )
} ,
resp : ` {
"error" : {
"code" : null ,
"message" : "model not found" ,
"param" : null ,
"type" : "api_error"
}
} ` ,
} ,
}
gin . SetMode ( gin . TestMode )
for _ , tc := range testCases {
router := gin . New ( )
router . Use ( RetrieveMiddleware ( ) )
router . Handle ( http . MethodGet , "/api/show/:model" , tc . endpoint )
req , _ := http . NewRequest ( http . MethodGet , "/api/show/test-model" , nil )
resp := httptest . NewRecorder ( )
router . ServeHTTP ( resp , req )
var expected , actual map [ string ] any
err := json . Unmarshal ( [ ] byte ( tc . resp ) , & expected )
if err != nil {
t . Fatalf ( "failed to unmarshal expected response: %v" , err )
}
err = json . Unmarshal ( resp . Body . Bytes ( ) , & actual )
if err != nil {
t . Fatalf ( "failed to unmarshal actual response: %v" , err )
}
if ! reflect . DeepEqual ( expected , actual ) {
t . Errorf ( "responses did not match\nExpected: %+v\nActual: %+v" , expected , actual )
}
}
}