# -*- coding: utf-8 -*-
"""
AI 提供商工厂

根据配置创建对应的 AI 提供商实例
"""

from typing import Optional
import logging

from .base import BaseAIProvider
from .openai import OpenAIProvider
from .claude import ClaudeProvider

logger = logging.getLogger(__name__)


class AIProviderFactory:
    """AI 提供商工厂类"""

    # 提供商类型到类的映射
    # 注意：大部分提供商都兼容 OpenAI API 格式，使用 OpenAIProvider
    # base_url 会根据 provider_type 自动从 AIProviderType 获取
    _providers = {
        'openai': OpenAIProvider,
        'deepseek': OpenAIProvider,  # DeepSeek 兼容 OpenAI 格式
        'claude': ClaudeProvider,
        'qwen': OpenAIProvider,  # 通义千问兼容 OpenAI 格式
        'zhipu': OpenAIProvider,  # 智谱 GLM 兼容 OpenAI 格式
        'groq': OpenAIProvider,  # Groq 兼容 OpenAI 格式
        'together': OpenAIProvider,  # Together AI 兼容 OpenAI 格式
        'openrouter': OpenAIProvider,  # OpenRouter 兼容 OpenAI 格式
        'gemini': OpenAIProvider,  # Gemini 兼容 OpenAI 格式
        'mistral': OpenAIProvider,  # Mistral 兼容 OpenAI 格式
        'siliconflow': OpenAIProvider,  # SiliconFlow 兼容 OpenAI 格式
        # 'anyrouter': 使用 CLI 模式，不通过 API 工厂
        'custom': OpenAIProvider,  # 自定义也使用 OpenAI 兼容格式
        'custom_openai': OpenAIProvider,  # 自定义 OpenAI 兼容 API
        'custom_anthropic': ClaudeProvider,  # 自定义 Anthropic 兼容 API
    }

    @classmethod
    def create(
        cls,
        provider_type: str,
        api_key: str,
        base_url: str = None,
        model: str = None,
        **kwargs
    ) -> BaseAIProvider:
        """创建 AI 提供商实例
        
        Args:
            provider_type: 提供商类型 (openai, deepseek, claude, qwen, zhipu, custom)
            api_key: API 密钥
            base_url: API Base URL（可选）
            model: 模型名称（可选）
            **kwargs: 其他配置参数
            
        Returns:
            AI 提供商实例
            
        Raises:
            ValueError: 不支持的提供商类型
        """
        # AnyRouter、AgentRouter 和 Augment 使用 CLI 模式，需要特殊处理
        if provider_type.lower() in ('anyrouter', 'agentrouter'):
            from ..core.claude_router_service import ClaudeRouterService
            return ClaudeRouterService(
                router_type=provider_type.lower(),
                user_id=kwargs.get('user_id'),
                model=model
            )
        
        if provider_type.lower() == 'augment':
            from ..core.augment_service import AugmentService
            return AugmentService(
                user_id=kwargs.get('user_id'),
                verbose=kwargs.get('verbose', False)
            )
        
        provider_class = cls._providers.get(provider_type.lower())
        
        if not provider_class:
            raise ValueError(f"不支持的 AI 提供商类型: {provider_type}")
        
        logger.info(f"创建 AI 提供商: {provider_type}")
        
        # 对于使用 OpenAI 兼容格式的提供商，如果 base_url 为空，使用该提供商的默认 URL
        # 避免错误地使用 OpenAI 的默认 URL
        if not base_url:
            from app.models.database import AIProviderType
            default_url = AIProviderType.get_default_base_url(provider_type.lower())
            if default_url:
                base_url = default_url
                logger.info(f"使用 {provider_type} 默认 Base URL: {base_url}")
        
        return provider_class(
            api_key=api_key,
            base_url=base_url,
            model=model,
            **kwargs
        )

    @classmethod
    def create_from_config(cls, config) -> BaseAIProvider:
        """从数据库配置创建 AI 提供商实例
        
        Args:
            config: AIProviderConfig 数据库模型实例
            
        Returns:
            AI 提供商实例
        """
        # AnyRouter/AgentRouter/Augment 使用 CLI 模式，需要特殊处理
        if config.provider_type in ('anyrouter', 'agentrouter'):
            from ..core.claude_router_service import ClaudeRouterService
            return ClaudeRouterService(
                router_type=config.provider_type,
                user_id=config.user_id,
                model=config.model
            )
        
        if config.provider_type == 'augment':
            from ..core.augment_service import AugmentService
            return AugmentService(
                user_id=config.user_id,
                verbose=False,
                model=config.model  # 传递模型参数
            )
        
        return cls.create(
            provider_type=config.provider_type,
            api_key=config.get_api_key(),
            base_url=config.base_url,
            model=config.model
        )

    @classmethod
    def create_for_user(cls, user_id: int, ai_config_id: Optional[int] = None) -> Optional[BaseAIProvider]:
        """为用户创建 AI 提供商实例
        
        优先使用用户配置，否则使用系统默认配置
        
        Args:
            user_id: 用户 ID
            ai_config_id: 指定的 AI 配置 ID（可选，如果提供则使用该配置）
            
        Returns:
            AI 提供商实例，如果没有可用配置则返回 None
        """
        from app.models.database import AIProviderConfig, SystemAIConfig
        
        # 如果指定了 ai_config_id，使用该配置
        if ai_config_id:
            config = AIProviderConfig.query.get(ai_config_id)
            if config:
                # 验证配置是否属于当前用户或是系统配置
                if config.user_id == user_id or config.user_id is None:
                    config.mark_used()
                    logger.info(f"使用指定的 AI 配置 ID {ai_config_id}: {config.name} ({config.provider_type})")
                    return cls.create_from_config(config)
                else:
                    logger.warning(f"AI 配置 ID {ai_config_id} 不属于用户 {user_id}，使用默认配置")
            else:
                logger.warning(f"未找到 AI 配置 ID {ai_config_id}，使用默认配置")
        
        # 检查是否强制使用系统 AI
        if SystemAIConfig.force_system_ai():
            config = AIProviderConfig.get_system_default()
        else:
            # 获取用户有效配置
            config = AIProviderConfig.get_effective_config(user_id)
        
        if config:
            # 标记使用
            config.mark_used()
            return cls.create_from_config(config)
        
        logger.warning(f"用户 {user_id} 没有可用的 AI 配置")
        return None

    @classmethod
    def get_supported_providers(cls) -> list:
        """获取支持的提供商列表"""
        return list(cls._providers.keys())

    @classmethod
    def register_provider(cls, provider_type: str, provider_class: type):
        """注册新的提供商类型
        
        Args:
            provider_type: 提供商类型标识
            provider_class: 提供商类（必须继承 BaseAIProvider）
        """
        if not issubclass(provider_class, BaseAIProvider):
            raise TypeError("provider_class 必须继承 BaseAIProvider")
        
        cls._providers[provider_type.lower()] = provider_class
        logger.info(f"注册新的 AI 提供商: {provider_type}")
