关于 DeepSeek-R1-Distill-Qwen-1.5B 显存占用比下载权重高非常多的现象分析

现象

从 huggingface 上下载下来的 DeepSeek-R1-Distill-Qwen-1.5B 模型权重大小只有 3.5GB,但是加载到 GPU 上之后,占用显存达到 5.4GB,多了 1.9GB。

排查

显存占用比权重多 1.9GB 的原因应该是模型在创建的时候申请了权重以外的内存,这部分内存没有体现在权重的大小中。

使用 pytorch 查看显存占用

pytorch 博客 Understanding GPU Memory 1: Visualizing All Allocations over Time 中给出了 pytorch 内部的显存统计方法,函数都定义在 torch.cuda.memory 中。具体代码可以参考官方博客,下面展示显存统计结果。

pytorch-memory-viz 结果

可以看到从程序开始到结束,显存一直都是 5.4GB 左右,最上方有一小部分激活占用的显存,但是由于我们输入非常短,所以占用很小。从这个结果我们可以想象到,模型总显存占用比权重多的那部分显存在模型一开始就创建了,后续的应用层是无法操作的。

使用 transformers 查看显存占用

transformers 的 AutoModel 都提供了一个 get_memory_footprint 方法,可以查看模型的显存占用。其函数源码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def get_memory_footprint(self, return_buffers=True):
r"""
Get the memory footprint of a model. This will return the memory footprint of the current model in bytes.
Useful to benchmark the memory footprint of the current model and design some tests. Solution inspired from the
PyTorch discussions: https://discuss.pytorch.org/t/gpu-memory-that-model-uses/56822/2

Arguments:
return_buffers (`bool`, *optional*, defaults to `True`):
Whether to return the size of the buffer tensors in the computation of the memory footprint. Buffers
are tensors that do not require gradients and not registered as parameters. E.g. mean and std in batch
norm layers. Please see: https://discuss.pytorch.org/t/what-pytorch-means-by-buffers/120266/2
"""
mem = sum([param.nelement() * param.element_size() for param in self.parameters()])
if return_buffers:
mem_bufs = sum([buf.nelement() * buf.element_size() for buf in self.buffers()])
mem = mem + mem_bufs
return mem

可以看到,模型的显存占用包括两个部分:parameter 和 buffer。使用这个方法查看 DeepSeek-R1-Distill-Qwen-1.5B 的显存占用,代码如下:

1
2
3
4
5
6
7
model_id = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda", torch_dtype=torch.bfloat16)

footprint = model.get_memory_footprint() / 10**9
param_size = model.get_memory_footprint(False) / 10**9
print(f"footprint: {footprint:.3f} GB, param_size: {param_size:.3f} GB")

打印结果如下:

1
footprint: 5.433 GB, param_size: 3.554 GB

parameter size 和下载的模型权重大小相同,这表明 parameter size 展示的是模型的权重,而模型整体的 footprint 包括 parameter 和 buffer,所以显存占用比权重多 1.9GB。

buffer 和 parameter 的区别是?根据 PyTorch 讨论,二者的区别非常小,主要区别如下:

  1. parameter 是模型的可变/动态/待训练权重,而 buffer 是模型的不可变/静态/不可训练权重。model.state_dict() 保存了模型当前的状态,所以这个函数会将 parameter 保存下来,而 buffer 则不会,因为 buffer 是不需要通过反向传播训练的,每次训练时它都是完全相同的,比如位置编码。
  2. 将一个张量声明为模型的 buffer,能够明确地表明这个张量是不需要反向传播的,可读性更高。

查看模型的 buffer 大小

pytorch 对 nn.Module 都提供了一个 named_buffers 方法,可以以迭代器返回模型中所有定义的 buffer,在这里我们使用它来查看模型中所有与权重无关的张量。代码如下:

1
2
3
4
named_buffers = list(model.named_buffers())
name_to_shape = [(name, tensor.shape) for name, tensor in named_buffers]
for name, shape in name_to_shape[:6]:
print(f"{name}: {shape}")

打印结果如下:

1
2
3
4
5
6
model.layers.0.self_attn.rotary_emb.inv_freq: torch.Size([64])
model.layers.0.self_attn.rotary_emb.cos_cached: torch.Size([131072, 128])
model.layers.0.self_attn.rotary_emb.sin_cached: torch.Size([131072, 128])
model.layers.1.self_attn.rotary_emb.inv_freq: torch.Size([64])
model.layers.1.self_attn.rotary_emb.cos_cached: torch.Size([131072, 128])
model.layers.1.self_attn.rotary_emb.sin_cached: torch.Size([131072, 128])

所有 buffer 都是被 rotary_embedding 模块创建的,它们表达了模型中的位置编码信息。与 parameter 不同,它们不需要被反向传播更新,因此被定义为 buffer。其中,131072 是 qwen 支持的最大上下文长度。由于位置编码是静态的,所以这里提前计算好最大上下文的旋转位置编码,避免在推理时每次都计算。但是这样导致了在上下文非常长时,光是旋转位置编码的内存占用就达到了 131072 * 128 * 2 * 28 * 2 / 10**9 = 1.88GB,而且每一层的旋转位置编码还是完全相同的,浪费了许多显存。

如何解决?

手动设置最大上下文

在模型加载时手动设置上下文长度,可以减少 buffer 长度,从而减少显存占用。代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
model_id = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="cuda",
torch_dtype=torch.bfloat16,
max_position_embeddings=4096)

footprint = model.get_memory_footprint() / 10**9
param_size = model.get_memory_footprint(False) / 10**9
print(f"footprint: {footprint:.3f} GB, param_size: {param_size:.3f} GB")

named_buffers = list(model.named_buffers())
name_to_shape = [(name, tensor.shape) for name, tensor in named_buffers]
for name, shape in name_to_shape[:6]:
print(f"{name}: {shape}")

打印结果如下:

1
2
3
4
5
6
7
footprint: 3.613 GB, param_size: 3.554 GB
model.layers.0.self_attn.rotary_emb.inv_freq: torch.Size([64])
model.layers.0.self_attn.rotary_emb.cos_cached: torch.Size([4096, 128])
model.layers.0.self_attn.rotary_emb.sin_cached: torch.Size([4096, 128])
model.layers.1.self_attn.rotary_emb.inv_freq: torch.Size([64])
model.layers.1.self_attn.rotary_emb.cos_cached: torch.Size([4096, 128])
model.layers.1.self_attn.rotary_emb.sin_cached: torch.Size([4096, 128])

通过 max_position_embeddings 参数设置最大上下文长度,显存占用从 5.433GB 减小到 3.613GB。

transformers 在 4.45.1 版本不再缓存旋转位置编码,而是手动计算

在 transformers commit 65bb284 中,为了确保 torch.compile 能够正确编译模型,对所有 Decoder 模型进行了一次重构。这导致 rotary_embedding 的 cos 和 sin 值不再被缓存,而是每次都手动计算。修改前后的代码如下:

transformers 4.44 与 4.45.1 的 QwenRotaryEmbedding 实现对比
transformers - 4.44 transformers - 4.45
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2
class Qwen2RotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()

self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)

# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
)

def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)

freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)

def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

return (
self.cos_cached[:seq_len].to(dtype=x.dtype),
self.sin_cached[:seq_len].to(dtype=x.dtype),
)
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Qwen2
class Qwen2RotaryEmbedding(nn.Module):
def __init__(
self,
dim=None,
max_position_embeddings=2048,
base=10000,
device=None,
scaling_factor=1.0,
rope_type="default",
config: Optional[Qwen2Config] = None,
):
super().__init__()
# TODO (joao): remove the `if` below, only used for BC
self.rope_kwargs = {}
if config is None:
logger.warning_once(
"`Qwen2RotaryEmbedding` can now be fully parameterized by passing the model config through the "
"`config` argument. All other arguments will be removed in v4.46"
)
self.rope_kwargs = {
"rope_type": rope_type,
"factor": scaling_factor,
"dim": dim,
"base": base,
"max_position_embeddings": max_position_embeddings,
}
self.rope_type = rope_type
self.max_seq_len_cached = max_position_embeddings
self.original_max_seq_len = max_position_embeddings
else:
# BC: "rope_type" was originally "type"
if config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings

self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]

inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq

def _dynamic_frequency_update(self, position_ids, device):
"""
dynamic RoPE layers should recompute `inv_freq` in the following situations:
1 - growing beyond the cached sequence length (allow scaling)
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
"""
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn(
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len

if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len

@torch.no_grad()
def forward(self, x, position_ids):
if "dynamic" in self.rope_type:
self._dynamic_frequency_update(position_ids, device=x.device)

# Core RoPE block
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()

# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
cos = cos * self.attention_scaling
sin = sin * self.attention_scaling

return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)

在 transformers-4.45.1 之后,模型不再缓存 sin_cachedcos_cached,而是每次都重新计算。


关于 DeepSeek-R1-Distill-Qwen-1.5B 显存占用比下载权重高非常多的现象分析
http://hebangwen.github.io/2025/02/12/transformers-high-memory-usage/
作者
何榜文
发布于
2025年2月12日
许可协议