"""
CTF 题目生成工具模块

本模块提供题目生成过程中的工具函数和状态管理。

核心功能：
---------
1. 生成状态管理 (generation_statuses)
   - 按 task_id 索引，支持多任务并发
   - 步骤状态管理（阶段数量由配置决定）

2. 日志记录 (write_ai_log)
   - 记录 AI 对话到数据库
   - 关联到题目和用户
   - 支持步骤耗时统计

3. 结果持久化
   - 保存生成结果到数据库
   - 支持会话恢复
"""

from typing import Dict, Any, Optional
import os
import threading
import json
import datetime
import traceback
import uuid
import re
import time

# 导入阶段检测器（统一的阶段定义）
from app.services.ai.core.stage_detector import StageDetector

# 尝试导入数据库操作
try:
    from app.models.database.operations import (
        save_challenge_record,
        get_latest_challenge
    )
    HAS_DB_OPERATIONS = True
except ImportError as e:
    HAS_DB_OPERATIONS = False
    import logging
    logger = logging.getLogger(__name__)
    logger.warning(f"数据库操作模块导入失败: {e}")
    logger.warning("将使用文件方式存储数据")

# 全局变量，用于存储生成状态和输出（按 task_id 索引，支持多任务并发）
# 格式: {task_id: {generation_started, message, step_statuses, ...}}
generation_statuses: Dict[str, Dict[str, Any]] = {}

# 为了向后兼容，提供一个默认的 generation_status 访问接口
# 这个接口会返回当前最新的任务状态（用于旧代码的兼容）
def get_generation_status(task_id: Optional[str] = None) -> Dict[str, Any]:
    """获取生成状态
    
    Args:
        task_id: 任务ID，如果为None则返回最新的任务状态（向后兼容）
    
    Returns:
        生成状态字典
    """
    if task_id:
        # 如果指定了task_id，只返回该任务的状态，如果不存在则返回None的默认状态
        if task_id in generation_statuses:
            return generation_statuses[task_id]
        else:
            # 如果任务ID不存在，返回一个空状态字典（而不是最新任务的状态）
            # 这样可以避免错误的任务状态被使用
            return {
                "generation_started": False,
                "message": "任务状态不存在",
                "step_statuses": {},
                "current_step": 0,
                "completed": False,
                "key_info": None,
                "error": None,
                "log_file": None,
                "log_position": 0,
                "task_id": task_id
            }
    
    # 如果没有指定task_id，返回最新的任务状态（向后兼容，用于旧代码）
    if generation_statuses:
        # 返回最新的任务状态（按创建时间排序）
        latest_task_id = max(generation_statuses.keys(), key=lambda k: generation_statuses[k].get('created_at', 0))
        return generation_statuses[latest_task_id]
    
    # 如果没有任务，返回默认状态（不预设阶段数量，由实际配置决定）
    return {
        "generation_started": False,
        "message": "等待开始生成...",
        "step_statuses": {},  # 阶段状态由配置决定，不硬编码
        "current_step": 0,
        "completed": False,
        "key_info": None,
        "error": None,
        "log_file": None,
        "log_position": 0
    }

# 向后兼容：提供一个字典接口，访问时自动获取当前任务状态
class GenerationStatusProxy:
    """生成状态代理类，提供向后兼容的字典接口"""
    def get(self, key, default=None):
        status = get_generation_status()
        return status.get(key, default)
    
    def __setitem__(self, key, value):
        # 获取当前最新的任务ID
        if generation_statuses:
            latest_task_id = max(generation_statuses.keys(), key=lambda k: generation_statuses[k].get('created_at', 0))
            if latest_task_id in generation_statuses:
                generation_statuses[latest_task_id][key] = value
            else:
                # 如果没有任务，创建一个默认任务状态
                default_task_id = f"default-{int(time.time())}"
                if default_task_id not in generation_statuses:
                    generation_statuses[default_task_id] = get_generation_status()
                generation_statuses[default_task_id][key] = value
    
    def __getitem__(self, key):
        status = get_generation_status()
        return status[key]
    
    def __contains__(self, key):
        status = get_generation_status()
        return key in status

# 为了向后兼容，保持 generation_status 作为代理对象
generation_status = GenerationStatusProxy()

# 存储生成结果（按 task_id 索引，支持多任务并发）
# 格式: {task_id: {result_data}}
generation_output: Dict[str, Dict[str, Any]] = {}

# 存储最后一次生成结果，用于持久化存储
last_generation_result = {}

# 用于控制生成线程的锁
generation_lock = threading.Lock()

# 存储日志的锁
log_lock = threading.Lock()

# 终止生成的标志
generation_cancelled = False

def cancel_generation():
    """请求终止生成过程"""
    global generation_cancelled
    generation_cancelled = True

def reset_cancel_flag():
    """重置终止标志"""
    global generation_cancelled
    generation_cancelled = False

def is_generation_cancelled():
    """检查是否请求终止"""
    global generation_cancelled
    return generation_cancelled

# 存储生成结果的文件路径
GENERATION_RESULT_FILE = "data/last_generation_result.json"

# 是否使用数据库存储
USE_DATABASE_STORAGE = HAS_DB_OPERATIONS

# 当前生成的题目ID
current_challenge_id = None

# 当前用户ID
current_user_id = None

# 存储本次生成中记录的对话日志ID，用于后续关联到题目
conversation_log_ids = []

# 设置当前题目ID
def set_current_challenge_id(challenge_id):
    """设置当前题目ID，用于关联日志
    
    Args:
        challenge_id: 题目ID
    """
    global current_challenge_id
    current_challenge_id = challenge_id
    import logging
    logger = logging.getLogger(__name__)
    logger.debug(f"已设置当前题目ID: {challenge_id}")

# 设置当前用户ID
def set_current_user_id(user_id):
    """设置当前用户ID，用于关联日志
    
    Args:
        user_id: 用户ID
    """
    global current_user_id
    current_user_id = user_id
    print(f"已设置当前用户ID: {user_id}")

# 保存生成结果到文件
def save_generation_result(result: Dict):
    """保存生成结果到数据库，用于会话恢复
    
    Args:
        result: 生成结果字典
    
    Returns:
        bool: 是否成功保存
    """
    global current_challenge_id
    global current_user_id
    
    # 添加时间戳
    if not result.get('timestamp'):
        result['timestamp'] = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    
    # 保存到数据库
    if USE_DATABASE_STORAGE:
        try:
            # 使用全局变量中的用户ID
            user_id = current_user_id
            if user_id:
                print(f"使用全局变量中的用户ID: {user_id}，将关联到生成的题目")
            else:
                print("警告：未找到当前用户ID，题目将不会关联到用户")
            
            challenge_id = save_challenge_record(result, user_id)
            if challenge_id:
                current_challenge_id = challenge_id
                print(f"已将生成结果保存到数据库，ID: {challenge_id}")
                return True
            else:
                print("保存到数据库失败：未能获取有效的题目ID")
                return False
        except Exception as db_error:
            print(f"保存生成结果到数据库失败: {str(db_error)}")
            traceback.print_exc()
            return False
    else:
        print("数据库操作不可用，无法保存生成结果")
        return False

# 从文件加载生成结果
def load_generation_result():
    """从数据库加载生成结果，用于应用启动时恢复状态
    
    Returns:
        dict: 加载的生成结果，如果没有则返回None
    """
    global current_challenge_id
    
    # 从数据库加载最新的结果
    if USE_DATABASE_STORAGE:
        try:
            latest_challenge = get_latest_challenge()
            if latest_challenge:
                current_challenge_id = latest_challenge.get('id')
                print(f"成功从数据库加载最新题目（ID: {current_challenge_id}）")
                return latest_challenge
        except Exception as db_error:
            print(f"从数据库加载生成结果失败: {str(db_error)}")
            traceback.print_exc()
    
    print("数据库操作不可用或找不到题目记录")
    return None

# 格式化日志内容
def format_content_for_log(content, role):
    """格式化日志内容
    
    Args:
        content: 日志内容
        role: 角色
        
    Returns:
        str: 格式化后的内容
    """
    if not content:
        return "【空内容】"
        
    # 对于代码块，进行适当的缩进
    lines = content.split('\n')
    formatted_lines = []
    
    in_code_block = False
    for line in lines:
        # 检测代码块开始
        if line.strip().startswith("```"):
            in_code_block = not in_code_block
            formatted_lines.append(line)
        elif in_code_block:
            # 代码块内部，保持原样
            formatted_lines.append(line)
        else:
            # 正常文本，进行适当的修饰
            formatted_lines.append(line)
            
    return '\n'.join(formatted_lines)

def write_ai_log(step: int, role: str, content: str):
    """记录AI对话日志到数据库
    
    Args:
        step: 当前步骤
        role: 角色 (user/assistant/system)
        content: 消息内容
    """
    global conversation_log_ids
    global current_user_id
    
    with log_lock:
        # 生成唯一ID，如果还没有
        if not hasattr(write_ai_log, "log_id"):
            now = datetime.datetime.now()
            write_ai_log.log_id = now.strftime("%Y%m%d_%H%M%S")
            write_ai_log.step_start_time = {}
            write_ai_log.logged_messages = {}  # 用于跟踪已记录的消息
            print(f"创建新的日志ID: {write_ai_log.log_id}")
            # 清空先前的日志ID列表
            conversation_log_ids = []
        
        # 检查是否已经记录过这个消息
        message_key = f"{step}_{role}"
        
        # 如果没有logged_messages属性，先创建
        if not hasattr(write_ai_log, "logged_messages"):
            write_ai_log.logged_messages = {}
            
        # 检查是否记录过相同位置的消息
        if message_key in write_ai_log.logged_messages:
            # 内容完全相同，跳过
            if write_ai_log.logged_messages[message_key] == content:
                print(f"跳过已记录的消息: 步骤={step}, 角色={role}")
                return
            # 系统消息可以被替换，其他消息应该保持不变
            elif role != "system":
                print(f"发现不同内容的消息在相同位置: 步骤={step}, 角色={role}")
                return
        
        # 记录这条消息
        write_ai_log.logged_messages[message_key] = content
        
        # 创建日志条目
        timestamp = datetime.datetime.now().isoformat()
        
        # 记录每个步骤的开始时间
        if role == "user" and step not in write_ai_log.step_start_time:
            write_ai_log.step_start_time[step] = timestamp
        
        # 如果是assistant回复，计算步骤耗时
        step_duration = None
        if role == "assistant" and step in write_ai_log.step_start_time:
            start_time = datetime.datetime.fromisoformat(write_ai_log.step_start_time[step])
            end_time = datetime.datetime.fromisoformat(timestamp)
            step_duration = (end_time - start_time).total_seconds()
            
        # AI对话日志功能已禁用（save_ai_conversation_log 函数不存在）
        # 如需启用，请在 app/models/database/operations.py 中实现该函数
        pass

# get_generation_status 函数已在上面定义（第67行），这里不再重复定义

# 获取生成输出
def get_generation_output():
    """获取生成输出
    
    Returns:
        Dict: 生成输出字典
    """
    return generation_output

# 是否生成已完成
def is_generation_completed(task_id: Optional[str] = None):
    """检查生成是否已完成
    
    Args:
        task_id: 任务ID，如果为None则使用最新的任务（向后兼容）
    
    Returns:
        bool: 是否已完成
    """
    status = get_generation_status(task_id)
    return status.get("completed", False)

# 获取生成错误信息
def get_generation_error(task_id: Optional[str] = None):
    """获取生成错误信息
    
    Args:
        task_id: 任务ID，如果为None则使用最新的任务（向后兼容）
    
    Returns:
        str: 错误信息
    """
    status = get_generation_status(task_id)
    return status.get("error", None)

# 获取完整的对话历史
def get_full_conversation_history():
    """获取完整对话历史

    Returns:
        List: 对话历史列表
    """
    # AI对话日志功能已禁用（get_conversation_logs 函数不存在）
    # 如需启用，请在 app/models/database/operations.py 中实现该函数
    return []

def get_step_description(step: int) -> str:
    """获取 Augment 阶段描述（委托给 StageDetector）"""
    return StageDetector.get_stage_name(step)


def reset_log_id():
    """重置日志ID，用于开始新的生成过程时调用"""
    global current_challenge_id
    
    if hasattr(write_ai_log, "log_id"):
        delattr(write_ai_log, "log_id")
    if hasattr(write_ai_log, "step_start_time"):
        delattr(write_ai_log, "step_start_time")
    if hasattr(write_ai_log, "logged_messages"):
        delattr(write_ai_log, "logged_messages")
    
    # 重置当前题目ID为None，确保新日志不会关联到旧题目
    current_challenge_id = None
    
    print("已重置日志ID和题目ID")

def initialize_generation(task_id: Optional[str] = None, total_stages: int = 0):
    """初始化生成状态
    
    Args:
        task_id: 任务ID，如果为None则创建一个默认任务
        total_stages: 总阶段数，如果为0则不预设阶段状态（由配置决定）
    """
    with generation_lock:
        if not task_id:
            task_id = f"default-{int(time.time())}"
        
        generation_statuses[task_id] = {
            "current_step": 0,
            "message": "正在初始化...",
            "generation_started": True,
            "completed": False,
            "step_statuses": {} if total_stages == 0 else {i: "waiting" for i in range(total_stages)},
            "total_stages": total_stages,
            "key_info": None,
            "error": None,
            "log_file": None,
            "log_position": 0,
            "created_at": time.time(),
            "task_id": task_id
        }
        # 初始化该任务的输出（不清空其他任务的输出）
        if task_id not in generation_output:
            generation_output[task_id] = {}
        print(f"已初始化生成状态（任务ID: {task_id}, 阶段数: {total_stages if total_stages > 0 else '由配置决定'}）")

# 更新对话日志关联到题目
def update_conversation_logs_challenge(challenge_id):
    """更新所有本次生成的对话日志，关联到指定题目
    
    Args:
        challenge_id: 题目ID
    
    Returns:
        bool: 是否成功更新
    """
    global conversation_log_ids
    
    if not challenge_id:
        print("无效的题目ID，无法更新对话日志关联")
        return False
        
    if not conversation_log_ids:
        import logging
        logger = logging.getLogger(__name__)
        logger.debug("没有需要更新的对话日志")
        return True
        
    try:
        from app.models.database.models import db, AIConversationLog
        
        # 批量更新对话日志
        updated_count = 0
        for log_id in conversation_log_ids:
            try:
                log = AIConversationLog.query.get(log_id)
                if log:
                    log.challenge_id = challenge_id
                    updated_count += 1
            except Exception as e:
                print(f"更新对话日志 {log_id} 失败: {str(e)}")
                
        # 提交更改
        db.session.commit()
        
        print(f"成功更新 {updated_count}/{len(conversation_log_ids)} 条对话日志关联到题目 {challenge_id}")
        return True
    except Exception as e:
        print(f"更新对话日志失败: {str(e)}")
        traceback.print_exc()
        try:
            db.session.rollback()
        except:
            pass
        return False

# ... existing code ... 