背景:Transformer 里面多模态的输入 embedding

让我们先看一段代码:

class Qwen3VLVisionPatchEmbed(nn.Module):
    def __init__(self, config) -> None:
        super().__init__()
        self.patch_size = config.patch_size
        self.temporal_patch_size = config.temporal_patch_size
        self.in_channels = config.in_channels
        self.embed_dim = config.hidden_size

        kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
        self.proj = nn.Conv3d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        target_dtype = self.proj.weight.dtype
        hidden_states = hidden_states.view(
            -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
        )
        hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
        return hidden_states

这个是一般 VL 代码里面,Vision encoder(例如 clip/siglip 等) 部分输入图像(一般是 pixcel)的处理。经过这个处理后的模型,会被当做文本里面的 tokens 的 embedding,后面就接着是 transformer 里面的 attention 啊 mlp 啥的。看起来很稀松平常。

我们先来看看这个代码里面的实际大小:假设我有10 张图片,每张图片是 512512,in-channel 是标准的 3,patch size 是 1616,时空 patch size=2,那么原始输入是 10512512/2/16/16=5120(相当于 batch size),hidden size=2316*16=1536

速度瓶颈分析

好的,实际在运行过程中,我们发现这个部分极其极其慢,一个简单的映射层的时间竟然会超过 100s!下面我们就对这个问题进行深入分析。先说简单结论:

TL,DR:

<aside> 💡

  1. 在现在大部分的 VLM 中,卷积完全没有起到应有的作用,layout完全等价于一个线性的 linear 操作,因此转化成 linear 肯定是效率最高的。
  2. bias 的引入造成了在我们这个任务中额外的计算复杂度,特别是在conv3d 的 tensor 尺寸是5D 张量的索引,这种不连续的 layout 会更加的让速度变慢,也是索引开销带来的。退化的算子的复杂度会极其恐怖。
  3. fp32 的实验告诉我们,fp16 、bf16下这个计算慢的最重要的原因就是因为内存中的 layout 排布和转置排布等,造成了巨量的计算浪费。bias 在这个里面的影响变小了,说明还是因为索引排布的问题是占大头。
  4. 将 conv 的操作中的 memory format 转化成 channel last 的格式,是最简单直接的。在大部分场景中应该都是最佳选择。
  5. 最终无论如何,conv3d 都是要慢于 linear,说明这种边界判断等标准的 conv操作在这个里面还是造成了了显著的耗时问题。 </aside>

第一层分析:常见的慢的操作

hidden_states = hidden_states.**view**(
            -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
        )

hidden_states.**to**(dtype=target_dtype)

这个 to,view 操作就老生常谈了,肯定是会出现让整体变慢的情况,我们都是知道的。

具体而言:(gemini 的分析)