mirror of https://github.com/chaitin/PandaWiki.git
Compare commits
3 Commits
5231b61db4
...
552fd06c85
| Author | SHA1 | Date |
|---|---|---|
|
|
552fd06c85 | |
|
|
cfc53da267 | |
|
|
3f9124c649 |
|
|
@ -1,10 +1,17 @@
|
||||||
package domain
|
package domain
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
// OpenAI API 请求结构体
|
// OpenAI API 请求结构体
|
||||||
type OpenAICompletionsRequest struct {
|
type OpenAICompletionsRequest struct {
|
||||||
Model string `json:"model" validate:"required"`
|
Model string `json:"model" validate:"required"`
|
||||||
Messages []OpenAIMessage `json:"messages" validate:"required"`
|
Messages []OpenAIMessage `json:"messages" validate:"required"`
|
||||||
Stream bool `json:"stream,omitempty"`
|
Stream bool `json:"stream,omitempty"`
|
||||||
|
StreamOptions *OpenAIStreamOptions `json:"stream_options,omitempty"`
|
||||||
Temperature *float64 `json:"temperature,omitempty"`
|
Temperature *float64 `json:"temperature,omitempty"`
|
||||||
MaxTokens *int `json:"max_tokens,omitempty"`
|
MaxTokens *int `json:"max_tokens,omitempty"`
|
||||||
TopP *float64 `json:"top_p,omitempty"`
|
TopP *float64 `json:"top_p,omitempty"`
|
||||||
|
|
@ -17,9 +24,89 @@ type OpenAICompletionsRequest struct {
|
||||||
ResponseFormat *OpenAIResponseFormat `json:"response_format,omitempty"`
|
ResponseFormat *OpenAIResponseFormat `json:"response_format,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type OpenAIStreamOptions struct {
|
||||||
|
IncludeUsage bool `json:"include_usage,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// MessageContent 支持字符串或内容数组
|
||||||
|
type MessageContent struct {
|
||||||
|
isString bool
|
||||||
|
strValue string
|
||||||
|
arrValue []OpenAIContentPart
|
||||||
|
}
|
||||||
|
|
||||||
|
// OpenAIContentPart 表示内容数组中的单个元素
|
||||||
|
type OpenAIContentPart struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Text string `json:"text,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON 自定义解析,支持 string 或 array 格式
|
||||||
|
func (mc *MessageContent) UnmarshalJSON(data []byte) error {
|
||||||
|
// 尝试解析为字符串
|
||||||
|
var str string
|
||||||
|
if err := json.Unmarshal(data, &str); err == nil {
|
||||||
|
mc.isString = true
|
||||||
|
mc.strValue = str
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 尝试解析为数组
|
||||||
|
var arr []OpenAIContentPart
|
||||||
|
if err := json.Unmarshal(data, &arr); err == nil {
|
||||||
|
mc.isString = false
|
||||||
|
mc.arrValue = arr
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("content must be string or array")
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON 自定义序列化
|
||||||
|
func (mc MessageContent) MarshalJSON() ([]byte, error) {
|
||||||
|
if mc.isString {
|
||||||
|
return json.Marshal(mc.strValue)
|
||||||
|
}
|
||||||
|
return json.Marshal(mc.arrValue)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewStringContent 创建字符串类型的 MessageContent
|
||||||
|
func NewStringContent(s string) *MessageContent {
|
||||||
|
return &MessageContent{
|
||||||
|
isString: true,
|
||||||
|
strValue: s,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewArrayContent 创建数组类型的 MessageContent
|
||||||
|
func NewArrayContent(parts []OpenAIContentPart) *MessageContent {
|
||||||
|
return &MessageContent{
|
||||||
|
isString: false,
|
||||||
|
arrValue: parts,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// String 获取文本内容
|
||||||
|
func (mc *MessageContent) String() string {
|
||||||
|
if mc.isString {
|
||||||
|
return mc.strValue
|
||||||
|
}
|
||||||
|
// 从数组中提取文本
|
||||||
|
var builder strings.Builder
|
||||||
|
for i, part := range mc.arrValue {
|
||||||
|
if part.Type == "text" {
|
||||||
|
if i > 0 && part.Text != "" {
|
||||||
|
builder.WriteString(" ")
|
||||||
|
}
|
||||||
|
builder.WriteString(part.Text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return builder.String()
|
||||||
|
}
|
||||||
|
|
||||||
type OpenAIMessage struct {
|
type OpenAIMessage struct {
|
||||||
Role string `json:"role" validate:"required"`
|
Role string `json:"role" validate:"required"`
|
||||||
Content string `json:"content,omitempty"`
|
Content *MessageContent `json:"content,omitempty"`
|
||||||
Name string `json:"name,omitempty"`
|
Name string `json:"name,omitempty"`
|
||||||
ToolCalls []OpenAIToolCall `json:"tool_calls,omitempty"`
|
ToolCalls []OpenAIToolCall `json:"tool_calls,omitempty"`
|
||||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||||
|
|
@ -90,6 +177,7 @@ type OpenAIStreamResponse struct {
|
||||||
Created int64 `json:"created"`
|
Created int64 `json:"created"`
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Choices []OpenAIStreamChoice `json:"choices"`
|
Choices []OpenAIStreamChoice `json:"choices"`
|
||||||
|
Usage *OpenAIUsage `json:"usage,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type OpenAIStreamChoice struct {
|
type OpenAIStreamChoice struct {
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,186 @@
|
||||||
|
package domain
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMessageContent_UnmarshalJSON_String(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
json string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{"simple string", `"hello"`, "hello"},
|
||||||
|
{"with quotes", `"say \"hello\""`, `say "hello"`},
|
||||||
|
{"with newline", `"line1\nline2"`, "line1\nline2"},
|
||||||
|
{"empty string", `""`, ""},
|
||||||
|
{"unicode", `"你好 🌍"`, "你好 🌍"},
|
||||||
|
{"special chars", `"Hello \"World\"\nNew Line\tTab"`, "Hello \"World\"\nNew Line\tTab"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var mc MessageContent
|
||||||
|
err := json.Unmarshal([]byte(tt.json), &mc)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, tt.expected, mc.String())
|
||||||
|
assert.True(t, mc.isString)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMessageContent_UnmarshalJSON_Array(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
json string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"single text part",
|
||||||
|
`[{"type":"text","text":"Hello"}]`,
|
||||||
|
"Hello",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"multiple text parts",
|
||||||
|
`[{"type":"text","text":"Hello"},{"type":"text","text":"World"}]`,
|
||||||
|
"Hello World",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"mixed types with image",
|
||||||
|
`[{"type":"text","text":"Look at this"},{"type":"image_url","image_url":{"url":"https://example.com/img.png"}},{"type":"text","text":"image"}]`,
|
||||||
|
"Look at this image",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"empty array",
|
||||||
|
`[]`,
|
||||||
|
"",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var mc MessageContent
|
||||||
|
err := json.Unmarshal([]byte(tt.json), &mc)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, tt.expected, mc.String())
|
||||||
|
assert.False(t, mc.isString)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMessageContent_UnmarshalJSON_Invalid(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
json string
|
||||||
|
}{
|
||||||
|
{"number", `123`},
|
||||||
|
{"boolean", `true`},
|
||||||
|
{"object", `{"key":"value"}`},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var mc MessageContent
|
||||||
|
err := json.Unmarshal([]byte(tt.json), &mc)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "content must be string or array")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMessageContent_UnmarshalJSON_Null(t *testing.T) {
|
||||||
|
var mc *MessageContent
|
||||||
|
err := json.Unmarshal([]byte(`null`), &mc)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Nil(t, mc)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMessageContent_MarshalJSON_String(t *testing.T) {
|
||||||
|
mc := NewStringContent("Hello World")
|
||||||
|
data, err := json.Marshal(mc)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, `"Hello World"`, string(data))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMessageContent_MarshalJSON_Array(t *testing.T) {
|
||||||
|
mc := NewArrayContent([]OpenAIContentPart{
|
||||||
|
{Type: "text", Text: "Hello"},
|
||||||
|
{Type: "text", Text: "World"},
|
||||||
|
})
|
||||||
|
data, err := json.Marshal(mc)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.JSONEq(t, `[{"type":"text","text":"Hello"},{"type":"text","text":"World"}]`, string(data))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMessageContent_Roundtrip_String(t *testing.T) {
|
||||||
|
original := NewStringContent("Test message with \"quotes\" and \nnewlines")
|
||||||
|
|
||||||
|
// Marshal
|
||||||
|
data, err := json.Marshal(original)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Unmarshal
|
||||||
|
var decoded MessageContent
|
||||||
|
err = json.Unmarshal(data, &decoded)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify
|
||||||
|
assert.Equal(t, original.String(), decoded.String())
|
||||||
|
assert.Equal(t, original.isString, decoded.isString)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMessageContent_Roundtrip_Array(t *testing.T) {
|
||||||
|
parts := []OpenAIContentPart{
|
||||||
|
{Type: "text", Text: "Part 1"},
|
||||||
|
{Type: "text", Text: "Part 2"},
|
||||||
|
}
|
||||||
|
original := NewArrayContent(parts)
|
||||||
|
|
||||||
|
// Marshal
|
||||||
|
data, err := json.Marshal(original)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Unmarshal
|
||||||
|
var decoded MessageContent
|
||||||
|
err = json.Unmarshal(data, &decoded)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify
|
||||||
|
assert.Equal(t, original.String(), decoded.String())
|
||||||
|
assert.Equal(t, original.isString, decoded.isString)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewStringContent(t *testing.T) {
|
||||||
|
mc := NewStringContent("test")
|
||||||
|
assert.NotNil(t, mc)
|
||||||
|
assert.True(t, mc.isString)
|
||||||
|
assert.Equal(t, "test", mc.strValue)
|
||||||
|
assert.Equal(t, "test", mc.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewArrayContent(t *testing.T) {
|
||||||
|
parts := []OpenAIContentPart{
|
||||||
|
{Type: "text", Text: "Hello"},
|
||||||
|
}
|
||||||
|
mc := NewArrayContent(parts)
|
||||||
|
assert.NotNil(t, mc)
|
||||||
|
assert.False(t, mc.isString)
|
||||||
|
assert.Equal(t, parts, mc.arrValue)
|
||||||
|
assert.Equal(t, "Hello", mc.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMessageContent_String_EmptyArray(t *testing.T) {
|
||||||
|
mc := NewArrayContent([]OpenAIContentPart{})
|
||||||
|
assert.Equal(t, "", mc.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMessageContent_String_NoTextParts(t *testing.T) {
|
||||||
|
mc := NewArrayContent([]OpenAIContentPart{
|
||||||
|
{Type: "image_url", Text: ""},
|
||||||
|
})
|
||||||
|
assert.Equal(t, "", mc.String())
|
||||||
|
}
|
||||||
|
|
@ -49,6 +49,7 @@ require (
|
||||||
github.com/sbzhu/weworkapi_golang v0.0.0-20210525081115-1799804a7c8d
|
github.com/sbzhu/weworkapi_golang v0.0.0-20210525081115-1799804a7c8d
|
||||||
github.com/silenceper/wechat/v2 v2.1.9
|
github.com/silenceper/wechat/v2 v2.1.9
|
||||||
github.com/spf13/viper v1.20.1
|
github.com/spf13/viper v1.20.1
|
||||||
|
github.com/stretchr/testify v1.10.0
|
||||||
github.com/swaggo/echo-swagger v1.4.1
|
github.com/swaggo/echo-swagger v1.4.1
|
||||||
github.com/swaggo/swag v1.16.5
|
github.com/swaggo/swag v1.16.5
|
||||||
github.com/tidwall/gjson v1.14.1
|
github.com/tidwall/gjson v1.14.1
|
||||||
|
|
@ -98,6 +99,7 @@ require (
|
||||||
github.com/cloudwego/eino-ext/components/model/openai v0.0.0-20250710065240-482d48888f25 // indirect
|
github.com/cloudwego/eino-ext/components/model/openai v0.0.0-20250710065240-482d48888f25 // indirect
|
||||||
github.com/cloudwego/eino-ext/libs/acl/openai v0.0.0-20250626133421-3c142631c961 // indirect
|
github.com/cloudwego/eino-ext/libs/acl/openai v0.0.0-20250626133421-3c142631c961 // indirect
|
||||||
github.com/cohesion-org/deepseek-go v1.2.8 // indirect
|
github.com/cohesion-org/deepseek-go v1.2.8 // indirect
|
||||||
|
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||||
github.com/dlclark/regexp2 v1.11.4 // indirect
|
github.com/dlclark/regexp2 v1.11.4 // indirect
|
||||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||||
|
|
@ -165,6 +167,7 @@ require (
|
||||||
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
|
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
|
||||||
github.com/perimeterx/marshmallow v1.1.5 // indirect
|
github.com/perimeterx/marshmallow v1.1.5 // indirect
|
||||||
github.com/pkg/errors v0.9.1 // indirect
|
github.com/pkg/errors v0.9.1 // indirect
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
github.com/rs/xid v1.6.0 // indirect
|
github.com/rs/xid v1.6.0 // indirect
|
||||||
github.com/sagikazarmark/locafero v0.9.0 // indirect
|
github.com/sagikazarmark/locafero v0.9.0 // indirect
|
||||||
github.com/sirupsen/logrus v1.9.3 // indirect
|
github.com/sirupsen/logrus v1.9.3 // indirect
|
||||||
|
|
|
||||||
|
|
@ -268,7 +268,9 @@ func (h *ShareChatHandler) ChatCompletions(c echo.Context) error {
|
||||||
var lastUserMessage string
|
var lastUserMessage string
|
||||||
for i := len(req.Messages) - 1; i >= 0; i-- {
|
for i := len(req.Messages) - 1; i >= 0; i-- {
|
||||||
if req.Messages[i].Role == "user" {
|
if req.Messages[i].Role == "user" {
|
||||||
lastUserMessage = req.Messages[i].Content
|
if req.Messages[i].Content != nil {
|
||||||
|
lastUserMessage = req.Messages[i].Content.String()
|
||||||
|
}
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -345,11 +347,12 @@ func (h *ShareChatHandler) handleOpenAIStreamResponse(c echo.Context, eventCh <-
|
||||||
Index: 0,
|
Index: 0,
|
||||||
Delta: domain.OpenAIMessage{
|
Delta: domain.OpenAIMessage{
|
||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
Content: event.Content,
|
Content: domain.NewStringContent(event.Content),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.writeOpenAIStreamEvent(c, streamResp); err != nil {
|
if err := h.writeOpenAIStreamEvent(c, streamResp); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
@ -397,7 +400,7 @@ func (h *ShareChatHandler) handleOpenAINonStreamResponse(c echo.Context, eventCh
|
||||||
Index: 0,
|
Index: 0,
|
||||||
Message: domain.OpenAIMessage{
|
Message: domain.OpenAIMessage{
|
||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
Content: content,
|
Content: domain.NewStringContent(content),
|
||||||
},
|
},
|
||||||
FinishReason: "stop",
|
FinishReason: "stop",
|
||||||
},
|
},
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue