loss突然变为nan?你可能踩了sqrt()的坑

这几天在写一个网络模型,需要自定义损失函数,关于gps距离误差计算的,写完之后,噩梦开始了…
训练过程中loss总是莫名其妙的突然变为nan,网上查阅了许多资料,做了各种尝试,比如调整学习率、调整batch大小、调整网络复杂度、梯度裁剪、过滤脏数据、检查是否存在除0、log(0), 加入BatchNormalization层等,无奈还是会出现loss变为nan的问题。
后来分析loss函数本身,发现唯一可能出现问题的地方是下面这行代码里的tf.sqrt()函数:

return K.mean(tf.sqrt(tf.add(tf.square(lx), tf.square(ly))))

于是,又上网查发现tensorflow或者pytorch在loss函数中使用sqrt可能导致loss训练变为nan的问题,原因如下:
sqrt()即x^1/2,在x=0处不可导,前向传播过程中,loss的计算不会出问题,但在反向传播进行梯度计算的时候可能会遇到在0处求导的情况,这也是loss突然变为nan的原因,在sqrt()添加一个极小数之后得到解决:

return K.mean(tf.sqrt(tf.add(tf.add(tf.square(lx), tf.square(ly)), 1e-10)))
    原文作者:reniviD
    原文地址: https://blog.csdn.net/ten_k/article/details/106576818
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞