PyTorch学习笔记——repeat()和expand()区别

torch.Tensor是包含一种数据类型元素的多维矩阵。

A
torch.Tensor is a multi-dimensional matrix containing elements of a single data type.

torch.Tensor有两个实例方法可以用来扩展某维的数据的尺寸,分别是repeat()expand()

expand()

expand(*sizes) -> Tensor

*sizes(torch.Size or int) – the desired expanded
size

Returns a new view of the self tensor with singleton dimensions expanded to a larger size.

返回当前张量在某维扩展更大后的张量。扩展(expand)张量不会分配新的内存,只是在存在的张量上创建一个新的视图(view),一个大小(size)等于1的维度扩展到更大的尺寸。

例子:

import torch

>> x = torch.tensor([1, 2, 3])
>> x.expand(2, 3)
tensor([[1, 2, 3],
        [1, 2, 3]])

>> x = torch.randn(2, 1, 1, 4)
>> x.expand(-1, 2, 3, -1)
torch.Size([2, 2, 3, 4])

repeat()

repeat(*sizes) -> Tensor

*size(torch.Size or int) – The
number of times to repeat this tensor along each dimension.

Repeats this tensor along the specified dimensions.

沿着特定的维度重复这个张量,和expand()不同的是,这个函数拷贝张量的数据。

例子:

import torch

>> x = torch.tensor([1, 2, 3])
>> x.repeat(3, 2)
tensor([[1, 2, 3, 1, 2, 3],
        [1, 2, 3, 1, 2, 3],
        [1, 2, 3, 1, 2, 3]])

>> x2 = torch.randn(2, 3, 4)
>> x2.repeat(2, 1, 3).shape
torch.Tensor([4, 3, 12])

    原文作者:德川Captain
    原文地址: https://zhuanlan.zhihu.com/p/58109107
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞