26.1.4 Model extraction techniques

2025.10.06.
AI Security Blog

Model extraction, or model stealing, involves creating a functional replica of a target model using only query access. As a red teamer, you treat the target model as a black box—an API that returns predictions—and use its responses to train a local “surrogate” model. This surrogate can then be used for offline analysis, discovering vulnerabilities, or crafting more effective adversarial attacks without repeatedly querying the target and risking detection.

The Core Loop: Query, Label, Train

The fundamental principle behind model extraction is straightforward. You systematically query the target model with a set of inputs and record its outputs. This creates a new labeled dataset where the “labels” are the predictions from the target model. You then use this dataset to train your own local model to mimic the target’s behavior.

Kapcsolati űrlap - EN

Do you have a question about AI Security? Reach out to us here:

Attacker (You) 1. Send Queries (e.g., images, text) Target Model (Black Box API) 2. Receive Predictions (e.g., class labels, logits) Surrogate Model 3. Train surrogate on (Query, Prediction) pairs

Figure 1: The model extraction workflow. The attacker queries the target, collects its predictions, and uses this data to train a local surrogate model.

Classifier Extraction: A Practical Example

Extracting classifiers is the most common use case. The effectiveness of the attack depends heavily on the query data and the architecture of your surrogate model. You don’t need the exact same architecture as the target; you just need a model complex enough to approximate its decision boundary.

Step 1: Define the Target API and Surrogate Model

First, let’s simulate a black-box API and define a simple surrogate model. We’ll use scikit-learn for this demonstration. The target model is unknown to us; we can only interact with its predict function.


import numpy as np
from sklearn.tree import DecisionTreeClassifier
from sklearn.linear_model import LogisticRegression

# --- Target Model (This part is the "black box") ---
# In a real scenario, you wouldn't have this code.
# It's a complex model we want to steal.
_target_model = DecisionTreeClassifier(max_depth=10)
_target_model.fit(np.random.rand(100, 10), np.random.randint(0, 2, 100))

def query_target_api(data):
    # This function simulates making an API call.
    return _target_model.predict(data)

# --- Attacker's Surrogate Model ---
# We choose a simple model to train on the stolen data.
surrogate_model = LogisticRegression()

Step 2: The Extraction Loop

Now, execute the core loop. We generate random data, query the API to get labels, and build our training set. After collecting enough data, we train our local surrogate model.


# 1. Generate query data
# The quality of this data is critical for a good extraction.
num_queries = 5000
query_data_dimension = 10
attacker_queries = np.random.rand(num_queries, query_data_dimension)

# 2. Query the API and collect "stolen" labels
stolen_labels = query_target_api(attacker_queries)

# 3. Train the surrogate model on the collected data
surrogate_model.fit(attacker_queries, stolen_labels)

print("Surrogate model trained successfully.")

# Now, test the surrogate's accuracy against the target
test_data = np.random.rand(100, query_data_dimension)
target_preds = query_target_api(test_data)
surrogate_preds = surrogate_model.predict(test_data)
accuracy = np.mean(target_preds == surrogate_preds)
print(f"Surrogate model fidelity (accuracy): {accuracy:.2f}")

This simple example demonstrates the principle. For more complex models like deep neural networks, the query data strategy becomes more important. You might use data from a public dataset in the same domain (e.g., ImageNet samples) or even generate adversarial examples to specifically probe the decision boundary.

LLM Capability Extraction

Extracting a full large language model (LLM) like GPT-4 is computationally infeasible. However, you can extract specific *capabilities*, such as its writing style, its ability to summarize legal documents, or its proficiency in generating code. This is often framed as a knowledge distillation problem.

The process involves:

  1. Define the Capability: Isolate the specific task you want to replicate (e.g., “translate English questions into SQL queries”).
  2. Generate Prompts: Create a dataset of prompts that elicit this capability from the target LLM.
  3. Query and Collect Responses: Send these prompts to the target LLM API and save the responses. This creates a `(prompt, response)` dataset.
  4. Fine-tune a Base Model: Use this dataset to fine-tune a smaller, open-source model (e.g., Mistral 7B, Llama 3 8B) to mimic the target’s responses for that specific task.

Conceptual Code for LLM Fine-tuning

The following is a conceptual example using the Hugging Face transformers library. It outlines the steps for preparing data and initiating a fine-tuning process.


from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
from datasets import Dataset

# Assume `instruction_dataset` is a list of dictionaries like:
# [{'prompt': 'Question: ...', 'response': 'SQL: ...'}, ...]
# This data was generated by querying the target LLM.
instruction_dataset = load_stolen_data("path/to/target_llm_responses.json")

# 1. Load a base model and tokenizer to fine-tune
base_model_name = "mistralai/Mistral-7B-v0.1"
model = AutoModelForCausalLM.from_pretrained(base_model_name)
tokenizer = AutoTokenizer.from_pretrained(base_model_name)

# 2. Format the dataset for training
hf_dataset = Dataset.from_list(instruction_dataset)
def format_data(examples):
    texts = [f"### Prompt:n{p}nn### Response:n{r}" for p, r in zip(examples['prompt'], examples['response'])]
    return tokenizer(texts, truncation=True, padding="max_length")
tokenized_dataset = hf_dataset.map(format_data, batched=True)

# 3. Define training arguments and run the trainer
training_args = TrainingArguments(
    output_dir="./surrogate_llm",
    num_train_epochs=1,
    per_device_train_batch_size=4,
)
trainer = Trainer(model=model, args=training_args, train_dataset=tokenized_dataset)
trainer.train()
print("LLM capability extraction (fine-tuning) complete.")

This process effectively transfers a specialized skill from a large, closed-source model to a smaller, locally-hosted one, which can be a significant security and intellectual property risk for the owner of the target model.