import sys
import os

version = sys.argv[1]
is_fixed = sys.argv[2] == "1"
log_file = os.environ.get("LOGS", ".") + f"/variant_test_{version}.log"

with open(log_file, "w") as log:
    def log_write(msg):
        print(msg)
        log.write(msg + "\n")
    
    log_write(f"=== Testing semantic-kernel {version} (is_fixed={is_fixed}) ===\n")
    
    try:
        from semantic_kernel.connectors.in_memory import InMemoryStore
        from semantic_kernel.data.vector import VectorStoreCollectionDefinition, VectorStoreField
        from semantic_kernel.data.vector import FieldTypes
        from pydantic import BaseModel
        from typing import List
    except ImportError as e:
        log_write(f"[-] Failed to import semantic-kernel: {e}")
        sys.exit(1)
    
    # Define a simple data model
    class TestDataModel(BaseModel):
        id: str
        content: str
        embedding: List[float]
        __class__: str = "test"  # Adding a dunder key for testing

        class Config:
            arbitrary_types_allowed = True
    
    # Create vector store
    vector_store = InMemoryStore()
    
    # Create collection definition
    definition = VectorStoreCollectionDefinition(
        fields=[
            VectorStoreField(name="id", field_type=FieldTypes.KEY, type="str"),
            VectorStoreField(name="content", field_type=FieldTypes.DATA, type="str"),
            VectorStoreField(name="embedding", field_type=FieldTypes.VECTOR, type="float", dimensions=3),
        ]
    )
    
    collection = vector_store.get_collection(
        collection_name="test_collection",
        record_type=TestDataModel,
        definition=definition
    )
    
    # Create test record
    test_record = TestDataModel(id="1", content="test", embedding=[0.1, 0.2, 0.3])
    
    bypasses_found = 0
    total_tests = 0
    
    def test_filter(name, filter_str, should_pass_in_vuln=True):
        global bypasses_found, total_tests
        total_tests += 1
        try:
            result = collection._parse_and_validate_filter(filter_str)
            result_value = result(test_record)
            if is_fixed:
                # In fixed version, this should have raised an exception
                log_write(f"[BYPASS] {name}: Filter executed when it should have been blocked!")
                log_write(f"         Filter: {filter_str}")
                log_write(f"         Result: {result_value}")
                bypasses_found += 1
            else:
                log_write(f"[VULN] {name}: Filter executed (expected in vuln version)")
        except Exception as e:
            if is_fixed:
                log_write(f"[BLOCKED] {name}: {str(e)[:80]}")
            else:
                log_write(f"[FAIL] {name}: {str(e)[:80]}")
    
    log_write("\n=== ORIGINAL VULNERABILITY TESTS ===")
    # Test 1: Original vulnerability - __class__
    test_filter("Original __class__", "lambda x: x.__class__.__name__ == 'TestDataModel'")
    test_filter("Original __dict__", "lambda x: x.__dict__ is not None")
    test_filter("Original __mro__", "lambda x: x.__class__.__mro__ is not None")
    
    log_write("\n=== VARIANT TESTS: Subscript Access ===")
    # Variant 1: Try to access via subscript (AttributeDict supports both)
    # If data model has a field named __class__, this might work
    test_filter("Subscript __class__", "lambda x: x['__class__'] == 'test'")
    test_filter("Subscript dict access", "lambda x: x.__dict__")
    
    log_write("\n=== VARIANT TESTS: Function Result Chaining ===")
    # Variant 2: Attribute access on function results
    test_filter("len().__class__", "lambda x: len(x.content).__class__ is not None")
    test_filter("str().__class__", "lambda x: str(x.content).__class__ is not None")
    test_filter("int().__class__", "lambda x: int(x.id).__class__ is not None")
    
    log_write("\n=== VARIANT TESTS: Literal Access ===")
    # Variant 3: Direct access on literals
    test_filter("List literal __class__", "lambda x: [].__class__ is not None")
    test_filter("Dict literal __class__", "lambda x: {}.__class__ is not None")
    test_filter("Tuple literal __class__", "lambda x: ().__class__ is not None")
    test_filter("String literal __class__", "lambda x: ''.__class__ is not None")
    test_filter("Int literal __class__", "lambda x: (1).__class__ is not None")
    
    log_write("\n=== VARIANT TESTS: Nested/Indirect Access ===")
    # Variant 4: Get attr from values
    test_filter("Content __class__", "lambda x: x.content.__class__ is not None")
    test_filter("Content __dict__", "lambda x: x.content.__dict__ is not None")
    
    log_write("\n=== VARIANT TESTS: Alternative Dunder Attributes ===")
    # Variant 5: Test attributes that might not be in blocklist
    test_filter("__weakref__", "lambda x: hasattr(x, '__weakref__')")
    test_filter("__doc__", "lambda x: x.__doc__ is not None")
    test_filter("__hash__", "lambda x: x.__hash__ is not None")
    test_filter("__eq__", "lambda x: x.__eq__ is not None")
    test_filter("__repr__", "lambda x: x.__repr__ is not None")
    test_filter("__str__", "lambda x: x.__str__ is not None")
    test_filter("__format__", "lambda x: x.__format__ is not None")
    test_filter("__sizeof__", "lambda x: x.__sizeof__ is not None")
    
    log_write("\n=== VARIANT TESTS: String Method Chains ===")
    # Variant 6: String methods that might expose internals
    test_filter("Upper then __class__", "lambda x: x.content.upper().__class__ is not None")
    test_filter("Strip then __class__", "lambda x: x.content.strip().__class__ is not None")
    
    log_write("\n=== VARIANT TESTS: Subscript on Function Results ===")
    # Variant 7: Subscript on function results
    test_filter("Get keys", "lambda x: x.keys()")
    test_filter("Get values", "lambda x: x.values()")
    test_filter("Get items", "lambda x: x.items()")
    
    log_write("\n=== VARIANT TESTS: Missing Blocklist Items ===")
    # Variant 8: Check attributes that might not be blocked
    test_filter("__wrapped__", "lambda x: hasattr(x, '__wrapped__')")
    test_filter("__code__ via func", "lambda x: hasattr(abs, '__code__')")
    test_filter("__closure__ via func", "lambda x: hasattr(abs, '__closure__')")
    
    log_write("\n=== VARIANT TESTS: Type-based Introspection ===")
    # Variant 9: Using type() (which is allowed) to get class info
    test_filter("type() access", "lambda x: type(x) is not None")
    test_filter("type().__name__", "lambda x: type(x).__name__ == 'AttributeDict'")
    # Try to chain from type() to dangerous attrs
    test_filter("type().__mro__", "lambda x: type(x).__mro__ is not None")
    test_filter("type().__subclasses__", "lambda x: hasattr(type(x), '__subclasses__')")
    
    log_write("\n=== VARIANT TESTS: Container Manipulation ===")
    # Variant 10: List/dict operations
    test_filter("List append", "lambda x: [].append(1) or True")
    test_filter("List extend", "lambda x: [].extend([1]) or True")
    test_filter("Dict get", "lambda x: {}.get('key', 'default') == 'default'")
    
    log_write("\n========================================")
    if is_fixed:
        log_write(f"Results for FIXED version ({version}):")
        log_write(f"  Total tests: {total_tests}")
        log_write(f"  Bypasses found: {bypasses_found}")
        if bypasses_found > 0:
            log_write(f"  STATUS: BYPASSES FOUND!")
            sys.exit(0)  # Success - bypass found
        else:
            log_write(f"  STATUS: All dangerous filters blocked")
            sys.exit(1)  # No bypass
    else:
        log_write(f"Results for VULNERABLE version ({version}):")
        log_write(f"  Tests that should pass: {total_tests}")
        log_write(f"  (Baseline for comparison)")
        sys.exit(0)
