pytorch学习笔记_3_pytorch中的一些概念
PyTorch中的广播机制(Broadcasting)
1. 广播机制原理
广播(broadcasting)是指当进行张量(Tensor)运算时,如果两个张量的形状(shape)不同,PyTorch 会自动扩展它们,使它们具有相同的形状,从而能够进行逐元素运算(element-wise operation)。这一机制避免了显式复制数据,提高了计算效率,并节省了内存。
2. 广播的条件
- 维度必须要么相等,要么其中一个维度为 1,要么该维度不存在(会被自动补 1)。
- 不能广播的情况:如果在某个维度上,两个张量的形状既不相等也不包含 1,则无法广播。
3. 广播的规则
PyTorch 在进行广播时,会遵循以下规则来匹配张量的形状:
-
从后向前(右对齐)比较两个张量的形状:
- 若两个张量的维度数不同,则在较小的张量前面补 1,使得两个张量具有相同的维度数。
- 从后向前比较两个张量的维度时 规则参照下一条2.
-
维度匹配规则:
- 如果两个维度相等,则可以直接进行运算。
- 如果某个维度的值为 1,则 PyTorch 会在计算时自动扩展该维度,使其匹配另一个张量的维度大小。
- 如果两个维度不相等且都不为 1,则会报错,表示不能广播。
4. 示例
示例 1:可以广播的情况
1 | import torch |
解释:
A
的形状(3,1)
→ 扩展为(3,4)
B
的形状(1,4)
→ 扩展为(3,4)
- 结果
C
的形状为(3,4)
A
的形状(4)
→ 扩展为(2,4)
B
的形状(2,4)
→ 无须扩展- 结果
C
的形状为(2,4)
示例 2:不能广播的情况
1 | A = torch.rand(2,3) |
解释:
A
的形状是(2,3)
B
的形状是(3,2)
- 在第二个维度上,
A
为3
,B
为2
,不满足广播规则,运算失败。
5. 广播的应用
广播机制在深度学习和科学计算中非常重要,主要用于:
- 不同形状张量的运算:在不显式复制数据的情况下进行高效计算。
- 批量计算:对不同尺寸的数据进行同一操作,常用于计算损失、归一化、特征缩放等。
- 神经网络中的参数运算:如批量归一化(Batch Normalization)中,每个通道的均值和方差都是标量,但会广播到整个批次的数据。
实际应用示例
示例 1:归一化
1 | X = torch.rand(5, 3) # 5个样本,每个样本有3个特征 |
示例 2:计算加权和
1 | X = torch.rand(4, 3) # 4个样本,每个样本有3个特征 |
6. 总结
- 广播机制使不同形状的张量可以进行运算,避免显式复制数据,提高计算效率。
- 广播规则:从右向左匹配维度,维度相等或为 1 时可广播,否则会报错。
- 应用场景:用于特征归一化、加权计算、神经网络参数运算等。
广播机制在 PyTorch 中极大提升了代码的简洁性和运行效率,是高效张量计算的核心功能之一。
本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 JLChenBlog!
评论