# -*- coding: utf-8 -*-
"""
OpenAI 提供商实现

支持 GPT-4o, GPT-4, GPT-3.5-turbo 等模型
"""

from typing import Dict, List, Any, Generator
import logging
import json

from .base import BaseAIProvider

logger = logging.getLogger(__name__)


class OpenAIProvider(BaseAIProvider):
    """OpenAI API 提供商"""

    @property
    def provider_type(self) -> str:
        return 'openai'

    @property
    def provider_name(self) -> str:
        return 'OpenAI'

    def get_default_base_url(self) -> str:
        return 'https://api.openai.com/v1'

    def get_default_model(self) -> str:
        return 'gpt-4o'

    def get_available_models(self) -> List[str]:
        return ['gpt-4o', 'gpt-4o-mini', 'gpt-4-turbo', 'gpt-4', 'gpt-3.5-turbo']

    def _get_client(self):
        """获取 OpenAI 客户端"""
        try:
            from openai import OpenAI
            return OpenAI(
                api_key=self.api_key,
                base_url=self.base_url,
                timeout=self.timeout
            )
        except ImportError:
            raise ImportError("请安装 openai 库: pip install openai")

    def chat(
        self,
        messages: List[Dict[str, str]],
        tools: List[Dict] = None,
        temperature: float = 0.7,
        max_tokens: int = None,
        **kwargs
    ) -> Dict[str, Any]:
        """执行聊天补全"""
        import time
        client = self._get_client()
        
        # 构建请求参数
        request_params = {
            'model': self.model,
            'messages': self._format_messages(messages),
            'temperature': temperature,
        }
        
        if max_tokens:
            request_params['max_tokens'] = max_tokens
        
        if tools:
            request_params['tools'] = self._format_tools(tools)
            request_params['tool_choice'] = 'auto'
        
        # 重试逻辑
        last_error = None
        for attempt in range(self.max_retries):
            try:
                self._log('info', f'发送请求到 {self.model}' + (f' (重试 {attempt})' if attempt > 0 else ''))
                response = client.chat.completions.create(**request_params)
                return self._parse_response(response)
                
            except Exception as e:
                last_error = e
                error_str = str(e).lower()
                # 对于连接错误、超时、服务器错误，进行重试
                should_retry = any(keyword in error_str for keyword in [
                    'connection', 'timeout', 'rate', '502', '503', '504', 
                    'bad gateway', 'service unavailable', 'gateway timeout'
                ])
                if should_retry:
                    if attempt < self.max_retries - 1:
                        wait_time = (attempt + 1) * 5  # 递增等待时间
                        self._log('warning', f'API 调用失败 ({type(e).__name__})，{wait_time}秒后重试...')
                        time.sleep(wait_time)
                        continue
                self._log('error', f'API 调用失败: {str(e)}')
                raise
        
        self._log('error', f'API 调用失败（已重试 {self.max_retries} 次）: {str(last_error)}')
        raise last_error

    def chat_stream(
        self,
        messages: List[Dict[str, str]],
        tools: List[Dict] = None,
        temperature: float = 0.7,
        max_tokens: int = None,
        **kwargs
    ) -> Generator[Dict[str, Any], None, None]:
        """流式聊天补全"""
        client = self._get_client()
        
        request_params = {
            'model': self.model,
            'messages': self._format_messages(messages),
            'temperature': temperature,
            'stream': True,
        }
        
        if max_tokens:
            request_params['max_tokens'] = max_tokens
        
        if tools:
            request_params['tools'] = self._format_tools(tools)
            request_params['tool_choice'] = 'auto'
        
        try:
            self._log('info', f'开始流式请求到 {self.model}')
            stream = client.chat.completions.create(**request_params)
            
            collected_content = ""
            collected_tool_calls = []
            
            for chunk in stream:
                if chunk.choices and len(chunk.choices) > 0:
                    delta = chunk.choices[0].delta
                    
                    # 处理文本内容
                    if delta.content:
                        collected_content += delta.content
                        yield {
                            'type': 'content',
                            'content': delta.content,
                            'accumulated': collected_content
                        }
                    
                    # 处理工具调用
                    if delta.tool_calls:
                        for tool_call in delta.tool_calls:
                            # 收集工具调用信息
                            if tool_call.index >= len(collected_tool_calls):
                                collected_tool_calls.append({
                                    'id': tool_call.id or '',
                                    'type': 'function',
                                    'function': {
                                        'name': '',
                                        'arguments': ''
                                    }
                                })
                            
                            tc = collected_tool_calls[tool_call.index]
                            if tool_call.id:
                                tc['id'] = tool_call.id
                            if tool_call.function:
                                if tool_call.function.name:
                                    tc['function']['name'] = tool_call.function.name
                                if tool_call.function.arguments:
                                    tc['function']['arguments'] += tool_call.function.arguments
                    
                    # 检查是否结束
                    if chunk.choices[0].finish_reason:
                        yield {
                            'type': 'finish',
                            'finish_reason': chunk.choices[0].finish_reason,
                            'content': collected_content,
                            'tool_calls': collected_tool_calls if collected_tool_calls else None
                        }
                        
        except Exception as e:
            self._log('error', f'流式 API 调用失败: {str(e)}')
            raise

    def _parse_response(self, response) -> Dict[str, Any]:
        """解析 OpenAI 响应"""
        choice = response.choices[0]
        message = choice.message
        
        result = {
            'content': message.content or '',
            'finish_reason': choice.finish_reason,
            'tool_calls': None,
            'usage': None
        }
        
        # 解析工具调用
        if message.tool_calls:
            result['tool_calls'] = []
            for tc in message.tool_calls:
                result['tool_calls'].append({
                    'id': tc.id,
                    'type': tc.type,
                    'function': {
                        'name': tc.function.name,
                        'arguments': tc.function.arguments
                    }
                })
        
        # 解析使用统计
        if response.usage:
            result['usage'] = {
                'prompt_tokens': response.usage.prompt_tokens,
                'completion_tokens': response.usage.completion_tokens,
                'total_tokens': response.usage.total_tokens
            }
        
        return result

    def continue_with_tool_result(
        self,
        messages: List[Dict[str, str]],
        tool_call_id: str,
        tool_name: str,
        tool_result: str,
        **kwargs
    ) -> Dict[str, Any]:
        """继续对话，传入工具执行结果
        
        Args:
            messages: 之前的消息列表
            tool_call_id: 工具调用 ID
            tool_name: 工具名称
            tool_result: 工具执行结果
            
        Returns:
            AI 响应
        """
        # 添加工具结果消息
        messages_with_result = messages + [{
            'role': 'tool',
            'tool_call_id': tool_call_id,
            'name': tool_name,
            'content': tool_result
        }]
        
        return self.chat(messages_with_result, tools=self._tools, **kwargs)
