torch.Tensor是包含一种数据类型元素的多维矩阵。
A
torch.Tensor is a multi-dimensional matrix containing elements of a single data type.
torch.Tensor有两个实例方法可以用来扩展某维的数据的尺寸,分别是repeat()和expand():
expand()
expand(*sizes) -> Tensor
*sizes(torch.Size or int) – the desired expanded
sizeReturns a new view of the self tensor with singleton dimensions expanded to a larger size.
返回当前张量在某维扩展更大后的张量。扩展(expand)张量不会分配新的内存,只是在存在的张量上创建一个新的视图(view),一个大小(size)等于1的维度扩展到更大的尺寸。
例子:
import torch
>> x = torch.tensor([1, 2, 3])
>> x.expand(2, 3)
tensor([[1, 2, 3],
[1, 2, 3]])
>> x = torch.randn(2, 1, 1, 4)
>> x.expand(-1, 2, 3, -1)
torch.Size([2, 2, 3, 4])
repeat()
repeat(*sizes) -> Tensor
*size(torch.Size or int) – The
number of times to repeat this tensor along each dimension.Repeats this tensor along the specified dimensions.
沿着特定的维度重复这个张量,和expand()不同的是,这个函数拷贝张量的数据。
例子:
import torch
>> x = torch.tensor([1, 2, 3])
>> x.repeat(3, 2)
tensor([[1, 2, 3, 1, 2, 3],
[1, 2, 3, 1, 2, 3],
[1, 2, 3, 1, 2, 3]])
>> x2 = torch.randn(2, 3, 4)
>> x2.repeat(2, 1, 3).shape
torch.Tensor([4, 3, 12])