net.apply(weights_init)的理解
在DCGAN的学习中,Pytorch官方对于权重初始化使用了下列方法
# custom weights initialization called on ``netG`` and ``netD``
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
在这里对该代码学习后的理解做一些记录。首先是apply(fn)
,根据官网解释该方法是Module类的方法,作用是将fn递归地应用于每个子模块(由.children()返回),以及其自身。典型用途便是初始化模型的参数。我们这里来写一个简单的神经网络 net 并将其实例化
def weights_init(m):
print(m)
net = nn.Sequential(nn.Linear(1, 1), nn.Conv2d(1, 1, 1))
net.apply(weights_init)
我们定义了一个weights_init
函数,和一个Sequential
类,该类有两层,第一层是全链接层,第二层是卷积层。将该类实例化后调用其apply()
方法,我们来运行看看
>>> Linear(in_features=1, out_features=1, bias=True)
Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1))
Sequential(
(0): Linear(in_features=1, out_features=1, bias=True)
(1): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1))
)
可以看到apply()
遍历了该类的每一层和其自身,我们这里将打印的参数再改成内建参数m.__class__
看看
def weights_init(m):
print(m.__class__)
net = nn.Sequential(nn.Linear(1, 1), nn.Conv2d(1, 1, 1))
net.apply(weights_init)
>>> <class 'torch.nn.modules.linear.Linear'>
<class 'torch.nn.modules.conv.Conv2d'>
<class 'torch.nn.modules.container.Sequential'>
在这里net对象的所在类被递归出来了,最后在把其换成m.__class__.__name__
运行
def weights_init(m):
print(m.__class__.__name__)
net = nn.Sequential(nn.Linear(1, 1), nn.Conv2d(1, 1, 1))
net.apply(weights_init)
>>> Linear
Conv2d
Sequential
可以看出该方法含义是递归神经网络并返回每层名字,如果该名字找到了字符串'Conv'
或者'BatchNorm'
,则对其权重做归一化