Model extraction, or model stealing, involves creating a functional replica of a target model using only query access. As a red teamer, you treat the target model as a black box—an API that returns predictions—and use its responses to train a local “surrogate” model. This surrogate can then be used for offline analysis, discovering vulnerabilities, or crafting more effective adversarial attacks without repeatedly querying the target and risking detection.
The Core Loop: Query, Label, Train
The fundamental principle behind model extraction is straightforward. You systematically query the target model with a set of inputs and record its outputs. This creates a new labeled dataset where the “labels” are the predictions from the target model. You then use this dataset to train your own local model to mimic the target’s behavior.
Figure 1: The model extraction workflow. The attacker queries the target, collects its predictions, and uses this data to train a local surrogate model.
Classifier Extraction: A Practical Example
Extracting classifiers is the most common use case. The effectiveness of the attack depends heavily on the query data and the architecture of your surrogate model. You don’t need the exact same architecture as the target; you just need a model complex enough to approximate its decision boundary.
Step 1: Define the Target API and Surrogate Model
First, let’s simulate a black-box API and define a simple surrogate model. We’ll use scikit-learn for this demonstration. The target model is unknown to us; we can only interact with its predict function.
import numpy as np
from sklearn.tree import DecisionTreeClassifier
from sklearn.linear_model import LogisticRegression
# --- Target Model (This part is the "black box") ---
# In a real scenario, you wouldn't have this code.
# It's a complex model we want to steal.
_target_model = DecisionTreeClassifier(max_depth=10)
_target_model.fit(np.random.rand(100, 10), np.random.randint(0, 2, 100))
def query_target_api(data):
# This function simulates making an API call.
return _target_model.predict(data)
# --- Attacker's Surrogate Model ---
# We choose a simple model to train on the stolen data.
surrogate_model = LogisticRegression()
Step 2: The Extraction Loop
Now, execute the core loop. We generate random data, query the API to get labels, and build our training set. After collecting enough data, we train our local surrogate model.
# 1. Generate query data
# The quality of this data is critical for a good extraction.
num_queries = 5000
query_data_dimension = 10
attacker_queries = np.random.rand(num_queries, query_data_dimension)
# 2. Query the API and collect "stolen" labels
stolen_labels = query_target_api(attacker_queries)
# 3. Train the surrogate model on the collected data
surrogate_model.fit(attacker_queries, stolen_labels)
print("Surrogate model trained successfully.")
# Now, test the surrogate's accuracy against the target
test_data = np.random.rand(100, query_data_dimension)
target_preds = query_target_api(test_data)
surrogate_preds = surrogate_model.predict(test_data)
accuracy = np.mean(target_preds == surrogate_preds)
print(f"Surrogate model fidelity (accuracy): {accuracy:.2f}")
This simple example demonstrates the principle. For more complex models like deep neural networks, the query data strategy becomes more important. You might use data from a public dataset in the same domain (e.g., ImageNet samples) or even generate adversarial examples to specifically probe the decision boundary.
LLM Capability Extraction
Extracting a full large language model (LLM) like GPT-4 is computationally infeasible. However, you can extract specific *capabilities*, such as its writing style, its ability to summarize legal documents, or its proficiency in generating code. This is often framed as a knowledge distillation problem.
The process involves:
- Define the Capability: Isolate the specific task you want to replicate (e.g., “translate English questions into SQL queries”).
- Generate Prompts: Create a dataset of prompts that elicit this capability from the target LLM.
- Query and Collect Responses: Send these prompts to the target LLM API and save the responses. This creates a `(prompt, response)` dataset.
- Fine-tune a Base Model: Use this dataset to fine-tune a smaller, open-source model (e.g., Mistral 7B, Llama 3 8B) to mimic the target’s responses for that specific task.
Conceptual Code for LLM Fine-tuning
The following is a conceptual example using the Hugging Face transformers library. It outlines the steps for preparing data and initiating a fine-tuning process.
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
from datasets import Dataset
# Assume `instruction_dataset` is a list of dictionaries like:
# [{'prompt': 'Question: ...', 'response': 'SQL: ...'}, ...]
# This data was generated by querying the target LLM.
instruction_dataset = load_stolen_data("path/to/target_llm_responses.json")
# 1. Load a base model and tokenizer to fine-tune
base_model_name = "mistralai/Mistral-7B-v0.1"
model = AutoModelForCausalLM.from_pretrained(base_model_name)
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
# 2. Format the dataset for training
hf_dataset = Dataset.from_list(instruction_dataset)
def format_data(examples):
texts = [f"### Prompt:n{p}nn### Response:n{r}" for p, r in zip(examples['prompt'], examples['response'])]
return tokenizer(texts, truncation=True, padding="max_length")
tokenized_dataset = hf_dataset.map(format_data, batched=True)
# 3. Define training arguments and run the trainer
training_args = TrainingArguments(
output_dir="./surrogate_llm",
num_train_epochs=1,
per_device_train_batch_size=4,
)
trainer = Trainer(model=model, args=training_args, train_dataset=tokenized_dataset)
trainer.train()
print("LLM capability extraction (fine-tuning) complete.")
This process effectively transfers a specialized skill from a large, closed-source model to a smaller, locally-hosted one, which can be a significant security and intellectual property risk for the owner of the target model.