2024-07-03 02:50:56 +08:00
package openai
import (
"bytes"
2024-07-14 13:07:45 +08:00
"encoding/base64"
2024-07-03 02:50:56 +08:00
"encoding/json"
"io"
"net/http"
"net/http/httptest"
2024-08-13 01:33:34 +08:00
"reflect"
2024-07-03 07:01:45 +08:00
"strings"
2024-07-03 02:50:56 +08:00
"testing"
"time"
"github.com/gin-gonic/gin"
2024-12-05 08:31:19 +08:00
"github.com/google/go-cmp/cmp"
2024-08-02 05:52:15 +08:00
"github.com/ollama/ollama/api"
2024-07-03 02:50:56 +08:00
)
2024-08-02 05:52:15 +08:00
const (
2024-08-13 01:33:34 +08:00
prefix = ` data:image/jpeg;base64, `
image = ` iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII= `
2024-08-02 05:52:15 +08:00
)
2024-07-14 13:07:45 +08:00
2024-09-06 16:16:28 +08:00
var (
False = false
True = true
)
2024-07-20 02:37:12 +08:00
func captureRequestMiddleware ( capturedRequest any ) gin . HandlerFunc {
return func ( c * gin . Context ) {
bodyBytes , _ := io . ReadAll ( c . Request . Body )
c . Request . Body = io . NopCloser ( bytes . NewReader ( bodyBytes ) )
err := json . Unmarshal ( bodyBytes , capturedRequest )
if err != nil {
c . AbortWithStatusJSON ( http . StatusInternalServerError , "failed to unmarshal request" )
}
c . Next ( )
}
}
func TestChatMiddleware ( t * testing . T ) {
2024-07-03 02:50:56 +08:00
type testCase struct {
2024-08-13 01:33:34 +08:00
name string
body string
req api . ChatRequest
err ErrorResponse
2024-07-03 02:50:56 +08:00
}
2024-07-20 02:37:12 +08:00
var capturedRequest * api . ChatRequest
2024-07-03 02:50:56 +08:00
2024-07-10 04:48:31 +08:00
testCases := [ ] testCase {
{
2024-08-13 01:33:34 +08:00
name : "chat handler" ,
body : ` {
"model" : "test-model" ,
"messages" : [
{ "role" : "user" , "content" : "Hello" }
]
} ` ,
req : api . ChatRequest {
Model : "test-model" ,
Messages : [ ] api . Message {
{
Role : "user" ,
Content : "Hello" ,
} ,
} ,
Options : map [ string ] any {
"temperature" : 1.0 ,
"top_p" : 1.0 ,
} ,
Stream : & False ,
2024-07-03 02:50:56 +08:00
} ,
2024-07-20 02:37:12 +08:00
} ,
2024-09-06 16:16:28 +08:00
{
name : "chat handler with options" ,
body : ` {
"model" : "test-model" ,
"messages" : [
{ "role" : "user" , "content" : "Hello" }
] ,
"stream" : true ,
"max_tokens" : 999 ,
"seed" : 123 ,
"stop" : [ "\n" , "stop" ] ,
"temperature" : 3.0 ,
"frequency_penalty" : 4.0 ,
"presence_penalty" : 5.0 ,
"top_p" : 6.0 ,
"response_format" : { "type" : "json_object" }
} ` ,
req : api . ChatRequest {
Model : "test-model" ,
Messages : [ ] api . Message {
{
Role : "user" ,
Content : "Hello" ,
} ,
} ,
Options : map [ string ] any {
"num_predict" : 999.0 , // float because JSON doesn't distinguish between float and int
"seed" : 123.0 ,
"stop" : [ ] any { "\n" , "stop" } ,
2024-09-08 00:08:08 +08:00
"temperature" : 3.0 ,
"frequency_penalty" : 4.0 ,
"presence_penalty" : 5.0 ,
2024-09-06 16:16:28 +08:00
"top_p" : 6.0 ,
} ,
2024-12-05 08:31:19 +08:00
Format : json . RawMessage ( ` "json" ` ) ,
2024-09-06 16:16:28 +08:00
Stream : & True ,
} ,
} ,
2024-12-13 09:09:30 +08:00
{
name : "chat handler with streaming usage" ,
body : ` {
"model" : "test-model" ,
"messages" : [
{ "role" : "user" , "content" : "Hello" }
] ,
"stream" : true ,
"stream_options" : { "include_usage" : true } ,
"max_tokens" : 999 ,
"seed" : 123 ,
"stop" : [ "\n" , "stop" ] ,
"temperature" : 3.0 ,
"frequency_penalty" : 4.0 ,
"presence_penalty" : 5.0 ,
"top_p" : 6.0 ,
"response_format" : { "type" : "json_object" }
} ` ,
req : api . ChatRequest {
Model : "test-model" ,
Messages : [ ] api . Message {
{
Role : "user" ,
Content : "Hello" ,
} ,
} ,
Options : map [ string ] any {
"num_predict" : 999.0 , // float because JSON doesn't distinguish between float and int
"seed" : 123.0 ,
"stop" : [ ] any { "\n" , "stop" } ,
"temperature" : 3.0 ,
"frequency_penalty" : 4.0 ,
"presence_penalty" : 5.0 ,
"top_p" : 6.0 ,
} ,
Format : json . RawMessage ( ` "json" ` ) ,
Stream : & True ,
} ,
} ,
2024-07-20 02:37:12 +08:00
{
2024-08-13 01:33:34 +08:00
name : "chat handler with image content" ,
body : ` {
"model" : "test-model" ,
"messages" : [
{
"role" : "user" ,
"content" : [
{
"type" : "text" ,
"text" : "Hello"
2024-07-20 02:37:12 +08:00
} ,
2024-08-13 01:33:34 +08:00
{
"type" : "image_url" ,
"image_url" : {
"url" : "` + prefix + image + `"
}
}
]
}
]
} ` ,
req : api . ChatRequest {
Model : "test-model" ,
Messages : [ ] api . Message {
{
Role : "user" ,
Content : "Hello" ,
} ,
{
Role : "user" ,
Images : [ ] api . ImageData {
func ( ) [ ] byte {
img , _ := base64 . StdEncoding . DecodeString ( image )
return img
} ( ) ,
2024-07-20 02:37:12 +08:00
} ,
} ,
2024-08-13 01:33:34 +08:00
} ,
Options : map [ string ] any {
"temperature" : 1.0 ,
"top_p" : 1.0 ,
} ,
Stream : & False ,
2024-07-20 02:37:12 +08:00
} ,
} ,
{
2024-08-13 01:33:34 +08:00
name : "chat handler with tools" ,
body : ` {
"model" : "test-model" ,
"messages" : [
{ "role" : "user" , "content" : "What's the weather like in Paris Today?" } ,
{ "role" : "assistant" , "tool_calls" : [ { "id" : "id" , "type" : "function" , "function" : { "name" : "get_current_weather" , "arguments" : "{\"location\": \"Paris, France\", \"format\": \"celsius\"}" } } ] }
]
} ` ,
req : api . ChatRequest {
Model : "test-model" ,
Messages : [ ] api . Message {
{
Role : "user" ,
Content : "What's the weather like in Paris Today?" ,
} ,
{
Role : "assistant" ,
ToolCalls : [ ] api . ToolCall {
{
Function : api . ToolCallFunction {
Name : "get_current_weather" ,
2025-04-03 00:44:27 +08:00
Arguments : map [ string ] any {
2024-08-13 01:33:34 +08:00
"location" : "Paris, France" ,
"format" : "celsius" ,
} ,
} ,
2024-07-20 02:37:12 +08:00
} ,
2024-08-13 01:33:34 +08:00
} ,
2024-07-20 02:37:12 +08:00
} ,
2024-08-13 01:33:34 +08:00
} ,
Options : map [ string ] any {
"temperature" : 1.0 ,
"top_p" : 1.0 ,
} ,
Stream : & False ,
2024-07-03 02:50:56 +08:00
} ,
} ,
2025-08-07 06:50:30 +08:00
{
name : "chat handler with tools and content" ,
body : ` {
"model" : "test-model" ,
"messages" : [
{ "role" : "user" , "content" : "What's the weather like in Paris Today?" } ,
{ "role" : "assistant" , "content" : "Let's see what the weather is like in Paris" , "tool_calls" : [ { "id" : "id" , "type" : "function" , "function" : { "name" : "get_current_weather" , "arguments" : "{\"location\": \"Paris, France\", \"format\": \"celsius\"}" } } ] }
]
} ` ,
req : api . ChatRequest {
Model : "test-model" ,
Messages : [ ] api . Message {
{
Role : "user" ,
Content : "What's the weather like in Paris Today?" ,
} ,
{
Role : "assistant" ,
Content : "Let's see what the weather is like in Paris" ,
ToolCalls : [ ] api . ToolCall {
{
Function : api . ToolCallFunction {
Name : "get_current_weather" ,
Arguments : map [ string ] any {
"location" : "Paris, France" ,
"format" : "celsius" ,
} ,
} ,
} ,
} ,
} ,
} ,
Options : map [ string ] any {
"temperature" : 1.0 ,
"top_p" : 1.0 ,
} ,
Stream : & False ,
} ,
} ,
2025-08-07 08:00:24 +08:00
{
name : "tool response with call ID" ,
body : ` {
"model" : "test-model" ,
"messages" : [
{ "role" : "user" , "content" : "What's the weather like in Paris Today?" } ,
{ "role" : "assistant" , "tool_calls" : [ { "id" : "id_abc" , "type" : "function" , "function" : { "name" : "get_current_weather" , "arguments" : "{\"location\": \"Paris, France\", \"format\": \"celsius\"}" } } ] } ,
{ "role" : "tool" , "tool_call_id" : "id_abc" , "content" : "The weather in Paris is 20 degrees Celsius" }
]
} ` ,
req : api . ChatRequest {
Model : "test-model" ,
Messages : [ ] api . Message {
{
Role : "user" ,
Content : "What's the weather like in Paris Today?" ,
} ,
{
Role : "assistant" ,
ToolCalls : [ ] api . ToolCall {
{
Function : api . ToolCallFunction {
Name : "get_current_weather" ,
Arguments : map [ string ] any {
"location" : "Paris, France" ,
"format" : "celsius" ,
} ,
} ,
} ,
} ,
} ,
{
Role : "tool" ,
Content : "The weather in Paris is 20 degrees Celsius" ,
ToolName : "get_current_weather" ,
} ,
} ,
Options : map [ string ] any {
"temperature" : 1.0 ,
"top_p" : 1.0 ,
} ,
Stream : & False ,
} ,
} ,
{
name : "tool response with name" ,
body : ` {
"model" : "test-model" ,
"messages" : [
{ "role" : "user" , "content" : "What's the weather like in Paris Today?" } ,
{ "role" : "assistant" , "tool_calls" : [ { "id" : "id" , "type" : "function" , "function" : { "name" : "get_current_weather" , "arguments" : "{\"location\": \"Paris, France\", \"format\": \"celsius\"}" } } ] } ,
{ "role" : "tool" , "name" : "get_current_weather" , "content" : "The weather in Paris is 20 degrees Celsius" }
]
} ` ,
req : api . ChatRequest {
Model : "test-model" ,
Messages : [ ] api . Message {
{
Role : "user" ,
Content : "What's the weather like in Paris Today?" ,
} ,
{
Role : "assistant" ,
ToolCalls : [ ] api . ToolCall {
{
Function : api . ToolCallFunction {
Name : "get_current_weather" ,
Arguments : map [ string ] any {
"location" : "Paris, France" ,
"format" : "celsius" ,
} ,
} ,
} ,
} ,
} ,
{
Role : "tool" ,
Content : "The weather in Paris is 20 degrees Celsius" ,
ToolName : "get_current_weather" ,
} ,
} ,
Options : map [ string ] any {
"temperature" : 1.0 ,
"top_p" : 1.0 ,
} ,
Stream : & False ,
} ,
} ,
2024-11-30 12:00:09 +08:00
{
name : "chat handler with streaming tools" ,
body : ` {
"model" : "test-model" ,
"messages" : [
{ "role" : "user" , "content" : "What's the weather like in Paris?" }
] ,
"stream" : true ,
"tools" : [ {
"type" : "function" ,
"function" : {
"name" : "get_weather" ,
"description" : "Get the current weather" ,
"parameters" : {
"type" : "object" ,
"required" : [ "location" ] ,
"properties" : {
"location" : {
"type" : "string" ,
"description" : "The city and state"
} ,
"unit" : {
"type" : "string" ,
"enum" : [ "celsius" , "fahrenheit" ]
}
}
}
}
} ]
} ` ,
req : api . ChatRequest {
Model : "test-model" ,
Messages : [ ] api . Message {
{
Role : "user" ,
Content : "What's the weather like in Paris?" ,
} ,
} ,
Tools : [ ] api . Tool {
{
Type : "function" ,
Function : api . ToolFunction {
Name : "get_weather" ,
Description : "Get the current weather" ,
Parameters : struct {
2025-08-06 07:46:24 +08:00
Type string ` json:"type" `
Defs any ` json:"$defs,omitempty" `
Items any ` json:"items,omitempty" `
Required [ ] string ` json:"required" `
Properties map [ string ] api . ToolProperty ` json:"properties" `
2024-11-30 12:00:09 +08:00
} {
Type : "object" ,
Required : [ ] string { "location" } ,
2025-08-06 07:46:24 +08:00
Properties : map [ string ] api . ToolProperty {
2024-11-30 12:00:09 +08:00
"location" : {
2025-04-08 05:27:01 +08:00
Type : api . PropertyType { "string" } ,
2024-11-30 12:00:09 +08:00
Description : "The city and state" ,
} ,
"unit" : {
2025-04-08 05:27:01 +08:00
Type : api . PropertyType { "string" } ,
2025-04-09 06:05:38 +08:00
Enum : [ ] any { "celsius" , "fahrenheit" } ,
2024-11-30 12:00:09 +08:00
} ,
} ,
} ,
} ,
} ,
} ,
Options : map [ string ] any {
"temperature" : 1.0 ,
"top_p" : 1.0 ,
} ,
Stream : & True ,
} ,
} ,
2024-08-13 01:33:34 +08:00
{
name : "chat handler error forwarding" ,
body : ` {
"model" : "test-model" ,
"messages" : [
{ "role" : "user" , "content" : 2 }
]
} ` ,
err : ErrorResponse {
Error : Error {
Message : "invalid message content type: float64" ,
Type : "invalid_request_error" ,
} ,
2024-07-20 02:37:12 +08:00
} ,
} ,
}
endpoint := func ( c * gin . Context ) {
c . Status ( http . StatusOK )
}
gin . SetMode ( gin . TestMode )
router := gin . New ( )
router . Use ( ChatMiddleware ( ) , captureRequestMiddleware ( & capturedRequest ) )
router . Handle ( http . MethodPost , "/api/chat" , endpoint )
for _ , tc := range testCases {
2024-08-13 01:33:34 +08:00
t . Run ( tc . name , func ( t * testing . T ) {
req , _ := http . NewRequest ( http . MethodPost , "/api/chat" , strings . NewReader ( tc . body ) )
req . Header . Set ( "Content-Type" , "application/json" )
2024-07-20 02:37:12 +08:00
2024-09-06 16:16:28 +08:00
defer func ( ) { capturedRequest = nil } ( )
2024-07-20 02:37:12 +08:00
resp := httptest . NewRecorder ( )
router . ServeHTTP ( resp , req )
2024-08-13 01:33:34 +08:00
var errResp ErrorResponse
if resp . Code != http . StatusOK {
if err := json . Unmarshal ( resp . Body . Bytes ( ) , & errResp ) ; err != nil {
t . Fatal ( err )
}
2024-12-05 08:31:19 +08:00
return
2024-08-13 01:33:34 +08:00
}
2024-12-05 08:31:19 +08:00
if diff := cmp . Diff ( & tc . req , capturedRequest ) ; diff != "" {
t . Fatalf ( "requests did not match: %+v" , diff )
2024-08-13 01:33:34 +08:00
}
2024-12-05 08:31:19 +08:00
if diff := cmp . Diff ( tc . err , errResp ) ; diff != "" {
t . Fatalf ( "errors did not match for %s:\n%s" , tc . name , diff )
2024-08-13 01:33:34 +08:00
}
2024-07-20 02:37:12 +08:00
} )
}
}
func TestCompletionsMiddleware ( t * testing . T ) {
type testCase struct {
2024-08-13 01:33:34 +08:00
name string
body string
req api . GenerateRequest
err ErrorResponse
2024-07-20 02:37:12 +08:00
}
var capturedRequest * api . GenerateRequest
testCases := [ ] testCase {
{
2024-08-13 01:33:34 +08:00
name : "completions handler" ,
body : ` {
"model" : "test-model" ,
"prompt" : "Hello" ,
"temperature" : 0.8 ,
"stop" : [ "\n" , "stop" ] ,
"suffix" : "suffix"
} ` ,
req : api . GenerateRequest {
Model : "test-model" ,
Prompt : "Hello" ,
Options : map [ string ] any {
"frequency_penalty" : 0.0 ,
"presence_penalty" : 0.0 ,
2024-09-07 08:45:45 +08:00
"temperature" : 0.8 ,
2024-08-13 01:33:34 +08:00
"top_p" : 1.0 ,
"stop" : [ ] any { "\n" , "stop" } ,
} ,
Suffix : "suffix" ,
Stream : & False ,
2024-07-03 07:01:45 +08:00
} ,
} ,
2024-12-13 09:09:30 +08:00
{
name : "completions handler stream" ,
body : ` {
"model" : "test-model" ,
"prompt" : "Hello" ,
"stream" : true ,
"temperature" : 0.8 ,
"stop" : [ "\n" , "stop" ] ,
"suffix" : "suffix"
} ` ,
req : api . GenerateRequest {
Model : "test-model" ,
Prompt : "Hello" ,
Options : map [ string ] any {
"frequency_penalty" : 0.0 ,
"presence_penalty" : 0.0 ,
"temperature" : 0.8 ,
"top_p" : 1.0 ,
"stop" : [ ] any { "\n" , "stop" } ,
} ,
Suffix : "suffix" ,
Stream : & True ,
} ,
} ,
{
name : "completions handler stream with usage" ,
body : ` {
"model" : "test-model" ,
"prompt" : "Hello" ,
"stream" : true ,
"stream_options" : { "include_usage" : true } ,
"temperature" : 0.8 ,
"stop" : [ "\n" , "stop" ] ,
"suffix" : "suffix"
} ` ,
req : api . GenerateRequest {
Model : "test-model" ,
Prompt : "Hello" ,
Options : map [ string ] any {
"frequency_penalty" : 0.0 ,
"presence_penalty" : 0.0 ,
"temperature" : 0.8 ,
"top_p" : 1.0 ,
"stop" : [ ] any { "\n" , "stop" } ,
} ,
Suffix : "suffix" ,
Stream : & True ,
} ,
} ,
2024-07-14 13:07:45 +08:00
{
2024-08-13 01:33:34 +08:00
name : "completions handler error forwarding" ,
body : ` {
"model" : "test-model" ,
"prompt" : "Hello" ,
"temperature" : null ,
"stop" : [ 1 , 2 ] ,
"suffix" : "suffix"
} ` ,
err : ErrorResponse {
Error : Error {
Message : "invalid type for 'stop' field: float64" ,
Type : "invalid_request_error" ,
} ,
2024-07-20 02:37:12 +08:00
} ,
} ,
}
2024-07-14 13:07:45 +08:00
2024-07-20 02:37:12 +08:00
endpoint := func ( c * gin . Context ) {
c . Status ( http . StatusOK )
}
2024-07-14 13:07:45 +08:00
2024-07-20 02:37:12 +08:00
gin . SetMode ( gin . TestMode )
router := gin . New ( )
router . Use ( CompletionsMiddleware ( ) , captureRequestMiddleware ( & capturedRequest ) )
router . Handle ( http . MethodPost , "/api/generate" , endpoint )
2024-07-14 13:07:45 +08:00
2024-07-20 02:37:12 +08:00
for _ , tc := range testCases {
2024-08-13 01:33:34 +08:00
t . Run ( tc . name , func ( t * testing . T ) {
req , _ := http . NewRequest ( http . MethodPost , "/api/generate" , strings . NewReader ( tc . body ) )
req . Header . Set ( "Content-Type" , "application/json" )
2024-07-20 02:37:12 +08:00
resp := httptest . NewRecorder ( )
router . ServeHTTP ( resp , req )
2024-08-13 01:33:34 +08:00
var errResp ErrorResponse
if resp . Code != http . StatusOK {
if err := json . Unmarshal ( resp . Body . Bytes ( ) , & errResp ) ; err != nil {
t . Fatal ( err )
}
}
if capturedRequest != nil && ! reflect . DeepEqual ( tc . req , * capturedRequest ) {
t . Fatal ( "requests did not match" )
}
if ! reflect . DeepEqual ( tc . err , errResp ) {
t . Fatal ( "errors did not match" )
}
2024-07-20 02:37:12 +08:00
capturedRequest = nil
} )
}
}
func TestEmbeddingsMiddleware ( t * testing . T ) {
type testCase struct {
2024-08-13 01:33:34 +08:00
name string
body string
req api . EmbedRequest
err ErrorResponse
2024-07-20 02:37:12 +08:00
}
var capturedRequest * api . EmbedRequest
testCases := [ ] testCase {
2024-07-17 04:36:08 +08:00
{
2024-08-13 01:33:34 +08:00
name : "embed handler single input" ,
body : ` {
"input" : "Hello" ,
"model" : "test-model"
} ` ,
req : api . EmbedRequest {
Input : "Hello" ,
Model : "test-model" ,
2024-07-17 04:36:08 +08:00
} ,
} ,
{
2024-08-13 01:33:34 +08:00
name : "embed handler batch input" ,
body : ` {
"input" : [ "Hello" , "World" ] ,
"model" : "test-model"
} ` ,
req : api . EmbedRequest {
Input : [ ] any { "Hello" , "World" } ,
Model : "test-model" ,
2024-07-17 04:36:08 +08:00
} ,
} ,
2024-07-20 02:37:12 +08:00
{
2024-08-13 01:33:34 +08:00
name : "embed handler error forwarding" ,
body : ` {
"model" : "test-model"
} ` ,
err : ErrorResponse {
Error : Error {
Message : "invalid input" ,
Type : "invalid_request_error" ,
} ,
2024-07-20 02:37:12 +08:00
} ,
} ,
}
2024-07-03 07:01:45 +08:00
2024-07-10 04:48:31 +08:00
endpoint := func ( c * gin . Context ) {
c . Status ( http . StatusOK )
}
2024-07-03 07:01:45 +08:00
2024-07-20 02:37:12 +08:00
gin . SetMode ( gin . TestMode )
router := gin . New ( )
router . Use ( EmbeddingsMiddleware ( ) , captureRequestMiddleware ( & capturedRequest ) )
router . Handle ( http . MethodPost , "/api/embed" , endpoint )
2024-07-10 04:48:31 +08:00
for _ , tc := range testCases {
2024-08-13 01:33:34 +08:00
t . Run ( tc . name , func ( t * testing . T ) {
req , _ := http . NewRequest ( http . MethodPost , "/api/embed" , strings . NewReader ( tc . body ) )
req . Header . Set ( "Content-Type" , "application/json" )
2024-07-03 07:01:45 +08:00
2024-07-10 04:48:31 +08:00
resp := httptest . NewRecorder ( )
router . ServeHTTP ( resp , req )
2024-07-03 07:01:45 +08:00
2024-08-13 01:33:34 +08:00
var errResp ErrorResponse
if resp . Code != http . StatusOK {
if err := json . Unmarshal ( resp . Body . Bytes ( ) , & errResp ) ; err != nil {
t . Fatal ( err )
}
}
if capturedRequest != nil && ! reflect . DeepEqual ( tc . req , * capturedRequest ) {
t . Fatal ( "requests did not match" )
}
if ! reflect . DeepEqual ( tc . err , errResp ) {
t . Fatal ( "errors did not match" )
}
2024-07-20 02:37:12 +08:00
capturedRequest = nil
2024-07-10 04:48:31 +08:00
} )
}
}
2024-07-03 07:01:45 +08:00
2024-08-13 01:33:34 +08:00
func TestListMiddleware ( t * testing . T ) {
2024-07-10 04:48:31 +08:00
type testCase struct {
2024-08-13 01:33:34 +08:00
name string
endpoint func ( c * gin . Context )
resp string
2024-07-10 04:48:31 +08:00
}
testCases := [ ] testCase {
2024-07-03 02:50:56 +08:00
{
2024-08-13 01:33:34 +08:00
name : "list handler" ,
endpoint : func ( c * gin . Context ) {
2024-07-03 02:50:56 +08:00
c . JSON ( http . StatusOK , api . ListResponse {
Models : [ ] api . ListModelResponse {
{
2024-08-13 01:33:34 +08:00
Name : "test-model" ,
ModifiedAt : time . Unix ( int64 ( 1686935002 ) , 0 ) . UTC ( ) ,
2024-07-03 02:50:56 +08:00
} ,
} ,
} )
} ,
2024-08-13 01:33:34 +08:00
resp : ` {
"object" : "list" ,
"data" : [
{
"id" : "test-model" ,
"object" : "model" ,
"created" : 1686935002 ,
"owned_by" : "library"
}
]
} ` ,
} ,
{
name : "list handler empty output" ,
endpoint : func ( c * gin . Context ) {
c . JSON ( http . StatusOK , api . ListResponse { } )
} ,
resp : ` {
"object" : "list" ,
"data" : null
} ` ,
} ,
}
2024-07-03 02:50:56 +08:00
2024-08-13 01:33:34 +08:00
gin . SetMode ( gin . TestMode )
2024-07-03 02:50:56 +08:00
2024-08-13 01:33:34 +08:00
for _ , tc := range testCases {
router := gin . New ( )
router . Use ( ListMiddleware ( ) )
router . Handle ( http . MethodGet , "/api/tags" , tc . endpoint )
req , _ := http . NewRequest ( http . MethodGet , "/api/tags" , nil )
2024-07-03 02:50:56 +08:00
2024-08-13 01:33:34 +08:00
resp := httptest . NewRecorder ( )
router . ServeHTTP ( resp , req )
var expected , actual map [ string ] any
err := json . Unmarshal ( [ ] byte ( tc . resp ) , & expected )
if err != nil {
t . Fatalf ( "failed to unmarshal expected response: %v" , err )
}
err = json . Unmarshal ( resp . Body . Bytes ( ) , & actual )
if err != nil {
t . Fatalf ( "failed to unmarshal actual response: %v" , err )
}
if ! reflect . DeepEqual ( expected , actual ) {
t . Errorf ( "responses did not match\nExpected: %+v\nActual: %+v" , expected , actual )
}
}
}
func TestRetrieveMiddleware ( t * testing . T ) {
type testCase struct {
name string
endpoint func ( c * gin . Context )
resp string
}
testCases := [ ] testCase {
2024-07-03 02:50:56 +08:00
{
2024-08-13 01:33:34 +08:00
name : "retrieve handler" ,
endpoint : func ( c * gin . Context ) {
2024-07-03 02:50:56 +08:00
c . JSON ( http . StatusOK , api . ShowResponse {
2024-08-13 01:33:34 +08:00
ModifiedAt : time . Unix ( int64 ( 1686935002 ) , 0 ) . UTC ( ) ,
2024-07-03 02:50:56 +08:00
} )
} ,
2024-08-13 01:33:34 +08:00
resp : ` {
"id" : "test-model" ,
"object" : "model" ,
"created" : 1686935002 ,
"owned_by" : "library" }
` ,
} ,
{
name : "retrieve handler error forwarding" ,
endpoint : func ( c * gin . Context ) {
c . JSON ( http . StatusBadRequest , gin . H { "error" : "model not found" } )
2024-07-03 02:50:56 +08:00
} ,
2024-08-13 01:33:34 +08:00
resp : ` {
"error" : {
"code" : null ,
"message" : "model not found" ,
"param" : null ,
"type" : "api_error"
}
} ` ,
2024-07-03 02:50:56 +08:00
} ,
}
gin . SetMode ( gin . TestMode )
for _ , tc := range testCases {
2024-08-13 01:33:34 +08:00
router := gin . New ( )
router . Use ( RetrieveMiddleware ( ) )
router . Handle ( http . MethodGet , "/api/show/:model" , tc . endpoint )
req , _ := http . NewRequest ( http . MethodGet , "/api/show/test-model" , nil )
2024-07-03 02:50:56 +08:00
2024-08-13 01:33:34 +08:00
resp := httptest . NewRecorder ( )
router . ServeHTTP ( resp , req )
2024-07-03 02:50:56 +08:00
2024-08-13 01:33:34 +08:00
var expected , actual map [ string ] any
err := json . Unmarshal ( [ ] byte ( tc . resp ) , & expected )
if err != nil {
t . Fatalf ( "failed to unmarshal expected response: %v" , err )
}
2024-07-20 02:37:12 +08:00
2024-08-13 01:33:34 +08:00
err = json . Unmarshal ( resp . Body . Bytes ( ) , & actual )
if err != nil {
t . Fatalf ( "failed to unmarshal actual response: %v" , err )
}
if ! reflect . DeepEqual ( expected , actual ) {
t . Errorf ( "responses did not match\nExpected: %+v\nActual: %+v" , expected , actual )
}
2024-07-03 02:50:56 +08:00
}
}