class TensorsDataset(torch.utils.data.Dataset):
'''
A simple loading dataset - loads the tensor that are passed in input. This is the same as
torch.utils.data.TensorDataset except that you can add transformations to your data and target tensor.
Target tensor can also be None, in which case it is not returned.
'''
def __init__(self, data_tensor, target_tensor=None, transforms=None, target_transforms=None):
if target_tensor is not None:
assert data_tensor.size(0) == target_tensor.size(0)
self.data_tensor = data_tensor
self.target_tensor = target_tensor
if transforms is None:
transforms = []
if target_transforms is None:
target_transforms = []
if not isinstance(transforms, list):
transforms = [transforms]
if not isinstance(target_transforms, list):
target_transforms = [target_transforms]
self.transforms = transforms
self.target_transforms = target_transforms
def __getitem__(self, index):
data_tensor = self.data_tensor[index]
for transform in self.transforms:
data_tensor = transform(data_tensor)
if self.target_tensor is None:
return data_tensor
target_tensor = self.target_tensor[index]
for transform in self.target_transforms:
target_tensor = transform(target_tensor)
return data_tensor, target_tensor
def __len__(self):
return self.data_tensor.size(0)
重新定义Pytorch中的TensorDataset,可实现transforms
原文作者:pytorch
原文地址: https://www.cnblogs.com/marsggbo/p/10459235.html
本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
原文地址: https://www.cnblogs.com/marsggbo/p/10459235.html
本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。