259 lines
9.0 KiB
Python
259 lines
9.0 KiB
Python
from flask import current_app
|
|
from models import db, User, ApiCall, Transaction
|
|
from datetime import datetime
|
|
import requests
|
|
import json
|
|
import logging
|
|
from modelapiservice import get_model_service
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class ApiProxyService:
|
|
@staticmethod
|
|
def deduct_balance(user_id, api_call_id, cost, model):
|
|
"""统一扣费逻辑"""
|
|
try:
|
|
user = User.query.get(user_id)
|
|
api_call = ApiCall.query.get(api_call_id)
|
|
if not user or not api_call:
|
|
return
|
|
|
|
# 刷新用户数据
|
|
db.session.refresh(user)
|
|
|
|
balance_before = user.balance
|
|
user.balance -= cost
|
|
balance_after = user.balance
|
|
|
|
api_call.status = 'success'
|
|
api_call.cost = cost
|
|
db.session.add(api_call)
|
|
|
|
transaction = Transaction(
|
|
user_id=user.id,
|
|
type='consume',
|
|
amount=-cost,
|
|
balance_before=balance_before,
|
|
balance_after=balance_after,
|
|
description=f'API调用 - {model}',
|
|
api_call_id=api_call.id
|
|
)
|
|
db.session.add(transaction)
|
|
db.session.commit()
|
|
logger.info(f"扣费成功: 用户 {user.id}, 消费 {cost}, 余额 {balance_before} -> {balance_after}")
|
|
except Exception as e:
|
|
logger.error(f'扣费失败: {e}')
|
|
db.session.rollback()
|
|
|
|
@staticmethod
|
|
def handle_api_request(user_id, data):
|
|
"""
|
|
通用 API 代理 (支持文生图和对话)
|
|
"""
|
|
user = User.query.get(user_id)
|
|
|
|
if not user:
|
|
return {'error': '用户不存在'}, 404
|
|
|
|
if not user.is_active:
|
|
return {'error': '账户已被禁用'}, 403
|
|
|
|
model = data.get('model')
|
|
messages = data.get('messages', [])
|
|
stream = data.get('stream', False)
|
|
|
|
# 验证必要字段
|
|
if not model:
|
|
return {'error': 'model字段不能为空'}, 400
|
|
|
|
if not messages or len(messages) == 0:
|
|
return {'error': 'messages不能为空'}, 400
|
|
|
|
# 获取模型服务并检查余额
|
|
try:
|
|
service = get_model_service(model)
|
|
except Exception as e:
|
|
return {'error': f'不支持的模型: {model}'}, 400
|
|
|
|
is_sufficient, estimated_cost, error_msg = service.check_balance(user.balance)
|
|
if not is_sufficient:
|
|
return {
|
|
'error': '余额不足',
|
|
'message': error_msg,
|
|
'required': estimated_cost,
|
|
'balance': user.balance
|
|
}, 402
|
|
|
|
prompt = messages[0].get('content', '') if messages else ''
|
|
if not prompt:
|
|
prompt = "Empty prompt"
|
|
|
|
# 创建API调用记录
|
|
api_call = ApiCall(
|
|
user_id=user_id,
|
|
api_type='chat_completion',
|
|
prompt=prompt[:500], # 截断
|
|
parameters=json.dumps({
|
|
'model': model,
|
|
'stream': stream
|
|
}),
|
|
status='processing',
|
|
cost=estimated_cost,
|
|
request_time=datetime.utcnow()
|
|
)
|
|
|
|
try:
|
|
db.session.add(api_call)
|
|
db.session.flush()
|
|
|
|
# 准备请求
|
|
api_url, api_key = service.get_api_config()
|
|
if not api_url or not api_key:
|
|
raise ValueError(f'模型 {model} API 配置未完成')
|
|
|
|
headers = {
|
|
'Authorization': f'Bearer {api_key}',
|
|
'Content-Type': 'application/json'
|
|
}
|
|
|
|
payload = service.prepare_payload(data)
|
|
target_url = f'{api_url}/chat/completions'
|
|
|
|
logger.info(f'API 转发: {target_url}, User: {user.id}, Model: {model}')
|
|
|
|
response = requests.post(
|
|
target_url,
|
|
headers=headers,
|
|
json=payload,
|
|
stream=stream,
|
|
timeout=300
|
|
)
|
|
|
|
if response.status_code != 200:
|
|
error_msg = f'第三方 API 返回错误: {response.status_code}'
|
|
try:
|
|
error_detail = response.json()
|
|
error_msg += f' - {error_detail}'
|
|
except:
|
|
error_msg += f' - {response.text[:200]}'
|
|
|
|
api_call.status = 'failed'
|
|
api_call.error_message = error_msg
|
|
db.session.commit()
|
|
return {'error': 'API 调用失败', 'details': error_msg}, 502
|
|
|
|
# 处理响应
|
|
if stream:
|
|
# 流式响应处理
|
|
def generate():
|
|
final_usage = None
|
|
try:
|
|
for chunk in response.iter_content(chunk_size=1024):
|
|
if chunk:
|
|
if hasattr(service, 'parse_stream_usage'):
|
|
try:
|
|
text_chunk = chunk.decode('utf-8', errors='ignore')
|
|
usage = service.parse_stream_usage(text_chunk)
|
|
if usage:
|
|
final_usage = usage
|
|
except:
|
|
pass
|
|
yield chunk
|
|
|
|
# 计算最终费用
|
|
actual_cost = service.calculate_cost(final_usage, stream=True)
|
|
if actual_cost == 0 and estimated_cost > 0:
|
|
actual_cost = estimated_cost
|
|
|
|
with current_app.app_context():
|
|
ApiProxyService.deduct_balance(user.id, api_call.id, actual_cost, model)
|
|
|
|
except Exception as e:
|
|
logger.error(f'Stream error: {e}')
|
|
|
|
return generate(), 200 # Special return for stream
|
|
else:
|
|
result = response.json()
|
|
api_call.status = 'success'
|
|
api_call.response_time = datetime.utcnow()
|
|
|
|
# 计算费用
|
|
usage = result.get('usage')
|
|
final_cost = service.calculate_cost(usage, stream=False)
|
|
if final_cost == 0 and estimated_cost > 0:
|
|
final_cost = estimated_cost
|
|
|
|
# 简化响应格式
|
|
simplified_result = {
|
|
'success': True,
|
|
'api_call_id': api_call.id,
|
|
'cost': final_cost,
|
|
'model': model,
|
|
'content': ''
|
|
}
|
|
|
|
if 'choices' in result and len(result['choices']) > 0:
|
|
content = result['choices'][0].get('message', {}).get('content', '')
|
|
simplified_result['content'] = content
|
|
api_call.result_url = content[:500]
|
|
|
|
ApiProxyService.deduct_balance(user.id, api_call.id, final_cost, model)
|
|
|
|
return simplified_result, 200
|
|
|
|
except Exception as e:
|
|
logger.error(f'API 调用异常: {str(e)}', exc_info=True)
|
|
if api_call.id:
|
|
api_call.status = 'failed'
|
|
api_call.error_message = str(e)
|
|
db.session.commit()
|
|
return {'error': '服务异常', 'message': str(e)}, 500
|
|
|
|
@staticmethod
|
|
def get_models():
|
|
"""获取可用的模型列表"""
|
|
# 暂时返回硬编码的模型列表,后续可以从各 Service 聚合
|
|
return {
|
|
'object': 'list',
|
|
'data': [
|
|
{
|
|
'id': 'deepseek-chat',
|
|
'object': 'model',
|
|
'owned_by': 'deepseek',
|
|
'description': 'DeepSeek Chat V3'
|
|
},
|
|
{
|
|
'id': 'deepseek-reasoner',
|
|
'object': 'model',
|
|
'owned_by': 'deepseek',
|
|
'description': 'DeepSeek Reasoner (R1)'
|
|
},
|
|
# ... 其他模型 ...
|
|
]
|
|
}, 200
|
|
|
|
@staticmethod
|
|
def get_pricing():
|
|
"""获取价格信息"""
|
|
pricing = {
|
|
'text_to_image': {
|
|
'price': current_app.config.get('IMAGE_GENERATION_PRICE', 0),
|
|
'currency': 'CNY',
|
|
'unit': '每张图片'
|
|
}
|
|
}
|
|
|
|
return {
|
|
'pricing': pricing
|
|
}, 200
|
|
|
|
@staticmethod
|
|
def get_api_call(user_id, call_id):
|
|
"""获取API调用详情"""
|
|
api_call = ApiCall.query.filter_by(id=call_id, user_id=user_id).first()
|
|
|
|
if not api_call:
|
|
return {'error': 'API调用记录不存在'}, 404
|
|
|
|
return api_call.to_dict(), 200
|