
from flask import Blueprint, render_template, jsonify, redirect, url_for, g, current_app, flash
from .utils import generation_status, generation_output, reset_log_id, load_generation_result, set_current_challenge_id, update_conversation_logs_challenge, current_challenge_id, cancel_generation, reset_cancel_flag, is_generation_cancelled
import datetime
import traceback
import threading
from app.services.auth.decorators import login_required

# 创建蓝图
bp = Blueprint('generate', __name__)

# 标记生成函数是否正在运行的标志（保留占位，当前流程允许多任务并发）
generate_challenge_running = False

@bp.route('/auto-start', methods=['GET'])
@login_required
def auto_start():
    """自动启动页面：使用 JavaScript 自动调用 API 并跳转"""
    from flask import session, render_template
    
    # 检查 session 中是否有生成所需的数据
    form_data = session.get('form_data', {})
    category_id = session.get('category_id', 'web')
    
    if not form_data or not category_id:
        from flask import flash, redirect, url_for
        flash('请先完成题目配置', 'warning')
        return redirect(url_for('wizard.select_category'))
    
    # 渲染一个中间页面，自动调用 start_generate API
    return render_template('pages/generate/wizard/auto_start.html')


@bp.route('/start_generate', methods=['POST'])
@login_required
def start_generate():
    """开始生成题目"""
    import logging
    logger = logging.getLogger(__name__)
    logger.info("开始生成题目")

    # 不再限制只能运行一个任务，支持多任务并发
    # 每个任务都有独立的 task_id，可以同时运行多个任务

    try:
        # 1. 验证用户 AI 配置
        from app.services.generator.validator import GenerationValidator
        
        user_id = g.user.id if hasattr(g, 'user') and g.user else None
        if not user_id:
            return jsonify({
                "status": "error",
                "message": "用户未登录"
            })

        is_valid, error_msg = GenerationValidator.validate_ai_config(user_id)
        if not is_valid:
            return jsonify({
                "status": "error",
                "message": error_msg
            })

        # 2. 获取方向和表单数据（需要先获取，因为 AI 配置选择在 form_data 中）
        from flask import session
        from app.services.generator.extractor import FormDataExtractor
        
        category_id = session.get('category_id', 'web')
        form_data = session.get('form_data', {})

        # 确定使用哪种 AI 模式
        from app.models.database import AIProviderConfig, AIProviderType
        cli_providers = AIProviderType.CLI_PROVIDERS
        
        ai_mode = 'augment'  # 默认使用 Augment CLI
        selected_ai_config = None
        
        # 检查用户是否在向导页面选择了特定的 AI 配置
        ai_config_id = form_data.get('ai_config_id')
        logger.info(f"从 form_data 获取的 ai_config_id: {ai_config_id} (类型: {type(ai_config_id)})")
        logger.info(f"form_data 中的键: {list(form_data.keys())}")
        if ai_config_id:
            try:
                ai_config_id = int(ai_config_id)
                logger.info(f"尝试加载 AI 配置 ID: {ai_config_id}")
                selected_ai_config = AIProviderConfig.query.get(ai_config_id)
                if selected_ai_config:
                    logger.info(f"找到 AI 配置: {selected_ai_config.name}, provider_type: {selected_ai_config.provider_type}, user_id: {selected_ai_config.user_id}")
                    # 验证配置是否属于当前用户或是系统配置
                    if selected_ai_config.user_id == user_id or selected_ai_config.user_id is None:
                        if selected_ai_config.provider_type in cli_providers:
                            ai_mode = selected_ai_config.provider_type
                            logger.info(f"使用用户选择的 CLI 模式: {selected_ai_config.provider_type}")
                        else:
                            ai_mode = 'api'
                            logger.info(f"使用用户选择的 API 模式: {selected_ai_config.provider_type} - {selected_ai_config.model}")
                    else:
                        logger.warning(f"用户选择的 AI 配置不属于当前用户: {ai_config_id} (配置用户ID: {selected_ai_config.user_id}, 当前用户ID: {user_id})")
                        selected_ai_config = None
                else:
                    logger.warning(f"未找到 AI 配置 ID: {ai_config_id}")
            except (ValueError, TypeError) as e:
                logger.warning(f"无效的 AI 配置 ID: {ai_config_id}, 错误: {e}")
        else:
            logger.info("form_data 中未找到 ai_config_id，将使用默认配置")
        
        # 如果没有选择特定配置，使用默认配置
        if not selected_ai_config:
            try:
                selected_ai_config = AIProviderConfig.get_effective_config(user_id)
                if selected_ai_config:
                    if selected_ai_config.provider_type in cli_providers:
                        ai_mode = selected_ai_config.provider_type
                        logger.info(f"使用默认 CLI 模式: {selected_ai_config.provider_type}")
                    else:
                        ai_mode = 'api'
                        logger.info(f"使用默认 API 模式: {selected_ai_config.provider_type} - {selected_ai_config.model}")
            except Exception as e:
                logger.debug(f"获取 AI 配置失败，回退到 Augment 模式: {e}")
        
        # 验证表单数据
        is_valid, error_msg, missing_fields = GenerationValidator.validate_form_data(category_id, form_data)
        if not is_valid:
            return jsonify({
                "status": "error",
                "message": error_msg
            }), 400
        
        # 提取生成参数
        params = FormDataExtractor.extract_generation_params(category_id, form_data)
        language = params.get('language')
        vulnerabilities = params.get('vulnerabilities')
        scene = params.get('scene')
        difficulty = params.get('difficulty')
        extra_requirements = params.get('extra_requirements', '')
        
        # 调试日志：记录提取的参数
        logger.info(f"提取的参数 - language: {language}, difficulty: {difficulty}, vulnerabilities: {vulnerabilities}, scene: {scene}")

        # 3. 验证参数（基于字段配置动态验证）
        is_valid, error_msg = GenerationValidator.validate_parameters(
            category_id, language, vulnerabilities, scene, difficulty, form_data
        )
        if not is_valid:
            return jsonify({
                "status": "error",
                "message": error_msg or "参数验证失败"
            })

        logger.info(f"所有参数验证通过 (方向: {category_id})")
        
        # 清理场景数据：移除 sub_scenes 字段（如果存在）
        if isinstance(scene, dict) and 'sub_scenes' in scene:
            scene = scene.copy()  # 创建副本避免修改原始 session
            del scene['sub_scenes']
            logger.warning("检测到场景数据中包含 sub_scenes 字段，已自动移除")

        logger.info(f"方向: {category_id}")
        logger.info(f"语言: {language}")
        logger.info(f"漏洞/知识点: {vulnerabilities}")
        logger.info(f"场景: {scene}")
        logger.info(f"难度: {difficulty}")
        logger.info(f"额外要求: {extra_requirements[:50]}..." if len(extra_requirements) > 50 else f"额外要求: {extra_requirements}")
        
        # 用户ID已在前面获取
        logger.info(f"当前用户ID: {user_id}，将传递给生成线程")

        # 获取方向配置（用于获取方向名称等）
        from app.models.database.models import CategoryConfig
        category = CategoryConfig.query.get(category_id)
        category_name = category.name if category else category_id.upper()

        # 获取应用对象（在主线程中获取，然后传递给生成线程）
        app = current_app._get_current_object()

        # 重置终止标志
        reset_cancel_flag()
        
        # 创建任务记录（使用线程安全的方式）
        from .tasks import create_task_id, update_task
        task_id = create_task_id()
        
        # 注意：不再使用全局 generation_status，而是直接操作 generation_statuses[task_id]
        # 这样可以避免多任务冲突
        
        # 获取任务名称（使用短UUID）
        import uuid
        task_name = str(uuid.uuid4())[:8].upper()
        
        # 初始化任务状态（使用线程安全的 update_task，但先直接创建基础结构）
        from .tasks import tasks_status, _tasks_lock
        with _tasks_lock:
            tasks_status[task_id] = {
                'user_id': user_id,
                'category_id': category_id,
                'name': task_name,
                'difficulty': difficulty,  # 保存难度信息
                'status': 'running',
                'progress': 0,
                'created_at': datetime.datetime.now(),
                'updated_at': datetime.datetime.now(),
                'form_data': form_data,
                'log_file': None,
                'challenge_id': None
            }
        
        logger.info(f'创建新任务: {task_id} ({task_name})，用户: {user_id}')
        
        # 初始化任务状态到 generation_statuses（按 task_id 索引，支持多任务并发）
        from .utils import generation_statuses, generation_lock
        import time
        with generation_lock:
            generation_statuses[task_id] = {
                "generation_started": True,
                "message": "任务已开始，正在生成中...",
                "step_statuses": {},
                "current_step": 0,
                "completed": False,
                "key_info": None,
                "error": None,
                "log_file": None,
                "log_position": 0,
                "created_at": time.time(),
                "task_id": task_id,
                "ai_mode": ai_mode
            }

        # 初始化阶段总数（从数据库配置获取）
        try:
            if category:
                stages = category.get_stages()
                total_stages = len(stages) if stages else 0
                
                # 构建阶段名称映射（用于阶段检测）
                stage_names = {}
                if stages:
                    for i, stage in enumerate(stages):
                        if isinstance(stage, dict):
                            stage_names[i] = stage.get('name', f'阶段 {i}')
                        else:
                            stage_names[i] = str(stage)
                
                with generation_lock:
                    generation_statuses[task_id]["total_stages"] = total_stages
                    generation_statuses[task_id]["stage_names"] = stage_names
                    # 初始化 step_statuses
                    if total_stages > 0:
                        generation_statuses[task_id]["step_statuses"] = {i: "waiting" for i in range(total_stages)}
                    else:
                        generation_statuses[task_id]["step_statuses"] = {}
                logger.info(f'初始化 {total_stages} 个阶段: {list(stage_names.values())}')
            else:
                with generation_lock:
                    generation_statuses[task_id]["total_stages"] = 0
                    generation_statuses[task_id]["step_statuses"] = {}
        except Exception as e:
            logger.warning(f'初始化阶段总数失败: {e}')
            with generation_lock:
                generation_statuses[task_id]["total_stages"] = 0
                generation_statuses[task_id]["step_statuses"] = {}
        
        # 启动生成进程，传递用户ID、应用对象、AI模式和任务ID
        thread = threading.Thread(target=generate_challenge, args=(language, vulnerabilities, scene, difficulty, user_id, app, ai_mode, extra_requirements, category_id, task_id, form_data))
        thread.daemon = True
        thread.start()
        
        return jsonify({
            "status": "success",
            "message": "生成进程已启动",
            "task_id": task_id,
            "redirect_url": f"/generate/tasks/{task_id}"
        })
        
    except Exception as e:
        logger.error(f"启动生成进程时出错: {str(e)}")
        traceback.print_exc()
        
        generation_status["error"] = str(e)
        
        return jsonify({
            "status": "error",
            "message": f"启动生成进程失败: {str(e)}"
        })

def _extract_challenge_info_from_writeup(writeup, dir_name):
    """从 writeup 中提取题目信息

    Args:
        writeup: writeup 内容
        dir_name: 输出目录名称

    Returns:
        题目信息字典
    """
    import re

    challenge_info = {
        'name': '未命名题目',
        'description': '',
        'difficulty': '未设置',
        'category': 'WEB',
        'estimated_time': '1-2小时'
    }

    # 从 writeup 的题目信息表格中提取信息
    # 支持两种格式：
    # 1. 4列格式：| 题目名 | 题目描述 | 类型 | 预计解题时间 |
    # 2. 6列格式：| 题目名 | 题目描述 | 类型 | 预计解题时间 | 难度 | 是否提供源码 |
    
    # 先尝试匹配 6 列格式（包含难度列）
    table_match_6 = re.search(
        r'\|\s*题目名\s*\|\s*题目描述\s*\|\s*类型\s*\|\s*预计解题时间\s*\|\s*难度\s*\|\s*是否提供源码\s*\|.*?\n'  # 表头（6列）
        r'\|[:\-\s\|]+\|.*?\n'  # 分隔线（匹配多个列）
        r'\|\s*([^|]+)\s*\|\s*([^|]+)\s*\|\s*([^|]+)\s*\|\s*([^|]+)\s*\|\s*([^|]+)\s*\|\s*[^|]+\s*\|',  # 数据行（6列）
        writeup,
        re.DOTALL
    )
    
    if table_match_6:
        challenge_info['name'] = table_match_6.group(1).strip()
        challenge_info['description'] = table_match_6.group(2).strip()
        challenge_info['category'] = table_match_6.group(3).strip()
        challenge_info['estimated_time'] = table_match_6.group(4).strip()
        challenge_info['difficulty'] = table_match_6.group(5).strip()
        import logging
        logger = logging.getLogger(__name__)
        logger.info(f"成功从6列表格提取题目信息:")
        logger.info(f"   - 题目名: {challenge_info['name']}")
        logger.info(f"   - 题目描述: {challenge_info['description']}")
        logger.info(f"   - 类型: {challenge_info['category']}")
        logger.info(f"   - 预计时间: {challenge_info['estimated_time']}")
        logger.info(f"   - 难度: {challenge_info['difficulty']}")
    else:
        # 尝试匹配 4 列格式
        table_match = re.search(
            r'\|\s*题目名\s*\|\s*题目描述\s*\|\s*类型\s*\|\s*预计解题时间\s*\|.*?\n'  # 表头（4列）
            r'\|[:\-\s]+\|[:\-\s]+\|[:\-\s]+\|[:\-\s]+\|.*?\n'  # 分隔线（4列）
            r'\|\s*([^|]+)\s*\|\s*([^|]+)\s*\|\s*([^|]+)\s*\|\s*([^|]+)\s*\|',  # 数据行（4列）
            writeup,
            re.DOTALL
        )

        if table_match:
            challenge_info['name'] = table_match.group(1).strip()
            challenge_info['description'] = table_match.group(2).strip()
            challenge_info['category'] = table_match.group(3).strip()
            challenge_info['estimated_time'] = table_match.group(4).strip()
            import logging
            logger = logging.getLogger(__name__)
            logger.info(f"成功从4列表格提取题目信息:")
            logger.info(f"   - 题目名: {challenge_info['name']}")
            logger.info(f"   - 题目描述: {challenge_info['description']}")
            logger.info(f"   - 类型: {challenge_info['category']}")
            logger.info(f"   - 预计时间: {challenge_info['estimated_time']}")
        else:
            import logging
            logger = logging.getLogger(__name__)
            logger.warning("未能从新格式表格提取题目信息，尝试旧格式...")
            # 旧格式兼容：| 题目名 | 类型 | 难度 |
            old_table_match = re.search(
                r'\|\s*题目名\s*\|\s*类型\s*\|\s*难度\s*\|.*?\n'  # 表头
                r'\|[:\-\s]+\|[:\-\s]+\|[:\-\s]+\|.*?\n'  # 分隔线
                r'\|\s*([^|]+)\s*\|\s*([^|]+)\s*\|\s*([^|]+)\s*\|',  # 数据行
                writeup,
                re.DOTALL
            )
        if old_table_match:
            challenge_info['name'] = old_table_match.group(1).strip()
            challenge_info['category'] = old_table_match.group(2).strip()
            challenge_info['difficulty'] = old_table_match.group(3).strip()
            logger.info(f"从旧格式表格提取题目信息:")
            logger.info(f"   - 题目名: {challenge_info['name']}")
            logger.info(f"   - 类型: {challenge_info['category']}")
            logger.info(f"   - 难度: {challenge_info['difficulty']}")
        else:
            # 尝试从 writeup 标题提取（格式：# 题目名 - Writeup 或 # 题目名）
            title_match = re.search(r'^#\s+([^#\n]+?)(?:\s*-\s*[Ww]riteup)?\s*$', writeup, re.MULTILINE)
            if title_match:
                challenge_info['name'] = title_match.group(1).strip()
                logger.info(f"从标题提取题目名: {challenge_info['name']}")
            else:
                # 如果都失败，从目录名提取题目名称
                logger.warning("表格提取失败，从目录名提取题目名称...")
                match = re.search(r'\d{8}_\d{6}_(.+)', dir_name)
                if match:
                    challenge_info['name'] = match.group(1)
                    logger.info(f"   - 从目录名提取题目名: {challenge_info['name']}")

    return challenge_info

def generate_challenge(language, vulnerabilities, scene, difficulty, user_id=None, app=None, ai_mode='augment', extra_requirements='', category_id='web', task_id=None, form_data=None):
    """生成题目的主函数，在单独的线程中运行（已重构为使用服务层）

    Args:
        language: 编程语言
        vulnerabilities: 漏洞列表
        scene: 场景
        difficulty: 难度
        user_id: 用户ID，从主线程中传递过来的
        app: Flask应用对象，从主线程中传递过来的
        ai_mode: AI 模式，'augment' 或 'api'
        extra_requirements: 用户额外要求
        category_id: 方向ID
        task_id: 任务ID
        form_data: 完整的表单数据
    """
    global generation_status
    
    # 更新任务状态（使用线程安全的方式）
    if task_id:
        from .tasks import update_task, get_task
        existing_task = get_task(task_id)
        if existing_task:
            update_task(task_id, status='running')
            
            # 定期同步日志文件路径
            def sync_log_file():
                import time
                from .utils import get_generation_status
                max_iterations = 1800
                iteration = 0
                while iteration < max_iterations:
                    try:
                        task = get_task(task_id)
                        if task:
                            task_status = get_generation_status(task_id)
                            if task_status and task_status.get('task_id') == task_id:
                                if task_status.get('log_file'):
                                    update_task(task_id, log_file=task_status['log_file'])
                                if 'progress' in task_status:
                                    update_task(task_id, progress=task_status.get('progress', 0))
                        task = get_task(task_id)
                        if task and task.get('status') in ['completed', 'failed', 'cancelled']:
                            break
                        time.sleep(1)
                        iteration += 1
                    except Exception as e:
                        import logging
                        logger = logging.getLogger(__name__)
                        logger.warning(f"同步日志文件路径失败: {e}")
                        break
            
            sync_thread = threading.Thread(target=sync_log_file, daemon=True)
            sync_thread.start()

    # 如果没有传递应用对象，尝试获取
    if app is None:
        from flask import current_app
        app = current_app._get_current_object()

    with app.app_context():
        import logging
        logger = logging.getLogger(__name__)
        
        try:
            # 使用服务层生成题目
            from app.services.generator.service import ChallengeGeneratorService
            
            service = ChallengeGeneratorService(
                user_id=user_id,
                category_id=category_id,
                task_id=task_id
            )
            
            # 从 form_data 中提取 ai_config_id
            ai_config_id = None
            if form_data and 'ai_config_id' in form_data:
                try:
                    ai_config_id = int(form_data['ai_config_id'])
                    logger.info(f"传递 AI 配置 ID 到生成服务: {ai_config_id}")
                except (ValueError, TypeError):
                    logger.warning(f"无效的 ai_config_id: {form_data.get('ai_config_id')}")
            
            result = service.generate(
                language=language,
                vulnerabilities=vulnerabilities,
                scene=scene,
                difficulty=difficulty,
                extra_requirements=extra_requirements,
                ai_mode=ai_mode,
                app=app,
                form_data=form_data,  # 传递完整的表单数据
                ai_config_id=ai_config_id
            )
            
            # 处理结果（使用 task_id 更新对应任务的状态，避免多任务冲突）
            # 注意：只有当生成真正完成且有challenge_id时，才标记为完成
            # Augment模式会在_call_augment完成后才返回，此时生成已经完成
            if result['status'] == 'success':
                challenge_id = result.get('challenge_id')
                # 只有当challenge_id存在时（说明已经真正完成并创建了题目记录），才标记为完成
                if challenge_id:
                    from .utils import set_current_challenge_id
                    set_current_challenge_id(challenge_id)
                    
                    # 更新对应任务的状态（使用 task_id）
                    if task_id:
                        from .utils import generation_statuses, generation_lock
                        with generation_lock:
                            if task_id in generation_statuses:
                                generation_statuses[task_id]["message"] = "生成成功!"
                                generation_statuses[task_id]["completed"] = True
                                generation_statuses[task_id]["redirect_url"] = f"/generate/result?challenge_id={challenge_id}"
                                generation_statuses[task_id]["challenge_id"] = challenge_id
                        
                        # 同步到 tasks_status
                        from .tasks import update_task
                        update_task(task_id, 
                                   status='completed',
                                   challenge_id=challenge_id)
            else:
                # 更新失败状态（使用 task_id）
                error_msg = result.get('message', '生成失败')
                if task_id:
                    from .utils import generation_statuses, generation_lock
                    with generation_lock:
                        if task_id in generation_statuses:
                            generation_statuses[task_id]["error"] = error_msg
                            generation_statuses[task_id]["message"] = f"生成失败: {error_msg}"
                
        except Exception as e:
            logger.error(f"生成进程中发生错误: {str(e)}", exc_info=True)
            
            # 更新任务状态为失败
            if task_id:
                from .tasks import update_task
                update_task(task_id, status='failed', error=str(e))
                
                # 更新对应任务的状态（使用 task_id）
                from .utils import generation_statuses, generation_lock
                with generation_lock:
                    if task_id in generation_statuses:
                        generation_statuses[task_id]["error"] = str(e)
                        generation_statuses[task_id]["message"] = f"生成出错: {str(e)}"
        
        finally:
            logger.info("=== 生成题目主进程结束 ===\n")



__all__ = ['bp']