!!! I am just starting to understand PyTorch !!!个
假设该模型具有以下体系 struct :
(conv1): Conv2d(2, 6, kernel_size=(5, 5), stride=(1, 1))
(pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
(fc1): Linear(in_features=256, out_features=120, bias=True)
(fc2): Linear(in_features=120, out_features=84, bias=True)
(fc3): Linear(in_features=84, out_features=10, bias=True)
例如,我应该怎么做才能在Conv1和Pool层之间添加一些MyFunction?
以下是我当前的代码:
class CNN(Module):
def __init__(self) -> None:
super(CNN, self).__init__()
self.cnn_layer = Sequential(
Conv2d(in_channels=2, out_channels=6, kernel_size=5),
# MyFunction here
ReLU(inplace=True),
MaxPool2d(kernel_size=2, stride=2),
)
self.linear_layers = Sequential(
Linear(256, 120), Linear(120, 84), Linear(84, 10)
)
def forward(self, image):
image = self.cnn_layer(image)
image = image.view(-1, 4 * 4 * 16)
image = self.linear_layers(image)
return image