1import torch 2import torch.nn as nn 3from torch.nn import functional as F 4 5 6class RestNetBasicBlock(nn.Module): 7 def __init__(self, in_channels, out_channels, stride): 8 super(RestNetBasicBlock, self).__init__() 9 self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1) 10 self.bn1 = nn.BatchNorm2d(out_channels) 11 self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1) 12 self.bn2 = nn.BatchNorm2d(out_channels) 13 14 def forward(self, x): 15 output = self.conv1(x) 16 output = F.relu(self.bn1(output)) 17 output = self.conv2(output) 18 output = self.bn2(output) 19 return F.relu(x + output) 20 21 22class RestNetDownBlock(nn.Module): 23 def __init__(self, in_channels, out_channels, stride): 24 super(RestNetDownBlock, self).__init__() 25 self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride[0], padding=1) 26 self.bn1 = nn.BatchNorm2d(out_channels) 27 self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride[1], padding=1) 28 self.bn2 = nn.BatchNorm2d(out_channels) 29 self.extra = nn.Sequential( 30 nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride[0], padding=0), 31 nn.BatchNorm2d(out_channels) 32 ) 33 34 def forward(self, x): 35 extra_x = self.extra(x) 36 output = self.conv1(x) 37 out = F.relu(self.bn1(output)) 38 39 out = self.conv2(out) 40 out = self.bn2(out) 41 return F.relu(extra_x + out) 42 43 44class RestNet18(nn.Module): 45 def __init__(self): 46 super(RestNet18, self).__init__() 47 self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) 48 self.bn1 = nn.BatchNorm2d(64) 49 self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 50 51 self.layer1 = nn.Sequential(RestNetBasicBlock(64, 64, 1), 52 RestNetBasicBlock(64, 64, 1)) 53 54 self.layer2 = nn.Sequential(RestNetDownBlock(64, 128, [2, 1]), 55 RestNetBasicBlock(128, 128, 1)) 56 57 self.layer3 = nn.Sequential(RestNetDownBlock(128, 256, [2, 1]), 58 RestNetBasicBlock(256, 256, 1)) 59 60 self.layer4 = nn.Sequential(RestNetDownBlock(256, 512, [2, 1]), 61 RestNetBasicBlock(512, 512, 1)) 62 63 self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1)) 64 65 self.fc = nn.Linear(512, 10) 66 67 def forward(self, x): 68 out = self.conv1(x) 69 out = self.layer1(out) 70 out = self.layer2(out) 71 out = self.layer3(out) 72 out = self.layer4(out) 73 out = self.avgpool(out) 74 out = out.reshape(x.shape[0], -1) 75 out = self.fc(out) 76 return out 77 78 79if __name__ == '__main__': 80 # build model 81 model = RestNet18() 82 model.eval() 83 84 # export onnx (rknn-toolkit2 only support to opset_version=12) 85 x = torch.randn((1, 3, 224, 224)) 86 torch.onnx.export(model, x, './resnet18.onnx', opset_version=12, input_names=['input'], output_names=['output']) 87