from flask import session
from typing import Dict, List, Any, Optional

class SessionManager:
    """
    会话管理类：提供所有会话相关的操作
    
    包括：
    1. 会话状态管理
    2. 会话数据存取
    3. 数据库会话数据获取
    """
    
    # ==== 会话状态管理 ====
    @staticmethod
    def update_session_state(state):
        """
        更新会话状态
        
        状态流程：
        1. language_selected: 语言选择完成
        2. vulnerability_selected: 漏洞选择完成
        3. scene_selected: 场景选择完成
        4. difficulty_selected: 难度选择完成
        5. functions_allocated: 功能分配完成
        6. challenge_generated: 题目生成完成
        """
        valid_states = [
            'language_selected',
            'difficulty_selected',
            'vulnerability_selected',
            'scene_selected',
            'extra_selected',
            'functions_allocated',
            'challenge_generated'
        ]
        
        if state not in valid_states:
            raise ValueError(f'无效的状态: {state}')
        
        session['current_state'] = state

    @staticmethod
    def get_session_state():
        """获取当前会话状态"""
        return session.get('current_state', None)

    @staticmethod
    def clear_session_state():
        """清除会话状态"""
        if 'current_state' in session:
            del session['current_state'] 
    
    # ==== 会话数据存取 ====
    @staticmethod
    def get_selected_language() -> Optional[str]:
        """获取选择的编程语言"""
        return session.get('language')
    
    @staticmethod
    def set_selected_language(language: str) -> None:
        """设置选择的编程语言"""
        session['language'] = language
        
    @staticmethod
    def get_selected_vulnerabilities() -> List[str]:
        """获取选择的漏洞列表"""
        return session.get('vulnerabilities', [])

    @staticmethod
    def get_selected_scene() -> Optional[Dict]:
        """获取选择的场景"""
        return session.get('scene')

    @staticmethod
    def set_current_step(step: str) -> None:
        """设置当前步骤"""
        session['current_step'] = step
        
    @staticmethod
    def clear_session() -> None:
        """清除所有会话数据"""
        session.clear()

    @staticmethod
    def get_selected_vulnerability_names() -> List[str]:
        """获取选择的漏洞名称列表"""
        return session.get('vulnerability_names', [])
    
    @staticmethod
    def set_selected_vulnerability_names(names: List[str]) -> None:
        """设置选择的漏洞名称列表"""
        session['vulnerability_names'] = names
        
    @staticmethod
    def get_selected_difficulty() -> Optional[str]:
        """获取选择的难度级别"""
        return session.get('difficulty')
    
    @staticmethod
    def set_selected_difficulty(difficulty: str) -> None:
        """设置选择的难度级别"""
        session['difficulty'] = difficulty

    @staticmethod
    def get_extra_requirements() -> Optional[str]:
        """获取额外要求"""
        return session.get('extra_requirements')
    
    @staticmethod
    def set_extra_requirements(requirements: str) -> None:
        """设置额外要求"""
        session['extra_requirements'] = requirements


# 为了保持向后兼容性的别名
class SessionData(SessionManager):
    """兼容层：保持原有的SessionData类对外接口"""
    pass

# 单独函数，保持向后兼容
def update_session_state(state):
    """更新会话状态（向后兼容函数）"""
    return SessionManager.update_session_state(state)

def get_session_state():
    """获取当前会话状态（向后兼容函数）"""
    return SessionManager.get_session_state()

def clear_session_state():
    """清除会话状态（向后兼容函数）"""
    return SessionManager.clear_session_state() 