# -*- coding: utf-8 -*-
"""
统一 AI 服务

为题目生成流程提供统一的 AI 调用接口，支持多种 AI 提供商
"""

import os
import json
import logging
import re
from typing import Dict, List, Any, Optional, Callable
from pathlib import Path

from .providers import AIProviderFactory, BaseAIProvider
from .providers.base import ToolDefinition
from .tools import ToolExecutor, SandboxConfig
from .core.stage_detector import StageDetector

logger = logging.getLogger(__name__)


class AIService:
    """统一 AI 服务
    
    封装 AI 提供商和工具执行，为题目生成提供完整的 AI 能力
    """

    def __init__(
        self,
        user_id: int,
        output_dir: str = None,
        log_callback: Callable = None,
        category_id: str = None,
        ai_config_id: Optional[int] = None
    ):
        """初始化 AI 服务
        
        Args:
            user_id: 用户 ID，用于获取用户的 AI 配置
            output_dir: 输出目录（用于沙箱限制）
            log_callback: 日志回调函数
            category_id: 方向 ID（用于设置工作目录）
            ai_config_id: 指定的 AI 配置 ID（可选，如果提供则使用该配置）
        """
        self.user_id = user_id
        self.output_dir = output_dir
        self.log_callback = log_callback
        self.category_id = category_id
        self.ai_config_id = ai_config_id
        
        # 初始化 AI 提供商
        self.provider: Optional[BaseAIProvider] = None
        self._init_provider()
        
        # 初始化工具执行器
        self.tool_executor: Optional[ToolExecutor] = None
        self._init_tool_executor()
        
        # 对话历史
        self.messages: List[Dict[str, str]] = []
        
        # 系统提示词
        self.system_prompt: Optional[str] = None
        
        # 上下文管理配置
        self.context_config = {
            'max_messages': 100,          # 保留的最大消息数
            'max_tool_results_length': 10000,  # 工具结果最大长度
            'compress_threshold': 99999,  # 禁用压缩（设置为极大值）
        }
        
        # 阶段摘要存储（关键信息不会丢失）
        self.stage_artifacts: Dict[int, Dict[str, Any]] = {}
        
        # 已生成的文件列表（用于上下文恢复）
        self.generated_files: List[str] = []
        
        # 过滤后的阶段列表（从 Prompt 配置中获取）
        self._filtered_stages: List[Dict[str, Any]] = []

    def _init_provider(self):
        """初始化 AI 提供商"""
        try:
            self.provider = AIProviderFactory.create_for_user(self.user_id, ai_config_id=self.ai_config_id)
            if self.provider:
                self.provider.set_log_callback(self._on_provider_log)
                self._log('info', f'已初始化 AI 提供商: {self.provider.provider_name} ({self.provider.model})')
            else:
                self._log('warning', '未找到可用的 AI 配置')
        except Exception as e:
            self._log('error', f'初始化 AI 提供商失败: {str(e)}')
            self.provider = None

    def _init_tool_executor(self):
        """初始化工具执行器"""
        # 创建沙箱配置
        # 读取权限：整个 ge10 目录
        # 写入权限：只允许对应方向的 output 目录和 /tmp
        sandbox = SandboxConfig.create_for_ctf_generation(
            output_base_dir=self.output_dir if self.output_dir else None,
            category_id=self.category_id
        )
        
        # 根据 category_id 设置工作目录
        working_dir = None
        if self.category_id:
            # 获取 ge10 目录
            # 从 ai_service.py 向上查找：ai_service.py -> ai_driver -> services -> app -> ctf
            # 路径: ctf/app/services/ai_driver/ai_service.py
            # parent(1): ai_driver, parent(2): services, parent(3): app, parent(4): ctf
            ctf_dir = Path(__file__).parent.parent.parent.parent
            # ge10 目录在 ctf 目录下：ctf/ge10
            # 设置为对应方向的目录，如 ge10/web/，使用 resolve() 获取绝对路径
            category_path = (ctf_dir / 'ge10' / self.category_id).resolve()
            working_dir = str(category_path)
            # 确保目录存在
            if not category_path.exists():
                # 如果方向目录不存在，回退到 ge10/
                ge10_path = (ctf_dir / 'ge10').resolve()
                working_dir = str(ge10_path)
                self._log('warning', f'方向目录不存在: {category_path}，使用默认目录: {ge10_path}')
            else:
                self._log('info', f'工作目录设置为: {working_dir}')
        
        # 创建工具执行器
        self.tool_executor = ToolExecutor(
            sandbox=sandbox,
            log_callback=self._on_tool_log,
            working_dir=working_dir
        )

    def _log(self, level: str, message: str):
        """记录日志"""
        if self.log_callback:
            self.log_callback(level, message)
        log_func = getattr(logger, level, logger.info)
        log_func(f'[AIService] {message}')

    def _on_provider_log(self, level: str, message: str):
        """AI 提供商日志回调"""
        self._log(level, f'[Provider] {message}')

    def _on_tool_log(self, level: str, message: str):
        """工具执行日志回调"""
        self._log(level, f'[Tool] {message}')


    def is_available(self) -> bool:
        """检查 AI 服务是否可用"""
        return self.provider is not None

    def get_provider_info(self) -> Dict[str, Any]:
        """获取当前 AI 提供商信息"""
        if self.provider:
            return {
                'type': self.provider.provider_type,
                'name': self.provider.provider_name,
                'model': self.provider.model,
                'base_url': self.provider.base_url
            }
        return None

    def set_system_prompt(self, prompt: str):
        """设置系统提示词
        
        Args:
            prompt: 系统提示词内容
        """
        self.system_prompt = prompt
        self._log('info', f'已设置系统提示词 ({len(prompt)} 字符)')

    def load_system_prompt_from_file(self, file_path: str):
        """从文件加载系统提示词
        
        Args:
            file_path: 提示词文件路径
        """
        try:
            path = Path(file_path)
            if path.exists():
                self.system_prompt = path.read_text(encoding='utf-8')
                self._log('info', f'已从文件加载系统提示词: {file_path}')
            else:
                self._log('warning', f'提示词文件不存在: {file_path}')
        except Exception as e:
            self._log('error', f'加载提示词文件失败: {str(e)}')

    def clear_history(self):
        """清空对话历史"""
        self.messages = []
        self.stage_artifacts = {}
        self.generated_files = []
        self._log('info', '已清空对话历史')

    # ==================== 上下文管理方法 ====================
    
    def _extract_stage_artifact(self, stage: int, content: str) -> Dict[str, Any]:
        """从 AI 输出中提取阶段关键信息（委托给 StageDetector）"""
        info = StageDetector.extract_stage_info(stage, content)
        artifact = {'stage': stage, 'raw_summary': ''}
        if info:
            artifact.update(info)
        return artifact
    
    def _save_stage_artifact(self, stage: int, content: str):
        """保存阶段关键信息
        
        Args:
            stage: 阶段号
            content: AI 输出内容
        """
        artifact = self._extract_stage_artifact(stage, content)
        self.stage_artifacts[stage] = artifact
        self._log('debug', f'保存阶段 {stage} 摘要: {artifact.get("raw_summary", "")}')
    
    def _track_generated_file(self, file_path: str):
        """跟踪生成的文件
        
        Args:
            file_path: 文件路径
        """
        if file_path and file_path not in self.generated_files:
            self.generated_files.append(file_path)
    
    def _build_context_summary(self) -> str:
        """构建上下文摘要（委托给 StageDetector）"""
        return StageDetector.build_context_summary(self.stage_artifacts, self.generated_files)
    
    def _compress_messages(self, messages: List[Dict], current_stage: int) -> List[Dict]:
        """压缩消息列表，保留关键信息
        
        策略：
        1. 保留系统提示词
        2. 保留阶段摘要
        3. 保留最近 N 条消息
        4. 截断过长的工具结果
        
        Args:
            messages: 原始消息列表
            current_stage: 当前阶段
            
        Returns:
            压缩后的消息列表
        """
        if len(messages) < self.context_config['compress_threshold']:
            return messages
        
        self._log('info', f'触发上下文压缩: {len(messages)} 条消息 → 压缩中...')
        
        compressed = []
        max_messages = self.context_config['max_messages']
        max_tool_len = self.context_config['max_tool_results_length']
        
        # 1. 保留系统提示词
        if messages and messages[0].get('role') == 'system':
            compressed.append(messages[0])
            messages = messages[1:]
        
        # 2. 插入上下文摘要（如果有之前阶段的信息）
        context_summary = self._build_context_summary()
        if context_summary:
            compressed.append({
                'role': 'user',
                'content': context_summary + "\n\n请基于以上历史信息继续当前任务。"
            })
            # 获取阶段ID用于获取阶段名称
            stage_id = self._get_stage_id_by_index(current_stage) if self._filtered_stages else str(current_stage)
            stage_name = self._get_stage_name(stage_id)
            compressed.append({
                'role': 'assistant',
                'content': f"好的，我已了解之前阶段的关键信息。当前在阶段 {current_stage}（{stage_name}），继续执行。"
            })
        
        # 3. 保留最近的消息，确保 tool 消息和对应的 tool_calls 成对保留
        recent_messages = messages[-max_messages:] if len(messages) > max_messages else messages
        
        # 检查第一条消息是否是 tool 消息，如果是则向前查找对应的 assistant 消息
        if recent_messages and recent_messages[0].get('role') == 'tool':
            start_idx = len(messages) - len(recent_messages)
            if start_idx > 0:
                for i in range(start_idx - 1, -1, -1):
                    msg = messages[i]
                    recent_messages = [msg] + recent_messages
                    if msg.get('role') == 'assistant' and msg.get('tool_calls'):
                        break
        
        for msg in recent_messages:
            new_msg = msg.copy()
            
            # 截断过长的工具结果
            if msg.get('role') == 'tool' and len(msg.get('content', '')) > max_tool_len:
                content = msg['content']
                new_msg['content'] = content[:max_tool_len] + f"\n... [已截断，原长度 {len(content)} 字符]"
            
            # 截断过长的助手消息（但保留关键结构）
            if msg.get('role') == 'assistant' and len(msg.get('content', '')) > 5000:
                content = msg['content']
                if '|' in content:
                    new_msg['content'] = content[:5000] + "\n... [已截断]"
                else:
                    new_msg['content'] = content[:3000] + "\n... [已截断]"
            
            compressed.append(new_msg)
        
        self._log('info', f'上下文压缩完成: {len(messages) + 1} → {len(compressed)} 条消息')
        return compressed
    
    def _extract_file_from_tool_call(self, tool_call: Dict) -> Optional[str]:
        """从工具调用中提取文件路径
        
        Args:
            tool_call: 工具调用信息
            
        Returns:
            文件路径或 None
        """
        try:
            func_name = tool_call.get('function', {}).get('name', '')
            args = tool_call.get('function', {}).get('arguments', '{}')
            if isinstance(args, str):
                args = json.loads(args)
            
            # 写文件工具
            if func_name in ['write_file', 'create_file']:
                return args.get('path') or args.get('file_path')
            
            # bash 命令中的重定向
            if func_name in ['bash', 'run_command']:
                cmd = args.get('command', '')
                # 匹配 > 或 >> 重定向
                match = re.search(r'>\s*([^\s]+)', cmd)
                if match:
                    return match.group(1)
        except:
            pass
        return None

    def chat(
        self,
        user_message: str,
        use_tools: bool = True,
        max_tool_iterations: int = 10
    ) -> Dict[str, Any]:
        """发送消息并获取响应
        
        支持工具调用的多轮对话
        
        Args:
            user_message: 用户消息
            use_tools: 是否启用工具调用
            max_tool_iterations: 最大工具调用迭代次数
            
        Returns:
            响应字典:
            - content: 最终文本响应
            - tool_calls_count: 工具调用次数
            - success: 是否成功
        """
        if not self.provider:
            return {
                'success': False,
                'content': '未配置 AI 提供商，请先在 AI 配置页面设置',
                'tool_calls_count': 0
            }

        # 构建消息列表
        messages = []
        
        # 添加系统提示词
        if self.system_prompt:
            messages.append({
                'role': 'system',
                'content': self.system_prompt
            })
        
        # 添加历史消息
        messages.extend(self.messages)
        
        # 添加当前用户消息
        messages.append({
            'role': 'user',
            'content': user_message
        })
        
        # 记录用户消息到历史
        self.messages.append({
            'role': 'user',
            'content': user_message
        })

        # 获取工具定义
        tools = ToolDefinition.get_ctf_tools() if use_tools else None
        
        tool_calls_count = 0
        iteration = 0
        
        try:
            while iteration < max_tool_iterations:
                iteration += 1
                
                # 调用 AI
                response = self.provider.chat(messages, tools=tools)
                
                # 检查是否有工具调用
                if response.get('tool_calls') and use_tools:
                    tool_calls_count += len(response['tool_calls'])
                    
                    # 添加助手消息（包含工具调用）
                    assistant_msg = {
                        'role': 'assistant',
                        'content': response.get('content', ''),
                        'tool_calls': response['tool_calls']
                    }
                    messages.append(assistant_msg)
                    
                    # 执行每个工具调用
                    for tool_call in response['tool_calls']:
                        tool_name = tool_call['function']['name']
                        self._log('info', f'执行工具: {tool_name}')
                        
                        # 执行工具
                        tool_result = self.tool_executor.execute_tool_call(tool_call)
                        
                        # 添加工具结果消息
                        messages.append({
                            'role': 'tool',
                            'tool_call_id': tool_call['id'],
                            'name': tool_name,
                            'content': tool_result
                        })
                    
                    # 继续循环，让 AI 处理工具结果
                    continue
                else:
                    # 没有工具调用，返回最终响应
                    final_content = response.get('content', '')
                    
                    # 记录助手响应到历史
                    self.messages.append({
                        'role': 'assistant',
                        'content': final_content
                    })
                    
                    return {
                        'success': True,
                        'content': final_content,
                        'tool_calls_count': tool_calls_count,
                        'usage': response.get('usage')
                    }
            
            # 达到最大迭代次数
            self._log('warning', f'达到最大工具调用迭代次数: {max_tool_iterations}')
            return {
                'success': True,
                'content': response.get('content', ''),
                'tool_calls_count': tool_calls_count,
                'warning': '达到最大工具调用次数限制'
            }
            
        except Exception as e:
            self._log('error', f'AI 调用失败: {str(e)}')
            return {
                'success': False,
                'content': f'AI 调用失败: {str(e)}',
                'tool_calls_count': tool_calls_count
            }

    def generate_ctf_challenge(self, language: str, vulnerabilities: List[str], scene: str, difficulty: str, extra_requirements: str = '', category_id: str = 'web', form_data: Dict[str, Any] = None, task_id: Optional[str] = None) -> Dict[str, Any]:
        """生成 CTF 题目（模拟 Augment 八轮流程）
        
        使用预设的提示词和工具，生成完整的 CTF 题目
        实现分阶段执行和验证，确保题目质量
        Args:
            language: 编程语言
            vulnerabilities: 漏洞类型列表
            scene: 场景
            difficulty: 难度（中文：入门/简单/中等/困难）
            extra_requirements: 用户额外要求
            category_id: 方向ID（如 'web'）
            form_data: 表单数据（可选）
            task_id: 任务ID（可选，用于创建独立工作目录避免多任务冲突）
            
        Returns:
            生成结果
        """
        import subprocess
        import datetime
        
        if not self.provider:
            return {
                'status': 'error',
                'message': '未配置 AI 提供商'
            }
        
        # 如果是 CLI 模式的服务（AnyRouter、AgentRouter、Augment），直接委托给它们的 generate_ctf_challenge 方法
        if hasattr(self.provider, 'provider_type') and self.provider.provider_type in ('anyrouter', 'agentrouter', 'augment'):
            provider_name = getattr(self.provider, 'provider_name', self.provider.provider_type)
            self._log('info', f'检测到 {provider_name} 提供商，使用 CLI 模式生成')
            # 设置日志回调
            if hasattr(self.provider, 'set_log_callback'):
                self.provider.set_log_callback(self.log_callback)
            # 直接调用 CLI 服务的 generate_ctf_challenge
            return self.provider.generate_ctf_challenge(
                language=language,
                vulnerabilities=vulnerabilities,
                scene=scene.get('name') if isinstance(scene, dict) else scene,
                difficulty=difficulty,
                extra_requirements=extra_requirements,
                category_id=category_id,
                form_data=form_data,
                task_id=task_id
            )

        # 设置日志文件，让前端能够显示日志
        project_root = Path(__file__).parent.parent.parent.parent
        ge10_dir = project_root / 'ge10'  # 在所有分支中都需要使用，提前定义
        
        log_dir = project_root / 'logs'
        log_dir.mkdir(exist_ok=True)
        
        # 日志文件名包含 task_id 以确保唯一性（如果提供了 task_id）
        if task_id:
            # 使用 task_id 的一部分（去掉时间戳前缀，只保留 UUID 部分）来创建唯一文件名
            task_suffix = task_id.split('-')[-1][:8] if '-' in task_id else task_id[:8]
            log_file = log_dir / f"ai_service_{task_suffix}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.txt"
        else:
            # 如果没有 task_id，使用时间戳和进程ID确保唯一性
            import os
            log_file = log_dir / f"ai_service_{os.getpid()}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.txt"
        
        # 设置日志文件路径到 generation_statuses（按 task_id 索引）
        try:
            from app.routes.generator.utils import generation_statuses, generation_lock
            if task_id:
                with generation_lock:
                    if task_id not in generation_statuses:
                        generation_statuses[task_id] = {}
                    generation_statuses[task_id]["log_file"] = str(log_file)
                    generation_statuses[task_id]["log_position"] = 0
            else:
                # 向后兼容：如果没有 task_id，使用全局 generation_status
                from app.routes.generator.utils import generation_status
                generation_status["log_file"] = str(log_file)
                generation_status["log_position"] = 0
            self._log('info', f'日志文件: {log_file} (task_id: {task_id})')
        except Exception as e:
            self._log('warning', f'设置日志文件失败: {e}')
        
        # 使用统一的日志格式器
        from app.services.ai.core.log_formatter import UnifiedLogWriter, LogLevel
        
        # 创建统一日志写入器
        log_writer = UnifiedLogWriter(str(log_file), "AIService")
        log_writer.__enter__()
        
        # 保存原始日志回调
        original_log_callback = self.log_callback
        
        # 包装日志回调，使用统一格式写入文件
        def enhanced_log_callback(level: str, message: str):
            # 映射日志级别
            level_map = {
                'debug': LogLevel.DEBUG,
                'info': LogLevel.INFO,
                'warning': LogLevel.WARNING,
                'warn': LogLevel.WARNING,
                'error': LogLevel.ERROR,
                'success': LogLevel.SUCCESS
            }
            log_level = level_map.get(level.lower(), LogLevel.INFO)
            log_writer.write(message, level=log_level)
            if original_log_callback:
                original_log_callback(level, message)
        
        self.log_callback = enhanced_log_callback
        self._log_writer = log_writer  # 保存引用以便后续使用
        
        # 创建兼容的 write_to_log 函数，使用统一格式
        def write_to_log(message: str):
            """写入日志到文件（使用统一格式）"""
            # 解析消息中的级别标记
            level = LogLevel.INFO
            clean_message = message
            
            # 检测消息中的级别标记
            if message.startswith('\n'):
                clean_message = message[1:]
            if '[上下文管理]' in message:
                level = LogLevel.SYSTEM
            elif '[系统]' in message:
                level = LogLevel.SYSTEM
            elif '[轮次' in message and '] 发送给 AI 的消息:' in message:
                level = LogLevel.INFO
                log_writer.write_separator()
            elif '[轮次' in message and '] AI 响应:' in message:
                level = LogLevel.INFO
                log_writer.write_separator()
            elif '[轮次' in message and '] 工具执行结果' in message:
                level = LogLevel.TOOL
            elif '工具:' in message or '参数:' in message or '结果:' in message:
                level = LogLevel.TOOL
            elif '=' * 60 in message:
                log_writer.write_separator()
                return
            
            log_writer.write(clean_message.strip(), level=level)
        
        # 将 write_to_log 保存为实例变量，供后续使用
        self._write_to_log = write_to_log

        # 加载数据库中的 Prompt 配置（这会设置 _filtered_stages）
        system_prompt = self._load_prompt_from_database(category_id, difficulty, language, vulnerabilities, scene, form_data or {})
        
        # 初始化阶段状态并保存总阶段数（使用 task_id 避免多任务冲突）
        try:
            from app.routes.generator.utils import generation_statuses, generation_lock
            total_stages = len(self._filtered_stages) if self._filtered_stages else 0
            if total_stages > 0:
                # 使用 task_id 更新对应任务的状态
                if task_id:
                    with generation_lock:
                        if task_id not in generation_statuses:
                            generation_statuses[task_id] = {}
                        generation_statuses[task_id]["total_stages"] = total_stages
                        
                        # 初始化 step_statuses（动态初始化，基于实际阶段数）
                        if "step_statuses" not in generation_statuses[task_id]:
                            generation_statuses[task_id]["step_statuses"] = {}
                        
                        # 初始化所有阶段状态为 waiting
                        for i in range(total_stages):
                            if i not in generation_statuses[task_id]["step_statuses"]:
                                generation_statuses[task_id]["step_statuses"][i] = "waiting"
                        
                        self._log('info', f'初始化 {total_stages} 个阶段 (task_id: {task_id})')
                else:
                    # 向后兼容：如果没有 task_id，使用全局 generation_status（不推荐）
                    from app.routes.generator.utils import generation_status
                    generation_status["total_stages"] = total_stages
                    if "step_statuses" not in generation_status:
                        generation_status["step_statuses"] = {}
                    for i in range(total_stages):
                        if i not in generation_status["step_statuses"]:
                            generation_status["step_statuses"][i] = "waiting"
                    self._log('info', f'初始化 {total_stages} 个阶段（向后兼容模式）')
            else:
                self._log('warning', '未找到阶段配置，阶段状态将由配置动态决定')
        except Exception as e:
            self._log('warning', f'初始化阶段状态失败: {e}')
        
        if system_prompt:
            self.set_system_prompt(system_prompt)
            self._log('info', f'已从数据库加载系统 Prompt ({len(system_prompt)} 字符)')
        else:
            # 不再回退到文件系统，必须从数据库配置
            self._log('error', f'数据库中未找到方向 {category_id} 难度 {difficulty} 的 Prompt 配置，请先在后台配置 Prompt')
            return {
                'status': 'error',
                'message': f'未找到方向 {category_id} 难度 {difficulty} 的 Prompt 配置，请先在后台配置 Prompt'
            }

        # 根据表单配置动态构建用户指令
        instruction_lines = []
        
        if form_data:
            try:
                from app.models.database.models import CategoryConfig
                category = CategoryConfig.query.get(category_id)
                if category:
                    form_fields = category.get_form_fields()
                    
                    # 按照表单字段配置顺序，动态生成指令
                    for field in form_fields:
                        field_id = field.get('id')
                        field_label = field.get('label', field_id)  # 使用字段标签或ID
                        field_type = field.get('type')
                        field_visible = field.get('visible', True)  # 默认可见
                        
                        # 跳过不可见的字段和特殊字段
                        if not field_visible or field_id in ['extra_requirements']:
                            continue
                        
                        # 获取字段值
                        field_value = form_data.get(field_id)
                        if field_value is None or field_value == '':
                            continue  # 跳过空值
                        
                        # 根据字段类型格式化值
                        formatted_value = None
                        
                        if field_type in ['multi_select', 'multi_select_categorized']:
                            # 多选字段：格式化为逗号分隔的字符串
                            if isinstance(field_value, list):
                                formatted_value = ', '.join([str(v.get('name', v) if isinstance(v, dict) else v) for v in field_value if v])
                            elif isinstance(field_value, str):
                                formatted_value = field_value
                            else:
                                formatted_value = str(field_value)
                        
                        elif field_type == 'select':
                            # 单选字段：可能是字符串、字典或选项值
                            if isinstance(field_value, dict):
                                formatted_value = field_value.get('name', str(field_value))
                            else:
                                # 如果是选项值，尝试从字段配置中找到对应标签
                                options = field.get('options', [])
                                for opt in options:
                                    opt_value = opt.get('value') if isinstance(opt, dict) else opt
                                    if opt_value == field_value:
                                        formatted_value = opt.get('label', opt.get('value', str(opt))) if isinstance(opt, dict) else str(opt)
                                        break
                                if not formatted_value:
                                    formatted_value = str(field_value)
                        
                        elif field_type == 'select_with_sub':
                            # 带子选项的选择：可能是字典
                            if isinstance(field_value, dict):
                                # 检查是否标记为无场景
                                if field_value.get('no_scene'):
                                    continue  # 跳过无场景的情况
                                scene_name = field_value.get('name', str(field_value))
                                # 如果场景名称为"无场景"或空，也跳过
                                if not scene_name or scene_name.strip() in ['无场景', '无特定场景', 'none', 'None', '']:
                                    continue
                                # 确保场景名称不是空字符串
                                if not scene_name or not scene_name.strip():
                                    continue
                                sub_scene = field_value.get('sub_scene')
                                if sub_scene:
                                    if isinstance(sub_scene, dict):
                                        sub_scene_name = sub_scene.get('name', '')
                                        if sub_scene_name:
                                            formatted_value = f"{scene_name} - {sub_scene_name}"
                                        else:
                                            formatted_value = scene_name
                                    else:
                                        formatted_value = f"{scene_name} - {sub_scene}"
                                else:
                                    formatted_value = scene_name
                            else:
                                # 如果是字符串，检查是否为"无场景"
                                if str(field_value).strip() in ['无场景', '无特定场景', 'none', 'None', '']:
                                    continue
                                formatted_value = str(field_value)
                        
                        else:
                            # 其他类型：直接转换为字符串
                            formatted_value = str(field_value)
                        
                        # 如果有格式化后的值，添加到指令中
                        if formatted_value and formatted_value.strip():
                            # 跳过"无场景"、"无特定场景"等无意义的值
                            if formatted_value.strip() in ['无场景', '无特定场景', 'none', 'None', '']:
                                continue
                            instruction_lines.append(f"- {field_label}：{formatted_value}")
                
            except Exception as e:
                # 如果出错，回退到简化的逻辑
                self._log('warning', f'根据表单配置构建指令时出错: {e}')
                instruction_lines = [
                    f"- 难度：{difficulty}",
                    f"- 漏洞类型：{', '.join(vulnerabilities) if isinstance(vulnerabilities, list) else vulnerabilities}",
                    f"- 编程语言：{language}"
                ]
        
        # 如果没有构建成功，使用默认格式
        if not instruction_lines:
            vuln_str = ', '.join(vulnerabilities) if isinstance(vulnerabilities, list) else vulnerabilities
            instruction_lines = [
                f"- 难度：{difficulty}",
                f"- 漏洞类型：{vuln_str}",
                f"- 编程语言：{language}"
            ]
        
        # 添加额外要求（如果有）
        if extra_requirements and extra_requirements.strip():
            instruction_lines.append(f"- 额外要求：{extra_requirements.strip()}")
        
        user_instruction = '\n'.join(instruction_lines) + '\n'

        # 记录日志
        vuln_str = ', '.join(vulnerabilities) if isinstance(vulnerabilities, list) else vulnerabilities
        self._log('info', f'开始生成 CTF 题目: {difficulty} / {language} / {vuln_str}')
        self._log('info', f'用户指令内容:\n{user_instruction}')
        
        # 第一个阶段: 用户输入需求（基于配置）
        first_stage_index = 0
        if self._filtered_stages:
            first_stage = self._filtered_stages[0]
            first_stage_name = first_stage.get('name', '用户输入需求')
            self._update_stage_progress(first_stage_index, 'processing', f'正在处理: {first_stage_name}...')
            self._update_stage_progress(first_stage_index, 'completed', f'{first_stage_name}已确认')
        else:
            # 向后兼容：如果没有配置，使用默认
            self._update_stage_progress(0, 'processing', '正在处理用户需求...')
            self._update_stage_progress(0, 'completed', '用户需求已确认')
        
        # 当前阶段跟踪
        current_stage = first_stage_index
        output_dir = None
        total_tool_calls = 0
        max_iterations = 100  # 最大迭代次数
        
        # 阶段迭代计数（用于检测卡住）
        stage_iteration_count = 0
        max_stage_iterations = 15  # 每个阶段最多 15 轮迭代
        
        # 跟踪最近的工具调用（用于检测重复/循环执行）
        recent_tool_calls = []  # 存储最近6次工具调用名称（用于检测4次重复）
        
        # 重置上下文管理状态
        self.stage_artifacts = {}
        self.generated_files = []
        
        # 构建消息列表
        messages = []
        if self.system_prompt:
            messages.append({'role': 'system', 'content': self.system_prompt})
        messages.append({'role': 'user', 'content': user_instruction})
        
        # 获取工具定义
        tools = ToolDefinition.get_ctf_tools()
        
        # 上次压缩时的阶段（避免频繁压缩）
        last_compress_stage = -1
        
        try:
            iteration = 0
            while iteration < max_iterations:
                iteration += 1
                
                # === 检查是否请求取消生成 ===
                from app.routes.generator.utils import is_generation_cancelled
                if is_generation_cancelled():
                    self._log('warning', '检测到生成取消请求，停止生成')
                    self._write_to_log("[系统] 检测到生成取消请求，停止生成")
                    # 标记所有阶段为取消状态
                    if self._filtered_stages:
                        for idx in range(len(self._filtered_stages)):
                            self._update_stage_progress(idx, 'cancelled', '用户取消')
                    return {
                        'status': 'error',
                        'message': '生成已取消',
                        'tool_calls_count': total_tool_calls,
                        'output_dir': output_dir
                    }
                
                # === 上下文管理：检查是否需要压缩 ===
                if (len(messages) >= self.context_config['compress_threshold'] and 
                    current_stage > last_compress_stage):
                    self._write_to_log(f"[上下文管理] 消息数 {len(messages)} 达到阈值，触发压缩...")
                    messages = self._compress_messages(messages, current_stage)
                    last_compress_stage = current_stage
                    self._write_to_log(f"[上下文管理] 压缩后消息数: {len(messages)}")
                
                # 记录发送给 AI 的最后一条消息
                if messages:
                    last_msg = messages[-1]
                    if last_msg['role'] == 'user':
                        self._write_to_log(f"[轮次 {iteration}] 发送给 AI 的消息:")
                        msg_content = last_msg['content']
                        # 截断过长的内容
                        if len(msg_content) > 2000:
                            self._write_to_log(msg_content[:2000] + f"\n... (截断，共 {len(msg_content)} 字符)")
                        else:
                            self._write_to_log(msg_content)
                    elif last_msg['role'] == 'tool':
                        tool_name = last_msg.get('name', 'unknown')
                        tool_content = last_msg['content']
                        if len(tool_content) > 1000:
                            self._log_writer.write_tool(tool_name, tool_content[:1000] + f"\n... (截断，共 {len(tool_content)} 字符)")
                        else:
                            self._log_writer.write_tool(tool_name, tool_content)
                
                # 调用 AI
                response = self.provider.chat(messages, tools=tools)
                content = response.get('content', '')
                
                # 记录 AI 响应
                if content:
                    self._write_to_log(f"[轮次 {iteration}] AI 响应:")
                    if len(content) > 10000:
                        self._write_to_log(content[:10000] + f"\n... (截断，共 {len(content)} 字符)")
                    else:
                        self._write_to_log(content)
                
                # 检测阶段变化
                # 先尝试检测数字格式的阶段编号（向后兼容）
                new_stage_num = self._detect_stage(content)
                
                # 如果检测到数字阶段，尝试映射到配置中的阶段索引
                # AI输出的阶段编号从1开始（阶段1, 阶段2...），配置索引从0开始（0, 1, 2...）
                # 所以：AI阶段编号 - 1 = 配置索引
                # 例如：AI阶段1 -> 索引0, AI阶段2 -> 索引1, ...
                new_stage = None
                if new_stage_num is not None:
                    if self._filtered_stages:
                        # 直接使用索引映射：AI阶段编号从1开始，配置索引从0开始
                        stage_index = new_stage_num - 1
                        if 0 <= stage_index < len(self._filtered_stages):
                            new_stage = stage_index
                        elif new_stage_num > 0:
                            # 如果超出范围，忽略此次阶段检测，保持在当前阶段
                            # AI可能输出了错误的阶段编号（如阶段7），不应该强制映射到最后一个阶段
                            self._log('warning', f'检测到阶段编号 {new_stage_num} 超出配置范围（共{len(self._filtered_stages)}个阶段，索引0-{len(self._filtered_stages)-1}），忽略此次阶段切换，保持在当前阶段 {current_stage}')
                            new_stage = None  # 忽略，保持在当前阶段
                    else:
                        # 没有配置，AI的阶段编号从1开始，索引从0开始
                        new_stage = new_stage_num - 1 if new_stage_num > 0 else 0
                
                # 如果仍然没有检测到，尝试从文本中直接匹配配置中的阶段 ID 和名称
                if new_stage is None and self._filtered_stages:
                    for idx, stage in enumerate(self._filtered_stages):
                        stage_id = str(stage.get('id', ''))
                        stage_name = stage.get('name', '')
                        # 检查文本中是否包含阶段 ID 或阶段名称
                        if (f'阶段{stage_id}' in content or f'阶段 {stage_id}' in content or 
                            stage_name in content):
                            new_stage = idx
                            break
                
                if new_stage is not None and new_stage != current_stage:
                    # 如果检测到阶段回退，可能是AI的错误输出，忽略或者记录警告
                    if new_stage < current_stage:
                        current_stage_id = self._get_stage_id_by_index(current_stage) if self._filtered_stages else str(current_stage)
                        new_stage_id = self._get_stage_id_by_index(new_stage) if self._filtered_stages else str(new_stage)
                        self._log('warning', f'检测到阶段回退：从阶段 {current_stage} (ID: {current_stage_id}) 回退到阶段 {new_stage} (ID: {new_stage_id})，忽略此变更以保护已完成的工作')
                        # 不处理阶段回退，继续使用当前阶段，跳过阶段切换处理
                        # 将 new_stage 设为 None，让代码进入 else 分支（同一阶段处理）
                        new_stage = None
                    elif new_stage > current_stage:
                        # 阶段推进：正常处理
                        # === 上下文管理：保存当前阶段的关键信息 ===
                        if content:
                            self._save_stage_artifact(current_stage, content)
                        # 完成当前阶段和所有中间阶段
                        for s in range(current_stage, new_stage):
                            self._update_stage_progress(s, 'completed', f'阶段 {s} 完成')
                        # 开始新阶段
                        current_stage = new_stage
                        stage_iteration_count = 0  # 重置阶段迭代计数
                        # 获取阶段ID（而不是索引）用于获取阶段名称
                        stage_id = self._get_stage_id_by_index(current_stage) if self._filtered_stages else str(current_stage)
                        stage_name = self._get_stage_name(stage_id)
                        self._update_stage_progress(current_stage, 'processing', f'正在执行: {stage_name}')
                        self._log('info', f'进入阶段 {current_stage} (ID: {stage_id}): {stage_name}')

                else:
                    # 同一阶段，增加迭代计数
                    stage_iteration_count += 1
                    
                    # 获取配置中的阶段列表，确定最大阶段索引
                    max_stage_index = len(self._filtered_stages) - 1 if self._filtered_stages else 5
                    
                    # 检查是否在当前阶段停留太久
                    if stage_iteration_count >= max_stage_iterations and current_stage < max_stage_index:
                        # 强制推进到下一阶段
                        next_stage = current_stage + 1
                        self._update_stage_progress(current_stage, 'completed', f'阶段 {current_stage} 完成（超时推进）')
                        current_stage = next_stage
                        stage_iteration_count = 0
                        # 获取阶段ID（而不是索引）用于获取阶段名称
                        stage_id = self._get_stage_id_by_index(current_stage) if self._filtered_stages else str(current_stage)
                        stage_name = self._get_stage_name(stage_id)
                        self._update_stage_progress(current_stage, 'processing', f'正在执行: {stage_name}')
                        self._log('warning', f'阶段超时，强制推进到阶段 {current_stage} (ID: {stage_id}): {stage_name}')
                        
                        # 基于配置的阶段列表生成引导消息
                        if self._filtered_stages and current_stage < len(self._filtered_stages):
                            stage_config = self._filtered_stages[current_stage]
                            stage_id = str(stage_config.get('id', current_stage))
                            stage_name = stage_config.get('name', f'阶段 {stage_id}')
                            prompt_msg = f"请开始阶段{stage_id}：{stage_name}。"
                            messages.append({'role': 'user', 'content': prompt_msg})
                            self._write_to_log(f"[系统] 强制推进到阶段 {current_stage}（{stage_id}: {stage_name}），发送引导消息")

                
                # 检测输出目录
                if output_dir is None:
                    detected_dir = self._detect_output_dir(content)
                    if detected_dir:
                        output_dir = detected_dir
                        self._log('info', f'检测到输出目录: {output_dir}')
                
                # 检查是否有工具调用
                if response.get('tool_calls'):
                    total_tool_calls += len(response['tool_calls'])
                    
                    # 记录工具调用信息
                    self._write_to_log(f"[轮次 {iteration}] AI 请求执行 {len(response['tool_calls'])} 个工具:")
                    
                    # 添加助手消息
                    assistant_msg = {
                        'role': 'assistant',
                        'content': content,
                        'tool_calls': response['tool_calls']
                    }
                    messages.append(assistant_msg)
                    
                    # 执行每个工具调用
                    for tool_call in response['tool_calls']:
                        tool_name = tool_call['function']['name']
                        tool_args = tool_call.get('function', {}).get('arguments', '{}')
                        
                        # 记录工具调用详情（使用统一格式）
                        if len(tool_args) > 500:
                            self._log_writer.write_tool(tool_name, f"参数: {tool_args[:500]}... (截断)")
                        else:
                            self._log_writer.write_tool(tool_name, f"参数: {tool_args}")
                        
                        self._log('info', f'执行工具: {tool_name}')
                        
                        # 跟踪最近的工具调用（用于检测循环执行）
                        # 对于 list_directory 和 read_file，存储工具名和关键参数（路径）来区分不同的路径
                        try:
                            import json
                            parsed_args = json.loads(tool_args) if isinstance(tool_args, str) else tool_args
                            if tool_name in ['list_directory', 'read_file'] and 'path' in parsed_args:
                                # 存储工具名和路径，用于区分不同路径的调用
                                tool_call_key = f"{tool_name}:{parsed_args['path']}"
                            else:
                                # 其他工具只存储工具名
                                tool_call_key = tool_name
                        except:
                            # 解析失败时只存储工具名
                            tool_call_key = tool_name
                        
                        recent_tool_calls.append(tool_call_key)
                        if len(recent_tool_calls) > 6:  # 保留最近6次，用于检测4次重复
                            recent_tool_calls.pop(0)
                        
                        # 从工具调用中检测输出目录
                        if output_dir is None:
                            output_dir = self._detect_output_dir_from_tool(tool_call)
                        
                        # === 上下文管理：跟踪生成的文件 ===
                        generated_file = self._extract_file_from_tool_call(tool_call)
                        if generated_file:
                            self._track_generated_file(generated_file)
                        
                        # 执行工具
                        tool_result = self.tool_executor.execute_tool_call(tool_call)
                        
                        # 记录工具执行结果（使用统一格式）
                        if len(tool_result) > 500:
                            self._log_writer.write_tool(tool_name, f"结果: {tool_result[:500]}... (截断，共 {len(tool_result)} 字符)")
                        else:
                            self._log_writer.write_tool(tool_name, f"结果: {tool_result}")
                        
                        # 添加工具结果消息
                        messages.append({
                            'role': 'tool',
                            'tool_call_id': tool_call['id'],
                            'name': tool_name,
                            'content': tool_result
                        })
                    
                    # 检测重复工具调用（循环执行）
                    # 只检测可能导致循环的只读工具（如 list_directory），不影响正常的写文件操作
                    # 如果在最后阶段且有输出目录，且最近连续多次执行相同的只读工具，强制停止
                    if self._filtered_stages:
                        is_last_stage = (current_stage >= len(self._filtered_stages) - 1)
                        if is_last_stage and output_dir:
                            # 定义可能导致循环的只读工具列表（这些工具重复执行通常无意义）
                            read_only_tools = ['list_directory', 'read_file']
                            
                            # 检查最近是否连续执行相同的只读工具（至少4次，更严格）
                            # 注意：这里检查的是工具名和路径的组合，相同路径的重复调用才会被检测为循环
                            # 如果路径不同，则认为是正常的探索过程，不是循环
                            if len(recent_tool_calls) >= 4:
                                last_four = recent_tool_calls[-4:]
                                # 提取工具名（如果包含冒号，则取冒号前的部分）
                                first_tool = last_four[0].split(':')[0] if ':' in last_four[0] else last_four[0]
                                
                                # 只有当工具名在只读工具列表中，且4次调用完全相同（工具名+路径）时，才认为是循环
                                if (last_four[0] == last_four[1] == last_four[2] == last_four[3] and 
                                    first_tool in read_only_tools):
                                    # 提取路径信息用于日志
                                    path_info = last_four[0].split(':', 1)[1] if ':' in last_four[0] else ''
                                    path_msg = f" (路径: {path_info})" if path_info else ""
                                    self._log('warning', f'检测到重复执行只读工具 {first_tool} 4次{path_msg}，在最后阶段强制停止')
                                    self._write_to_log(f"[系统] 检测到重复执行只读工具 {first_tool} 4次{path_msg}，强制停止循环执行")
                                    # 标记当前阶段为完成
                                    self._update_stage_progress(current_stage, 'completed', '检测到循环执行，强制完成')
                                    # 标记所有阶段完成并返回
                                    if self._filtered_stages:
                                        for idx in range(len(self._filtered_stages)):
                                            self._update_stage_progress(idx, 'completed', '')
                                    self._log('info', '题目生成完成!')
                                    return {
                                        'status': 'success',
                                        'response': content,
                                        'tool_calls_count': total_tool_calls,
                                        'output_dir': output_dir
                                    }
                    
                    continue
                else:
                    # 没有工具调用，检查是否完成
                    # 如果 AI 认为完成了，进行验证
                    last_stage_id = self._get_last_stage_id()
                    # 判断是否到达最后一个阶段（基于配置的阶段列表索引）
                    is_last_stage = False
                    if self._filtered_stages:
                        is_last_stage = (current_stage >= len(self._filtered_stages) - 1)
                    
                    if is_last_stage and output_dir:
                        # 最后一个阶段: 成品输出 - 直接完成
                        last_stage_index = len(self._filtered_stages) - 1 if self._filtered_stages else current_stage
                        self._update_stage_progress(last_stage_index, 'completed', '题目生成完成')
                        self._log('info', '题目生成完成!')
                        
                        # 标记所有阶段完成
                        if self._filtered_stages:
                            for idx in range(len(self._filtered_stages)):
                                self._update_stage_progress(idx, 'completed', '')
                        
                        return {
                            'status': 'success',
                            'response': content,
                            'tool_calls_count': total_tool_calls,
                            'output_dir': output_dir
                        }
                    else:
                        # 还没到验证阶段，继续生成
                        messages.append({'role': 'assistant', 'content': content})
                        
                        # 检查是否所有阶段都已完成
                        all_stages_completed = False
                        if self._filtered_stages:
                            from app.routes.generator.utils import generation_status
                            step_statuses = generation_status.get("step_statuses", {})
                            completed_count = 0
                            total_count = len(self._filtered_stages)
                            for idx in range(total_count):
                                status = step_statuses.get(idx)
                                if status == 'completed':
                                    completed_count += 1
                            # 如果所有阶段都已完成，标记为完成
                            if completed_count >= total_count:
                                all_stages_completed = True
                        
                        # 如果所有阶段都已完成且有输出目录，直接返回
                        if all_stages_completed and output_dir:
                            self._log('info', '检测到所有阶段已完成，返回结果')
                            self._update_stage_progress(len(self._filtered_stages) - 1 if self._filtered_stages else current_stage, 'completed', '题目生成完成')
                            return {
                                'status': 'success',
                                'response': content,
                                'tool_calls_count': total_tool_calls,
                                'output_dir': output_dir
                            }
                        
                        # 如果 AI 停止了但还没完成，提示继续（但要确保没有重复发送）
                        # 重要：只有当确实还有阶段未完成时，才发送继续消息
                        if not is_last_stage and self._filtered_stages and not all_stages_completed:
                            # 检查当前阶段是否已经完成（避免重复发送消息）
                            from app.routes.generator.utils import generation_status
                            step_statuses = generation_status.get("step_statuses", {})
                            current_status = step_statuses.get(current_stage)
                            
                            # 只有当当前阶段确实还在处理中，且不是已完成状态时，才发送继续消息
                            if current_status != 'completed':
                                last_stage_config = self._filtered_stages[-1]
                                last_stage_id = str(last_stage_config.get('id', ''))
                                last_stage_name = last_stage_config.get('name', '成品输出')
                                continue_msg = f"请继续执行，当前在阶段 {current_stage}，需要完成到阶段 {last_stage_id}（{last_stage_name}）。"
                                messages.append({'role': 'user', 'content': continue_msg})
                                continue
                            else:
                                # 当前阶段已完成，但还有后续阶段，直接推进到下一阶段
                                if current_stage < len(self._filtered_stages) - 1:
                                    current_stage += 1
                                    # 获取阶段ID（而不是索引）用于获取阶段名称
                                    stage_id = self._get_stage_id_by_index(current_stage) if self._filtered_stages else str(current_stage)
                                    stage_name = self._get_stage_name(stage_id)
                                    self._update_stage_progress(current_stage, 'processing', f'正在执行: {stage_name}')
                                    self._log('info', f'自动推进到阶段 {current_stage} (ID: {stage_id}): {stage_name}')
                                    continue
                        
                        # 如果没有输出目录，提示创建
                        if output_dir is None:
                            messages.append({
                                'role': 'user', 
                                'content': '请创建输出目录并生成题目文件。输出目录格式：output/YYYYMMDD_HHMMSS_题目名称/'
                            })
                            continue
                        
                        break
            
            # 达到最大迭代次数
            self._log('warning', f'达到最大迭代次数: {max_iterations}')
            
            # 如果有输出目录，标记为完成并返回
            if output_dir:
                if self._filtered_stages:
                    for idx in range(len(self._filtered_stages)):
                        self._update_stage_progress(idx, 'completed', '')
                else:
                    # 使用实际的阶段数量，如果没有配置则标记为完成状态（不强制8个阶段）
                    if self._filtered_stages:
                        for idx in range(len(self._filtered_stages)):
                            self._update_stage_progress(idx, 'completed', '')
                    else:
                        # 如果没有配置，标记当前步骤为完成
                        self._update_stage_progress(current_stage, 'completed', '')
                return {
                    'status': 'success',
                    'response': '达到最大迭代次数，但生成已完成',
                    'tool_calls_count': total_tool_calls,
                    'output_dir': output_dir,
                    'warning': '达到最大迭代次数'
                }
            
            return {
                'status': 'error',
                'message': f'达到最大迭代次数 ({max_iterations})，生成未完成',
                'tool_calls_count': total_tool_calls
            }
            
        except Exception as e:
            self._log('error', f'生成失败: {str(e)}')
            return {
                'status': 'error',
                'message': f'生成失败: {str(e)}',
                'tool_calls_count': total_tool_calls
            }

    def _detect_stage(self, content: str) -> Optional[int]:
        """从 AI 输出中检测当前阶段（委托给 StageDetector）"""
        return StageDetector.detect_stage(content)

    def _detect_output_dir(self, content: str) -> Optional[str]:
        """从内容中检测输出目录（委托给 StageDetector）"""
        project_root = Path(__file__).parent.parent.parent.parent
        ge10_dir = str(project_root / 'ge10')
        return StageDetector.detect_output_dir(content, ge10_dir)

    def _detect_output_dir_from_tool(self, tool_call: Dict) -> Optional[str]:
        """从工具调用中检测输出目录
        
        Args:
            tool_call: 工具调用信息
            
        Returns:
            输出目录路径或 None
        """
        import re
        import json
        
        try:
            args = tool_call.get('function', {}).get('arguments', '{}')
            if isinstance(args, str):
                args = json.loads(args)
            
            # 检查各种可能包含路径的参数
            for key in ['path', 'file_path', 'directory', 'command']:
                if key in args:
                    value = args[key]
                    # 匹配输出目录模式（支持新格式：ge10/{category}/output/{challenge}）
                    match = re.search(r'(?:ge10/[^/]+/)?output/(\d{8}_\d{6}_[^/\s"\']+)', str(value))
                    if match:
                        dir_name = match.group(1)
                        project_root = Path(__file__).parent.parent.parent.parent
                        # 尝试在所有方向的 output 目录中查找
                        ge10_dir = project_root / 'ge10'
                        for category_dir in ge10_dir.iterdir():
                            if category_dir.is_dir() and not category_dir.name.startswith('.'):
                                full_path = category_dir / 'output' / dir_name
                                if full_path.exists():
                                    return str(full_path)
        except:
            pass
        
        return None

    def _get_stage_name(self, stage_id: str) -> str:
        """获取阶段名称（优先从配置中获取，否则委托给 StageDetector）"""
        # 首先尝试从配置的阶段列表中获取
        if self._filtered_stages:
            for stage in self._filtered_stages:
                if str(stage.get('id', '')) == str(stage_id):
                    return stage.get('name', f'阶段 {stage_id}')
        
        # 如果配置中没有，尝试使用数字 ID 从 StageDetector 获取（传入配置）
        try:
            numeric_id = int(stage_id) if stage_id.replace('.', '').isdigit() else None
            if numeric_id is not None:
                # 传入阶段配置，让StageDetector从配置中获取名称
                return StageDetector.get_stage_name(numeric_id, self._filtered_stages)
        except:
            pass
        
        return f'阶段 {stage_id}'
    
    def _get_last_stage_id(self) -> Optional[str]:
        """获取配置中的最后一个阶段 ID"""
        if self._filtered_stages:
            last_stage = self._filtered_stages[-1]
            return str(last_stage.get('id', ''))
        return None
    
    def _get_stage_index_by_id(self, stage_id: str) -> Optional[int]:
        """根据阶段 ID 获取阶段索引"""
        if not self._filtered_stages:
            return None
        for idx, stage in enumerate(self._filtered_stages):
            if str(stage.get('id', '')) == str(stage_id):
                return idx
        return None
    
    def _get_stage_id_by_index(self, index: int) -> Optional[str]:
        """根据阶段索引获取阶段 ID"""
        if self._filtered_stages and 0 <= index < len(self._filtered_stages):
            return str(self._filtered_stages[index].get('id', ''))
        return None

    def _update_stage_progress(self, stage: int, status: str, message: str):
        """更新阶段进度（兼容前端显示）
        
        Args:
            stage: 阶段索引 (0-based)
            status: 状态 (waiting, processing, completed, error)
            message: 消息
        """
        # 获取阶段ID用于获取阶段名称
        stage_id = self._get_stage_id_by_index(stage) if self._filtered_stages else str(stage)
        stage_name = self._get_stage_name(stage_id)
        self._log('debug', f'阶段 {stage} (ID: {stage_id}, {stage_name}): {status} - {message}')

    def _load_prompt_from_database(self, category_id: str, difficulty: str, language: str, 
                                   vulnerabilities: List[str], scene: str, form_data: Dict[str, Any]) -> Optional[str]:
        """从数据库加载并编译 Prompt
        
        Args:
            category_id: 方向ID
            difficulty: 难度（中文：入门/简单/中等/困难）
            language: 编程语言
            vulnerabilities: 漏洞列表
            scene: 场景
            form_data: 表单数据
            
        Returns:
            编译后的完整 Prompt 字符串，如果加载失败则返回 None
        """
        try:
            from app.models.database.models import CategoryConfig
            from app.services.prompt.generator import StagePromptGenerator
            from flask import has_app_context, current_app
            
            # 检查是否在应用上下文中（调用时应该在上下文中）
            if not has_app_context():
                self._log('warning', '不在 Flask 应用上下文中，无法访问数据库')
                return None
            
            category = CategoryConfig.query.get(category_id)
            if not category:
                self._log('warning', f'方向配置不存在: {category_id}')
                return None
            
            # 将中文难度转换为英文 key
            difficulty_map = {
                '入门': 'beginner',
                '简单': 'easy',
                '中等': 'medium',
                '困难': 'hard'
            }
            difficulty_key = difficulty_map.get(difficulty, 'beginner')
            
            # 获取全局系统提示（如果存在）
            system_prompt = None
            try:
                config_dict = category.to_dict(include_config=True)
                if 'advanced_config' in config_dict and isinstance(config_dict['advanced_config'], dict):
                    system_prompt = config_dict['advanced_config'].get('system_prompt')
            except:
                pass
            
            # 获取阶段配置（优先使用新格式）
            # 新格式：直接获取该难度的阶段列表
            stages = category.get_stages_by_difficulty(difficulty)
            knowledge_count = len(vulnerabilities) if isinstance(vulnerabilities, list) else 1
            merged_stages = []
            
            if stages:
                # 新格式：每个阶段直接包含 prompt，只需处理条件过滤
                for stage in stages:
                    stage_copy = {**stage}
                    
                    # 检查阶段条件（如 knowledge_count > 1）
                    condition = stage.get('condition')
                    if condition:
                        try:
                            if 'knowledge_count' in condition:
                                if '>' in condition:
                                    threshold = int(condition.split('>')[1].strip())
                                    if not (knowledge_count > threshold):
                                        self._log('debug', f'阶段 {stage.get("id")} 不满足条件 {condition}，跳过')
                                        continue
                                elif '>=' in condition:
                                    threshold = int(condition.split('>=')[1].strip())
                                    if not (knowledge_count >= threshold):
                                        self._log('debug', f'阶段 {stage.get("id")} 不满足条件 {condition}，跳过')
                                        continue
                                elif '<' in condition:
                                    threshold = int(condition.split('<')[1].strip())
                                    if not (knowledge_count < threshold):
                                        self._log('debug', f'阶段 {stage.get("id")} 不满足条件 {condition}，跳过')
                                        continue
                                elif '<=' in condition:
                                    threshold = int(condition.split('<=')[1].strip())
                                    if not (knowledge_count <= threshold):
                                        self._log('debug', f'阶段 {stage.get("id")} 不满足条件 {condition}，跳过')
                                        continue
                                elif '==' in condition:
                                    threshold = int(condition.split('==')[1].strip())
                                    if not (knowledge_count == threshold):
                                        self._log('debug', f'阶段 {stage.get("id")} 不满足条件 {condition}，跳过')
                                        continue
                        except Exception as e:
                            self._log('warning', f'解析阶段条件失败: {condition}, 错误: {str(e)}')
                            if not stage.get('required', False):
                                continue
                    
                    # 新格式中，prompt 是直接字段，将其作为 user_extension
                    if stage.get('prompt'):
                        prompt_content = stage['prompt']
                        if stage_copy.get('user_extension'):
                            stage_copy['user_extension'] = stage_copy['user_extension'] + '\n\n' + prompt_content
                        else:
                            stage_copy['user_extension'] = prompt_content
                    
                    # 检查阶段是否有内容
                    has_content = (
                        stage_copy.get('system_prompt', '').strip() or
                        stage_copy.get('user_extension', '').strip()
                    )
                    
                    if not has_content:
                        is_required = stage_copy.get('required', False) or stage_copy.get('skip_forbidden', False)
                        if not is_required:
                            self._log('debug', f'阶段 {stage_copy.get("id")} ({stage_copy.get("name")}) 没有内容且非必需，跳过')
                            continue
                    
                    merged_stages.append(stage_copy)
            else:
                # 降级到旧格式：从所有阶段中筛选（向后兼容）
                all_stages = category.get_stages()
                if not all_stages:
                    self._log('warning', f'方向 {category_id} 没有配置阶段')
                    return None
                
                for stage in all_stages:
                    stage_copy = {**stage}
                    
                    # 检查阶段条件
                    condition = stage.get('condition')
                    if condition:
                        try:
                            if 'knowledge_count' in condition:
                                if '>' in condition:
                                    threshold = int(condition.split('>')[1].strip())
                                    if not (knowledge_count > threshold):
                                        self._log('debug', f'阶段 {stage.get("id")} 不满足条件 {condition}，跳过')
                                        continue
                                elif '>=' in condition:
                                    threshold = int(condition.split('>=')[1].strip())
                                    if not (knowledge_count >= threshold):
                                        self._log('debug', f'阶段 {stage.get("id")} 不满足条件 {condition}，跳过')
                                        continue
                                elif '<' in condition:
                                    threshold = int(condition.split('<')[1].strip())
                                    if not (knowledge_count < threshold):
                                        self._log('debug', f'阶段 {stage.get("id")} 不满足条件 {condition}，跳过')
                                        continue
                                elif '<=' in condition:
                                    threshold = int(condition.split('<=')[1].strip())
                                    if not (knowledge_count <= threshold):
                                        self._log('debug', f'阶段 {stage.get("id")} 不满足条件 {condition}，跳过')
                                        continue
                                elif '==' in condition:
                                    threshold = int(condition.split('==')[1].strip())
                                    if not (knowledge_count == threshold):
                                        self._log('debug', f'阶段 {stage.get("id")} 不满足条件 {condition}，跳过')
                                        continue
                        except Exception as e:
                            self._log('warning', f'解析阶段条件失败: {condition}, 错误: {str(e)}')
                            if not stage.get('required', False):
                                continue
                    
                    # 旧格式：合并难度配置
                    if stage.get('difficulty_config') and difficulty_key in stage['difficulty_config']:
                        diff_config = stage['difficulty_config'][difficulty_key]
                        if diff_config.get('name'):
                            stage_copy['name'] = diff_config['name']
                        if diff_config.get('description') is not None:
                            stage_copy['description'] = diff_config['description']
                        if diff_config.get('output_format'):
                            stage_copy['output_format'] = diff_config['output_format']
                        if diff_config.get('system_prompt'):
                            stage_copy['system_prompt'] = diff_config['system_prompt']
                        if diff_config.get('user_extension'):
                            stage_copy['user_extension'] = diff_config['user_extension']
                    
                    # 旧格式：从 prompts 字典中提取
                    prompts = stage.get('prompts', {})
                    if isinstance(prompts, dict) and difficulty in prompts:
                        prompt_data = prompts[difficulty]
                        if prompt_data and isinstance(prompt_data, dict):
                            prompt_content = prompt_data.get('content', '')
                            if prompt_content:
                                if stage_copy.get('user_extension'):
                                    stage_copy['user_extension'] = stage_copy['user_extension'] + '\n\n' + prompt_content
                                else:
                                    stage_copy['user_extension'] = prompt_content
                    
                    # 处理知识库脚本说明
                    knowledge_instruction = None
                    if stage.get('difficulty_config') and difficulty_key in stage['difficulty_config']:
                        diff_config = stage['difficulty_config'][difficulty_key]
                        knowledge_instruction = diff_config.get('knowledge_script_instruction')
                    if not knowledge_instruction:
                        knowledge_instruction = stage.get('knowledge_script_instruction')
                    
                    if knowledge_instruction and knowledge_instruction.strip():
                        script_instruction = knowledge_instruction.strip()
                        if stage_copy.get('user_extension'):
                            stage_copy['user_extension'] = stage_copy['user_extension'] + '\n\n' + script_instruction
                        else:
                            stage_copy['user_extension'] = script_instruction
                    
                    # 检查阶段是否有内容
                    has_content = (
                        stage_copy.get('system_prompt', '').strip() or
                        stage_copy.get('user_extension', '').strip()
                    )
                    
                    if not has_content:
                        is_required = stage_copy.get('required', False) or stage_copy.get('skip_forbidden', False)
                        if not is_required:
                            self._log('debug', f'阶段 {stage_copy.get("id")} ({stage_copy.get("name")}) 没有内容且非必需，跳过')
                            continue
                    
                    merged_stages.append(stage_copy)
            
            # 构建上下文变量
            context = {
                'category': category.name,
                'difficulty': difficulty,
                'language': language,
                'knowledge_points': ', '.join(vulnerabilities) if isinstance(vulnerabilities, list) else vulnerabilities,
                'scene': scene.get('name') if isinstance(scene, dict) else scene,
                'available_languages': form_data.get('available_languages', 'PHP, Python, Node.js, Go, Java'),
                'flag_format': form_data.get('flag_format', 'DASCTF{...}'),
            }
            
            # 获取难度规则（优先使用新格式）
            difficulties = category.get_difficulties()
            if difficulties and difficulty in difficulties:
                # 新格式：从 difficulties[difficulty].rules 获取
                rules = difficulties[difficulty].get('rules', {})
                context['writeup_count'] = rules.get('writeup_count', 5)
                context['max_knowledge'] = rules.get('max_count', 5)
                context['depth_range'] = rules.get('depth_range', [1.5, 10.0])
                context['diff_rate'] = rules.get('diff_rate', 0.3)
                
                # 构建难度表格（从所有难度）
                difficulty_table_rows = []
                for diff_name, diff_cfg in difficulties.items():
                    diff_rules = diff_cfg.get('rules', {})
                    max_count = diff_rules.get('max_count', 0)
                    writeup_count = diff_rules.get('writeup_count', 0)
                    depth_range = diff_rules.get('depth_range', [])
                    depth_str = f"[{depth_range[0]}, {depth_range[1]}]" if depth_range else ""
                    difficulty_table_rows.append(f"| {diff_name} | {max_count} 个 | {writeup_count} 篇 | {depth_str} |")
                context['difficulty_table'] = '\n'.join(difficulty_table_rows)
            else:
                # 旧格式：从 difficulty_rules 获取
                difficulty_rules = category.get_difficulty_rules()
                for rule in difficulty_rules:
                    if rule.get('name') == difficulty:
                        context['writeup_count'] = rule.get('writeup_count', 5)
                        context['max_knowledge'] = rule.get('max_count', 5)
                        context['depth_range'] = rule.get('depth_range', [1.5, 10.0])
                        context['diff_rate'] = rule.get('diff_rate', 0.3)
                        break
                
                # 构建难度表格
                if difficulty_rules:
                    difficulty_table_rows = []
                    for rule in difficulty_rules:
                        name = rule.get('name', '')
                        max_count = rule.get('max_count', 0)
                        writeup_count = rule.get('writeup_count', 0)
                        depth_range = rule.get('depth_range', [])
                        depth_str = f"[{depth_range[0]}, {depth_range[1]}]" if depth_range else ""
                        difficulty_table_rows.append(f"| {name} | {max_count} 个 | {writeup_count} 篇 | {depth_str} |")
                    context['difficulty_table'] = '\n'.join(difficulty_table_rows)
            
            # 优先从数据库读取已编译的 Prompt 模板
            compiled_prompts = category.get_compiled_prompts()
            difficulty_map = {
                '入门': 'beginner',
                '简单': 'easy',
                '中等': 'medium',
                '困难': 'hard'
            }
            difficulty_key = difficulty_map.get(difficulty, 'beginner')
            
            if compiled_prompts and difficulty_key in compiled_prompts:
                # 使用已编译的模板，替换占位符
                from app.services.prompt.compiler_service import PromptCompilerService
                compiled_prompt = PromptCompilerService.replace_placeholders(
                    compiled_prompts[difficulty_key],
                    language,
                    vulnerabilities,
                    scene
                )
                self._log('info', f'已从数据库加载已编译的 Prompt 模板（难度: {difficulty}）')
            else:
                # 如果数据库中没有，则重新编译
                self._log('info', f'数据库中未找到已编译的 Prompt 模板，重新编译...')
                global_rules = system_prompt if system_prompt else None
                compiled_prompt = StagePromptGenerator.compile_full_prompt(
                    category_id=category_id,
                    stages=merged_stages,
                    context=context,
                    global_rules=global_rules
                )
            
            if compiled_prompt:
                # 保存过滤后的阶段列表，供后续使用
                self._filtered_stages = merged_stages
                self._log('info', f'已加载 Prompt，包含 {len(merged_stages)} 个阶段')
                return compiled_prompt
            else:
                self._log('warning', '从数据库加载的 Prompt 为空')
                return None
                
        except Exception as e:
            self._log('error', f'从数据库加载 Prompt 失败: {str(e)}')
            import traceback
            self._log('debug', traceback.format_exc())
            return None


def create_ai_service(user_id: int, **kwargs) -> AIService:
    """创建 AI 服务实例的便捷函数
    
    Args:
        user_id: 用户 ID
        **kwargs: 其他参数传递给 AIService
        
    Returns:
        AIService 实例
    """
    return AIService(user_id=user_id, **kwargs)


class ContinueConversationService:
    """继续对话服务
    
    支持任意 AI 提供商的继续对话功能
    """
    
    def __init__(self, user_id: int, workspace_dir: str, log_callback: Callable = None):
        self.user_id = user_id
        self.workspace_dir = Path(workspace_dir)
        self.log_callback = log_callback
        self.provider: Optional[BaseAIProvider] = None
        self.tool_executor: Optional[ToolExecutor] = None
        self.messages: List[Dict[str, Any]] = []
        self._init_provider()
        self._init_tool_executor()
        self._load_conversation_history()
    
    def _log(self, level: str, message: str):
        if self.log_callback:
            self.log_callback(level, message)
        log_func = getattr(logger, level, logger.info)
        log_func(f'[ContinueConversation] {message}')
    
    def _init_provider(self):
        try:
            self.provider = AIProviderFactory.create_for_user(self.user_id)
            if self.provider:
                self._log('info', f'已初始化 AI 提供商: {self.provider.provider_name}')
            else:
                self._log('warning', '未找到可用的 AI 配置')
        except Exception as e:
            self._log('error', f'初始化 AI 提供商失败: {str(e)}')
            self.provider = None
    
    def _init_tool_executor(self):
        # ContinueConversationService 的 workspace_dir 应该是题目输出目录
        # 允许写入该目录和 /tmp
        sandbox = SandboxConfig.create_for_ctf_generation(
            output_base_dir=str(self.workspace_dir),
            category_id=None  # ContinueConversationService 不需要 category_id，直接使用 workspace_dir
        )
        self.tool_executor = ToolExecutor(
            sandbox=sandbox,
            log_callback=lambda level, msg: self._log(level, f'[Tool] {msg}')
        )
    
    def _get_history_file(self) -> Path:
        return self.workspace_dir / '.conversation_history.json'
    
    def _load_conversation_history(self):
        history_file = self._get_history_file()
        if history_file.exists():
            try:
                with open(history_file, 'r', encoding='utf-8') as f:
                    data = json.load(f)
                    self.messages = data.get('messages', [])
                    self._log('info', f'已加载 {len(self.messages)} 条历史消息')
            except Exception as e:
                self._log('warning', f'加载对话历史失败: {e}')
                self.messages = []
        else:
            self.messages = []
    
    def _save_conversation_history(self):
        history_file = self._get_history_file()
        try:
            with open(history_file, 'w', encoding='utf-8') as f:
                json.dump({'messages': self.messages}, f, ensure_ascii=False, indent=2)
        except Exception as e:
            self._log('warning', f'保存对话历史失败: {e}')
    
    def _build_system_prompt(self) -> str:
        return f"""你是一个 CTF 题目完善助手。用户已经生成了一个 CTF 题目，现在需要你帮助完善。

工作目录: {self.workspace_dir}

## 工具
- run_command: 执行 shell 命令
- read_file: 读取文件
- write_file: 写入文件
- list_directory: 列出目录内容

## 工作流程（必须遵循）

在优化题目时，你必须按以下流程操作：

1. **先搭建 Docker 环境**：
   - 进入 docker 目录：`cd {self.workspace_dir}/docker`
   - 构建并启动容器：`docker-compose up -d --build`
   - 检查容器状态：`docker ps`

2. **测试验证**：
   - 检查服务是否正常运行：`curl http://localhost:端口`
   - 运行 exp.py 验证漏洞：`python3 {self.workspace_dir}/exp.py localhost 端口 DASCTF{{test12345}}`
   - 查看容器日志：`docker-compose logs`

3. **根据测试结果优化**：
   - 如果测试失败，分析错误原因并修复代码
   - 修复后重新构建并测试：`docker-compose down && docker-compose up -d --build`
   - 重复测试直到 exp.py 能成功获取 flag

4. **清理**：
   - 完成后停止容器：`docker-compose down`

## 重要提醒
- 不要只看代码就下结论，必须实际运行测试
- 每次修改代码后都要重新构建 Docker 并测试
- 确保 exp.py 能成功获取 flag 才算完成"""
    
    def continue_conversation(self, user_instruction: str) -> Dict[str, Any]:
        if not self.provider:
            return {'success': False, 'error': '未配置 AI 提供商'}
        
        if not self.messages:
            self.messages.append({'role': 'system', 'content': self._build_system_prompt()})
        
        self.messages.append({'role': 'user', 'content': user_instruction})
        
        tools = ToolDefinition.get_ctf_tools()
        max_iterations = 20
        full_response = ""
        
        for _ in range(max_iterations):
            try:
                response = self.provider.chat(self.messages, tools=tools)
            except Exception as e:
                self._log('error', f'AI 调用失败: {e}')
                return {'success': False, 'error': str(e)}
            
            content = response.get('content', '')
            tool_calls = response.get('tool_calls')
            
            if content:
                full_response += content
            
            assistant_msg = {'role': 'assistant', 'content': content}
            if tool_calls:
                assistant_msg['tool_calls'] = tool_calls
            self.messages.append(assistant_msg)
            
            if not tool_calls:
                break
            
            for tool_call in tool_calls:
                tool_result = self.tool_executor.execute_tool_call(tool_call)
                self.messages.append({
                    'role': 'tool',
                    'tool_call_id': tool_call.get('id', ''),
                    'name': tool_call.get('function', {}).get('name', ''),
                    'content': tool_result
                })
        
        self._save_conversation_history()
        
        return {
            'success': True,
            'response': full_response,
            'workspace_dir': str(self.workspace_dir)
        }
    
    def continue_conversation_stream(self, user_instruction: str):
        if not self.provider:
            yield "[错误] 未配置 AI 提供商"
            return
        
        # 记录用户消息（用于前端显示历史）
        self.messages.append({'role': 'user', 'content': user_instruction})
        
        # 检查 provider 是否支持流式输出（ClaudeRouterService）
        if hasattr(self.provider, 'chat_stream'):
            # 使用流式接口 - Claude CLI 会自动管理对话历史（通过 --continue）
            full_content = ""
            try:
                for chunk in self.provider.chat_stream(
                    self.messages, 
                    tools=None, 
                    workspace_dir=str(self.workspace_dir)
                ):
                    chunk_type = chunk.get('type')
                    if chunk_type == 'content':
                        content = chunk.get('content', '')
                        full_content += content
                        yield content
                    elif chunk_type == 'tool':
                        tool_name = chunk.get('name', '')
                        yield f"\n🔧 执行工具: {tool_name}\n"
                    elif chunk_type == 'error':
                        yield f"\n❌ 错误: {chunk.get('message', '')}\n"
                    elif chunk_type == 'finish':
                        pass
                
                # 保存助手回复（用于前端显示历史）
                if full_content:
                    self.messages.append({'role': 'assistant', 'content': full_content})
                    
            except Exception as e:
                self._log('error', f'AI 调用失败: {e}')
                yield f"[错误] {str(e)}"
                return
        else:
            # 使用非流式接口（其他 provider）
            if not any(m.get('role') == 'system' for m in self.messages):
                self.messages.insert(0, {'role': 'system', 'content': self._build_system_prompt()})
            
            tools = ToolDefinition.get_ctf_tools()
            max_iterations = 20
            
            for _ in range(max_iterations):
                try:
                    response = self.provider.chat(self.messages, tools=tools, workspace_dir=str(self.workspace_dir))
                except Exception as e:
                    self._log('error', f'AI 调用失败: {e}')
                    yield f"[错误] {str(e)}"
                    return
                
                content = response.get('content', '')
                tool_calls = response.get('tool_calls')
                
                if content:
                    yield content
                
                assistant_msg = {'role': 'assistant', 'content': content}
                if tool_calls:
                    assistant_msg['tool_calls'] = tool_calls
                self.messages.append(assistant_msg)
                
                if not tool_calls:
                    break
                
                for tool_call in tool_calls:
                    tool_name = tool_call.get('function', {}).get('name', '')
                    yield f"\n🔧 执行工具: {tool_name}\n"
                    
                    tool_result = self.tool_executor.execute_tool_call(tool_call)
                    self.messages.append({
                        'role': 'tool',
                        'tool_call_id': tool_call.get('id', ''),
                        'name': tool_name,
                        'content': tool_result
                    })
                    
                    if len(tool_result) > 500:
                        yield f"✅ 结果: {tool_result[:500]}...\n"
                    else:
                        yield f"✅ 结果: {tool_result}\n"
        
        # 保存对话历史到文件（用于前端显示）
        self._save_conversation_history()
    
    def get_conversation_history(self) -> List[Dict[str, str]]:
        history = []
        for msg in self.messages:
            if msg.get('role') == 'user':
                history.append({'role': 'user', 'content': msg.get('content', '')})
            elif msg.get('role') == 'assistant' and msg.get('content'):
                history.append({'role': 'assistant', 'content': msg.get('content', '')})
        return history
    
