Pytorch 一些tensor常用操作方法总结

1、tensor自身属性判断

torch.is_tensor(obj):若obj为Tensor类型,那么返回True。

torch.numel(obj):返回Tensor对象中的元素总数。

obj.size():返回Tensor对象的维度。

2、tensor与numpy array 之间相互转换

torch.from_numpy(obj):利用一个numpy的array创建Tensor。注意,若obj原来是1列或者1行,无论obj是否为2维,所生成的Tensor都是一阶的,若需要2阶的Tensor,需要利用view()函数进行转换。

torch.numpy(obj):利用一个tensor创建numpy narray。

3、生成一些特定的tensor

torch.eye(n):返回一个单位方阵,和MATLAB的eye()非常像。还有其他参数。

torch.linspace(start, end, steps),返回一个1维的Tensor。

torch.ones(),与MATLAB的ones很接近。

torch.ones_like(input),返回一个全1的Tensor,其维度与input相一致。

torch.arange(start, end, step),直接返回一个Tensor而不是一个迭代器。

torch.zeros(),与MATLAB的zeros很像。

torch.zeros_like(),与torch.ones_like()类似。

3、由tensor拼接、拆分、维度变换、部分数据提取

这个用于构建模型时,多个layers之间进行数据交互时,维度匹配使用。有时为了提高效率,并行计算时也会涉及到拼接和拆分;

3.1、tensor拼接

torch.cat(seq, dim),将tuple seq中描述的Tensor进行连接,通过实例说明用法。dim指拼接的维度。

3.2、tensor拆分

torch.chunk(input, chunks, dim),与torch.cat()的作用相反。注意,返回值的数量会随chunks的值而发生变化。chunks指拆分后tuple个数,dim指拆分的tensor维度。

3.3、tensor维度变换

obj.view(dim0.dim1.dim2,…),将tensor obj 按照之前的tensor顺序重新按照(dim0.dim1.dim2,…)定义的维度生成新的tensor。

3.4、tensor部分数据提取

torch.index_select(input, dim, index),注意,index是一个1D的Tensor。

torch.masked_select(input, mask),有点像MATLAB中利用bool类型矩阵进行索引的功能,要求mask是ByteTensor类型的Tensor。参考示例代码。注意,执行结果是一个1D的Tensor。

4、存在维度为1的tensor实施维度扩大与压缩

一般这个用于batch中,若batch size = 1,这个操作非常简单;

4.1 存在维度为1的tensor实施维度压缩

torch.squeeze(input),将input中维度数值为1的维度去除。可以指定某一维度。结果是共享input的内存的。

torch.squeeze(input, dim),将input中维度dim(若其数值为1)的维度去除。可以指定某一维度。结果是共享input的内存的。

4.1 存在维度为1的tensor实施维度扩大

torch.unsqeeze(input, dim),在input目前的dim维度上增加一维。

5、tensor 转置

torch.t(input),将input进行转置,不是in place。输出的结果是共享内存的。要求input为2D。

6、tensor 保存和读取

torch.save()和torch.load()

7、tensor获取最大值与最小值

min(input) -> Tensor ,返回全部元素的最小值

min(input, dim, keepdim=False, out=None) -> (Tensor, LongTensor) 返回依据dim维度求最小值,得到两个结果,第一个是最小值,第二个是对应的索引值;

min(input, other, out=None) -> Tensor 返回同维度的input,和other中,对应元素的最小值

torch.max()用法同上

8、数据并行计算,见

https://pytorch.apachecn.org/#/docs/6pytorch.apachecn.org

9、dataSet抽象类的使用

pytorch读取训练集需要使用到2个类:
(1)torch.utils.data.Dataset
(2)torch.utils.data.DataLoader

class RandomDataset(Dataset):

def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)

def __getitem__(self, index):
return self.data[index]

def __len__(self):
return self.len

rand_loader = DataLoader(dataset=RandomDataset(input_size, 100),
batch_size=batch_size, shuffle=True)

    原文作者:知客
    原文地址: https://zhuanlan.zhihu.com/p/53229690
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞