pytorch 基础笔记

squeeze 函数

官方文档 去除size为1的维度,包括行和列。至于维度大于等于2时,squeeze()不起作用。 ## cat 函数

scatter_ 函数

Pytorch官网对于scatter_方法的定义

scatter_(dim, index, src) → Tensor

参数的意义如下:

  • dim(int): dim=0代表按第0维填充,dim=1代表按第1维填充,dim=2代表按照第2维填充。
  • index(LongTensor): 代表需要填充的位置信息。
  • src(Tensor)/value(Float): 代表需要的填充的内容。

Note: src和value(Float)可以互换,如果存在src就不能存在value,反之亦然。

这里做一个假设:对于一个三维的矩阵,他使用scatter_的方法得到的结果M应该是如下三种情况:

  1. dim=0, M[index[a][b][c]] [b] [c] = src[a][b][c]
  2. dim=1, M[a] [index[a][b][c]] [c] = src[a][b][c]
  3. dim=2, M[a] [b] [index[a][b][c]] = src[a][b][c]

因此,通过对应的关系可以看出index中数值大小其实也是有一定的要求的。

  1. dim=0, index[a][b][c] < src.shape[0], index.shape[1] == src.shape[1] & index.shape[2] == src.shape[2]
  2. dim=1, index[a][b][c] < src.shape[1], index.shape[0] == src.shape[0] & index.shape[2] == src.shape[2]
  3. dim=2, index[a][b][c] < src.shape[2], index.shape[0] == src.shape[0] & index.shape[1] == src.shape[1]

下面是来自Pytorch一些代码展示:

_scatter 使用示例
_scatter 使用示例
感谢搬砖