# -*- coding: utf-8 -*-
"""
CTF 方向配置相关模型

包含方向配置、知识库、Prompt 模板等相关模型
"""

from .base import db, get_beijing_now
from .user import User, Role

class CategoryConfig(db.Model):
    """CTF 方向配置模型
    
    用于存储不同 CTF 方向（Web、Crypto、Reverse、Pwn、Misc）的配置信息，
    包括表单字段、生成阶段、Prompt 模板等。
    """
    __tablename__ = 'category_configs'

    id = db.Column(db.String(50), primary_key=True)  # web, crypto, reverse, pwn, misc
    name = db.Column(db.String(100), nullable=False)  # 显示名称
    icon = db.Column(db.String(50), default='folder')  # FontAwesome 图标名
    description = db.Column(db.Text)  # 方向描述
    enabled = db.Column(db.Boolean, default=False)  # 是否启用
    sort_order = db.Column(db.Integer, default=0)  # 排序顺序

    # 表单配置 (JSON)
    form_fields = db.Column(db.Text)  # 表单字段配置
    form_layout = db.Column(db.Text)  # 表单布局配置

    # 阶段配置 (JSON)
    stages = db.Column(db.Text)  # 生成阶段配置（旧格式，保留用于兼容）

    # 难度规则配置 (JSON)
    difficulty_rules = db.Column(db.Text)  # 难度规则（旧格式，保留用于兼容）

    # 难度独立配置 (JSON) - 新格式，每个难度有独立的 stages 和 rules
    difficulties = db.Column(db.Text)  # 格式: {"入门": {"stages": [...], "rules": {...}}, ...}

    # 输出配置 (JSON)
    output_config = db.Column(db.Text)  # 输出结构配置

    # UI 定制配置 (JSON)
    ui_config = db.Column(db.Text)  # UI 定制（颜色、文案等）

    # 高级设置 (JSON)
    advanced_config = db.Column(db.Text)  # 高级设置（超时、重试等）

    # 编译后的 Prompt 模板 (JSON，按难度分别存储)
    compiled_prompts = db.Column(db.Text)  # 存储编译后的 prompt 模板，格式: {"beginner": "...", "easy": "...", "medium": "...", "hard": "..."}

    # 文件路径配置
    prompt_template_path = db.Column(db.String(255))  # Prompt 模板路径
    knowledge_base_path = db.Column(db.String(255))  # 知识库路径
    knowledge_db_path = db.Column(db.String(255))  # 知识点数据库路径
    choice_script_path = db.Column(db.String(255))  # choice.py 脚本路径
    output_dir = db.Column(db.String(255))  # 题目生成输出目录路径

    # 时间戳
    created_at = db.Column(db.DateTime, default=get_beijing_now)
    updated_at = db.Column(db.DateTime, default=get_beijing_now, onupdate=get_beijing_now)

    # 关联（PromptTemplate 已移除，不再使用 prompt_templates 表）

    def to_dict(self, include_config=True):
        """转换为字典"""
        import json
        data = {
            'id': self.id,
            'name': self.name,
            'icon': self.icon,
            'description': self.description,
            'enabled': self.enabled,
            'sort_order': self.sort_order,
            'prompt_template_path': self.prompt_template_path,
            'knowledge_base_path': self.knowledge_base_path,
            'knowledge_db_path': self.knowledge_db_path,
            'choice_script_path': self.choice_script_path,
            'output_dir': self.output_dir,
            'created_at': self.created_at.isoformat() if self.created_at else None,
            'updated_at': self.updated_at.isoformat() if self.updated_at else None
        }

        if include_config:
            data['form_fields'] = self.get_form_fields()
            data['form_layout'] = self.get_form_layout()
            data['stages'] = self.get_stages()
            data['difficulty_rules'] = self.get_difficulty_rules()
            # 添加新格式的 difficulties
            data['difficulties'] = self.get_difficulties()
            data['output_config'] = self.get_output_config()
            data['ui_config'] = self.get_ui_config()
            data['advanced_config'] = self.get_advanced_config()

        return data

    def get_form_fields(self):
        """获取表单字段配置"""
        import json
        try:
            return json.loads(self.form_fields) if self.form_fields else []
        except:
            return []

    def set_form_fields(self, fields):
        """设置表单字段配置"""
        import json
        self.form_fields = json.dumps(fields, ensure_ascii=False)

    def get_form_layout(self):
        """获取表单布局配置"""
        import json
        try:
            return json.loads(self.form_layout) if self.form_layout else {}
        except:
            return {}

    def set_form_layout(self, layout):
        """设置表单布局配置"""
        import json
        self.form_layout = json.dumps(layout, ensure_ascii=False)

    def get_stages(self):
        """获取阶段配置"""
        import json
        try:
            return json.loads(self.stages) if self.stages else []
        except:
            return []

    def set_stages(self, stages):
        """设置阶段配置"""
        import json
        self.stages = json.dumps(stages, ensure_ascii=False)

    def get_difficulty_rules(self):
        """获取难度规则"""
        import json
        try:
            return json.loads(self.difficulty_rules) if self.difficulty_rules else []
        except:
            return []

    def set_difficulty_rules(self, rules):
        """设置难度规则"""
        import json
        self.difficulty_rules = json.dumps(rules, ensure_ascii=False)

    def get_output_config(self):
        """获取输出配置"""
        import json
        try:
            return json.loads(self.output_config) if self.output_config else {}
        except:
            return {}

    def set_output_config(self, config):
        """设置输出配置"""
        import json
        self.output_config = json.dumps(config, ensure_ascii=False)

    def get_ui_config(self):
        """获取 UI 配置"""
        import json
        try:
            return json.loads(self.ui_config) if self.ui_config else {}
        except:
            return {}

    def set_ui_config(self, config):
        """设置 UI 配置"""
        import json
        self.ui_config = json.dumps(config, ensure_ascii=False)

    def get_advanced_config(self):
        """获取高级配置"""
        import json
        try:
            return json.loads(self.advanced_config) if self.advanced_config else {}
        except:
            return {}

    def set_advanced_config(self, config):
        """设置高级配置"""
        import json
        self.advanced_config = json.dumps(config, ensure_ascii=False)

    def get_compiled_prompts(self):
        """获取编译后的 Prompt 模板"""
        import json
        try:
            return json.loads(self.compiled_prompts) if self.compiled_prompts else {}
        except:
            return {}

    def set_compiled_prompts(self, prompts_dict):
        """设置编译后的 Prompt 模板
        
        Args:
            prompts_dict: 字典，格式: {"beginner": "...", "easy": "...", "medium": "...", "hard": "..."}
        """
        import json
        self.compiled_prompts = json.dumps(prompts_dict, ensure_ascii=False) if prompts_dict else None

    def get_difficulties(self):
        """获取难度独立配置（新格式）
        
        Returns:
            dict: {"入门": {"stages": [...], "rules": {...}}, ...}
            如果 difficulties 字段存在，直接返回
            否则从旧格式 stages + difficulty_rules 自动转换（只读转换，不写回数据库）
        """
        import json
        try:
            # 优先返回新格式
            if self.difficulties:
                return json.loads(self.difficulties)
        except:
            pass
        
        # 如果新格式不存在，从旧格式转换
        return self._convert_old_to_new_format()
    
    def _convert_old_to_new_format(self):
        """从旧格式（stages + difficulty_rules）转换为新格式（difficulties）
        
        注意：这是只读转换，不会写回数据库
        """
        difficulties = {}
        stages = self.get_stages() or []
        rules = self.get_difficulty_rules() or []
        
        # 如果没有任何配置，返回空字典
        if not stages and not rules:
            return difficulties
        
        # 遍历难度规则，为每个难度创建配置
        for rule in rules:
            if not isinstance(rule, dict) or 'name' not in rule:
                continue
                
            diff_name = rule['name']
            # 提取规则（排除 name 字段）
            diff_rules = {k: v for k, v in rule.items() if k != 'name'}
            
            # 为该难度创建阶段列表
            diff_stages = []
            for stage in stages:
                if not isinstance(stage, dict):
                    continue
                    
                stage_copy = dict(stage)  # 深拷贝
                
                # 从 stage.prompts[diff_name] 提取 prompt
                if 'prompts' in stage_copy and isinstance(stage_copy['prompts'], dict):
                    if diff_name in stage_copy['prompts']:
                        prompt_data = stage_copy['prompts'][diff_name]
                        if isinstance(prompt_data, dict):
                            # 如果是字典，提取 content 字段作为 prompt
                            stage_copy['prompt'] = prompt_data.get('content', '')
                        elif isinstance(prompt_data, str):
                            # 如果是字符串，直接使用
                            stage_copy['prompt'] = prompt_data
                    # 移除 prompts 字段，因为新格式中 prompt 是直接字段
                    del stage_copy['prompts']
                
                # 从 difficulty_config 中提取当前难度的配置
                if 'difficulty_config' in stage_copy and isinstance(stage_copy['difficulty_config'], dict):
                    # 尝试中文难度名和英文键名
                    diff_config = stage_copy['difficulty_config'].get(diff_name) or \
                                  stage_copy['difficulty_config'].get('beginner') or \
                                  stage_copy['difficulty_config'].get('easy') or \
                                  stage_copy['difficulty_config'].get('medium') or \
                                  stage_copy['difficulty_config'].get('hard')
                    
                    if diff_config and isinstance(diff_config, dict):
                        # 合并配置
                        stage_copy.update({k: v for k, v in diff_config.items() if k != 'prompts'})
                        if 'prompt' in diff_config:
                            stage_copy['prompt'] = diff_config['prompt']
                    
                    # 移除 difficulty_config 字段
                    del stage_copy['difficulty_config']
                
                diff_stages.append(stage_copy)
            
            difficulties[diff_name] = {
                'stages': diff_stages,
                'rules': diff_rules
            }
        
        # 如果没有难度规则，但有阶段，创建一个默认配置
        if not difficulties and stages:
            # 使用默认难度名称
            default_difficulties = ['入门', '简单', '中等', '困难']
            for diff_name in default_difficulties:
                # 为每个默认难度创建配置
                diff_stages = []
                for stage in stages:
                    if not isinstance(stage, dict):
                        continue
                    stage_copy = dict(stage)
                    # 尝试从 prompts 中提取
                    if 'prompts' in stage_copy and isinstance(stage_copy['prompts'], dict):
                        if diff_name in stage_copy['prompts']:
                            prompt_data = stage_copy['prompts'][diff_name]
                            if isinstance(prompt_data, dict):
                                stage_copy['prompt'] = prompt_data.get('content', '')
                            elif isinstance(prompt_data, str):
                                stage_copy['prompt'] = prompt_data
                        del stage_copy['prompts']
                    diff_stages.append(stage_copy)
                
                if diff_stages:  # 只有当有阶段时才添加
                    difficulties[diff_name] = {
                        'stages': diff_stages,
                        'rules': {}
                    }
        
        return difficulties
    
    def set_difficulties(self, difficulties):
        """设置难度独立配置（新格式）
        
        Args:
            difficulties: dict, 格式: {"入门": {"stages": [...], "rules": {...}}, ...}
        """
        import json
        self.difficulties = json.dumps(difficulties, ensure_ascii=False) if difficulties else None
    
    def get_stages_by_difficulty(self, difficulty: str):
        """获取指定难度的阶段列表
        
        Args:
            difficulty: str, 难度名称（中文：入门/简单/中等/困难）
            
        Returns:
            list: 该难度的阶段列表，如果难度不存在则返回空列表
        """
        difficulties = self.get_difficulties()
        if difficulty in difficulties and isinstance(difficulties[difficulty], dict):
            return difficulties[difficulty].get('stages', [])
        return []
    
    def get_rules_by_difficulty(self, difficulty: str):
        """获取指定难度的规则
        
        Args:
            difficulty: str, 难度名称（中文：入门/简单/中等/困难）
            
        Returns:
            dict: 该难度的规则，如果难度不存在则返回空字典
        """
        difficulties = self.get_difficulties()
        if difficulty in difficulties and isinstance(difficulties[difficulty], dict):
            return difficulties[difficulty].get('rules', {})
        return {}

    @staticmethod
    def get_enabled_categories():
        """获取所有启用的方向"""
        return CategoryConfig.query.filter_by(enabled=True).order_by(CategoryConfig.sort_order).all()

    @staticmethod
    def get_all_categories():
        """获取所有方向"""
        return CategoryConfig.query.order_by(CategoryConfig.sort_order).all()


# PromptTemplate 模型已移除，不再使用 prompt_templates 表
# Prompt 模板现在存储在 CategoryConfig.compiled_prompts 字段中（JSON格式）

class CategoryAdmin(db.Model):
    """方向管理员模型
    
    用于分配不同用户管理不同方向的权限。
    """
    __tablename__ = 'category_admins'

    id = db.Column(db.Integer, primary_key=True)
    category_id = db.Column(db.String(50), db.ForeignKey('category_configs.id', ondelete='CASCADE'), nullable=False)
    user_id = db.Column(db.Integer, db.ForeignKey('users.id', ondelete='CASCADE'), nullable=False)
    role = db.Column(db.String(50), default='editor')  # admin, editor
    created_at = db.Column(db.DateTime, default=get_beijing_now)

    # 唯一约束
    __table_args__ = (
        db.UniqueConstraint('category_id', 'user_id', name='uq_category_user'),
    )

    # 关联
    category_config = db.relationship('CategoryConfig', backref=db.backref('admins', lazy='dynamic'))
    user = db.relationship('User', backref=db.backref('managed_categories', lazy='dynamic'))

    def to_dict(self):
        return {
            'id': self.id,
            'category_id': self.category_id,
            'category_name': self.category_config.name if self.category_config else None,
            'user_id': self.user_id,
            'username': self.user.username if self.user else None,
            'role': self.role,
            'created_at': self.created_at.isoformat() if self.created_at else None
        }

    @staticmethod
    def can_manage(user_id, category_id):
        """检查用户是否可以管理指定方向"""
        # 超级管理员可以管理所有方向
        user = User.query.get(user_id)
        if user and user.role == Role.ADMIN:
            return True
        
        # 检查是否为方向管理员
        admin = CategoryAdmin.query.filter_by(
            user_id=user_id,
            category_id=category_id
        ).first()
        return admin is not None

    @staticmethod
    def get_user_categories(user_id):
        """获取用户可管理的方向列表"""
        user = User.query.get(user_id)
        if user and user.role == Role.ADMIN:
            return CategoryConfig.get_all_categories()
        
        admin_records = CategoryAdmin.query.filter_by(user_id=user_id).all()
        category_ids = [a.category_id for a in admin_records]
        return CategoryConfig.query.filter(CategoryConfig.id.in_(category_ids)).all()