一、实现过程
使用Pytorch进行预处理时,通常使用torchvision.transforms.Normalize(mean, std)方法进行数据标准化,其中参数mean和std分别表示图像集每个通道的均值和标准差序列。
首先,给出mean和std的定义,数学表示如下:
假设有一组数据集 X i , i ∈ { 1 , 2 , ⋯ , n } X_i,\,\,i\in\{1,2,\cdots,n\} Xi,i∈{ 1,2,⋯,n},则这组数据集的均值为: m e a n = ∑ i = 1 n X i n (1) mean=\frac{\displaystyle\sum_{i=1}^nX_i}{n}\tag{1} mean=ni=1∑nXi(1)通常使用 X ‾ \overline X X表示数据的均值。
这组数据集的标准差为: s t d = ∑ i = 1 n ( X i − X ‾ ) 2 n = ∑ i = 1 n ( X i 2 − 2 X i X ‾ + X ‾ 2 ) n = ( ∑ i = 1 n X i 2 ) − n X ‾ 2 n = ∑ i = 1 n X i 2 n − X ‾ 2 (2) std=\sqrt{\frac{\displaystyle\sum_{i=1}^n\left(X_i-\overline X\right)^2}{n}}\\[2ex]=\sqrt{\frac{\displaystyle\sum_{i=1}^n(X_i^2-2X_i\overline X+\overline X^2)}{n}}\\[2ex]=\sqrt{\frac{\left(\displaystyle\sum_{i=1}^nX_i^2\right)-n\overline X^2}{n}}\\[2ex]=\sqrt{\frac{\displaystyle\sum_{i=1}^nX_i^2}{n}-\overline X^2}\tag{2} std=ni=1∑n(Xi−X)2 =ni=1∑n(Xi2−2XiX+X2) =n(i=1∑nXi2)−nX2 =ni=1∑nXi2−X2 (2)下面给出计算图像数据集每个通道的均值和标准差的函数代码:
import torch
from torchvision import transforms,datasets
from torch.utils.data import DataLoader
batch_size = 64
# 训练集(以CIFAR-10数据集为例)
train_dataset = datasets.CIFAR10(root='G:/datasets/cifar10',train=True,download=False,transform=transforms.ToTensor())
train_loader = DataLoader(train_dataset,shuffle=True,batch_size=batch_size)
def get_mean_std_value(loader):
''' 求数据集的均值和标准差 :param loader: :return: '''
data_sum,data_squared_sum,num_batches = 0,0,0
for data,_ in loader:
# data: [batch_size,channels,height,width]
# 计算dim=0,2,3维度的均值和,dim=1为通道数量,不用参与计算
data_sum += torch.mean(data,dim=[0,2,3]) # [batch_size,channels,height,width]
# 计算dim=0,2,3维度的平方均值和,dim=1为通道数量,不用参与计算
data_squared_sum += torch.mean(data**2,dim=[0,2,3]) # [batch_size,channels,height,width]
# 统计batch的数量
num_batches += 1
# 计算均值
mean = data_sum/num_batches
# 计算标准差
std = (data_squared_sum/num_batches - mean**2)**0.5
return mean,std
mean,std = get_mean_std_value(train_loader)
print('mean = {},std = {}'.format(mean,std))
CIFAR10数据集的均值和标准差为:
mean = tensor([0.4914, 0.4821, 0.4465]),std = tensor([0.2470, 0.2435, 0.2616])
MNIST数据集的均值和标准差为:
mean = tensor([0.1307]),std = tensor([0.3081])