mirror of https://github.com/chaitin/PandaWiki.git
Compare commits
2 Commits
552fd06c85
...
5231b61db4
| Author | SHA1 | Date |
|---|---|---|
|
|
5231b61db4 | |
|
|
3f9124c649 |
|
|
@ -1,10 +1,16 @@
|
|||
package domain
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// OpenAI API 请求结构体
|
||||
type OpenAICompletionsRequest struct {
|
||||
Model string `json:"model" validate:"required"`
|
||||
Messages []OpenAIMessage `json:"messages" validate:"required"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
StreamOptions *OpenAIStreamOptions `json:"stream_options,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
MaxTokens *int `json:"max_tokens,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
|
|
@ -17,9 +23,70 @@ type OpenAICompletionsRequest struct {
|
|||
ResponseFormat *OpenAIResponseFormat `json:"response_format,omitempty"`
|
||||
}
|
||||
|
||||
type OpenAIStreamOptions struct {
|
||||
IncludeUsage bool `json:"include_usage,omitempty"`
|
||||
}
|
||||
|
||||
// MessageContent 支持字符串或内容数组
|
||||
type MessageContent struct {
|
||||
isString bool
|
||||
strValue string
|
||||
arrValue []OpenAIContentPart
|
||||
}
|
||||
|
||||
// OpenAIContentPart 表示内容数组中的单个元素
|
||||
type OpenAIContentPart struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
}
|
||||
|
||||
// UnmarshalJSON 自定义解析,支持 string 或 array 格式
|
||||
func (mc *MessageContent) UnmarshalJSON(data []byte) error {
|
||||
// 尝试解析为字符串
|
||||
var str string
|
||||
if err := json.Unmarshal(data, &str); err == nil {
|
||||
mc.isString = true
|
||||
mc.strValue = str
|
||||
return nil
|
||||
}
|
||||
|
||||
// 尝试解析为数组
|
||||
var arr []OpenAIContentPart
|
||||
if err := json.Unmarshal(data, &arr); err == nil {
|
||||
mc.isString = false
|
||||
mc.arrValue = arr
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("content must be string or array")
|
||||
}
|
||||
|
||||
// MarshalJSON 自定义序列化
|
||||
func (mc MessageContent) MarshalJSON() ([]byte, error) {
|
||||
if mc.isString {
|
||||
return json.Marshal(mc.strValue)
|
||||
}
|
||||
return json.Marshal(mc.arrValue)
|
||||
}
|
||||
|
||||
// String 获取文本内容
|
||||
func (mc *MessageContent) String() string {
|
||||
if mc.isString {
|
||||
return mc.strValue
|
||||
}
|
||||
// 从数组中提取文本
|
||||
var result string
|
||||
for _, part := range mc.arrValue {
|
||||
if part.Type == "text" {
|
||||
result += part.Text
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
type OpenAIMessage struct {
|
||||
Role string `json:"role" validate:"required"`
|
||||
Content string `json:"content,omitempty"`
|
||||
Content *MessageContent `json:"content,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
ToolCalls []OpenAIToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
|
|
@ -90,12 +157,13 @@ type OpenAIStreamResponse struct {
|
|||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []OpenAIStreamChoice `json:"choices"`
|
||||
Usage *OpenAIUsage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
type OpenAIStreamChoice struct {
|
||||
Index int `json:"index"`
|
||||
Delta OpenAIMessage `json:"delta"`
|
||||
FinishReason *string `json:"finish_reason,omitempty"`
|
||||
Index int `json:"index"`
|
||||
Delta OpenAIMessage `json:"delta"`
|
||||
FinishReason *string `json:"finish_reason,omitempty"`
|
||||
}
|
||||
|
||||
// OpenAI 错误响应结构体
|
||||
|
|
|
|||
|
|
@ -268,7 +268,9 @@ func (h *ShareChatHandler) ChatCompletions(c echo.Context) error {
|
|||
var lastUserMessage string
|
||||
for i := len(req.Messages) - 1; i >= 0; i-- {
|
||||
if req.Messages[i].Role == "user" {
|
||||
lastUserMessage = req.Messages[i].Content
|
||||
if req.Messages[i].Content != nil {
|
||||
lastUserMessage = req.Messages[i].Content.String()
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
|
@ -345,11 +347,18 @@ func (h *ShareChatHandler) handleOpenAIStreamResponse(c echo.Context, eventCh <-
|
|||
Index: 0,
|
||||
Delta: domain.OpenAIMessage{
|
||||
Role: "assistant",
|
||||
Content: event.Content,
|
||||
Content: &domain.MessageContent{},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
// 使用临时变量设置 Content
|
||||
content := &domain.MessageContent{}
|
||||
*content = domain.MessageContent{}
|
||||
// 手动设置为字符串类型
|
||||
json.Unmarshal([]byte(`"`+event.Content+`"`), content)
|
||||
streamResp.Choices[0].Delta.Content = content
|
||||
|
||||
if err := h.writeOpenAIStreamEvent(c, streamResp); err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -387,6 +396,10 @@ func (h *ShareChatHandler) handleOpenAINonStreamResponse(c echo.Context, eventCh
|
|||
content += event.Content
|
||||
case "done":
|
||||
// send complete response
|
||||
// 构造 MessageContent
|
||||
messageContent := &domain.MessageContent{}
|
||||
json.Unmarshal([]byte(`"`+content+`"`), messageContent)
|
||||
|
||||
resp := domain.OpenAICompletionsResponse{
|
||||
ID: responseID,
|
||||
Object: "chat.completion",
|
||||
|
|
@ -397,7 +410,7 @@ func (h *ShareChatHandler) handleOpenAINonStreamResponse(c echo.Context, eventCh
|
|||
Index: 0,
|
||||
Message: domain.OpenAIMessage{
|
||||
Role: "assistant",
|
||||
Content: content,
|
||||
Content: messageContent,
|
||||
},
|
||||
FinishReason: "stop",
|
||||
},
|
||||
|
|
|
|||
Loading…
Reference in New Issue