魔改 PyTorch 模型

魔改 PyTorch 模型

在许多地方都有修改原模型的需求,比如最简单和常见的例子就是修改模型最后的全链接层,来获得针对特定任务的输出结果。对于 PyTorch,可以通过多种方式来达到魔改模型的效果。

PyTorch 的模型是一个 torch.nn.Module 的某个子类的对象,那魔改模型实际就等价于修改某个类,对面向对象熟悉的同学都知道,对类做修改有两个经典的方法:组合和继承。

通过继承修改模型

首先创建自己需要的模型类,然后其父类指向需要被修改的模型,这时自己的模型则具有完备的父类行为,再在子类中实现魔改的逻辑。其大致的框架代码如下所示:

from torchvision.models import ResNet

class CustomizedResNet(ResNet):

def __init__(self):

super().__init__()

...

def forward(self, x):

...

下面这个例子,我们将对 ResNet 进行魔改,把 ResNet 4 个 stage 输出的特征连起来,最后过一个全链接后输出一个标量。

from torchvision.models.resnet import Bottleneck, BasicBlock, ResNet

class CustomizedResNet(ResNet):

def __init__(self, block, layers, num_classes=2):

super().__init__(block, layers, num_classes)

self.fc = torch.nn.Linear(int(512 * block.expansion * 1.875), num_classes)

def forward(self, x):

x = self.conv1(x)

x = self.bn1(x)

x = self.relu(x)

x = self.maxpool(x)

x1 = self.layer1(x)

x2 = self.layer2(x1)

x3 = self.layer3(x2)

x4 = self.layer4(x3)

x = torch.cat(

[

self.avgpool(x1),

self.avgpool(x2),

self.avgpool(x3),

self.avgpool(x4),

],

dim=1,

)

x = torch.flatten(x, 1)

x = self.fc(x)

return x

new_resnet34 = CustomizedResNet(BasicBlock, [3, 4, 6, 3], num_classes=1)

new_resnet50 = CustomizedResNet(Bottleneck, [3, 4, 6, 3], num_classes=1)

new_resnet101 = CustomizedResNet(Bottleneck, [3, 4, 23, 3], num_classes=1)

new_resnet200 = CustomizedResNet(Bottleneck, [3, 24, 36, 3], num_classes=1)

通过组合修改模型

在面向对象编程中我们可能听说过「组合优于继承」,在模型修改的场景中其实也是这样,大多数情况下我们可能都适用组合而非继承。

首先依然需要创建模型的类,但这个类不再继承自魔改的类,而是直接继承 PyTorch 的模型基类 torch.nn.Module,然后将需要魔改的类作为类变量融入到模型中,下面是大致的框架代码:

from torchvision.models import resnet18

class CustomizedResNet(torch.nn.Module):

def __init__(self, backbone):

super().__init__()

self.backbone = backbone

...

def forward(self, x):

...

my_resnet18 = CustomizedResNet(resnet18)

同样我们来实现跟上面同样的需求:

from torchvision.models import resnet50

class CustomizedResNet(torch.nn.Module):

def __init__(self, backbone, num_classes=2):

super().__init__()

self.backbone = backbone

self.fc = torch.nn.Linear(3840, num_classes)

def forward(self, x):

x = self.backbone.conv1(x)

x = self.backbone.bn1(x)

x = self.backbone.relu(x)

x = self.backbone.maxpool(x)

x1 = self.backbone.layer1(x)

x2 = self.backbone.layer2(x1)

x3 = self.backbone.layer3(x2)

x4 = self.backbone.layer4(x3)

x = torch.cat(

[

self.backbone.avgpool(x1),

self.backbone.avgpool(x2),

self.backbone.avgpool(x3),

self.backbone.avgpool(x4),

],

dim=1,

)

x = torch.flatten(x, 1)

x = self.fc(x)

return x

new_resnet50 = CustomizedResNet(resnet50())

通过猴子补丁修改模型

最后再介绍一个最简单粗暴的方法:猴子补丁(Monkey Patch)。之所以叫猴子补丁,是因为这种方法在程序设计的角度上看,是具有破坏性的。我们尽可能得还是应该使用上面两种方法。

猴子补丁修改模型非常简单粗暴,直接使用需要修改的模型创建对象,然后直接对对象的属性做出修改,下面是把 ResNet34 的输出从 1000 改为 1 的例子:

from torchvision.models import resnet50

model = resnet50()

model.fc = torch.nn.Linear(2048, 1)

此外这种方法也仅能实现一些简单的需求,对于复杂的需求还是推荐使用组合的方法来完成。

相关推荐