让我们先看一段代码:
class Qwen3VLVisionPatchEmbed(nn.Module):
def __init__(self, config) -> None:
super().__init__()
self.patch_size = config.patch_size
self.temporal_patch_size = config.temporal_patch_size
self.in_channels = config.in_channels
self.embed_dim = config.hidden_size
kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
self.proj = nn.Conv3d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
target_dtype = self.proj.weight.dtype
hidden_states = hidden_states.view(
-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
)
hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
return hidden_states
这个是一般 VL 代码里面,Vision encoder(例如 clip/siglip 等) 部分输入图像(一般是 pixcel)的处理。经过这个处理后的模型,会被当做文本里面的 tokens 的 embedding,后面就接着是 transformer 里面的 attention 啊 mlp 啥的。看起来很稀松平常。
我们先来看看这个代码里面的实际大小:假设我有10 张图片,每张图片是 512512,in-channel 是标准的 3,patch size 是 1616,时空 patch size=2,那么原始输入是 10512512/2/16/16=5120(相当于 batch size),hidden size=2316*16=1536
好的,实际在运行过程中,我们发现这个部分极其极其慢,一个简单的映射层的时间竟然会超过 100s!下面我们就对这个问题进行深入分析。先说简单结论:
TL,DR:
<aside> 💡
hidden_states = hidden_states.**view**(
-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
)
hidden_states.**to**(dtype=target_dtype)
这个 to,view 操作就老生常谈了,肯定是会出现让整体变慢的情况,我们都是知道的。
具体而言:(gemini 的分析)
view 操作:如果张量在内存中是连续的,view 操作本身是高效的,几乎不会消耗时间。问题可能出现在张量是非连续时,导致额外的内存拷贝,但这种情况通常不会频繁发生。to 操作:如果仅仅是 dtype 相同,to 操作几乎没有开销。如果涉及到类型转换,特别是从浮点数类型到低精度类型(比如 float32 转到 float16),会有额外的计算开销,尤其是如果数据量很大时。