Pytorch数据类型转换
载入模块生成数据
import torch
import numpy as np
a_numpy = np.array([1,2,3])
Numpy转换为Tensor
a_tensor = torch.from_numpy(a_numpy)
print(a_tensor)
Tensor转换为Numpy
a_numpy = a_tensor.numpy()
print(a_numpy)
Int, float 转换为tensor
c = torch.tensor(2)
print(c)
tensor 转换为int
c = c.item()
print(c)
Numpy转换为Variable
a_variable = Variable(torch.from_numpy(a_numpy))
print(a_variable)
Variable转换为Numpy
a_numpy = a_variable.data.numpy()
print(a_numpy)