squeeze 函数
官方文档
去除size为1的维度,包括行和列。至于维度大于等于2时,squeeze()不起作用。
cat 函数
scatter_ 函数
Pytorch官网对于scatter_方法的定义
scatter_(dim, index, src) → Tensor
参数的意义如下:
dim(int)
: dim=0代表按第0维填充,dim=1代表按第1维填充,dim=2代表按照第2维填充。index(LongTensor)
: 代表需要填充的位置信息。src(Tensor)/value(Float)
: 代表需要的填充的内容。
Note: src和value(Float)可以互换,如果存在src就不能存在value,反之亦然。
这里做一个假设:对于一个三维的矩阵,他使用scatter_的方法得到的结果M应该是如下三种情况:
- dim=0, M[index[a][b][c]] [b] [c] = src[a][b][c]
- dim=1, M[a] [index[a][b][c]] [c] = src[a][b][c]
- dim=2, M[a] [b] [index[a][b][c]] = src[a][b][c]
因此,通过对应的关系可以看出index中数值大小其实也是有一定的要求的。
- dim=0, index[a][b][c] < src.shape[0], index.shape[1] == src.shape[1] & index.shape[2] == src.shape[2]
- dim=1, index[a][b][c] < src.shape[1], index.shape[0] == src.shape[0] & index.shape[2] == src.shape[2]
- dim=2, index[a][b][c] < src.shape[2], index.shape[0] == src.shape[0] & index.shape[1] == src.shape[1]
下面是来自Pytorch一些代码展示: