"""
Prompt 模板编译器

将模板化的 Prompt 编译为最终的 AI 指导文件。
支持变量插值、条件渲染、循环渲染等功能。
"""
import re
import json
import os
from typing import Dict, Any, List, Optional
from datetime import datetime

from app.models.database.models import CategoryConfig


class PromptCompiler:
    """Prompt 模板编译器"""

    # 系统变量列表
    SYSTEM_VARIABLES = [
        'category', 'difficulty', 'language', 'knowledge_points',
        'scene', 'extra_requirements', 'writeup_count', 'timestamp',
        'max_knowledge', 'depth_range', 'diff_rate'
    ]

    def __init__(self, base_path: str = None):
        """
        初始化编译器
        
        Args:
            base_path: Prompt 模板基础路径
        """
        self.base_path = base_path or os.path.join(os.path.dirname(__file__), '../../../ge10')

    def compile(self, category_id: str, user_input: Dict[str, Any] = None) -> str:
        """
        编译完整的 Prompt（已废弃，请使用 compile_from_template 或 PromptCompilerService）
        
        Args:
            category_id: 方向 ID
            user_input: 用户输入的参数（可选，用于预览）
            
        Returns:
            编译后的完整 Prompt 字符串
            
        Note:
            此方法已废弃，prompt_templates 表已不再使用。
            请使用 compile_from_template() 方法并提供 template_content 参数，
            或使用 PromptCompilerService.compile_and_save_prompts() 方法。
        """
        raise NotImplementedError(
            'PromptCompiler.compile() 方法已废弃，prompt_templates 表已不再使用。'
            '请使用 compile_from_template() 方法并提供 template_content 参数，'
            '或使用 PromptCompilerService.compile_and_save_prompts() 方法。'
        )

    def compile_from_template(self, template_content: str, category_id: str, 
                              user_input: Dict[str, Any] = None) -> str:
        """
        从模板内容编译 Prompt（用于预览）
        
        Args:
            template_content: 模板内容
            category_id: 方向 ID
            user_input: 用户输入的参数
            
        Returns:
            编译后的 Prompt 字符串
        """
        category = CategoryConfig.query.get(category_id)
        if not category:
            raise ValueError(f"方向 '{category_id}' 不存在")

        context = self._build_context(category, user_input)
        return self._render_template(template_content, context)

    def validate_template(self, template_content: str, category_config: Dict[str, Any]) -> Dict[str, Any]:
        """
        验证模板与配置的一致性
        
        Args:
            template_content: 模板内容
            category_config: 方向配置
            
        Returns:
            验证结果，包含 valid、errors、warnings
        """
        errors = []
        warnings = []

        # 1. 提取模板中的变量
        template_variables = self._extract_variables(template_content)

        # 2. 提取模板中的阶段
        template_stages = self._extract_stages(template_content)

        # 3. 验证阶段数量
        config_stages = category_config.get('stages', [])
        if len(config_stages) != len(template_stages):
            errors.append(
                f"阶段数量不一致: 配置 {len(config_stages)} 个, Prompt {len(template_stages)} 个"
            )

        # 4. 验证变量引用
        config_fields = [f['id'] for f in category_config.get('form_fields', [])]
        for var in template_variables:
            if var not in config_fields and var not in self.SYSTEM_VARIABLES:
                if not var.startswith('category.') and not var.startswith('output.'):
                    warnings.append(f"Prompt 引用了未定义的变量: {var}")

        # 5. 验证难度规则
        config_difficulties = [d['name'] for d in category_config.get('difficulty_rules', [])]
        template_difficulties = self._extract_difficulties(template_content)
        if template_difficulties and set(config_difficulties) != set(template_difficulties):
            warnings.append(
                f"难度级别可能不一致: 配置 {config_difficulties}, Prompt 中发现 {template_difficulties}"
            )

        # 6. 验证输出结构
        output_config = category_config.get('output_config', {})
        if output_config.get('docker') and 'Docker' not in template_content and 'docker' not in template_content:
            warnings.append("配置需要 Docker 输出但 Prompt 中未提及 Docker")

        return {
            'valid': len(errors) == 0,
            'errors': errors,
            'warnings': warnings,
            'stats': {
                'variables_count': len(template_variables),
                'stages_count': len(template_stages),
                'total_lines': len(template_content.split('\n')),
                'total_chars': len(template_content)
            }
        }

    def _build_context(self, category: CategoryConfig, user_input: Dict[str, Any] = None) -> Dict[str, Any]:
        """构建渲染上下文"""
        context = {
            'category': {
                'id': category.id,
                'name': category.name,
                'icon': category.icon,
                'description': category.description,
                'knowledge_label': self._get_knowledge_label(category.id),
                'has_language': self._has_field(category, 'language'),
                'has_scene': self._has_field(category, 'scene'),
                'languages': self._get_languages(category),
                'no_comments': True  # 默认禁止注释
            },
            'stages': category.get_stages(),
            'difficulty_rules': category.get_difficulty_rules(),
            'output': category.get_output_config(),
            'timestamp': datetime.now().strftime('%Y%m%d_%H%M%S')
        }

        # 合并用户输入
        if user_input:
            context.update(user_input)

            # 根据难度计算相关值
            difficulty = user_input.get('difficulty')
            if difficulty:
                rule = self._get_difficulty_rule(category, difficulty)
                if rule:
                    context['writeup_count'] = rule.get('writeup_count', 5)
                    context['max_knowledge'] = rule.get('max_count', 5)
                    context['depth_range'] = rule.get('depth_range', [1.5, 10.0])
                    context['diff_rate'] = rule.get('diff_rate', 0.3)

        return context

    def _render_template(self, template: str, context: Dict[str, Any]) -> str:
        """渲染模板"""
        result = template

        # 1. 处理条件块 {{#if ...}} ... {{/if}}
        result = self._render_conditionals(result, context)

        # 2. 处理循环块 {{#each ...}} ... {{/each}}
        result = self._render_loops(result, context)

        # 3. 处理变量插值 {{variable}}
        result = self._render_variables(result, context)

        return result

    def _render_conditionals(self, template: str, context: Dict[str, Any]) -> str:
        """处理条件渲染"""
        # 匹配 {{#if condition}} ... {{/if}}
        pattern = r'\{\{#if\s+([^}]+)\}\}(.*?)\{\{/if\}\}'

        def replace_conditional(match):
            condition = match.group(1).strip()
            content = match.group(2)

            # 评估条件
            if self._evaluate_condition(condition, context):
                return content
            else:
                return ''

        # 使用 DOTALL 标志使 . 匹配换行符
        result = re.sub(pattern, replace_conditional, template, flags=re.DOTALL)
        return result

    def _render_loops(self, template: str, context: Dict[str, Any]) -> str:
        """处理循环渲染"""
        # 匹配 {{#each items}} ... {{/each}}
        pattern = r'\{\{#each\s+([^}]+)\}\}(.*?)\{\{/each\}\}'

        def replace_loop(match):
            items_key = match.group(1).strip()
            item_template = match.group(2)

            items = self._get_value(items_key, context)
            if not items or not isinstance(items, list):
                return ''

            result_parts = []
            for i, item in enumerate(items):
                # 创建循环上下文
                loop_context = {**context, 'this': item, 'index': i}
                rendered = self._render_variables(item_template, loop_context)
                result_parts.append(rendered)

            return ''.join(result_parts)

        result = re.sub(pattern, replace_loop, template, flags=re.DOTALL)
        return result

    def _render_variables(self, template: str, context: Dict[str, Any]) -> str:
        """处理变量插值"""
        # 匹配 {{variable}} 或 {{object.property}}
        pattern = r'\{\{([^#/][^}]*)\}\}'

        def replace_variable(match):
            var_expr = match.group(1).strip()

            # 处理过滤器，如 {{items | join: ", "}}
            if '|' in var_expr:
                parts = var_expr.split('|')
                var_name = parts[0].strip()
                filter_expr = parts[1].strip()
                value = self._get_value(var_name, context)
                return self._apply_filter(value, filter_expr)

            # 处理三元表达式，如 {{condition ? "yes" : "no"}}
            if '?' in var_expr and ':' in var_expr:
                return self._evaluate_ternary(var_expr, context)

            value = self._get_value(var_expr, context)
            if value is None:
                return match.group(0)  # 保留原始占位符
            return str(value)

        result = re.sub(pattern, replace_variable, template)
        return result

    def _get_value(self, key: str, context: Dict[str, Any]) -> Any:
        """从上下文获取值，支持点号访问"""
        if not key:
            return None

        parts = key.split('.')
        value = context

        for part in parts:
            if isinstance(value, dict):
                value = value.get(part)
            elif isinstance(value, list) and part.isdigit():
                index = int(part)
                value = value[index] if index < len(value) else None
            elif hasattr(value, part):
                value = getattr(value, part)
            else:
                return None

            if value is None:
                return None

        return value

    def _evaluate_condition(self, condition: str, context: Dict[str, Any]) -> bool:
        """评估条件表达式"""
        condition = condition.strip()

        # 处理 not 前缀
        if condition.startswith('not '):
            return not self._evaluate_condition(condition[4:], context)

        # 处理比较运算符
        for op in ['>=', '<=', '!=', '==', '>', '<']:
            if op in condition:
                left, right = condition.split(op, 1)
                left_val = self._get_value(left.strip(), context)
                right_val = right.strip()

                # 尝试转换为数字
                try:
                    left_val = float(left_val) if left_val else 0
                    right_val = float(right_val)
                except (ValueError, TypeError):
                    pass

                if op == '==':
                    return left_val == right_val
                elif op == '!=':
                    return left_val != right_val
                elif op == '>':
                    return left_val > right_val
                elif op == '<':
                    return left_val < right_val
                elif op == '>=':
                    return left_val >= right_val
                elif op == '<=':
                    return left_val <= right_val

        # 简单的真值检查
        value = self._get_value(condition, context)
        if value is None:
            return False
        if isinstance(value, bool):
            return value
        if isinstance(value, (list, dict, str)):
            return len(value) > 0
        return bool(value)

    def _evaluate_ternary(self, expr: str, context: Dict[str, Any]) -> str:
        """评估三元表达式"""
        # 格式: condition ? "true_value" : "false_value"
        match = re.match(r'(.+?)\s*\?\s*["\'](.+?)["\']\s*:\s*["\'](.+?)["\']', expr)
        if match:
            condition = match.group(1).strip()
            true_val = match.group(2)
            false_val = match.group(3)

            if self._evaluate_condition(condition, context):
                return true_val
            else:
                return false_val

        return expr

    def _apply_filter(self, value: Any, filter_expr: str) -> str:
        """应用过滤器"""
        if not value:
            return ''

        # 解析过滤器名称和参数
        parts = filter_expr.split(':')
        filter_name = parts[0].strip()
        filter_arg = parts[1].strip().strip('"\'') if len(parts) > 1 else ''

        if filter_name == 'join':
            if isinstance(value, list):
                return filter_arg.join(str(v) for v in value)
        elif filter_name == 'length':
            return str(len(value))
        elif filter_name == 'upper':
            return str(value).upper()
        elif filter_name == 'lower':
            return str(value).lower()

        return str(value)

    def _extract_variables(self, template: str) -> List[str]:
        """从模板中提取变量名"""
        pattern = r'\{\{([^#/][^}|]*?)(?:\|[^}]*)?\}\}'
        matches = re.findall(pattern, template)
        variables = set()
        for match in matches:
            var = match.strip()
            # 排除三元表达式
            if '?' not in var:
                # 提取主变量名
                if '.' in var:
                    var = var.split('.')[0]
                variables.add(var)
        return list(variables)

    def _extract_stages(self, template: str) -> List[str]:
        """从模板中提取阶段定义"""
        # 匹配 "## 阶段 X：" 或 "阶段X：" 格式
        pattern = r'##?\s*阶段\s*(\d+(?:\.\d+)?)[：:]'
        matches = re.findall(pattern, template)
        return list(set(matches))

    def _extract_difficulties(self, template: str) -> List[str]:
        """从模板中提取难度级别"""
        # 匹配常见的难度名称
        difficulties = []
        for diff in ['入门', '简单', '中等', '困难', '地狱']:
            if diff in template:
                difficulties.append(diff)
        return difficulties

    def _get_knowledge_label(self, category_id: str) -> str:
        """获取知识点标签"""
        labels = {
            'web': '漏洞',
            'crypto': '算法',
            'reverse': '技术',
            'pwn': '漏洞',
            'misc': '技能'
        }
        return labels.get(category_id, '知识点')

    def _has_field(self, category: CategoryConfig, field_id: str) -> bool:
        """检查方向是否有指定字段"""
        form_fields = category.get_form_fields()
        return any(f.get('id') == field_id for f in form_fields)

    def _get_languages(self, category: CategoryConfig) -> List[str]:
        """获取方向支持的语言列表"""
        form_fields = category.get_form_fields()
        for field in form_fields:
            if field.get('id') == 'language':
                options = field.get('options', [])
                if options:
                    return [o.get('value') or o.get('label') or o for o in options]
        return ['Python', 'PHP', 'Node.js', 'Java', 'Go']

    def _get_difficulty_rule(self, category: CategoryConfig, difficulty: str) -> Optional[Dict[str, Any]]:
        """获取指定难度的规则"""
        rules = category.get_difficulty_rules()
        for rule in rules:
            if rule.get('name') == difficulty:
                return rule
        return None

