关于 DeepSeek-R1-Distill-Qwen-1.5B 显存占用比下载权重高非常多的现象分析
现象
从 huggingface 上下载下来的 DeepSeek-R1-Distill-Qwen-1.5B 模型权重大小只有 3.5GB,但是加载到 GPU 上之后,占用显存达到 5.4GB,多了 1.9GB。
排查
显存占用比权重多 1.9GB 的原因应该是模型在创建的时候申请了权重以外的内存,这部分内存没有体现在权重的大小中。
使用 pytorch 查看显存占用
pytorch 博客 Understanding GPU Memory 1: Visualizing All Allocations over Time 中给出了 pytorch 内部的显存统计方法,函数都定义在 torch.cuda.memory
中。具体代码可以参考官方博客,下面展示显存统计结果。
可以看到从程序开始到结束,显存一直都是 5.4GB 左右,最上方有一小部分激活占用的显存,但是由于我们输入非常短,所以占用很小。从这个结果我们可以想象到,模型总显存占用比权重多的那部分显存在模型一开始就创建了,后续的应用层是无法操作的。
使用 transformers 查看显存占用
transformers 的 AutoModel
都提供了一个 get_memory_footprint
方法,可以查看模型的显存占用。其函数源码如下:
1 |
|
可以看到,模型的显存占用包括两个部分:parameter 和 buffer。使用这个方法查看 DeepSeek-R1-Distill-Qwen-1.5B 的显存占用,代码如下:
1 |
|
打印结果如下:
1 |
|
parameter size 和下载的模型权重大小相同,这表明 parameter size 展示的是模型的权重,而模型整体的 footprint 包括 parameter 和 buffer,所以显存占用比权重多 1.9GB。
buffer 和 parameter 的区别是?根据 PyTorch 讨论,二者的区别非常小,主要区别如下:
- parameter 是模型的可变/动态/待训练权重,而 buffer 是模型的不可变/静态/不可训练权重。
model.state_dict()
保存了模型当前的状态,所以这个函数会将 parameter 保存下来,而 buffer 则不会,因为 buffer 是不需要通过反向传播训练的,每次训练时它都是完全相同的,比如位置编码。 - 将一个张量声明为模型的 buffer,能够明确地表明这个张量是不需要反向传播的,可读性更高。
查看模型的 buffer 大小
pytorch 对 nn.Module 都提供了一个 named_buffers
方法,可以以迭代器返回模型中所有定义的 buffer,在这里我们使用它来查看模型中所有与权重无关的张量。代码如下:
1 |
|
打印结果如下:
1 |
|
所有 buffer 都是被 rotary_embedding 模块创建的,它们表达了模型中的位置编码信息。与 parameter 不同,它们不需要被反向传播更新,因此被定义为 buffer。其中,131072 是 qwen 支持的最大上下文长度。由于位置编码是静态的,所以这里提前计算好最大上下文的旋转位置编码,避免在推理时每次都计算。但是这样导致了在上下文非常长时,光是旋转位置编码的内存占用就达到了 131072 * 128 * 2 * 28 * 2 / 10**9 = 1.88GB,而且每一层的旋转位置编码还是完全相同的,浪费了许多显存。
如何解决?
手动设置最大上下文
在模型加载时手动设置上下文长度,可以减少 buffer 长度,从而减少显存占用。代码如下:
1 |
|
打印结果如下:
1 |
|
通过 max_position_embeddings
参数设置最大上下文长度,显存占用从 5.433GB 减小到 3.613GB。
transformers 在 4.45.1 版本不再缓存旋转位置编码,而是手动计算
在 transformers commit 65bb284 中,为了确保 torch.compile
能够正确编译模型,对所有 Decoder 模型进行了一次重构。这导致 rotary_embedding 的 cos 和 sin 值不再被缓存,而是每次都手动计算。修改前后的代码如下:
transformers 4.44 与 4.45.1 的 QwenRotaryEmbedding 实现对比
transformers - 4.44 | transformers - 4.45 |
---|---|
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2 class Qwen2RotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) # Build here to make `torch.jit.trace` work. self._set_cos_sin_cache( seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() ) def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] if seq_len > self.max_seq_len_cached: self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) return ( self.cos_cached[:seq_len].to(dtype=x.dtype), self.sin_cached[:seq_len].to(dtype=x.dtype), ) |
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Qwen2 class Qwen2RotaryEmbedding(nn.Module): def __init__( self, dim=None, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0, rope_type="default", config: Optional[Qwen2Config] = None, ): super().__init__() # TODO (joao): remove the `if` below, only used for BC self.rope_kwargs = {} if config is None: logger.warning_once( "`Qwen2RotaryEmbedding` can now be fully parameterized by passing the model config through the " "`config` argument. All other arguments will be removed in v4.46" ) self.rope_kwargs = { "rope_type": rope_type, "factor": scaling_factor, "dim": dim, "base": base, "max_position_embeddings": max_position_embeddings, } self.rope_type = rope_type self.max_seq_len_cached = max_position_embeddings self.original_max_seq_len = max_position_embeddings else: # BC: "rope_type" was originally "type" if config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq def _dynamic_frequency_update(self, position_ids, device): """ dynamic RoPE layers should recompute `inv_freq` in the following situations: 1 - growing beyond the cached sequence length (allow scaling) 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) """ seq_len = torch.max(position_ids) + 1 if seq_len > self.max_seq_len_cached: # growth inv_freq, self.attention_scaling = self.rope_init_fn( self.config, device, seq_len=seq_len, **self.rope_kwargs ) self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.max_seq_len_cached = seq_len if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) self.max_seq_len_cached = self.original_max_seq_len @torch.no_grad() def forward(self, x, position_ids): if "dynamic" in self.rope_type: self._dynamic_frequency_update(position_ids, device=x.device) # Core RoPE block inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention cos = cos * self.attention_scaling sin = sin * self.attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) |
在 transformers-4.45.1 之后,模型不再缓存 sin_cached
和 cos_cached
,而是每次都重新计算。