"""
阶段 Prompt 生成器

根据方向配置自动生成每个阶段的固定 Prompt 内容。
- 阶段 0（用户输入需求）：根据表单字段自动生成固定内容
- 其他阶段：系统指令为空，由用户通过扩展填充
- 全局规则：可配置的禁止事项、关键信息保持等
"""
from typing import Dict, Any, List, Optional


class StagePromptGenerator:
    """阶段 Prompt 生成器"""
    
    # 全局规则模板（可由用户扩展覆盖）
    GLOBAL_RULES_TEMPLATE = '''## 🚫 禁止事项

**禁止词汇**：
- "由于篇幅/时间/上下文限制" "直接进入" "快速实现" "跳过" "省略"

**禁止行为**：
- 跳过任何阶段（特别是质量检查阶段）
- 只输出标题而无实际内容
- 通过降低难度来解决技术问题

**✅ 你有无限的时间和空间，直接执行，不要解释**

---

## ⚠️ 关键信息保持

1. **题目名称**：一旦确定，后续所有文件必须使用相同名称
2. **利用链设计**：设计阶段的步骤必须在代码生成阶段完整实现
3. **关键 payload**：设计的 payload 必须与代码中的过滤逻辑兼容
4. **输出目录**：所有文件保存到 `output/{YYYYMMDD_HHMMSS}_{题目名称}/`

---'''
    
    # Web 方向的阶段固定 Prompt 模板
    # 只有阶段 0 有固定内容，其他阶段为空（由用户扩展填充）
    WEB_STAGE_TEMPLATES = {
        '0': {
            'title': '用户输入需求',
            'system_prompt': '''**选择顺序**：语言 → 难度 → 漏洞 → 场景

**第一步：选择语言**
可选：{{available_languages}}

**第二步：选择难度**
| 难度 | 允许漏洞数量 | writeup 总数 | 深度范围 |
|------|-------------|-------------|---------|
{{difficulty_table}}

**第三步：选择漏洞**（根据难度限制数量）

**第四步：选择场景**（如留言板、博客、OA系统）

**示例**：
```
语言：{{language}}
难度：{{difficulty}}（最多 {{max_knowledge}} 个漏洞）
漏洞：{{knowledge_points}}
场景：{{scene}}
```'''
        },
        '0.5': {
            'title': '漏洞主次分类',
            'system_prompt': ''  # 由用户扩展填充
        },
        '1': {
            'title': '学习与知识提取',
            'system_prompt': ''  # 由用户扩展填充
        },
        '2': {
            'title': '知识整理',
            'system_prompt': ''  # 由用户扩展填充
        },
        '3': {
            'title': '题目设计',
            'system_prompt': ''  # 由用户扩展填充
        },
        '4': {
            'title': '质量检查',
            'system_prompt': ''  # 由用户扩展填充
        },
        '5': {
            'title': '代码生成',
            'system_prompt': ''  # 由用户扩展填充
        },
        '6': {
            'title': 'Docker 构建与测试',
            'system_prompt': ''  # 由用户扩展填充
        },
        '7': {
            'title': 'exp 和 writeup',
            'system_prompt': ''  # 由用户扩展填充
        },
        '8': {
            'title': '成品输出',
            'system_prompt': ''  # 由用户扩展填充
        }
    }
    
    # Crypto 方向的阶段固定 Prompt 模板
    CRYPTO_STAGE_TEMPLATES = {
        '0': {
            'title': '用户输入需求',
            'system_prompt': '''**选择顺序**：难度 → 算法类型

**第一步：选择难度**
| 难度 | 允许知识点数量 | writeup 总数 |
|------|---------------|-------------|
{{difficulty_table}}

**第二步：选择算法类型**（根据难度限制数量）

**示例**：
```
难度：{{difficulty}}（最多 {{max_knowledge}} 个知识点）
算法：{{knowledge_points}}
```'''
        },
        '1': {'title': '学习与知识提取', 'system_prompt': ''},
        '2': {'title': '知识整理', 'system_prompt': ''},
        '3': {'title': '题目设计', 'system_prompt': ''},
        '4': {'title': '质量检查', 'system_prompt': ''},
        '5': {'title': '代码生成', 'system_prompt': ''},
        '6': {'title': '验证测试', 'system_prompt': ''},
        '7': {'title': 'exp 和 writeup', 'system_prompt': ''}
    }
    
    # Pwn 方向的阶段固定 Prompt 模板
    PWN_STAGE_TEMPLATES = {
        '0': {
            'title': '用户输入需求',
            'system_prompt': '''**选择顺序**：难度 → 漏洞类型 → 保护机制

**第一步：选择难度**
| 难度 | 允许漏洞数量 | writeup 总数 |
|------|-------------|-------------|
{{difficulty_table}}

**第二步：选择漏洞类型**（根据难度限制数量）

**第三步：选择保护机制**（NX, ASLR, Canary, PIE 等）

**示例**：
```
难度：{{difficulty}}（最多 {{max_knowledge}} 个漏洞）
漏洞：{{knowledge_points}}
保护：{{protections}}
```'''
        },
        '1': {'title': '学习与知识提取', 'system_prompt': ''},
        '2': {'title': '知识整理', 'system_prompt': ''},
        '3': {'title': '题目设计', 'system_prompt': ''},
        '4': {'title': '质量检查', 'system_prompt': ''},
        '5': {'title': '代码生成', 'system_prompt': ''},
        '6': {'title': 'Docker 构建与测试', 'system_prompt': ''},
        '7': {'title': 'exp 和 writeup', 'system_prompt': ''}
    }
    
    # Reverse 方向的阶段固定 Prompt 模板
    REVERSE_STAGE_TEMPLATES = {
        '0': {
            'title': '用户输入需求',
            'system_prompt': '''**选择顺序**：难度 → 逆向类型 → 目标平台

**第一步：选择难度**
| 难度 | 允许知识点数量 | writeup 总数 |
|------|---------------|-------------|
{{difficulty_table}}

**第二步：选择逆向类型**（根据难度限制数量）

**第三步：选择目标平台**（Windows, Linux, Android 等）

**示例**：
```
难度：{{difficulty}}（最多 {{max_knowledge}} 个知识点）
类型：{{knowledge_points}}
平台：{{platform}}
```'''
        },
        '1': {'title': '学习与知识提取', 'system_prompt': ''},
        '2': {'title': '知识整理', 'system_prompt': ''},
        '3': {'title': '题目设计', 'system_prompt': ''},
        '4': {'title': '质量检查', 'system_prompt': ''},
        '5': {'title': '代码生成', 'system_prompt': ''},
        '6': {'title': '验证测试', 'system_prompt': ''},
        '7': {'title': 'exp 和 writeup', 'system_prompt': ''}
    }
    
    # 方向到模板的映射
    CATEGORY_TEMPLATES = {
        'web': WEB_STAGE_TEMPLATES,
        'crypto': CRYPTO_STAGE_TEMPLATES,
        'pwn': PWN_STAGE_TEMPLATES,
        'reverse': REVERSE_STAGE_TEMPLATES
    }
    
    @classmethod
    def get_stage_templates(cls, category_id: str) -> Dict[str, Dict]:
        """获取指定方向的阶段模板"""
        return cls.CATEGORY_TEMPLATES.get(category_id, cls.WEB_STAGE_TEMPLATES)
    
    @classmethod
    def generate_stage_prompt(cls, category_id: str, stage_id: str, 
                              config: Dict[str, Any] = None) -> Dict[str, str]:
        """
        生成指定阶段的 Prompt
        
        Args:
            category_id: 方向 ID
            stage_id: 阶段 ID
            config: 方向配置（用于变量替换）
            
        Returns:
            包含 title 和 system_prompt 的字典
        """
        templates = cls.get_stage_templates(category_id)
        template = templates.get(stage_id, {
            'title': f'阶段 {stage_id}',
            'system_prompt': '请完成本阶段的任务。'
        })
        
        return {
            'title': template.get('title', f'阶段 {stage_id}'),
            'system_prompt': template.get('system_prompt', '')
        }
    
    @classmethod
    def generate_all_stages(cls, category_id: str, stages: List[Dict], 
                           config: Dict[str, Any] = None) -> List[Dict]:
        """
        为所有阶段生成 Prompt
        
        Args:
            category_id: 方向 ID
            stages: 阶段列表
            config: 方向配置
            
        Returns:
            包含 system_prompt 和 user_extension 的阶段列表
        """
        result = []
        templates = cls.get_stage_templates(category_id)
        
        for stage in stages:
            stage_id = stage.get('id', '')
            template = templates.get(stage_id, {})
            
            result.append({
                'id': stage_id,
                'name': stage.get('name', template.get('title', f'阶段 {stage_id}')),
                'output': stage.get('output', ''),
                'required': stage.get('required', True),
                'condition': stage.get('condition'),
                'skip_forbidden': stage.get('skip_forbidden', False),
                'system_prompt': template.get('system_prompt', f"完成「{stage.get('name', '')}」阶段的任务。"),
                'user_extension': stage.get('user_extension', '')
            })
        
        return result
    
    @classmethod
    def compile_full_prompt(cls, category_id: str, stages: List[Dict],
                           context: Dict[str, Any] = None,
                           global_rules: str = None) -> str:
        """
        编译完整的 Prompt（固定部分 + 用户扩展）
        
        Args:
            category_id: 方向 ID
            stages: 包含 user_extension 的阶段列表
            context: 变量上下文
            global_rules: 自定义全局规则（为空则使用默认）
            
        Returns:
            完整的 Prompt 字符串
        """
        context = context or {}
        lines = []
        
        # 添加头部
        category_name = context.get('category', category_id.upper())
        lines.append(f"# 🎯 CTF {category_name} 题目设计专家系统\n")
        lines.append("你是一位经验丰富的 CTF 题目设计专家，擅长从已有题目中学习并创新设计。\n")
        
        # 添加全局规则
        rules = global_rules if global_rules else cls.GLOBAL_RULES_TEMPLATE
        lines.append(rules)
        lines.append("")
        
        # 添加任务流程表格
        lines.append("## 🧩 任务流程\n")
        lines.append("| 阶段 | 名称 | 关键输出 |")
        lines.append("|------|------|---------|")
        for stage in stages:
            stage_id = stage.get('id', '')
            name = stage.get('name', '')
            output = stage.get('output', '')
            required = '⚠️ 不可跳过' if stage.get('skip_forbidden') else ''
            lines.append(f"| {stage_id} | {name} | {output} {required} |")
        lines.append("")
        lines.append("**每个阶段开始时必须输出**：`阶段X：阶段名称`\n")
        lines.append("")
        lines.append("**注意**：阶段编号从1开始（阶段1、阶段2、阶段3...），不是从0开始。")
        lines.append("---\n")
        
        # 添加各阶段详细内容
        for idx, stage in enumerate(stages):
            stage_id = stage.get('id', '')
            name = stage.get('name', '')
            system_prompt = stage.get('system_prompt', '')
            user_extension = stage.get('user_extension', '')
            
            # 阶段标题
            lines.append(f"## 阶段 {stage_id}：{name}\n")
            
            # 系统固定内容（变量替换）
            if system_prompt:
                rendered = cls._render_variables(system_prompt, context)
                lines.append(rendered)
                lines.append("")
            
            # 用户扩展内容
            if user_extension:
                lines.append(user_extension)
                lines.append("")
            
            # 如果系统指令和用户扩展都为空，显示提示
            if not system_prompt and not user_extension:
                lines.append("*（请在管理后台添加此阶段的详细指导）*\n")
            
            # 如果是最后一个阶段，添加固定的结束标记要求
            is_last_stage = (idx == len(stages) - 1)
            if is_last_stage:
                lines.append("")
                lines.append("## ⚠️ 完成标记要求")
                lines.append("")
                lines.append("**完成所有任务后，必须输出以下结束标记（这是强制要求）**：")
                lines.append("")
                lines.append("```")
                lines.append("## [CTF_GENERATION_COMPLETE]")
                lines.append("```")
                lines.append("")
                lines.append("**重要说明**：")
                lines.append("- 只有输出上述标记后，系统才会认为生成已完成")
                lines.append("- 请在所有文件生成完成后，立即输出此标记")
                lines.append("- 不要使用其他格式，必须严格按照上述格式输出")
                lines.append("")
            
            lines.append("---\n")
        
        return "\n".join(lines)
    
    @classmethod
    def _render_variables(cls, template: str, context: Dict[str, Any]) -> str:
        """替换模板中的变量"""
        import re
        
        def replace_var(match):
            var_name = match.group(1)
            value = context.get(var_name, f'{{{{{var_name}}}}}')
            if isinstance(value, list):
                return ', '.join(str(v) for v in value)
            return str(value)
        
        return re.sub(r'\{\{(\w+)\}\}', replace_var, template)
