nn.Sequential 、 nn.ModuleList 、 nn.ModuleDict区别
2、各自用法
net = nn.Sequential(nn.Linear(128, 256), nn.ReLU())
net = nn.ModuleList([nn.Linear(128, 256), nn.ReLU()])
net = nn.ModuleDict({'linear': nn.Linear(128, 256), 'act': nn.ReLU()})
3、区别
-
ModuleList 仅仅是一个储存各种模块的列表,这些模块之间没有联系也没有顺序(所以不用保证相邻层的输入输出维度匹配),而且没有实现 forward 功能需要自己实现
-
和 ModuleList 一样, ModuleDict 实例仅仅是存放了一些模块的字典,并没有定义 forward 函数需要自己定义
-
而 Sequential 内的模块需要按照顺序排列,要保证相邻层的输入输出大小相匹配,内部 forward 功能已经实现,所以,直接如下写模型,是可以直接调用的,不再需要写forward,sequential 内部已经有 forward
4、转换
-
将 nn.ModuleList 转换成 nn.Sequential
module_list = nn.ModuleList([nn.Linear(128, 256), nn.ReLU()])
net = nn.Sequential(*module_list) -
nn.ModuleDict转换为nn.Sequential
module_dict = nn.ModuleDict({'linear': nn.Linear(128, 256), 'act': nn.ReLU()})
net = nn.Sequential(*module_dict.values())