This commit is contained in:
Cold
2025-11-27 18:03:01 +08:00
committed by 张斌
parent ad1ccc2bc1
commit d8410fab37
2 changed files with 27 additions and 29 deletions

View File

@@ -24,7 +24,7 @@ func NewClient(baseURL, apiKey string) *Client {
client := gclient.New() client := gclient.New()
client.SetHeader("Authorization", fmt.Sprintf("Bearer %s", apiKey)) client.SetHeader("Authorization", fmt.Sprintf("Bearer %s", apiKey))
client.SetHeader("Content-Type", "application/json") client.SetHeader("Content-Type", "application/json")
return &Client{ return &Client{
BaseURL: strings.TrimSuffix(baseURL, "/"), BaseURL: strings.TrimSuffix(baseURL, "/"),
APIKey: apiKey, APIKey: apiKey,
@@ -34,8 +34,8 @@ func NewClient(baseURL, apiKey string) *Client {
// CommonResponse 通用响应结构 // CommonResponse 通用响应结构
type CommonResponse struct { type CommonResponse struct {
Code int `json:"code"` Code int `json:"code"`
Message string `json:"message"` Message string `json:"message"`
Data interface{} `json:"data,omitempty"` Data interface{} `json:"data,omitempty"`
} }
@@ -47,7 +47,7 @@ func (r *CommonResponse) IsSuccess() bool {
// request 发送 HTTP 请求 // request 发送 HTTP 请求
func (c *Client) request(ctx context.Context, method, path string, body interface{}, result interface{}) error { func (c *Client) request(ctx context.Context, method, path string, body interface{}, result interface{}) error {
fullURL := c.BaseURL + path fullURL := c.BaseURL + path
var reqBody io.Reader var reqBody io.Reader
if body != nil { if body != nil {
jsonData, err := json.Marshal(body) jsonData, err := json.Marshal(body)
@@ -56,10 +56,10 @@ func (c *Client) request(ctx context.Context, method, path string, body interfac
} }
reqBody = strings.NewReader(string(jsonData)) reqBody = strings.NewReader(string(jsonData))
} }
var resp *gclient.Response var resp *gclient.Response
var err error var err error
switch method { switch method {
case "GET": case "GET":
resp, err = c.HTTPClient.Get(ctx, fullURL) resp, err = c.HTTPClient.Get(ctx, fullURL)
@@ -72,25 +72,25 @@ func (c *Client) request(ctx context.Context, method, path string, body interfac
default: default:
return fmt.Errorf("unsupported method: %s", method) return fmt.Errorf("unsupported method: %s", method)
} }
if err != nil { if err != nil {
return fmt.Errorf("http request failed: %w", err) return fmt.Errorf("http request failed: %w", err)
} }
defer resp.Close() defer resp.Close()
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
return fmt.Errorf("http request failed with status: %d", resp.StatusCode) return fmt.Errorf("http request failed with status: %d", resp.StatusCode)
} }
respBody, err := resp.ReadAll() respBody := resp.ReadAll()
if err != nil { if err != nil {
return fmt.Errorf("read response body failed: %w", err) return fmt.Errorf("read response body failed: %w", err)
} }
if err := json.Unmarshal(respBody, result); err != nil { if err := json.Unmarshal(respBody, result); err != nil {
return fmt.Errorf("unmarshal response failed: %w", err) return fmt.Errorf("unmarshal response failed: %w", err)
} }
return nil return nil
} }
@@ -99,11 +99,10 @@ func buildQueryString(params map[string]interface{}) string {
if len(params) == 0 { if len(params) == 0 {
return "" return ""
} }
var parts []string var parts []string
for k, v := range params { for k, v := range params {
parts = append(parts, fmt.Sprintf("%s=%v", url.QueryEscape(k), url.QueryEscape(fmt.Sprintf("%v", v)))) parts = append(parts, fmt.Sprintf("%s=%v", url.QueryEscape(k), url.QueryEscape(fmt.Sprintf("%v", v))))
} }
return strings.Join(parts, "&") return strings.Join(parts, "&")
} }

View File

@@ -11,15 +11,15 @@ import (
// ChatCompletionMessage OpenAI 格式的消息 // ChatCompletionMessage OpenAI 格式的消息
type ChatCompletionMessage struct { type ChatCompletionMessage struct {
Role string `json:"role"` // "user", "assistant", "system" Role string `json:"role"` // "user", "assistant", "system"
Content string `json:"content"` Content string `json:"content"`
} }
// ChatCompletionRequest OpenAI 格式的聊天补全请求 // ChatCompletionRequest OpenAI 格式的聊天补全请求
type ChatCompletionRequest struct { type ChatCompletionRequest struct {
Model string `json:"model"` // 模型名称(服务器会自动解析,可设置为任意值) Model string `json:"model"` // 模型名称(服务器会自动解析,可设置为任意值)
Messages []ChatCompletionMessage `json:"messages"` // 消息列表,必须至少包含一条 user 消息 Messages []ChatCompletionMessage `json:"messages"` // 消息列表,必须至少包含一条 user 消息
Stream bool `json:"stream,omitempty"` // 是否流式返回,默认 false Stream bool `json:"stream,omitempty"` // 是否流式返回,默认 false
} }
// ChatCompletionResponse OpenAI 格式的聊天补全响应(非流式) // ChatCompletionResponse OpenAI 格式的聊天补全响应(非流式)
@@ -29,9 +29,9 @@ type ChatCompletionResponse struct {
Created int64 `json:"created"` Created int64 `json:"created"`
Model string `json:"model"` Model string `json:"model"`
Choices []struct { Choices []struct {
Index int `json:"index"` Index int `json:"index"`
Message ChatCompletionMessage `json:"message"` Message ChatCompletionMessage `json:"message"`
FinishReason string `json:"finish_reason"` FinishReason string `json:"finish_reason"`
} `json:"choices"` } `json:"choices"`
Usage struct { Usage struct {
PromptTokens int `json:"prompt_tokens"` PromptTokens int `json:"prompt_tokens"`
@@ -47,8 +47,8 @@ type ChatCompletionChunk struct {
Created int64 `json:"created"` Created int64 `json:"created"`
Model string `json:"model"` Model string `json:"model"`
Choices []struct { Choices []struct {
Index int `json:"index"` Index int `json:"index"`
Delta struct { Delta struct {
Content string `json:"content"` Content string `json:"content"`
Role string `json:"role"` Role string `json:"role"`
} `json:"delta"` } `json:"delta"`
@@ -65,12 +65,12 @@ type ChatCompletionChunk struct {
// POST /api/v1/chats_openai/{chat_id}/chat/completions // POST /api/v1/chats_openai/{chat_id}/chat/completions
func (c *Client) CreateChatCompletion(ctx context.Context, chatID string, req *ChatCompletionRequest) (*ChatCompletionResponse, error) { func (c *Client) CreateChatCompletion(ctx context.Context, chatID string, req *ChatCompletionRequest) (*ChatCompletionResponse, error) {
path := fmt.Sprintf("/api/v1/chats_openai/%s/chat/completions", chatID) path := fmt.Sprintf("/api/v1/chats_openai/%s/chat/completions", chatID)
var resp ChatCompletionResponse var resp ChatCompletionResponse
if err := c.request(ctx, "POST", path, req, &resp); err != nil { if err := c.request(ctx, "POST", path, req, &resp); err != nil {
return nil, fmt.Errorf("create chat completion failed: %w", err) return nil, fmt.Errorf("create chat completion failed: %w", err)
} }
return &resp, nil return &resp, nil
} }
@@ -78,12 +78,12 @@ func (c *Client) CreateChatCompletion(ctx context.Context, chatID string, req *C
// POST /api/v1/agents_openai/{agent_id}/chat/completions // POST /api/v1/agents_openai/{agent_id}/chat/completions
func (c *Client) CreateAgentCompletion(ctx context.Context, agentID string, req *ChatCompletionRequest) (*ChatCompletionResponse, error) { func (c *Client) CreateAgentCompletion(ctx context.Context, agentID string, req *ChatCompletionRequest) (*ChatCompletionResponse, error) {
path := fmt.Sprintf("/api/v1/agents_openai/%s/chat/completions", agentID) path := fmt.Sprintf("/api/v1/agents_openai/%s/chat/completions", agentID)
var resp ChatCompletionResponse var resp ChatCompletionResponse
if err := c.request(ctx, "POST", path, req, &resp); err != nil { if err := c.request(ctx, "POST", path, req, &resp); err != nil {
return nil, fmt.Errorf("create agent completion failed: %w", err) return nil, fmt.Errorf("create agent completion failed: %w", err)
} }
return &resp, nil return &resp, nil
} }
@@ -91,8 +91,8 @@ func (c *Client) CreateAgentCompletion(ctx context.Context, agentID string, req
// 注意:流式响应需要特殊处理,这里返回一个可用于读取流的接口 // 注意:流式响应需要特殊处理,这里返回一个可用于读取流的接口
func (c *Client) CreateChatCompletionStream(ctx context.Context, chatID string, req *ChatCompletionRequest) (*StreamReader, error) { func (c *Client) CreateChatCompletionStream(ctx context.Context, chatID string, req *ChatCompletionRequest) (*StreamReader, error) {
req.Stream = true req.Stream = true
apiPath := fmt.Sprintf("/api/v1/chats_openai/%s/chat/completions", chatID) _ = fmt.Sprintf("/api/v1/chats_openai/%s/chat/completions", chatID)
// TODO: 实现流式读取逻辑 // TODO: 实现流式读取逻辑
return nil, fmt.Errorf("stream mode not implemented yet") return nil, fmt.Errorf("stream mode not implemented yet")
} }
@@ -119,4 +119,3 @@ func (sr *StreamReader) Close() error {
} }
return nil return nil
} }