# -*- coding: utf-8 -*-
"""
工具执行器

执行 AI 调用的工具，包括命令执行、文件操作等
"""

import os
import subprocess
import json
import logging
from typing import Dict, Any, Callable, Optional
from pathlib import Path

from .sandbox import SandboxConfig

logger = logging.getLogger(__name__)


class ToolExecutor:
    """工具执行器
    
    负责执行 AI 调用的各种工具，并确保安全性
    """

    def __init__(self, sandbox: SandboxConfig = None, log_callback: Callable = None, working_dir: str = None):
        """初始化工具执行器
        
        Args:
            sandbox: 沙箱配置
            log_callback: 日志回调函数，接受 (level, message) 参数
            working_dir: 默认工作目录
        """
        self.sandbox = sandbox or SandboxConfig()
        self.log_callback = log_callback

        # 设置默认工作目录
        if working_dir:
            self.working_dir = working_dir
        else:
            # 如果没有指定工作目录，使用 ctf/ge10 作为默认目录
            # 从 executor.py 向上查找：executor.py -> tools -> ai_driver -> services -> app -> ctf
            # 路径: ctf/app/services/ai_driver/tools/executor.py
            # parent(1): tools, parent(2): ai_driver, parent(3): services, parent(4): app, parent(5): ctf
            ctf_dir = Path(__file__).parent.parent.parent.parent.parent
            self.working_dir = str((ctf_dir / 'ge10').resolve())
        
        # 工具名称到处理函数的映射
        self._handlers = {
            'run_command': self._handle_run_command,
            'read_file': self._handle_read_file,
            'write_file': self._handle_write_file,
            'list_directory': self._handle_list_directory,
        }

    def execute(self, tool_name: str, arguments: Dict[str, Any]) -> Dict[str, Any]:
        """执行工具
        
        Args:
            tool_name: 工具名称
            arguments: 工具参数
            
        Returns:
            执行结果字典:
            - success: 是否成功
            - result: 执行结果或错误信息
        """
        handler = self._handlers.get(tool_name)
        
        if not handler:
            return {
                'success': False,
                'result': f"未知的工具: {tool_name}"
            }
        
        try:
            self._log('info', f"执行工具: {tool_name}")
            result = handler(arguments)
            return result
        except Exception as e:
            self._log('error', f"工具执行失败: {tool_name} - {str(e)}")
            return {
                'success': False,
                'result': f"执行失败: {str(e)}"
            }

    def execute_tool_call(self, tool_call: Dict) -> str:
        """执行 AI 返回的工具调用
        
        Args:
            tool_call: AI 返回的工具调用字典，包含:
                - id: 工具调用 ID
                - function: {name, arguments}
                
        Returns:
            工具执行结果字符串（用于返回给 AI）
        """
        func = tool_call.get('function', {})
        tool_name = func.get('name', '')
        
        # 解析参数
        try:
            arguments = json.loads(func.get('arguments', '{}'))
        except json.JSONDecodeError:
            arguments = {}
        
        # 执行工具
        result = self.execute(tool_name, arguments)
        
        # 格式化结果
        if result['success']:
            output = result['result']
        else:
            output = f"错误: {result['result']}"
        
        # 最终截断保护（防止超长输出撑爆上下文）
        max_len = self.sandbox.max_output_size if self.sandbox else 8000
        if len(output) > max_len:
            output = output[:max_len] + f"\n... [输出已截断，原长度 {len(output)} 字符]"
        return output

    def _log(self, level: str, message: str):
        """记录日志"""
        if self.log_callback:
            self.log_callback(level, message)
        
        log_func = getattr(logger, level, logger.info)
        log_func(f"[ToolExecutor] {message}")

    def _handle_run_command(self, args: Dict[str, Any]) -> Dict[str, Any]:
        """处理命令执行
        
        Args:
            args: {command, cwd, timeout}
        """
        command = args.get('command', '')
        cwd = args.get('cwd') or self.working_dir  # 使用默认工作目录
        timeout = args.get('timeout', self.sandbox.command_timeout)
        
        # 安全检查（传递工作目录用于 data/ 路径验证）
        allowed, reason = self.sandbox.is_command_allowed(command, working_dir=cwd)
        if not allowed:
            return {
                'success': False,
                'result': f"命令被拒绝: {reason}"
            }
        
        # 检查工作目录
        if cwd and not self.sandbox.is_path_allowed(cwd):
            return {
                'success': False,
                'result': f"工作目录不在允许范围内: {cwd}"
            }
        
        # 获取超时时间
        actual_timeout = self.sandbox.get_timeout(command)
        if timeout:
            actual_timeout = min(timeout, actual_timeout)
        
        try:
            self._log('info', f"执行命令: {command[:100]}...")
            
            # 检查是否是执行 data/ 目录内的可执行文件
            executable_paths = self.sandbox._extract_executable_paths(command, cwd)
            is_data_executable = any(
                self.sandbox.is_data_executable_path_allowed(path, cwd)
                for path in executable_paths
            )
            
            # 设置环境变量
            env = os.environ.copy()
            env['PAGER'] = 'cat'  # 禁用分页
            
            # 如果是 data/ 目录内的可执行文件，可以设置额外的环境变量限制
            if is_data_executable:
                # 设置受限环境变量（如果脚本支持的话）
                env['RESTRICTED_MODE'] = 'true'
                env['ALLOWED_WRITE_DIRS'] = '/tmp,output'
                self._log('info', f"执行 data/ 目录内的可执行文件，应用受限环境")
            
            # 执行命令
            result = subprocess.run(
                command,
                shell=True,
                cwd=cwd,
                capture_output=True,
                text=True,
                timeout=actual_timeout,
                env=env
            )
            
            # 组合输出
            output = result.stdout
            if result.stderr:
                output += f"\n[stderr]\n{result.stderr}"
            
            # 截断过长的输出
            if len(output) > self.sandbox.max_output_size:
                output = output[:self.sandbox.max_output_size] + "\n... [输出已截断]"
            
            return {
                'success': result.returncode == 0,
                'result': output or "(无输出)",
                'return_code': result.returncode
            }
            
        except subprocess.TimeoutExpired:
            return {
                'success': False,
                'result': f"命令执行超时 ({actual_timeout}秒)"
            }
        except Exception as e:
            return {
                'success': False,
                'result': f"命令执行错误: {str(e)}"
            }

    def _handle_read_file(self, args: Dict[str, Any]) -> Dict[str, Any]:
        """处理文件读取
        
        Args:
            args: {path}
        """
        path = args.get('path', '')
        
        if not path:
            return {
                'success': False,
                'result': "未指定文件路径"
            }
        
        # 处理相对路径：基于工作目录
        if not os.path.isabs(path):
            path = os.path.join(self.working_dir, path)
        
        # 安全检查
        if not self.sandbox.is_path_allowed(path):
            return {
                'success': False,
                'result': f"文件路径不在允许范围内: {path}"
            }
        
        try:
            file_path = Path(path)
            
            if not file_path.exists():
                return {
                    'success': False,
                    'result': f"文件不存在: {path}"
                }
            
            if not file_path.is_file():
                return {
                    'success': False,
                    'result': f"不是文件: {path}"
                }
            
            # 检查文件大小
            file_size = file_path.stat().st_size
            if file_size > self.sandbox.max_file_size:
                return {
                    'success': False,
                    'result': f"文件过大: {file_size} 字节 (最大 {self.sandbox.max_file_size})"
                }
            
            # 读取文件
            content = file_path.read_text(encoding='utf-8')
            
            self._log('info', f"读取文件: {path} ({len(content)} 字符)")
            
            return {
                'success': True,
                'result': content
            }
            
        except UnicodeDecodeError:
            return {
                'success': False,
                'result': f"文件不是有效的 UTF-8 文本: {path}"
            }
        except Exception as e:
            return {
                'success': False,
                'result': f"读取文件失败: {str(e)}"
            }

    def _handle_write_file(self, args: Dict[str, Any]) -> Dict[str, Any]:
        """处理文件写入
        
        Args:
            args: {path, content}
        """
        path = args.get('path', '')
        content = args.get('content', '')
        
        if not path:
            return {
                'success': False,
                'result': "未指定文件路径"
            }
        
        # 处理相对路径：基于工作目录
        if not os.path.isabs(path):
            path = os.path.join(self.working_dir, path)
        
        # 写入安全检查（使用专门的写入路径验证）
        if not self.sandbox.is_write_path_allowed(path):
            return {
                'success': False,
                'result': f"文件路径不在允许写入范围内: {path}"
            }
        
        # 检查内容大小
        if len(content.encode('utf-8')) > self.sandbox.max_file_size:
            return {
                'success': False,
                'result': f"内容过大 (最大 {self.sandbox.max_file_size} 字节)"
            }
        
        try:
            file_path = Path(path)
            
            # 自动创建目录
            file_path.parent.mkdir(parents=True, exist_ok=True)
            
            # 写入文件
            file_path.write_text(content, encoding='utf-8')
            
            self._log('info', f"写入文件: {path} ({len(content)} 字符)")
            
            return {
                'success': True,
                'result': f"文件已保存: {path}"
            }
            
        except Exception as e:
            return {
                'success': False,
                'result': f"写入文件失败: {str(e)}"
            }

    def _handle_list_directory(self, args: Dict[str, Any]) -> Dict[str, Any]:
        """处理目录列表
        
        Args:
            args: {path}
        """
        path = args.get('path', '')
        
        if not path:
            return {
                'success': False,
                'result': "未指定目录路径"
            }
        
        # 处理相对路径：基于工作目录
        if not os.path.isabs(path):
            path = os.path.join(self.working_dir, path)
        
        # 安全检查
        if not self.sandbox.is_path_allowed(path):
            return {
                'success': False,
                'result': f"目录路径不在允许范围内: {path}"
            }
        
        try:
            dir_path = Path(path)
            
            if not dir_path.exists():
                return {
                    'success': False,
                    'result': f"目录不存在: {path}"
                }
            
            if not dir_path.is_dir():
                return {
                    'success': False,
                    'result': f"不是目录: {path}"
                }
            
            # 列出内容
            items = []
            for item in sorted(dir_path.iterdir()):
                if item.is_dir():
                    items.append(f"📁 {item.name}/")
                else:
                    size = item.stat().st_size
                    items.append(f"📄 {item.name} ({size} bytes)")
            
            result = f"目录: {path}\n" + "\n".join(items) if items else f"目录为空: {path}"
            
            self._log('info', f"列出目录: {path} ({len(items)} 项)")
            
            return {
                'success': True,
                'result': result
            }
            
        except Exception as e:
            return {
                'success': False,
                'result': f"列出目录失败: {str(e)}"
            }

    def register_handler(self, tool_name: str, handler: Callable):
        """注册自定义工具处理器
        
        Args:
            tool_name: 工具名称
            handler: 处理函数，接受 args 字典，返回 {success, result} 字典
        """
        self._handlers[tool_name] = handler
        self._log('info', f"注册工具处理器: {tool_name}")
