update: 2026-03-28 20:59
This commit is contained in:
213
infogenie-backend-go/internal/service/ai.go
Normal file
213
infogenie-backend-go/internal/service/ai.go
Normal file
@@ -0,0 +1,213 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"infogenie-backend/internal/database"
|
||||
"infogenie-backend/internal/model"
|
||||
)
|
||||
|
||||
type ChatMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type chatRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []ChatMessage `json:"messages"`
|
||||
Temperature float64 `json:"temperature"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
}
|
||||
|
||||
type chatResponse struct {
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"message"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
|
||||
// loadAIConfig 从数据库读取AI配置
|
||||
func loadAIConfig(provider string) (apiKey, apiBase, defaultModel string, models []string, ok bool) {
|
||||
if database.DB == nil {
|
||||
return "", "", "", nil, false
|
||||
}
|
||||
|
||||
var config model.AIConfig
|
||||
if err := database.DB.Where("provider = ? AND is_enabled = ?", provider, true).First(&config).Error; err != nil {
|
||||
return "", "", "", nil, false
|
||||
}
|
||||
|
||||
// 解析models JSON
|
||||
var modelList []string
|
||||
if config.Models != "" {
|
||||
if err := json.Unmarshal([]byte(config.Models), &modelList); err != nil {
|
||||
// 如果解析失败,返回空的模型列表
|
||||
modelList = []string{}
|
||||
}
|
||||
}
|
||||
|
||||
return config.APIKey, config.APIBase, config.DefaultModel, modelList, true
|
||||
}
|
||||
|
||||
// loadRuntimeDeepSeek 读取管理员在后台配置的 DeepSeek 兼容接口(OpenAI 格式),优先于 ai_config.json
|
||||
func loadRuntimeDeepSeek() (apiBase, apiKey, defModel string, ok bool) {
|
||||
if database.DB == nil {
|
||||
return "", "", "", false
|
||||
}
|
||||
var row model.SiteAIRuntime
|
||||
if err := database.DB.First(&row, 1).Error; err != nil {
|
||||
return "", "", "", false
|
||||
}
|
||||
base := strings.TrimSpace(row.APIBase)
|
||||
key := strings.TrimSpace(row.APIKey)
|
||||
dm := strings.TrimSpace(row.DefaultModel)
|
||||
if base != "" && key != "" {
|
||||
return base, key, dm, true
|
||||
}
|
||||
return "", "", "", false
|
||||
}
|
||||
|
||||
func CallDeepSeek(messages []ChatMessage, model string, maxRetries int) (string, error) {
|
||||
// 首先尝试从SiteAIRuntime读取配置(向后兼容)
|
||||
if base, key, defModel, ok := loadRuntimeDeepSeek(); ok {
|
||||
if model == "" {
|
||||
model = defModel
|
||||
}
|
||||
if model == "" {
|
||||
model = "deepseek-chat"
|
||||
}
|
||||
url := strings.TrimSuffix(base, "/") + "/chat/completions"
|
||||
return callOpenAICompatible(url, key, model, messages, maxRetries, 90*time.Second)
|
||||
}
|
||||
|
||||
// 从新的AI配置表读取
|
||||
if apiKey, apiBase, defaultModel, models, ok := loadAIConfig("deepseek"); ok {
|
||||
if model == "" {
|
||||
model = defaultModel
|
||||
}
|
||||
if model == "" {
|
||||
model = "deepseek-chat"
|
||||
}
|
||||
// 验证模型是否在允许列表中
|
||||
if len(models) > 0 {
|
||||
allowed := false
|
||||
for _, m := range models {
|
||||
if m == model {
|
||||
allowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !allowed {
|
||||
model = models[0] // 使用第一个允许的模型
|
||||
}
|
||||
}
|
||||
url := strings.TrimSuffix(apiBase, "/") + "/chat/completions"
|
||||
return callOpenAICompatible(url, apiKey, model, messages, maxRetries, 90*time.Second)
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("DeepSeek配置未设置,请在管理员后台配置API Key和Base URL")
|
||||
}
|
||||
|
||||
func CallKimi(messages []ChatMessage, model string) (string, error) {
|
||||
// 从新的AI配置表读取
|
||||
if apiKey, apiBase, defaultModel, models, ok := loadAIConfig("kimi"); ok {
|
||||
if model == "" {
|
||||
model = defaultModel
|
||||
}
|
||||
if model == "" {
|
||||
model = "kimi-k2-0905-preview"
|
||||
}
|
||||
// 验证模型是否在允许列表中
|
||||
if len(models) > 0 {
|
||||
allowed := false
|
||||
for _, m := range models {
|
||||
if m == model {
|
||||
allowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !allowed {
|
||||
model = models[0] // 使用第一个允许的模型
|
||||
}
|
||||
}
|
||||
url := strings.TrimSuffix(apiBase, "/") + "/v1/chat/completions"
|
||||
return callOpenAICompatible(url, apiKey, model, messages, 1, 30*time.Second)
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("Kimi配置未设置,请在管理员后台配置API Key和Base URL")
|
||||
}
|
||||
|
||||
func callOpenAICompatible(url, apiKey, model string, messages []ChatMessage, maxRetries int, timeout time.Duration) (string, error) {
|
||||
reqBody := chatRequest{
|
||||
Model: model,
|
||||
Messages: messages,
|
||||
Temperature: 0.7,
|
||||
MaxTokens: 2000,
|
||||
}
|
||||
|
||||
bodyBytes, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("序列化请求失败: %w", err)
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: timeout}
|
||||
|
||||
var lastErr error
|
||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||
req, _ := http.NewRequest("POST", url, bytes.NewReader(bodyBytes))
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
if attempt < maxRetries-1 {
|
||||
backoff := time.Duration(math.Pow(2, float64(attempt))) * time.Second
|
||||
time.Sleep(backoff)
|
||||
continue
|
||||
}
|
||||
return "", fmt.Errorf("API调用异常(已重试%d次): %w", maxRetries, err)
|
||||
}
|
||||
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == 200 {
|
||||
var result chatResponse
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
return "", fmt.Errorf("解析响应失败: %w", err)
|
||||
}
|
||||
if len(result.Choices) == 0 {
|
||||
return "", fmt.Errorf("AI未返回有效内容")
|
||||
}
|
||||
return result.Choices[0].Message.Content, nil
|
||||
}
|
||||
|
||||
lastErr = fmt.Errorf("API调用失败: %d - %s", resp.StatusCode, string(respBody))
|
||||
if attempt < maxRetries-1 {
|
||||
backoff := time.Duration(math.Pow(2, float64(attempt))) * time.Second
|
||||
time.Sleep(backoff)
|
||||
}
|
||||
}
|
||||
|
||||
return "", lastErr
|
||||
}
|
||||
|
||||
func CallAI(provider, model string, messages []ChatMessage) (string, error) {
|
||||
switch provider {
|
||||
case "deepseek":
|
||||
return CallDeepSeek(messages, model, 3)
|
||||
case "kimi":
|
||||
return CallKimi(messages, model)
|
||||
default:
|
||||
return "", fmt.Errorf("不支持的AI提供商: %s,目前支持的提供商: deepseek, kimi", provider)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user