Compare commits

...

2 Commits

Author SHA1 Message Date
xiaomakuaiz cfc53da267 优化 OpenAI API 兼容性实现
根据 PR #1512 代码审查意见,进行以下优化:

**安全性修复 (P0)**
- 修复 JSON 注入安全漏洞:移除不安全的字符串拼接构造 JSON 的方式
- 添加 NewStringContent 和 NewArrayContent 构造函数,直接构造对象而非通过 JSON 序列化

**性能优化**
- String() 方法使用 strings.Builder 替代字符串拼接,提升性能
- 多个 text 部分之间添加空格分隔符,避免语义错误

**测试覆盖**
- 添加完整的单元测试覆盖 MessageContent 类型
- 测试字符串格式解析(包括特殊字符、Unicode、换行符等)
- 测试数组格式解析(单个/多个 text 部分、混合类型)
- 测试无效输入处理
- 测试序列化/反序列化往返

**代码改进**
- 更新 handler/share/chat.go 使用安全的构造函数
- 所有测试通过验证

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-authored-by: MonkeyCode-AI <monkeycode-ai@chaitin.com>

Co-authored-by: MonkeyCode-AI <monkeycode-ai@chaitin.com>
2025-11-14 15:43:26 +08:00
xiaomakuaiz 3f9124c649 修复 /share/v1/chat/completions 接口 OpenAI 兼容性问题
现有的 OpenAI API 兼容接口不支持标准的 OpenAI messages 格式,特别是当 content 字段为数组格式时会解析失败。

1. **扩展 MessageContent 类型**:实现自定义的 JSON 序列化/反序列化,支持 content 既可以是字符串,也可以是包含 text/type 的对象数组
2. **添加 stream_options 支持**:支持 OpenAI 标准的 stream_options 参数(如 include_usage)
3. **更新响应格式**:在流式响应中添加 usage 字段支持,符合 OpenAI 标准

- `domain/openai.go`:
  - 新增 `MessageContent` 类型及其 JSON 序列化方法
  - 新增 `OpenAIStreamOptions` 结构体
  - 更新 `OpenAIMessage.Content` 类型从 string 改为 *MessageContent
  - 在流式响应中添加 usage 字段

- `handler/share/chat.go`:
  - 更新消息内容提取逻辑,使用 MessageContent.String() 方法
  - 修复流式和非流式响应中的 content 序列化

- 已通过单元测试验证 MessageContent 可以正确解析字符串和数组格式
- 编译通过,无语法错误

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-authored-by: monkeycode-ai <monkeycode-ai@chaitin.com>

Co-authored-by: monkeycode-ai <monkeycode-ai@chaitin.com>
2025-11-12 19:51:42 +08:00
4 changed files with 287 additions and 7 deletions

View File

@ -1,10 +1,17 @@
package domain
import (
"encoding/json"
"fmt"
"strings"
)
// OpenAI API 请求结构体
type OpenAICompletionsRequest struct {
Model string `json:"model" validate:"required"`
Messages []OpenAIMessage `json:"messages" validate:"required"`
Stream bool `json:"stream,omitempty"`
StreamOptions *OpenAIStreamOptions `json:"stream_options,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
MaxTokens *int `json:"max_tokens,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
@ -17,9 +24,89 @@ type OpenAICompletionsRequest struct {
ResponseFormat *OpenAIResponseFormat `json:"response_format,omitempty"`
}
type OpenAIStreamOptions struct {
IncludeUsage bool `json:"include_usage,omitempty"`
}
// MessageContent 支持字符串或内容数组
type MessageContent struct {
isString bool
strValue string
arrValue []OpenAIContentPart
}
// OpenAIContentPart 表示内容数组中的单个元素
type OpenAIContentPart struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
}
// UnmarshalJSON 自定义解析,支持 string 或 array 格式
func (mc *MessageContent) UnmarshalJSON(data []byte) error {
// 尝试解析为字符串
var str string
if err := json.Unmarshal(data, &str); err == nil {
mc.isString = true
mc.strValue = str
return nil
}
// 尝试解析为数组
var arr []OpenAIContentPart
if err := json.Unmarshal(data, &arr); err == nil {
mc.isString = false
mc.arrValue = arr
return nil
}
return fmt.Errorf("content must be string or array")
}
// MarshalJSON 自定义序列化
func (mc MessageContent) MarshalJSON() ([]byte, error) {
if mc.isString {
return json.Marshal(mc.strValue)
}
return json.Marshal(mc.arrValue)
}
// NewStringContent 创建字符串类型的 MessageContent
func NewStringContent(s string) *MessageContent {
return &MessageContent{
isString: true,
strValue: s,
}
}
// NewArrayContent 创建数组类型的 MessageContent
func NewArrayContent(parts []OpenAIContentPart) *MessageContent {
return &MessageContent{
isString: false,
arrValue: parts,
}
}
// String 获取文本内容
func (mc *MessageContent) String() string {
if mc.isString {
return mc.strValue
}
// 从数组中提取文本
var builder strings.Builder
for i, part := range mc.arrValue {
if part.Type == "text" {
if i > 0 && part.Text != "" {
builder.WriteString(" ")
}
builder.WriteString(part.Text)
}
}
return builder.String()
}
type OpenAIMessage struct {
Role string `json:"role" validate:"required"`
Content string `json:"content,omitempty"`
Content *MessageContent `json:"content,omitempty"`
Name string `json:"name,omitempty"`
ToolCalls []OpenAIToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
@ -90,12 +177,13 @@ type OpenAIStreamResponse struct {
Created int64 `json:"created"`
Model string `json:"model"`
Choices []OpenAIStreamChoice `json:"choices"`
Usage *OpenAIUsage `json:"usage,omitempty"`
}
type OpenAIStreamChoice struct {
Index int `json:"index"`
Delta OpenAIMessage `json:"delta"`
FinishReason *string `json:"finish_reason,omitempty"`
Index int `json:"index"`
Delta OpenAIMessage `json:"delta"`
FinishReason *string `json:"finish_reason,omitempty"`
}
// OpenAI 错误响应结构体

View File

@ -0,0 +1,186 @@
package domain
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestMessageContent_UnmarshalJSON_String(t *testing.T) {
tests := []struct {
name string
json string
expected string
}{
{"simple string", `"hello"`, "hello"},
{"with quotes", `"say \"hello\""`, `say "hello"`},
{"with newline", `"line1\nline2"`, "line1\nline2"},
{"empty string", `""`, ""},
{"unicode", `"你好 🌍"`, "你好 🌍"},
{"special chars", `"Hello \"World\"\nNew Line\tTab"`, "Hello \"World\"\nNew Line\tTab"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var mc MessageContent
err := json.Unmarshal([]byte(tt.json), &mc)
require.NoError(t, err)
assert.Equal(t, tt.expected, mc.String())
assert.True(t, mc.isString)
})
}
}
func TestMessageContent_UnmarshalJSON_Array(t *testing.T) {
tests := []struct {
name string
json string
expected string
}{
{
"single text part",
`[{"type":"text","text":"Hello"}]`,
"Hello",
},
{
"multiple text parts",
`[{"type":"text","text":"Hello"},{"type":"text","text":"World"}]`,
"Hello World",
},
{
"mixed types with image",
`[{"type":"text","text":"Look at this"},{"type":"image_url","image_url":{"url":"https://example.com/img.png"}},{"type":"text","text":"image"}]`,
"Look at this image",
},
{
"empty array",
`[]`,
"",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var mc MessageContent
err := json.Unmarshal([]byte(tt.json), &mc)
require.NoError(t, err)
assert.Equal(t, tt.expected, mc.String())
assert.False(t, mc.isString)
})
}
}
func TestMessageContent_UnmarshalJSON_Invalid(t *testing.T) {
tests := []struct {
name string
json string
}{
{"number", `123`},
{"boolean", `true`},
{"object", `{"key":"value"}`},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var mc MessageContent
err := json.Unmarshal([]byte(tt.json), &mc)
assert.Error(t, err)
assert.Contains(t, err.Error(), "content must be string or array")
})
}
}
func TestMessageContent_UnmarshalJSON_Null(t *testing.T) {
var mc *MessageContent
err := json.Unmarshal([]byte(`null`), &mc)
assert.NoError(t, err)
assert.Nil(t, mc)
}
func TestMessageContent_MarshalJSON_String(t *testing.T) {
mc := NewStringContent("Hello World")
data, err := json.Marshal(mc)
require.NoError(t, err)
assert.Equal(t, `"Hello World"`, string(data))
}
func TestMessageContent_MarshalJSON_Array(t *testing.T) {
mc := NewArrayContent([]OpenAIContentPart{
{Type: "text", Text: "Hello"},
{Type: "text", Text: "World"},
})
data, err := json.Marshal(mc)
require.NoError(t, err)
assert.JSONEq(t, `[{"type":"text","text":"Hello"},{"type":"text","text":"World"}]`, string(data))
}
func TestMessageContent_Roundtrip_String(t *testing.T) {
original := NewStringContent("Test message with \"quotes\" and \nnewlines")
// Marshal
data, err := json.Marshal(original)
require.NoError(t, err)
// Unmarshal
var decoded MessageContent
err = json.Unmarshal(data, &decoded)
require.NoError(t, err)
// Verify
assert.Equal(t, original.String(), decoded.String())
assert.Equal(t, original.isString, decoded.isString)
}
func TestMessageContent_Roundtrip_Array(t *testing.T) {
parts := []OpenAIContentPart{
{Type: "text", Text: "Part 1"},
{Type: "text", Text: "Part 2"},
}
original := NewArrayContent(parts)
// Marshal
data, err := json.Marshal(original)
require.NoError(t, err)
// Unmarshal
var decoded MessageContent
err = json.Unmarshal(data, &decoded)
require.NoError(t, err)
// Verify
assert.Equal(t, original.String(), decoded.String())
assert.Equal(t, original.isString, decoded.isString)
}
func TestNewStringContent(t *testing.T) {
mc := NewStringContent("test")
assert.NotNil(t, mc)
assert.True(t, mc.isString)
assert.Equal(t, "test", mc.strValue)
assert.Equal(t, "test", mc.String())
}
func TestNewArrayContent(t *testing.T) {
parts := []OpenAIContentPart{
{Type: "text", Text: "Hello"},
}
mc := NewArrayContent(parts)
assert.NotNil(t, mc)
assert.False(t, mc.isString)
assert.Equal(t, parts, mc.arrValue)
assert.Equal(t, "Hello", mc.String())
}
func TestMessageContent_String_EmptyArray(t *testing.T) {
mc := NewArrayContent([]OpenAIContentPart{})
assert.Equal(t, "", mc.String())
}
func TestMessageContent_String_NoTextParts(t *testing.T) {
mc := NewArrayContent([]OpenAIContentPart{
{Type: "image_url", Text: ""},
})
assert.Equal(t, "", mc.String())
}

View File

@ -49,6 +49,7 @@ require (
github.com/sbzhu/weworkapi_golang v0.0.0-20210525081115-1799804a7c8d
github.com/silenceper/wechat/v2 v2.1.9
github.com/spf13/viper v1.20.1
github.com/stretchr/testify v1.10.0
github.com/swaggo/echo-swagger v1.4.1
github.com/swaggo/swag v1.16.5
github.com/tidwall/gjson v1.14.1
@ -98,6 +99,7 @@ require (
github.com/cloudwego/eino-ext/components/model/openai v0.0.0-20250710065240-482d48888f25 // indirect
github.com/cloudwego/eino-ext/libs/acl/openai v0.0.0-20250626133421-3c142631c961 // indirect
github.com/cohesion-org/deepseek-go v1.2.8 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/dlclark/regexp2 v1.11.4 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
@ -165,6 +167,7 @@ require (
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
github.com/perimeterx/marshmallow v1.1.5 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rs/xid v1.6.0 // indirect
github.com/sagikazarmark/locafero v0.9.0 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect

View File

@ -268,7 +268,9 @@ func (h *ShareChatHandler) ChatCompletions(c echo.Context) error {
var lastUserMessage string
for i := len(req.Messages) - 1; i >= 0; i-- {
if req.Messages[i].Role == "user" {
lastUserMessage = req.Messages[i].Content
if req.Messages[i].Content != nil {
lastUserMessage = req.Messages[i].Content.String()
}
break
}
}
@ -345,11 +347,12 @@ func (h *ShareChatHandler) handleOpenAIStreamResponse(c echo.Context, eventCh <-
Index: 0,
Delta: domain.OpenAIMessage{
Role: "assistant",
Content: event.Content,
Content: domain.NewStringContent(event.Content),
},
},
},
}
if err := h.writeOpenAIStreamEvent(c, streamResp); err != nil {
return err
}
@ -397,7 +400,7 @@ func (h *ShareChatHandler) handleOpenAINonStreamResponse(c echo.Context, eventCh
Index: 0,
Message: domain.OpenAIMessage{
Role: "assistant",
Content: content,
Content: domain.NewStringContent(content),
},
FinishReason: "stop",
},