# -*- coding: utf-8 -*-
"""
沙箱配置

定义工具执行的安全限制
"""

import os
import re
from typing import List, Set, Optional, Tuple
from pathlib import Path


class SandboxConfig:
    """沙箱配置类
    
    定义工具执行的安全边界
    """

    # 默认允许的命令
    DEFAULT_ALLOWED_COMMANDS = {
        'python3', 'python',
        'docker', 'docker-compose',
        'curl', 'wget',
        'cat', 'head', 'tail', 'less', 'more',
        'ls', 'find', 'grep', 'awk', 'sed',
        'mkdir', 'touch', 'cp', 'mv', 'rm',
        'echo', 'printf',
        'chmod',
        'cd',  # 允许但实际上在子进程中无效
        'pwd',
        'which', 'whereis',
        'pip3', 'pip',
        'npm', 'node',
        'git',
        'sleep',  # 等待命令
        'open',   # macOS 打开应用
        'true', 'false',  # shell 内置
        'test', '[',  # 条件测试
        'date', 'env',  # 环境信息
        'lsof', 'netstat', 'ss',  # 网络/端口检查
        'kill', 'pkill',  # 进程管理
        'ps', 'top',  # 进程查看
        'tar', 'zip', 'unzip', 'gzip', 'gunzip',  # 压缩解压
        'wc', 'sort', 'uniq', 'cut', 'tr',  # 文本处理
        'diff', 'patch',  # 文件比较
        'xargs',  # 命令构建
        'tee',  # 输出重定向
        'timeout',  # 超时控制
        'bash', 'sh',  # shell 解释器
        'perl', 'ruby',  # 其他脚本解释器
    }
    
    # 已知的解释器（用于识别脚本执行命令）
    KNOWN_INTERPRETERS = {
        'python3', 'python', 'python2',
        'bash', 'sh', 'zsh', 'csh', 'fish',
        'node', 'nodejs',
        'perl', 'ruby', 'php',
        'java', 'go', 'rustc',
    }

    # 禁止的命令（即使在允许列表中也不能执行）
    FORBIDDEN_COMMANDS = {
        'rm -rf /',
        'rm -rf /*',
        'sudo',
        'su',
        'passwd',
        'chown',
        'shutdown',
        'reboot',
        'init',
        'systemctl',
        'service',
        'kill -9 1',
        'dd if=',
        'mkfs',
        'fdisk',
        '> /dev/',
        'chmod 777 /',
    }

    # 禁止的命令模式（正则）
    FORBIDDEN_PATTERNS = [
        r'rm\s+-rf\s+/',
        r'rm\s+-rf\s+\*',
        # 移除对 /dev/ 重定向的限制，允许正常的重定向操作
        # r'>\s*/dev/',
        r'sudo\s+',
        r'chmod\s+777\s+/',
    ]

    def __init__(
        self,
        allowed_directories: List[str] = None,
        allowed_write_directories: List[str] = None,
        allowed_commands: Set[str] = None,
        blocked_commands: Set[str] = None,
        knowledge_executable_dirs: List[str] = None,
        max_file_size: int = 10 * 1024 * 1024,  # 10MB
        max_output_size: int = 8000,  # 8KB，防止撑爆上下文
        command_timeout: int = 300,  # 5 分钟
        docker_timeout: int = 600,  # 10 分钟
    ):
        """初始化沙箱配置
        
        Args:
            allowed_directories: 允许读取/访问的目录列表（用于读取、列表等操作）
            allowed_write_directories: 允许写入的目录列表（用于写入操作，如果为None则使用allowed_directories）
            allowed_commands: 允许执行的命令集合
            blocked_commands: 显式禁止的命令集合
            knowledge_executable_dirs: 知识库可执行目录范围（相对 data/）
            max_file_size: 最大文件大小（字节）
            max_output_size: 最大输出大小（字节）
            command_timeout: 普通命令超时时间（秒）
            docker_timeout: Docker 命令超时时间（秒）
        """
        self.allowed_directories = allowed_directories or []
        # 如果没有指定写入目录，默认使用读取目录（向后兼容）
        self.allowed_write_directories = allowed_write_directories if allowed_write_directories is not None else self.allowed_directories
        self.allowed_commands = allowed_commands or self.DEFAULT_ALLOWED_COMMANDS.copy()
        self.blocked_commands = blocked_commands or set()
        self.knowledge_executable_dirs = knowledge_executable_dirs or []
        self.max_file_size = max_file_size
        self.max_output_size = max_output_size
        self.command_timeout = command_timeout
        self.docker_timeout = docker_timeout

    @staticmethod
    def load_policy(default: Optional[dict] = None) -> dict:
        """从系统配置中加载沙箱策略"""
        try:
            # 延迟导入避免循环依赖
            from app.models.database.config import SystemConfig
            return SystemConfig.get_config('ai_sandbox_policy', default or {}) or (default or {})
        except Exception:
            return default or {}

    @staticmethod
    def _resolve_paths(paths: List[str], base_dir: Optional[Path] = None) -> List[str]:
        """将相对路径解析为绝对路径"""
        resolved = []
        for p in paths or []:
            try:
                path_obj = Path(p)
                if not path_obj.is_absolute() and base_dir:
                    path_obj = (base_dir / path_obj).resolve()
                else:
                    path_obj = path_obj.resolve()
                resolved.append(str(path_obj))
            except Exception:
                continue
        return resolved

    def apply_policy(self, policy: Optional[dict], base_dir: Optional[Path] = None):
        """应用策略到当前沙箱实例"""
        if not policy:
            return

        # 命令白名单
        policy_allowed = policy.get('allowed_commands')
        if isinstance(policy_allowed, list):
            cleaned = [cmd.strip() for cmd in policy_allowed if cmd]
            if '*' in cleaned:
                # "*" 表示允许所有命令（仍受禁止列表和安全检查限制）
                self.allowed_commands = {'*'}
            else:
                self.allowed_commands = set(self.allowed_commands) | set(cleaned)

        # 命令黑名单
        policy_blocked = policy.get('blocked_commands')
        if isinstance(policy_blocked, list):
            self.blocked_commands = set([cmd.strip() for cmd in policy_blocked if cmd])

        # 目录限制
        read_dirs = policy.get('allowed_read_dirs')
        if isinstance(read_dirs, list) and read_dirs:
            extra = self._resolve_paths(read_dirs, base_dir)
            self.allowed_directories = list(dict.fromkeys((self.allowed_directories or []) + extra))

        write_dirs = policy.get('allowed_write_dirs')
        if isinstance(write_dirs, list) and write_dirs:
            extra = self._resolve_paths(write_dirs, base_dir)
            self.allowed_write_directories = list(dict.fromkeys((self.allowed_write_directories or []) + extra))

        # 限制项
        self.max_file_size = int(policy.get('max_file_size', self.max_file_size) or self.max_file_size)
        self.max_output_size = int(policy.get('max_output_size', self.max_output_size) or self.max_output_size)
        self.command_timeout = int(policy.get('command_timeout', self.command_timeout) or self.command_timeout)
        self.docker_timeout = int(policy.get('docker_timeout', self.docker_timeout) or self.docker_timeout)

    def is_path_allowed(self, path: str) -> bool:
        """检查路径是否在允许的目录内（用于读取、列表等操作）
        
        Args:
            path: 要检查的路径
            
        Returns:
            是否允许
        """
        if not self.allowed_directories:
            return True  # 未配置限制则允许所有
        
        try:
            # 解析为绝对路径
            abs_path = Path(path).resolve()
            
            for allowed_dir in self.allowed_directories:
                allowed_abs = Path(allowed_dir).resolve()
                
                # 检查是否在允许目录内
                try:
                    abs_path.relative_to(allowed_abs)
                    return True
                except ValueError:
                    continue
            
            return False
            
        except Exception:
            return False
    
    def is_write_path_allowed(self, path: str) -> bool:
        """检查路径是否在允许写入的目录内（用于写入操作）
        
        Args:
            path: 要检查的路径
            
        Returns:
            是否允许写入
        """
        if not self.allowed_write_directories:
            return True  # 未配置限制则允许所有
        
        try:
            # 解析为绝对路径
            abs_path = Path(path).resolve()
            
            for allowed_dir in self.allowed_write_directories:
                allowed_abs = Path(allowed_dir).resolve()
                
                # 检查是否在允许目录内
                try:
                    abs_path.relative_to(allowed_abs)
                    return True
                except ValueError:
                    continue
            
            return False
            
        except Exception:
            return False
    
    def _extract_executable_paths(self, command: str, working_dir: str) -> List[str]:
        """从命令中提取所有可能的可执行文件路径
        
        处理多种情况：
        - python3 data/scripts/choice.py → ['data/scripts/choice.py']
        - ./data/utils/tool → ['data/utils/tool']
        - bash data/scripts/helper.sh arg1 → ['data/scripts/helper.sh']
        - node data/scripts/process.js → ['data/scripts/process.js']
        - ./data/scripts/executable → ['data/scripts/executable']
        
        Args:
            command: 命令字符串
            working_dir: 工作目录（用于解析相对路径）
            
        Returns:
            可执行文件路径列表
        """
        paths = []
        parts = command.strip().split()
        
        if not parts:
            return paths
        
        # 情况1: 第一个参数是解释器，第二个参数是脚本路径
        # python3 data/scripts/choice.py
        if len(parts) >= 2:
            first = parts[0]
            second = parts[1]
            
            # 检查第一个是否是解释器（去除路径，只保留命令名）
            first_basename = os.path.basename(first) if '/' in first else first
            if first_basename in self.KNOWN_INTERPRETERS:
                # 第二个参数可能是脚本路径
                if '/' in second or second.startswith('./') or second.startswith('../'):
                    paths.append(second)
        
        # 情况2: 第一个参数直接是可执行文件路径
        # ./data/utils/tool 或 data/utils/tool
        first = parts[0]
        if '/' in first or first.startswith('./') or first.startswith('../'):
            paths.append(first)
        
        return paths
    
    def is_data_executable_path_allowed(self, executable_path: str, working_dir: str) -> bool:
        """检查可执行文件路径是否在允许的目录内
        
        检查可执行文件是否在配置的允许目录范围内，防止路径遍历攻击。
        
        Args:
            executable_path: 可执行文件路径（相对或绝对）
            working_dir: 工作目录（用于解析相对路径）
            
        Returns:
            是否允许执行
        """
        try:
            # 解析为绝对路径
            if os.path.isabs(executable_path):
                abs_path = Path(executable_path).resolve()
            else:
                abs_path = (Path(working_dir) / executable_path).resolve()
            
            # 转换为相对于工作目录的路径
            working_path = Path(working_dir).resolve()
            try:
                rel_path = abs_path.relative_to(working_path)
            except ValueError:
                # 不在工作目录内
                return False
            
            # 检查路径中是否包含 .. 符号（防止路径遍历）
            path_parts = str(rel_path).split('/')
            if '..' in path_parts:
                return False
            
            # 获取配置的允许目录列表（从策略中）
            allowed_dirs = getattr(self, 'knowledge_executable_dirs', None)
            if allowed_dirs is None:
                # 从策略中加载
                policy = self.load_policy()
                allowed_dirs = policy.get('knowledge_executable_dirs', [])
            
            # 如果配置为空，允许所有 data/ 下的目录
            if not allowed_dirs:
                return True
            
            # 检查路径是否在允许的目录范围内
            path_str = str(rel_path)
            for allowed_dir in allowed_dirs:
                allowed_dir = allowed_dir.strip().strip('/')
                if not allowed_dir:
                    continue
                
                # 特殊处理：如果配置为 "." 或空字符串，只允许 data 根目录（不包括子目录）
                if allowed_dir == '.' or allowed_dir == '':
                    # 只允许路径不包含 "/" 的文件（即在 data 根目录）
                    if '/' not in path_str:
                        return True
                    continue
                
                # 检查路径是否以允许的目录开头
                if path_str.startswith(allowed_dir + '/') or path_str == allowed_dir:
                    return True
            
            return False
            
        except Exception:
            # 任何解析错误都拒绝
            return False

    def is_command_allowed(self, command: str, working_dir: Optional[str] = None) -> tuple:
        """检查命令是否允许执行
        
        Args:
            command: 要检查的命令
            working_dir: 工作目录（可选，用于验证 data/ 目录内的可执行文件）
            
        Returns:
            (是否允许, 原因)
        """
        allow_all = '*' in self.allowed_commands

        # 显式禁止的命令
        if self.blocked_commands:
            # 提取主命令名
            main = command.strip().split()
            main_cmd = os.path.basename(main[0]) if main else ''
            if main_cmd in self.blocked_commands:
                return False, f"命令 '{main_cmd}' 被沙箱策略禁止"

        # 检查禁止的命令模式
        for pattern in self.FORBIDDEN_PATTERNS:
            if re.search(pattern, command, re.IGNORECASE):
                return False, f"命令包含禁止的模式: {pattern}"
        
        # 检查禁止的命令
        for forbidden in self.FORBIDDEN_COMMANDS:
            if forbidden in command:
                return False, f"命令包含禁止的操作: {forbidden}"
        
        # 如果提供了工作目录，检查命令参数中的路径是否在允许范围内（知识库 Shell 专用）
        if working_dir:
            # 检查命令参数中的绝对路径
            parts = command.strip().split()
            for part in parts[1:]:  # 跳过命令本身，检查参数
                part = part.strip()
                # 跳过选项参数（以 - 开头）
                if part.startswith('-'):
                    continue
                # 检查是否是绝对路径
                if os.path.isabs(part):
                    # 检查是否在允许的读取目录内
                    if not self.is_path_allowed(part):
                        return False, f"命令参数中的路径 '{part}' 不在允许范围内（只允许访问: {', '.join(self.allowed_directories) if self.allowed_directories else '无限制'}）"
                # 检查是否包含 .. 路径遍历
                elif '..' in part:
                    return False, f"命令参数包含路径遍历: '{part}'"
        
        # 如果提供了工作目录，检查 data/ 目录内的可执行文件
        has_data_executable = False
        if working_dir:
            executable_paths = self._extract_executable_paths(command, working_dir)
            for exec_path in executable_paths:
                # 如果路径在 data/ 目录内，验证路径是否合法
                if self.is_data_executable_path_allowed(exec_path, working_dir):
                    # data/ 目录内的可执行文件允许执行（路径验证已通过）
                    has_data_executable = True
                elif exec_path and ('data/' in exec_path or exec_path.startswith('./data/') or exec_path.startswith('../data/')):
                    # 路径看起来像是 data/ 目录，但验证失败（可能是路径遍历攻击）
                    return False, f"可执行文件路径不在允许的 data/ 目录内或包含路径遍历: {exec_path}"
                elif exec_path and ('/' in exec_path or exec_path.startswith('./') or exec_path.startswith('../')):
                    # 如果提取到了路径，但路径不在 data/ 目录内，拒绝执行
                    # 这是为了防止执行 data/ 目录外的脚本（如 /tmp/evil.py, ../malicious.py）
                    return False, f"可执行文件路径不在允许的 data/ 目录内: {exec_path}"
        
        # 提取所有需要检查的命令
        # 处理复合命令（&&, ||, ;, |）和变量赋值
        
        # 分割复合命令
        compound_separators = ['&&', '||', ';', '|']
        commands_to_check = [command]
        
        for sep in compound_separators:
            new_commands = []
            for cmd in commands_to_check:
                new_commands.extend(cmd.split(sep))
            commands_to_check = new_commands
        
        for single_cmd in commands_to_check:
            single_cmd = single_cmd.strip()
            if not single_cmd:
                continue
            
            # 跳过纯变量赋值（如 timestamp=$(date ...)）
            # 变量赋值格式：VAR=value 或 VAR=$(...)
            if '=' in single_cmd:
                # 检查是否是变量赋值（等号前没有空格，且等号前是有效变量名）
                eq_pos = single_cmd.find('=')
                var_part = single_cmd[:eq_pos]
                # 变量名只能包含字母、数字、下划线，且不能以数字开头
                if var_part and re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', var_part):
                    # 这是变量赋值，检查赋值值中的命令替换
                    value_part = single_cmd[eq_pos + 1:]
                    # 提取 $(...) 或 `...` 中的命令
                    subst_matches = re.findall(r'\$\(([^)]+)\)', value_part)
                    subst_matches += re.findall(r'`([^`]+)`', value_part)
                    for subst_cmd in subst_matches:
                        subst_parts = subst_cmd.strip().split()
                        if subst_parts:
                            sub_main_cmd = subst_parts[0]
                            if '/' in sub_main_cmd:
                                sub_main_cmd = os.path.basename(sub_main_cmd)
                            if sub_main_cmd in self.blocked_commands:
                                return False, f"命令替换中的命令 '{sub_main_cmd}' 被禁止"
                            if not allow_all and sub_main_cmd not in self.allowed_commands:
                                return False, f"命令替换中的命令 '{sub_main_cmd}' 不在允许列表中"
                    continue  # 变量赋值本身不需要检查主命令
            
            parts = single_cmd.split()
            if not parts:
                continue
            
            main_cmd = parts[0]
            
            # 如果执行的是 data/ 目录内的可执行文件，且是直接执行（如 ./data/utils/tool），
            # 跳过命令白名单检查（因为路径验证已经通过）
            if has_data_executable and main_cmd.startswith('./') and 'data/' in main_cmd:
                # 直接执行的 data/ 目录内的可执行文件，跳过命令白名单检查
                continue
            
            # 处理路径形式的命令
            if '/' in main_cmd:
                main_cmd = os.path.basename(main_cmd)
            
            # 检查是否在允许列表中
            if main_cmd in self.blocked_commands:
                return False, f"命令 '{main_cmd}' 被沙箱策略禁止"
            if not allow_all and main_cmd not in self.allowed_commands:
                return False, f"命令 '{main_cmd}' 不在允许列表中"
        
        return True, ""

    def get_timeout(self, command: str) -> int:
        """获取命令的超时时间
        
        Args:
            command: 命令字符串
            
        Returns:
            超时时间（秒）
        """
        # Docker 相关命令使用更长的超时
        if 'docker' in command.lower():
            return self.docker_timeout
        return self.command_timeout

    @classmethod
    def create_for_ctf_generation(cls, output_base_dir: str = None, category_id: str = None, policy: Optional[dict] = None) -> 'SandboxConfig':
        """创建用于 CTF 题目生成的沙箱配置"""
        # 获取项目根目录
        project_root = Path(__file__).parent.parent.parent.parent.parent
        ge10_dir = project_root / 'ge10'

        # 允许读取的目录：整个 ge10 目录（读取权限不变）
        allowed_read_dirs = [
            str(ge10_dir),  # 整个 ge10 目录（包含各方向的目录）
            '/tmp',  # 临时目录（读取也需要）
        ]

        # 允许写入的目录：只允许对应方向的 output 目录和临时目录
        allowed_write_dirs = ['/tmp']  # 临时目录始终允许写入

        if output_base_dir:
            allowed_write_dirs.append(str(output_base_dir))
        elif category_id:
            category_output_dir = ge10_dir / category_id / 'output'
            allowed_write_dirs.append(str(category_output_dir))

        sandbox = cls(
            allowed_directories=allowed_read_dirs,
            allowed_write_directories=allowed_write_dirs,
            command_timeout=300,
            docker_timeout=600,
        )

        # 应用策略
        sandbox.apply_policy(policy or cls.load_policy(), base_dir=project_root)
        return sandbox

    @classmethod
    def create_for_knowledge_base(cls, category_id: str, policy: Optional[dict] = None) -> 'SandboxConfig':
        """创建用于知识库管理/命令执行的沙箱配置"""
        project_root = Path(__file__).parent.parent.parent.parent.parent
        ge10_dir = project_root / 'ge10'
        data_dir = ge10_dir / category_id / 'data'

        allowed_read_dirs = [str(data_dir), '/tmp']
        allowed_write_dirs = [str(data_dir), '/tmp']

        sandbox = cls(
            allowed_directories=allowed_read_dirs,
            allowed_write_directories=allowed_write_dirs,
            command_timeout=300,
            docker_timeout=600,
        )
        sandbox.apply_policy(policy or cls.load_policy(), base_dir=project_root)
        return sandbox
