tensorflow – Pytorch如何获得两次递减函数的梯度

这是我正在尝试实现的:

像往常一样,我们根据F(X)计算损失.但我们也定义了“对抗性损失”,这是基于F(X e)的损失. e定义为dF(X)/ dX乘以某个常数.损失和对抗性损失都会因总损失而被反向传播.

在张量流中,这部分(得到dF(X)/ dX)可以编码如下:

  grad, = tf.gradients( loss, X )
  grad = tf.stop_gradient(grad)
  e = constant * grad

下面是我的pytorch代码:

class DocReaderModel(object):
    def __init__(self, embedding=None, state_dict=None):
        self.train_loss = AverageMeter()
        self.embedding = embedding
        self.network = DNetwork(opt, embedding)
        self.optimizer = optim.SGD(parameters)

    def adversarial_loss(self, batch, loss, embedding, y):
        self.optimizer.zero_grad()
        loss.backward(retain_graph=True)
        grad = embedding.grad
        grad.detach_()

        perturb = F.normalize(grad, p=2)* 0.5
        self.optimizer.zero_grad()
        adv_embedding = embedding + perturb
        network_temp = DNetwork(self.opt, adv_embedding) # This is how to get F(X)
        network_temp.training = False
        network_temp.cuda()
        start, end, _ = network_temp(batch) # This is how to get F(X)
        del network_temp # I even deleted this instance.
        return F.cross_entropy(start, y[0]) + F.cross_entropy(end, y[1])

    def update(self, batch):
        self.network.train()
        start, end, pred = self.network(batch)
        loss = F.cross_entropy(start, y[0]) + F.cross_entropy(end, y[1])
        loss_adv = self.adversarial_loss(batch, loss, self.network.lexicon_encoder.embedding.weight, y) 
        loss_total = loss + loss_adv 

        self.optimizer.zero_grad()
        loss_total.backward()
        self.optimizer.step()

我有几个问题:

1)我用grad.detach_()替换了tf.stop_gradient.它是否正确?

2)我得到“RuntimeError:尝试第二次向后遍历图形,但缓冲区已经被释放.在第一次向后调用时指定retain_graph = True.”所以我在loss.backward添加了retain_graph = True.那个特定的错误消失了.
但是现在我在几个纪元后出现内存错误(RuntimeError:cuda运行时错误(2):/opt/conda/conda-bld/pytorch_1525909934016/work/aten/src/THC/generic/THCStorage.cu内存不足:58
).我怀疑我不必要地保留图表.

有人能让我知道pytorch的最佳实践吗?任何提示/甚至短评都将受到高度赞赏.

最佳答案 我认为你正在尝试实现生成对抗网络(GAN),但是从代码中,我不理解并且无法遵循你想要实现的目标,因为GAN有一些缺失的部分可以工作.我可以看到有一个鉴别器网络模块,DNetwork但缺少发电机网络模块.

如果猜测,当你说’损失功能两次’时,我认为你的意思是你有一个用于鉴别器网的损失功能和另一个用于发电机网的功能.如果是这种情况,让我分享一下如何实现基本的GAN模型.

举个例子,我们来看看这个Wasserstein GAN Jupyter notebook

我将跳过不太重要的部分并放大重要部分:

>首先,导入PyTorch库并进行设置

# Set up batch size, image size, and size of noise vector:
bs, sz, nz = 64, 64, 100 # nz is the size of the latent z vector for creating some random noise later

>构建鉴别器模块

class DCGAN_D(nn.Module):
    def __init__(self):
        ... truncated, the usual neural nets stuffs, layers, etc ...
    def forward(self, input):
        ... truncated, the usual neural nets stuffs, layers, etc ...

>构建生成器模块

class DCGAN_G(nn.Module):
    def __init__(self):
        ... truncated, the usual neural nets stuffs, layers, etc ...
    def forward(self, input):
        ... truncated, the usual neural nets stuffs, layers, etc ...

>把它们放在一起

netG = DCGAN_G().cuda()
netD = DCGAN_D().cuda()

>需要告知优化器要优化哪些变量.模块自动跟踪其变量.

optimizerD = optim.RMSprop(netD.parameters(), lr = 1e-4)
optimizerG = optim.RMSprop(netG.parameters(), lr = 1e-4)

> Discriminator的一个前进步骤和一个后退步骤

这里,网络可以在后向传递期间计算梯度,取决于此功能的输入.所以,在我的情况下,我有3种类型的损失;发生器损失,鉴别者实际图像丢失,鉴别者假图像丢失.对于3种不同的净通行证,我可以获得三次失步函数的梯度.

def step_D(input, init_grad):
    # input can be from generator's generated image data or input image from dataset
    err = netD(input)
    err.backward(init_grad) # backward pass net to calculate gradient
    return err # loss

>控制可训练参数[重要]

模型中的可训练参数是需要梯度的参数.

def make_trainable(net, val):
    for p in net.parameters():
        p.requires_grad = val # note, i.e, this is later set to False below in netG update in the train loop.

在TensorFlow中,此部分可以编码如下:

grad = tf.gradients(loss,X)
grad = tf.stop_gradient(grad)

所以,我认为这将回答你的第一个问题,“我用grad.detach_()替换了tf.stop_gradient.这是正确的吗?”
>训练循环

你可以在这里看到这里如何调用3种不同的损失函数.

    def train(niter, first=True):

        for epoch in range(niter):
            # Make iterable from PyTorch DataLoader
            data_iter = iter(dataloader)
            i = 0

            while i < n:
                ###########################
                # (1) Update D network
                ###########################
                make_trainable(netD, True)

                # train the discriminator d_iters times
                d_iters = 100

                j = 0

                while j < d_iters and i < n:
                    j += 1
                    i += 1

                    # clamp parameters to a cube
                    for p in netD.parameters():
                        p.data.clamp_(-0.01, 0.01)

                    data = next(data_iter)

                    ##### train with real #####
                    real_cpu, _ = data
                    real_cpu = real_cpu.cuda()
                    real = Variable( data[0].cuda() )
                    netD.zero_grad()

                    # Real image discriminator loss
                    errD_real = step_D(real, one)

                    ##### train with fake #####
                    fake = netG(create_noise(real.size()[0]))
                    input.data.resize_(real.size()).copy_(fake.data)

                    # Fake image discriminator loss
                    errD_fake = step_D(input, mone)

                    # Discriminator loss
                    errD = errD_real - errD_fake
                    optimizerD.step()

                ###########################
                # (2) Update G network
                ###########################
                make_trainable(netD, False)
                netG.zero_grad()

                # Generator loss
                errG = step_D(netG(create_noise(bs)), one)
                optimizerG.step()

                print('[%d/%d][%d/%d] Loss_D: %f Loss_G: %f Loss_D_real: %f Loss_D_fake %f'
                    % (epoch, niter, i, n,
                    errD.data[0], errG.data[0], errD_real.data[0], errD_fake.data[0]))

“I was getting “RuntimeError: Trying to backward through the graph a second time…”

PyTorch有这种行为;为了减少GPU内存使用量,在.backward()调用期间,所有中间结果(如果你有类似已保存的激活等)将在不再需要时被删除.因此,如果您尝试再次调用.backward(),则中间结果不存在,并且无法执行向后传递(并且您会看到错误).

这取决于你想要做什么.您可以调用.backward(retain_graph = True)进行不会删除中间结果的向后传递,这样您就可以再次调用.backward().除最后一次调用之外的所有调用都应该有retain_graph = True选项.

Can someone let me know pytorch’s best practice on this

正如您从上面的PyTorch代码以及PyTorch中正在尝试保留Pythonic的方式所做的那样,您可以了解PyTorch在那里的最佳实践.

点赞