package service import ( "bytes" "encoding/json" "fmt" "io" "math" "net/http" "strings" "time" "infogenie-backend/internal/database" "infogenie-backend/internal/model" ) type ChatMessage struct { Role string `json:"role"` Content string `json:"content"` } type chatRequest struct { Model string `json:"model"` Messages []ChatMessage `json:"messages"` Temperature float64 `json:"temperature"` MaxTokens int `json:"max_tokens"` } type chatResponse struct { Choices []struct { Message struct { Content string `json:"content"` } `json:"message"` } `json:"choices"` } // loadAIConfig 从数据库读取AI配置 func loadAIConfig(provider string) (apiKey, apiBase, defaultModel string, models []string, ok bool) { if database.DB == nil { return "", "", "", nil, false } var config model.AIConfig if err := database.DB.Where("provider = ? AND is_enabled = ?", provider, true).First(&config).Error; err != nil { return "", "", "", nil, false } // 解析models JSON var modelList []string if config.Models != "" { if err := json.Unmarshal([]byte(config.Models), &modelList); err != nil { // 如果解析失败,返回空的模型列表 modelList = []string{} } } return config.APIKey, config.APIBase, config.DefaultModel, modelList, true } // loadRuntimeDeepSeek 读取管理员在后台配置的 DeepSeek 兼容接口(OpenAI 格式),优先于 ai_config.json func loadRuntimeDeepSeek() (apiBase, apiKey, defModel string, ok bool) { if database.DB == nil { return "", "", "", false } var row model.SiteAIRuntime if err := database.DB.First(&row, 1).Error; err != nil { return "", "", "", false } base := strings.TrimSpace(row.APIBase) key := strings.TrimSpace(row.APIKey) dm := strings.TrimSpace(row.DefaultModel) if base != "" && key != "" { return base, key, dm, true } return "", "", "", false } func CallDeepSeek(messages []ChatMessage, model string, maxRetries int) (string, error) { // 首先尝试从SiteAIRuntime读取配置(向后兼容) if base, key, defModel, ok := loadRuntimeDeepSeek(); ok { if model == "" { model = defModel } if model == "" { model = "deepseek-chat" } url := strings.TrimSuffix(base, "/") + "/chat/completions" return callOpenAICompatible(url, key, model, messages, maxRetries, 90*time.Second) } // 从新的AI配置表读取 if apiKey, apiBase, defaultModel, models, ok := loadAIConfig("deepseek"); ok { if model == "" { model = defaultModel } if model == "" { model = "deepseek-chat" } // 验证模型是否在允许列表中 if len(models) > 0 { allowed := false for _, m := range models { if m == model { allowed = true break } } if !allowed { model = models[0] // 使用第一个允许的模型 } } url := strings.TrimSuffix(apiBase, "/") + "/chat/completions" return callOpenAICompatible(url, apiKey, model, messages, maxRetries, 90*time.Second) } return "", fmt.Errorf("DeepSeek配置未设置,请在管理员后台配置API Key和Base URL") } func CallKimi(messages []ChatMessage, model string) (string, error) { // 从新的AI配置表读取 if apiKey, apiBase, defaultModel, models, ok := loadAIConfig("kimi"); ok { if model == "" { model = defaultModel } if model == "" { model = "kimi-k2-0905-preview" } // 验证模型是否在允许列表中 if len(models) > 0 { allowed := false for _, m := range models { if m == model { allowed = true break } } if !allowed { model = models[0] // 使用第一个允许的模型 } } url := strings.TrimSuffix(apiBase, "/") + "/v1/chat/completions" return callOpenAICompatible(url, apiKey, model, messages, 1, 30*time.Second) } return "", fmt.Errorf("Kimi配置未设置,请在管理员后台配置API Key和Base URL") } func callOpenAICompatible(url, apiKey, model string, messages []ChatMessage, maxRetries int, timeout time.Duration) (string, error) { reqBody := chatRequest{ Model: model, Messages: messages, Temperature: 0.7, MaxTokens: 2000, } bodyBytes, err := json.Marshal(reqBody) if err != nil { return "", fmt.Errorf("序列化请求失败: %w", err) } client := &http.Client{Timeout: timeout} var lastErr error for attempt := 0; attempt < maxRetries; attempt++ { req, _ := http.NewRequest("POST", url, bytes.NewReader(bodyBytes)) req.Header.Set("Authorization", "Bearer "+apiKey) req.Header.Set("Content-Type", "application/json") resp, err := client.Do(req) if err != nil { lastErr = err if attempt < maxRetries-1 { backoff := time.Duration(math.Pow(2, float64(attempt))) * time.Second time.Sleep(backoff) continue } return "", fmt.Errorf("API调用异常(已重试%d次): %w", maxRetries, err) } respBody, _ := io.ReadAll(resp.Body) resp.Body.Close() if resp.StatusCode == 200 { var result chatResponse if err := json.Unmarshal(respBody, &result); err != nil { return "", fmt.Errorf("解析响应失败: %w", err) } if len(result.Choices) == 0 { return "", fmt.Errorf("AI未返回有效内容") } return result.Choices[0].Message.Content, nil } lastErr = fmt.Errorf("API调用失败: %d - %s", resp.StatusCode, string(respBody)) if attempt < maxRetries-1 { backoff := time.Duration(math.Pow(2, float64(attempt))) * time.Second time.Sleep(backoff) } } return "", lastErr } func CallAI(provider, model string, messages []ChatMessage) (string, error) { switch provider { case "deepseek": return CallDeepSeek(messages, model, 3) case "kimi": return CallKimi(messages, model) default: return "", fmt.Errorf("不支持的AI提供商: %s,目前支持的提供商: deepseek, kimi", provider) } }