"""
Prompt 编译服务

提供自动编译和存储 Prompt 模板的功能
"""
from typing import Dict, Any, Optional
from app.models.database.models import CategoryConfig
from app.services.prompt.generator import StagePromptGenerator


class PromptCompilerService:
    """Prompt 编译服务"""

    @staticmethod
    def compile_and_save_prompts(category: CategoryConfig) -> bool:
        """编译并保存所有难度的 Prompt 模板
        
        Args:
            category: CategoryConfig 实例
            
        Returns:
            是否成功
        """
        try:
            # 获取系统提示词
            advanced = category.get_advanced_config() or {}
            system_prompt = advanced.get('system_prompt', '')
            
            # 构建难度映射（中文 -> 英文）
            difficulty_map = {
                '入门': 'beginner',
                '简单': 'easy',
                '中等': 'medium',
                '困难': 'hard'
            }
            
            compiled_prompts = {}
            
            # 优先使用新格式（difficulties）
            difficulties = category.get_difficulties()
            
            if difficulties:
                # 新格式：每个难度有独立的 stages 和 rules
                for difficulty_name, difficulty_key in difficulty_map.items():
                    if difficulty_name not in difficulties:
                        continue
                    
                    diff_config = difficulties[difficulty_name]
                    stages = diff_config.get('stages', [])
                    rules = diff_config.get('rules', {})
                    
                    if not stages:
                        continue
                    
                    # 构建上下文
                    context = {
                        'category': category.name,
                        'difficulty': difficulty_name,
                        'language': '{{language}}',
                        'knowledge_points': '{{knowledge_points}}',
                        'scene': '{{scene}}',
                        'available_languages': 'PHP, Python, Node.js, Go, Java',
                        'flag_format': 'DASCTF{...}',
                        'writeup_count': rules.get('writeup_count', 5),
                        'max_knowledge': rules.get('max_count', 5),
                        'depth_range': rules.get('depth_range', [1.5, 10.0]),
                        'diff_rate': rules.get('diff_rate', 0.3),
                    }
                    
                    # 构建难度表格（从所有难度规则）
                    difficulty_table_rows = []
                    for diff_name, diff_cfg in difficulties.items():
                        diff_rules = diff_cfg.get('rules', {})
                        max_count = diff_rules.get('max_count', 0)
                        writeup_count = diff_rules.get('writeup_count', 0)
                        depth_range = diff_rules.get('depth_range', [])
                        depth_str = f"[{depth_range[0]}, {depth_range[1]}]" if depth_range else ""
                        difficulty_table_rows.append(f"| {diff_name} | {max_count} 个 | {writeup_count} 篇 | {depth_str} |")
                    context['difficulty_table'] = '\n'.join(difficulty_table_rows)
                    
                    # 处理阶段：每个 stage 直接包含 prompt 字段
                    processed_stages = []
                    for stage in stages:
                        stage_copy = {**stage}
                        
                        # 新格式中，prompt 是直接字段，不再需要从 prompts 字典提取
                        # 如果 stage 有 prompt 字段，将其作为 user_extension
                        if stage.get('prompt'):
                            prompt_content = stage['prompt']
                            
                            # 清理旧格式：移除"阶段X末"相关的内容
                            import re
                            original_length = len(prompt_content)
                            
                            # 移除"阶段X末：输出XXX"格式的行（支持多种格式）
                            # 格式1: **阶段X末：输出XXX**
                            prompt_content = re.sub(r'^\s*\*\*?阶段\s*\d+\s*末[：:].*?\*\*?\s*$', '', prompt_content, flags=re.MULTILINE)
                            # 格式2: 阶段X末：输出XXX
                            prompt_content = re.sub(r'^\s*阶段\s*\d+\s*末[：:].*?\s*$', '', prompt_content, flags=re.MULTILINE)
                            # 格式3: 阶段X末：输出XXX（在行中）
                            prompt_content = re.sub(r'阶段\s*\d+\s*末[：:][^\n]*', '', prompt_content)
                            
                            # 移除包含"阶段X末"的整行（更彻底的清理）
                            lines = prompt_content.split('\n')
                            cleaned_lines = []
                            for line in lines:
                                if not re.search(r'阶段\s*\d+\s*末', line):
                                    cleaned_lines.append(line)
                            prompt_content = '\n'.join(cleaned_lines)
                            
                            # 清理多余的空行
                            prompt_content = re.sub(r'\n{3,}', '\n\n', prompt_content)
                            prompt_content = prompt_content.strip()
                            
                            if original_length != len(prompt_content):
                                import logging
                                logger = logging.getLogger(__name__)
                                logger.info(f"已清理阶段 {stage.get('id')} 的旧格式内容（{original_length} -> {len(prompt_content)} 字符）")
                            
                            if stage_copy.get('user_extension'):
                                stage_copy['user_extension'] = stage_copy['user_extension'] + '\n\n' + prompt_content
                            else:
                                stage_copy['user_extension'] = prompt_content
                        
                        processed_stages.append(stage_copy)
                    
                    # 编译 Prompt
                    compiled_prompt = StagePromptGenerator.compile_full_prompt(
                        category_id=category.id,
                        stages=processed_stages,
                        context=context,
                        global_rules=system_prompt if system_prompt else None
                    )
                    
                    if compiled_prompt:
                        compiled_prompts[difficulty_key] = compiled_prompt
            else:
                # 旧格式：从 stages + difficulty_rules 编译（向后兼容）
                stages = category.get_stages()
                if not stages:
                    return False
                
                difficulty_rules = category.get_difficulty_rules()
                
                # 为每个难度编译 Prompt
                for difficulty_name, difficulty_key in difficulty_map.items():
                    # 查找对应的难度规则
                    difficulty_rule = None
                    for rule in difficulty_rules:
                        if rule.get('name') == difficulty_name:
                            difficulty_rule = rule
                            break
                    
                    # 构建默认上下文
                    context = {
                        'category': category.name,
                        'difficulty': difficulty_name,
                        'language': '{{language}}',
                        'knowledge_points': '{{knowledge_points}}',
                        'scene': '{{scene}}',
                        'available_languages': 'PHP, Python, Node.js, Go, Java',
                        'flag_format': 'DASCTF{...}',
                    }
                    
                    # 添加难度规则到上下文
                    if difficulty_rule:
                        context['writeup_count'] = difficulty_rule.get('writeup_count', 5)
                        context['max_knowledge'] = difficulty_rule.get('max_count', 5)
                        context['depth_range'] = difficulty_rule.get('depth_range', [1.5, 10.0])
                        context['diff_rate'] = difficulty_rule.get('diff_rate', 0.3)
                    
                    # 构建难度表格
                    if difficulty_rules:
                        difficulty_table_rows = []
                        for rule in difficulty_rules:
                            name = rule.get('name', '')
                            max_count = rule.get('max_count', 0)
                            writeup_count = rule.get('writeup_count', 0)
                            depth_range = rule.get('depth_range', [])
                            depth_str = f"[{depth_range[0]}, {depth_range[1]}]" if depth_range else ""
                            difficulty_table_rows.append(f"| {name} | {max_count} 个 | {writeup_count} 篇 | {depth_str} |")
                        context['difficulty_table'] = '\n'.join(difficulty_table_rows)
                    
                    # 合并当前难度的阶段配置
                    merged_stages = []
                    for stage in stages:
                        stage_copy = {**stage}
                        
                        # 如果有 difficulty_config，使用当前难度的配置
                        if stage.get('difficulty_config') and difficulty_key in stage['difficulty_config']:
                            diff_config = stage['difficulty_config'][difficulty_key]
                            if diff_config.get('name'):
                                stage_copy['name'] = diff_config['name']
                            if diff_config.get('description') is not None:
                                stage_copy['description'] = diff_config['description']
                            if diff_config.get('output_format'):
                                stage_copy['output'] = diff_config['output_format']
                            if diff_config.get('system_prompt'):
                                stage_copy['system_prompt'] = diff_config['system_prompt']
                            if diff_config.get('user_extension'):
                                stage_copy['user_extension'] = diff_config['user_extension']
                        
                        # 获取当前难度的 prompt 内容（旧格式）
                        prompts = stage.get('prompts', {})
                        if isinstance(prompts, dict) and difficulty_name in prompts:
                            prompt_data = prompts[difficulty_name]
                            if prompt_data and isinstance(prompt_data, dict):
                                prompt_content = prompt_data.get('content', '')
                                if prompt_content:
                                    if stage_copy.get('user_extension'):
                                        stage_copy['user_extension'] = stage_copy['user_extension'] + '\n\n' + prompt_content
                                    else:
                                        stage_copy['user_extension'] = prompt_content
                        
                        merged_stages.append(stage_copy)
                    
                    # 编译 Prompt
                    compiled_prompt = StagePromptGenerator.compile_full_prompt(
                        category_id=category.id,
                        stages=merged_stages,
                        context=context,
                        global_rules=system_prompt if system_prompt else None
                    )
                    
                    if compiled_prompt:
                        compiled_prompts[difficulty_key] = compiled_prompt
            
            # 保存到数据库
            category.set_compiled_prompts(compiled_prompts)
            
            # 提交到数据库（重要：确保保存生效）
            from app.models.database.models import db
            try:
                db.session.commit()
                import logging
                logger = logging.getLogger(__name__)
                logger.info(f"✅ 已保存编译后的 Prompt 到数据库（方向: {category.id}）")
            except Exception as e:
                import logging
                logger = logging.getLogger(__name__)
                logger.error(f"❌ 保存编译后的 Prompt 到数据库失败: {str(e)}")
                db.session.rollback()
                return False
            
            return True
            
        except Exception as e:
            import logging
            logging.error(f"编译 Prompt 失败: {str(e)}")
            import traceback
            logging.error(traceback.format_exc())
            return False

    @staticmethod
    def replace_placeholders(prompt_template: str, language: str, vulnerabilities: list, scene: str) -> str:
        """替换 Prompt 模板中的占位符
        
        Args:
            prompt_template: Prompt 模板（包含占位符）
            language: 编程语言
            vulnerabilities: 漏洞类型列表
            scene: 场景（可以是字符串、字典或None）
            
        Returns:
            替换后的 Prompt
        """
        # 替换占位符
        prompt = prompt_template
        
        # 处理 language：确保是字符串
        language_str = str(language) if language else ''
        prompt = prompt.replace('{{language}}', language_str)
        
        # 处理 vulnerabilities：确保是字符串
        if isinstance(vulnerabilities, list):
            vuln_str = ', '.join(str(v) for v in vulnerabilities)
        else:
            vuln_str = str(vulnerabilities) if vulnerabilities else ''
        prompt = prompt.replace('{{knowledge_points}}', vuln_str)
        
        # 处理 scene：确保是字符串，如果是None或空值则使用空字符串
        if scene is None:
            scene_str = ''
        elif isinstance(scene, dict):
            scene_str = str(scene.get('name', '')) if scene.get('name') else ''
        else:
            scene_str = str(scene) if scene else ''
        prompt = prompt.replace('{{scene}}', scene_str)
        
        return prompt

