# -*- coding: utf-8 -*-
"""
AI 提供商基类

定义所有 AI 提供商必须实现的接口
"""

from abc import ABC, abstractmethod
from typing import Dict, List, Any, Optional, Generator
import logging

logger = logging.getLogger(__name__)


class BaseAIProvider(ABC):
    """AI 提供商抽象基类
    
    所有 AI 提供商都必须继承此类并实现抽象方法
    """

    def __init__(self, api_key: str, base_url: str = None, model: str = None, **kwargs):
        """初始化 AI 提供商
        
        Args:
            api_key: API 密钥
            base_url: API 基础 URL（可选，使用默认值）
            model: 模型名称（可选，使用默认值）
            **kwargs: 其他配置参数
        """
        self.api_key = api_key
        self.base_url = base_url or self.get_default_base_url()
        self.model = model or self.get_default_model()
        self.timeout = kwargs.get('timeout', 300)  # 增加到 5 分钟
        self.max_retries = kwargs.get('max_retries', 5)  # 增加重试次数
        
        # 工具定义
        self._tools = []
        
        # 日志回调
        self.log_callback = kwargs.get('log_callback')
        
    @property
    @abstractmethod
    def provider_type(self) -> str:
        """返回提供商类型标识"""
        pass

    @property
    @abstractmethod
    def provider_name(self) -> str:
        """返回提供商显示名称"""
        pass

    @abstractmethod
    def get_default_base_url(self) -> str:
        """返回默认 API Base URL"""
        pass

    @abstractmethod
    def get_default_model(self) -> str:
        """返回默认模型名称"""
        pass

    @abstractmethod
    def get_available_models(self) -> List[str]:
        """返回可用模型列表"""
        pass

    @abstractmethod
    def chat(
        self,
        messages: List[Dict[str, str]],
        tools: List[Dict] = None,
        temperature: float = 0.7,
        max_tokens: int = None,
        **kwargs
    ) -> Dict[str, Any]:
        """执行聊天补全
        
        Args:
            messages: 消息列表，格式为 [{"role": "user/assistant/system", "content": "..."}]
            tools: 工具定义列表（用于 Function Calling）
            temperature: 温度参数
            max_tokens: 最大 token 数
            **kwargs: 其他参数
            
        Returns:
            响应字典，包含:
            - content: 文本响应内容
            - tool_calls: 工具调用列表（如果有）
            - usage: token 使用统计
            - finish_reason: 结束原因
        """
        pass

    @abstractmethod
    def chat_stream(
        self,
        messages: List[Dict[str, str]],
        tools: List[Dict] = None,
        temperature: float = 0.7,
        max_tokens: int = None,
        **kwargs
    ) -> Generator[Dict[str, Any], None, None]:
        """流式聊天补全
        
        Args:
            同 chat 方法
            
        Yields:
            响应片段字典
        """
        pass

    def test_connection(self) -> Dict[str, Any]:
        """测试 API 连接
        
        Returns:
            测试结果字典:
            - success: 是否成功
            - message: 结果消息
            - latency: 响应延迟（毫秒）
        """
        import time
        
        try:
            start_time = time.time()
            
            # 发送简单测试消息
            response = self.chat(
                messages=[{"role": "user", "content": "Hi, just testing. Reply with 'OK'."}],
                max_tokens=10
            )
            
            latency = int((time.time() - start_time) * 1000)
            
            if response.get('content'):
                return {
                    'success': True,
                    'message': f'连接成功，模型: {self.model}',
                    'latency': latency,
                    'model': self.model
                }
            else:
                return {
                    'success': False,
                    'message': '连接成功但响应为空',
                    'latency': latency
                }
                
        except Exception as e:
            logger.error(f"测试连接失败: {str(e)}")
            return {
                'success': False,
                'message': f'连接失败: {str(e)}',
                'latency': None
            }

    def set_tools(self, tools: List[Dict]):
        """设置可用工具
        
        Args:
            tools: 工具定义列表
        """
        self._tools = tools

    def get_tools(self) -> List[Dict]:
        """获取当前工具列表"""
        return self._tools

    def set_log_callback(self, callback):
        """设置日志回调函数
        
        Args:
            callback: 回调函数，接受 (level, message) 参数
        """
        self.log_callback = callback

    def _log(self, level: str, message: str):
        """记录日志
        
        Args:
            level: 日志级别 (info, warning, error)
            message: 日志消息
        """
        if self.log_callback:
            self.log_callback(level, message)
        
        log_func = getattr(logger, level, logger.info)
        log_func(f"[{self.provider_name}] {message}")

    def _format_messages(self, messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
        """格式化消息（子类可重写以适配不同格式）
        
        Args:
            messages: 原始消息列表
            
        Returns:
            格式化后的消息列表
        """
        return messages

    def _format_tools(self, tools: List[Dict]) -> List[Dict]:
        """格式化工具定义（子类可重写以适配不同格式）
        
        Args:
            tools: 原始工具定义列表
            
        Returns:
            格式化后的工具定义列表
        """
        return tools

    def _parse_response(self, response: Any) -> Dict[str, Any]:
        """解析 API 响应（子类可重写）
        
        Args:
            response: 原始 API 响应
            
        Returns:
            标准化的响应字典
        """
        return response


class ToolDefinition:
    """工具定义辅助类"""
    
    @staticmethod
    def create(
        name: str,
        description: str,
        parameters: Dict[str, Any],
        required: List[str] = None
    ) -> Dict:
        """创建工具定义
        
        Args:
            name: 工具名称
            description: 工具描述
            parameters: 参数定义
            required: 必需参数列表
            
        Returns:
            工具定义字典（OpenAI 格式）
        """
        return {
            "type": "function",
            "function": {
                "name": name,
                "description": description,
                "parameters": {
                    "type": "object",
                    "properties": parameters,
                    "required": required or []
                }
            }
        }

    @staticmethod
    def run_command() -> Dict:
        """运行命令工具定义"""
        return ToolDefinition.create(
            name="run_command",
            description="在指定目录执行 shell 命令",
            parameters={
                "command": {
                    "type": "string",
                    "description": "要执行的 shell 命令"
                },
                "cwd": {
                    "type": "string",
                    "description": "命令执行的工作目录（可选）"
                },
                "timeout": {
                    "type": "integer",
                    "description": "命令超时时间（秒），默认 60"
                }
            },
            required=["command"]
        )

    @staticmethod
    def read_file() -> Dict:
        """读取文件工具定义"""
        return ToolDefinition.create(
            name="read_file",
            description="读取指定路径的文件内容",
            parameters={
                "path": {
                    "type": "string",
                    "description": "文件的绝对路径"
                }
            },
            required=["path"]
        )

    @staticmethod
    def write_file() -> Dict:
        """写入文件工具定义"""
        return ToolDefinition.create(
            name="write_file",
            description="将内容写入指定路径的文件（自动创建目录）",
            parameters={
                "path": {
                    "type": "string",
                    "description": "文件的绝对路径"
                },
                "content": {
                    "type": "string",
                    "description": "要写入的文件内容"
                }
            },
            required=["path", "content"]
        )

    @staticmethod
    def list_directory() -> Dict:
        """列出目录工具定义"""
        return ToolDefinition.create(
            name="list_directory",
            description="列出指定目录下的文件和子目录",
            parameters={
                "path": {
                    "type": "string",
                    "description": "目录的绝对路径"
                }
            },
            required=["path"]
        )

    @staticmethod
    def save_memory() -> Dict:
        """保存记忆工具定义"""
        return ToolDefinition.create(
            name="save_memory",
            description="保存关键上下文信息到记忆存储（用于跨阶段保持一致性）",
            parameters={
                "key": {
                    "type": "string",
                    "description": "记忆键名（建议使用固定key，如 challenge_name, exploit_chain, output_dir 等）"
                },
                "content": {
                    "type": "string",
                    "description": "要保存的内容（建议结构化、简短、可复用）"
                }
            },
            required=["key", "content"]
        )

    @staticmethod
    def read_memory() -> Dict:
        """读取记忆工具定义"""
        return ToolDefinition.create(
            name="read_memory",
            description="读取指定 key 的记忆内容",
            parameters={
                "key": {
                    "type": "string",
                    "description": "要读取的记忆键名"
                }
            },
            required=["key"]
        )

    @staticmethod
    def list_memories() -> Dict:
        """列出记忆工具定义"""
        return ToolDefinition.create(
            name="list_memories",
            description="列出当前已保存的所有记忆 key",
            parameters={},
            required=[]
        )

    @staticmethod
    def clear_memories() -> Dict:
        """清空记忆工具定义"""
        return ToolDefinition.create(
            name="clear_memories",
            description="清空所有记忆或按 key 前缀清空（谨慎使用）",
            parameters={
                "prefix": {
                    "type": "string",
                    "description": "可选：仅清空以该前缀开头的记忆 key"
                }
            },
            required=[]
        )

    @staticmethod
    def get_ctf_tools() -> List[Dict]:
        """获取 CTF 题目生成所需的所有工具"""
        return [
            ToolDefinition.run_command(),
            ToolDefinition.read_file(),
            ToolDefinition.write_file(),
            ToolDefinition.list_directory(),
        ]
