# -*- coding: utf-8 -*-
"""
Claude (Anthropic) 提供商实现

Claude API 格式与 OpenAI 略有不同，需要适配
"""

from typing import Dict, List, Any, Generator
import logging
import json

from .base import BaseAIProvider

logger = logging.getLogger(__name__)


class ClaudeProvider(BaseAIProvider):
    """Anthropic Claude API 提供商"""

    @property
    def provider_type(self) -> str:
        return 'claude'

    @property
    def provider_name(self) -> str:
        return 'Claude (Anthropic)'

    def get_default_base_url(self) -> str:
        return 'https://api.anthropic.com/v1'

    def get_default_model(self) -> str:
        return 'claude-3-5-sonnet-20241022'

    def get_available_models(self) -> List[str]:
        return [
            'claude-3-5-sonnet-20241022',
            'claude-3-opus-20240229',
            'claude-3-sonnet-20240229',
            'claude-3-haiku-20240307'
        ]

    def _get_client(self):
        """获取 Anthropic 客户端"""
        try:
            from anthropic import Anthropic
            
            # 如果设置了自定义 base_url（如 AnyRouter），则使用它
            # 否则使用默认的 Anthropic API
            custom_base_url = None
            if self.base_url and self.base_url != self.get_default_base_url():
                custom_base_url = self.base_url
                self._log('info', f'使用自定义 Base URL: {custom_base_url}')
            
            return Anthropic(
                api_key=self.api_key,
                base_url=custom_base_url,
                timeout=self.timeout
            )
        except ImportError:
            raise ImportError("请安装 anthropic 库: pip install anthropic")

    def _format_messages(self, messages: List[Dict[str, str]]) -> tuple:
        """格式化消息为 Claude 格式
        
        Claude 需要将 system 消息单独提取出来
        
        Returns:
            (system_prompt, messages_list)
        """
        system_prompt = None
        formatted_messages = []
        
        for msg in messages:
            role = msg.get('role', 'user')
            content = msg.get('content', '')
            
            if role == 'system':
                system_prompt = content
            elif role == 'assistant':
                formatted_messages.append({
                    'role': 'assistant',
                    'content': content
                })
            elif role == 'tool':
                # Claude 的工具结果格式
                formatted_messages.append({
                    'role': 'user',
                    'content': [{
                        'type': 'tool_result',
                        'tool_use_id': msg.get('tool_call_id', ''),
                        'content': content
                    }]
                })
            else:  # user
                formatted_messages.append({
                    'role': 'user',
                    'content': content
                })
        
        return system_prompt, formatted_messages

    def _format_tools(self, tools: List[Dict]) -> List[Dict]:
        """格式化工具定义为 Claude 格式"""
        claude_tools = []
        for tool in tools:
            if tool.get('type') == 'function':
                func = tool.get('function', {})
                claude_tools.append({
                    'name': func.get('name', ''),
                    'description': func.get('description', ''),
                    'input_schema': func.get('parameters', {})
                })
        return claude_tools

    def chat(
        self,
        messages: List[Dict[str, str]],
        tools: List[Dict] = None,
        temperature: float = 0.7,
        max_tokens: int = None,
        **kwargs
    ) -> Dict[str, Any]:
        """执行聊天补全"""
        client = self._get_client()
        
        # 格式化消息
        system_prompt, formatted_messages = self._format_messages(messages)
        
        # 构建请求参数
        request_params = {
            'model': self.model,
            'messages': formatted_messages,
            'max_tokens': max_tokens or 4096,
            'temperature': temperature,
        }
        
        if system_prompt:
            request_params['system'] = system_prompt
        
        if tools:
            request_params['tools'] = self._format_tools(tools)
        
        try:
            self._log('info', f'发送请求到 {self.model}')
            response = client.messages.create(**request_params)
            return self._parse_response(response)
            
        except Exception as e:
            self._log('error', f'API 调用失败: {str(e)}')
            raise

    def chat_stream(
        self,
        messages: List[Dict[str, str]],
        tools: List[Dict] = None,
        temperature: float = 0.7,
        max_tokens: int = None,
        **kwargs
    ) -> Generator[Dict[str, Any], None, None]:
        """流式聊天补全"""
        client = self._get_client()
        
        system_prompt, formatted_messages = self._format_messages(messages)
        
        request_params = {
            'model': self.model,
            'messages': formatted_messages,
            'max_tokens': max_tokens or 4096,
            'temperature': temperature,
            'stream': True,
        }
        
        if system_prompt:
            request_params['system'] = system_prompt
        
        if tools:
            request_params['tools'] = self._format_tools(tools)
        
        try:
            self._log('info', f'开始流式请求到 {self.model}')
            
            with client.messages.stream(**request_params) as stream:
                collected_content = ""
                collected_tool_calls = []
                
                for event in stream:
                    if event.type == 'content_block_delta':
                        if hasattr(event.delta, 'text'):
                            collected_content += event.delta.text
                            yield {
                                'type': 'content',
                                'content': event.delta.text,
                                'accumulated': collected_content
                            }
                        elif hasattr(event.delta, 'partial_json'):
                            # 工具调用的增量 JSON
                            pass
                    
                    elif event.type == 'content_block_start':
                        if event.content_block.type == 'tool_use':
                            collected_tool_calls.append({
                                'id': event.content_block.id,
                                'type': 'function',
                                'function': {
                                    'name': event.content_block.name,
                                    'arguments': ''
                                }
                            })
                    
                    elif event.type == 'message_stop':
                        yield {
                            'type': 'finish',
                            'finish_reason': 'stop',
                            'content': collected_content,
                            'tool_calls': collected_tool_calls if collected_tool_calls else None
                        }
                        
        except Exception as e:
            self._log('error', f'流式 API 调用失败: {str(e)}')
            raise

    def _parse_response(self, response) -> Dict[str, Any]:
        """解析 Claude 响应"""
        result = {
            'content': '',
            'finish_reason': response.stop_reason,
            'tool_calls': None,
            'usage': None
        }
        
        # 解析内容块
        for block in response.content:
            if block.type == 'text':
                result['content'] += block.text
            elif block.type == 'tool_use':
                if result['tool_calls'] is None:
                    result['tool_calls'] = []
                result['tool_calls'].append({
                    'id': block.id,
                    'type': 'function',
                    'function': {
                        'name': block.name,
                        'arguments': json.dumps(block.input) if isinstance(block.input, dict) else block.input
                    }
                })
        
        # 解析使用统计
        if response.usage:
            result['usage'] = {
                'prompt_tokens': response.usage.input_tokens,
                'completion_tokens': response.usage.output_tokens,
                'total_tokens': response.usage.input_tokens + response.usage.output_tokens
            }
        
        return result

    def continue_with_tool_result(
        self,
        messages: List[Dict[str, str]],
        tool_call_id: str,
        tool_name: str,
        tool_result: str,
        **kwargs
    ) -> Dict[str, Any]:
        """继续对话，传入工具执行结果"""
        # Claude 格式的工具结果
        messages_with_result = messages + [{
            'role': 'tool',
            'tool_call_id': tool_call_id,
            'name': tool_name,
            'content': tool_result
        }]
        
        return self.chat(messages_with_result, tools=self._tools, **kwargs)
