pytorch中nn.ModuleList()使用方法
•
人工智能
定义ModuleList
我们可以将我们需要的层放入到一个集合中,然后将这个集合作为参数传入nn.ModuleList中,但是这个子类并不可以直接使用,因为这个子类并没有实现forward函数,所以要使用还需要放在继承了nn.Module的模型中进行使用。
model_list = nn.ModuleList([nn.Conv2d(1, 5, 2), nn.Linear(10, 2), nn.Sigmoid()])
x = torch.randn(32, 3, 24, 24)
for model in model_list:
model_list(x)
使用ModuleList定义网络
class Net(nn.Module):
def __init__(self):
super().__init__()
self.model_list = nn.ModuleList([nn.Conv2d(1, 5, 2), nn.Linear(10, 2), nn.Sigmoid()])
def forward(self, x):
return self.model_list(x)
打印网络层结构
model = Net() print(model)
Net(
(model_list): ModuleList(
(0): Conv2d(1, 5, kernel_size=(2, 2), stride=(1, 1))
(1): Linear(in_features=10, out_features=2, bias=True)
(2): Sigmoid()
)
)
本文来自网络,不代表协通编程立场,如若转载,请注明出处:https://www.net2asp.com/3f3a620e60.html
