ZeRO-3 and device_map is not compatible�
Why you can OOM at load time despite having 8Ă40GB
You are mixing two different distribution mechanisms:
-
device_map="auto" / max_memory / offload_folder
This triggers Accelerate Big Model Inference style inference-time dispatch: it âfills GPU(s) first, then CPU, then diskâ. (Hugging Face)
This is not DeepSpeed ZeRO sharding.
-
DeepSpeed ZeRO-3 (stage-3 sharding)
ZeRO-3 shards parameters/optimizer states across ranks, but it only works if the model is constructed/loaded under the ZeRO-3 initialization path (e.g., deepspeed.zero.Init or HfDeepSpeedConfig + from_pretrained), not via device_map.
In an accelerate launch --num_processes 8 run, each of the 8 processes executes your top-level Python code. With device_map="auto", each process will try to use all visible GPUs to dispatch the model, which can lead to âmultiple copies worthâ of allocations across the node (or heavy temporary allocations during dequantization), and you OOM before ZeRO-3 ever has a chance to shard things.
This is consistent with multiple upstream warnings/issues:
- ZeRO-3 is incompatible with
device_map and low_cpu_mem_usage in the Transformers loading path. (GitHub)
- You canât train a model loaded with
device_map='auto' in distributed mode (Accelerate/Transformers explicitly error on this in many setups). (GitHub)
Even if your run doesnât hit those exact ValueErrors (because you OOM first), the underlying incompatibility remains.
Why {"": local_rank} still OOMs on a single A100 40GB
Once you set Mxfp4Config(dequantize=True), you are effectively asking to materialize BF16/FP16 weights. A 20B-parameter model at BF16 is ~40GB just for parameters (20e9 Ă 2 bytes â 40GB), before accounting for:
- embeddings/head tied weights handling
- layernorm/buffers
- temporary tensors during weight loading/dequantization
- fragmentation / allocator reserves
There is a very similar report from an A100 40GB user: they get an OOM while loading because the model already consumes ~37GB and then fails on an extra ~2GB allocation. (Hugging Face)
So: mapping the whole dequantized model onto one 40GB GPU is expected to be right on the edge (and often fails).
The core fix: donât use device_map for ZeRO-3 training
What to remove from your from_pretrained call
For DeepSpeed ZeRO-3 training, remove:
device_map="auto"
max_memory=...
offload_folder=... (this is for Big Model Inference CPU/disk offload, not ZeRO offload)
Also set:
use_cache=False (cache is for generation; for training itâs wasted memory and often disabled in examples)
Correct loading patterns for ZeRO-3 sharded training
Option A (recommended): let Trainer/TRL + DeepSpeed handle initialization
If youâre using TRL/Trainer, pass a DeepSpeed config into the training arguments and load the model without device_map. The OpenAI cookbookâs fine-tuning article is single-H100 oriented (80GB) (OpenAI Cookbook), but the principle is the same: you need ZeRO-3 to own placement, not device_map.
Key idea: the distributed engine must be active during/around model init (or youâll load full weights per process).
Option B (robust for ânon-Trainerâ setups): HfDeepSpeedConfig before from_pretrained
Transformers documents a ânon-Trainer integrationâ where HfDeepSpeedConfig enables ZeRO-3 partitioning behavior during from_pretrained(). Critically, it must be instantiated before loading the model. (Hugging Face)
Minimal sketch (conceptual; adapt to your actual training loop):
import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, Mxfp4Config
from transformers.integrations import HfDeepSpeedConfig
model_id = "openai/gpt-oss-20b"
# Load your DS ZeRO-3 config (json/dict) matching stage-3 + offload settings
ds_config = json.load(open("ds_zero3.json"))
# Must be created BEFORE from_pretrained, and kept alive
dschf = HfDeepSpeedConfig(ds_config)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
quantization_config=Mxfp4Config(dequantize=True),
use_cache=False,
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
This avoids the device_map path entirely and uses the ZeRO-3-aware initialization hook described in the docs. (Hugging Face)
Option C (manual init): deepspeed.zero.Init(...)
Accelerate also shows that if automatic integration isnât in play, you can explicitly use deepspeed.zero.Init to ensure the model is initialized under ZeRO-3 rules. (Hugging Face)
Notes specific to MXFP4 and A100
- Transformers will try to use MXFP4 Triton kernels only if available and supported; otherwise it falls back. (Hugging Face)
- The gpt-oss model discussions include reports where A100 ends up dequantizing/falling back, and load-time memory becomes the limiter. (Hugging Face)
Also, there was a recent Transformers bug report about device_map="auto" failing to load dequantized gpt-oss on GPU+CPU offload (closed, but relevant if you keep experimenting with device_map). (GitHub)
Given youâre training with ZeRO-3 anyway, the clean solution is to stop using device_map in the training job.
What I think is happening in your exact script
- You launch 8 processes.
- Each process runs
from_pretrained(...).
- Because you set
device_map="auto" (+ max_memory), youâre in the Big Model Inference dispatch path (GPUâCPUâdisk). (Hugging Face)
- You also request dequantization to BF16, which creates large allocations and temporary buffers.
- Before ZeRO-3 sharding is applied, one or more processes allocate enough on one GPU to push it over 40GB â
torch.OutOfMemoryError.
This matches the A100-40GB OOM pattern reported by others when the model becomes effectively BF16-sized on a single device. (Hugging Face)
Similar cases + high-signal references
Device-map vs distributed training incompatibilities
- Transformers issue: canât train with
device_map='auto' in distributed mode. (GitHub)
- Accelerate issue: ZeRO-3 incompatible with
device_map / low_cpu_mem_usage. (GitHub)
- PEFT issue discussion explaining why: device_map/low_cpu_mem_usage implies a naive model-parallel style, while ZeRO-3 is sharded DP. (GitHub)
gpt-oss + A100 memory behavior
Official docs youâll actually use for the fix
- Accelerate docs: Big Model Inference (
device_map="auto") is an inference feature and how it dispatches memory. (Hugging Face)
- Transformers docs:
HfDeepSpeedConfig must be instantiated before loading to deploy ZeRO-3 efficiently. (Hugging Face)
- PEFT/Accelerate DeepSpeed guide: explains
zero3_init_flag / deepspeed.zero.Init. (Hugging Face)
- Transformers quantization docs: MXFP4 kernels behavior. (Hugging Face)
Minimal actionable change for your code
Replace your model kwargs with something like:
from transformers import AutoModelForCausalLM, AutoTokenizer, Mxfp4Config
import torch
model_id = "openai/gpt-oss-20b"
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
quantization_config=Mxfp4Config(dequantize=True),
use_cache=False,
# no device_map, no max_memory, no offload_folder
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
âŚand ensure ZeRO-3 init is truly active during load (Trainer/TRL DS integration, or HfDeepSpeedConfig, or explicit deepspeed.zero.Init). (Hugging Face)
If you apply only one principle: for ZeRO-3 training, do not use device_map. (GitHub)