import json
import random
import sys
import os

def find_category(data, target_category):
    """递归查找匹配的分类名，返回该分类的数据

    Args:
        data: JSON数据（字典或列表）
        target_category: 目标分类名称

    Returns:
        dict|None: 找到的分类数据，如果不存在返回None
    """
    if isinstance(data, dict):
        for key, value in data.items():
            if key == target_category:
                return value
            # 递归查找子分类
            result = find_category(value, target_category)
            if result is not None:
                return result
    elif isinstance(data, list):
        for item in data:
            result = find_category(item, target_category)
            if result is not None:
                return result
    return None


def check_writeup_exists(filename, raw_dir="data/writeups"):
    """检查writeup文件是否存在"""
    writeup_path = os.path.join(raw_dir, filename)
    return os.path.exists(writeup_path)


def find_writeup_file(challenge_name, raw_dir="data/writeups"):
    """查找writeup文件，支持模糊匹配

    Args:
        challenge_name: 挑战名称
        raw_dir: writeup目录

    Returns:
        str|None: 找到的文件名，如果不存在返回None
    """
    # 1. 精确匹配
    exact_filename = f"{challenge_name}.md"
    if check_writeup_exists(exact_filename, raw_dir):
        return exact_filename

    # 2. 模糊匹配：查找包含挑战名称的文件
    try:
        all_files = os.listdir(raw_dir)
        # 不区分大小写的匹配
        challenge_lower = challenge_name.lower()
        for filename in all_files:
            if filename.endswith('.md'):
                # 检查文件名是否包含挑战名称
                if challenge_lower in filename.lower():
                    return filename
                # 或者挑战名称包含在文件名中（去掉前缀后）
                # 例如: "JuniorCTF - Pizzagate.md" 匹配 "Pizzagate"
                name_without_ext = filename[:-3]  # 去掉.md
                if challenge_lower in name_without_ext.lower():
                    return filename
    except Exception as e:
        print(f"搜索writeup文件时出错: {e}")

    return None


def pick_challenges(json_path, categories, total_count=14):
    """从主数据中抽取挑战writeup

    Args:
        json_path: 挑战数据JSON文件路径（aaa.json）
        categories: 漏洞类型列表
        total_count: 总共挑选的文章数量（默认14篇）
    """
    # 载入挑战数据
    with open(json_path, 'r', encoding='utf-8') as f:
        data = json.load(f)

    # 计算每个分类应该选择的数量（平均分配）
    num_categories = len(categories)
    if num_categories == 0:
        print("⚠️ 未指定任何分类。")
        return

    # 第一轮：收集每个分类的可用writeup
    category_data = {}
    for cat in categories:
        # 查找分类数据
        cat_data = find_category(data, cat)
        if not cat_data:
            print(f"⚠️ 分类 [{cat}] 在挑战文件中未找到。")
            category_data[cat] = []
            continue

        # 获取挑战名称列表
        # vulnerability_db.json中的结构是: {"漏洞名": {"challenges": [{filename, quality, ...}, ...], ...}}
        challenges = cat_data.get("challenges", [])

        if not challenges:
            print(f"⚠️ 分类 [{cat}] 下无挑战。")
            category_data[cat] = []
            continue

        # 使用模糊匹配查找writeup文件
        existing_files = []
        for challenge in challenges:
            # 支持新格式（字典）和旧格式（字符串）
            if isinstance(challenge, dict):
                challenge_name = challenge.get('filename', '')
            else:
                challenge_name = challenge
            
            found_file = find_writeup_file(challenge_name)
            if found_file:
                existing_files.append(found_file)

        if not existing_files:
            print(f"⚠️ 分类 [{cat}] 下没有找到任何存在的writeup文件。")
            category_data[cat] = []
        else:
            category_data[cat] = existing_files
            print(f"📁 分类 [{cat}]: 找到 {len(existing_files)}/{len(challenges)} 个有效writeup")

    # 第二轮：动态分配数量，确保总数达到total_count
    # 过滤出有writeup的分类
    valid_categories = [cat for cat, files in category_data.items() if len(files) > 0]

    if not valid_categories:
        print("⚠️ 所有分类都没有找到有效的writeup文件。")
        return

    all_selected = []

    # 计算初始分配
    pick_per_category = total_count // len(valid_categories)
    remainder = total_count % len(valid_categories)

    # 第一次分配
    category_allocations = {}
    for idx, cat in enumerate(valid_categories):
        # 前remainder个分类多分配1篇
        allocated = pick_per_category + (1 if idx < remainder else 0)
        # 不能超过该分类的可用数量
        actual = min(allocated, len(category_data[cat]))
        category_allocations[cat] = actual

    # 计算未分配的数量（由于某些分类writeup不足）
    total_allocated = sum(category_allocations.values())
    unallocated = total_count - total_allocated

    # 如果有未分配的，重新分配给有余量的分类
    if unallocated > 0:
        # 找出还有余量的分类
        categories_with_余量 = [
            cat for cat in valid_categories
            if category_allocations[cat] < len(category_data[cat])
        ]

        # 将未分配的数量平均分配给有余量的分类
        while unallocated > 0 and categories_with_余量:
            for cat in categories_with_余量:
                if unallocated <= 0:
                    break
                available = len(category_data[cat]) - category_allocations[cat]
                if available > 0:
                    category_allocations[cat] += 1
                    unallocated -= 1

            # 更新有余量的分类列表
            categories_with_余量 = [
                cat for cat in valid_categories
                if category_allocations[cat] < len(category_data[cat])
            ]

    # 第三轮：根据分配数量随机选择
    for cat in valid_categories:
        pick_count = category_allocations[cat]
        existing_files = category_data[cat]

        selected = random.sample(existing_files, pick_count)
        all_selected.extend(selected)

        print(f"✅ 分类 [{cat}]: 从 {len(existing_files)} 个有效writeup中选择了 {pick_count} 篇")

    unique_selected = sorted(set(all_selected))

    # 检查是否达到目标数量
    actual_count = len(unique_selected)
    print("\n" + "=" * 60)

    if actual_count < total_count:
        print(f"⚠️  警告：目标数量为 {total_count} 篇，但只找到 {actual_count} 篇有效writeup")
        print(f"📚 最终汇总选出的 {actual_count} 篇文章：")
    else:
        print(f"📚 最终汇总选出的 {actual_count} 篇文章：")

    for fn in unique_selected:
        print(f"  - {fn}")

    # 输出统计信息
    if actual_count < total_count:
        shortage = total_count - actual_count
        print("\n" + "=" * 60)
        print(f"💡 提示：还差 {shortage} 篇才能达到目标数量")
        print(f"   建议：选择writeup数量更多的分类，或减少分类数量")


if __name__ == "__main__":
    if len(sys.argv) < 2:
        print("用法: python choice.py [--difficulty=难度] [--count=数量] <分类1> <分类2> ...")
        print("示例: python choice.py --difficulty=简单 SSTI模板注入 SQL注入")
        print("      python choice.py --difficulty=中等 --count=6 SSTI模板注入")
        print("      python choice.py 命令注入  # 不指定难度时默认为中等")
        print("\n难度选项:")
        print("  - 入门/简单: 默认选出5篇writeup")
        print("  - 中等/困难: 默认选出10篇writeup")
        print("\n数量选项:")
        print("  - --count=N: 指定选出N篇writeup（覆盖难度默认值）")
        sys.exit(1)

    # 解析参数
    args = sys.argv[1:]
    difficulty = None
    custom_count = None
    categories = []

    # 检查是否有 --difficulty 和 --count 参数
    for arg in args:
        if arg.startswith('--difficulty='):
            difficulty = arg.split('=', 1)[1]
        elif arg.startswith('--count='):
            try:
                custom_count = int(arg.split('=', 1)[1])
            except ValueError:
                print(f"⚠️  无效的数量参数: {arg}")
                sys.exit(1)
        else:
            categories.append(arg)

    if not categories:
        print("❌ 错误：至少需要指定一个分类")
        sys.exit(1)

    # 确定选择数量
    if custom_count is not None:
        # 使用用户指定的数量
        total_count = custom_count
        print(f"📊 指定数量: {total_count} 篇writeup\n")
    elif difficulty in ['入门', '简单']:
        total_count = 5
        print(f"📊 难度: {difficulty} - 将选出 {total_count} 篇writeup\n")
    elif difficulty in ['中等', '困难']:
        total_count = 10
        print(f"📊 难度: {difficulty} - 将选出 {total_count} 篇writeup\n")
    else:
        # 默认为中等难度
        total_count = 10
        if difficulty:
            print(f"⚠️  未识别的难度 '{difficulty}'，使用默认值（中等）\n")
        else:
            print(f"📊 未指定难度，使用默认值（中等）- 将选出 {total_count} 篇writeup\n")

    # 使用漏洞库文件
    json_file = "data/vulnerability_db.json"

    # 根据难度挑选writeup
    pick_challenges(json_file, categories, total_count=total_count)
