71 lines
2.1 KiB
Python
71 lines
2.1 KiB
Python
"""Copyright(c) 2023 lyuwenyu. All Rights Reserved.
|
|
|
|
https://towardsdatascience.com/getting-started-with-pytorch-image-models-timm-a-practitioners-guide-4e77b4bf9055#0583
|
|
"""
|
|
|
|
import torch
|
|
from torchvision.models.feature_extraction import get_graph_node_names, create_feature_extractor
|
|
|
|
from .utils import IntermediateLayerGetter
|
|
from ...core import register
|
|
|
|
|
|
@register()
|
|
class TimmModel(torch.nn.Module):
|
|
def __init__(self, \
|
|
name,
|
|
return_layers,
|
|
pretrained=False,
|
|
exportable=True,
|
|
features_only=True,
|
|
**kwargs) -> None:
|
|
|
|
super().__init__()
|
|
|
|
import timm
|
|
model = timm.create_model(
|
|
name,
|
|
pretrained=pretrained,
|
|
exportable=exportable,
|
|
features_only=features_only,
|
|
**kwargs
|
|
)
|
|
# nodes, _ = get_graph_node_names(model)
|
|
# print(nodes)
|
|
# features = {'': ''}
|
|
# model = create_feature_extractor(model, return_nodes=features)
|
|
|
|
assert set(return_layers).issubset(model.feature_info.module_name()), \
|
|
f'return_layers should be a subset of {model.feature_info.module_name()}'
|
|
|
|
# self.model = model
|
|
self.model = IntermediateLayerGetter(model, return_layers)
|
|
|
|
return_idx = [model.feature_info.module_name().index(name) for name in return_layers]
|
|
self.strides = [model.feature_info.reduction()[i] for i in return_idx]
|
|
self.channels = [model.feature_info.channels()[i] for i in return_idx]
|
|
self.return_idx = return_idx
|
|
self.return_layers = return_layers
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
outputs = self.model(x)
|
|
# outputs = [outputs[i] for i in self.return_idx]
|
|
return outputs
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
model = TimmModel(name='resnet34', return_layers=['layer2', 'layer3'])
|
|
data = torch.rand(1, 3, 640, 640)
|
|
outputs = model(data)
|
|
|
|
for output in outputs:
|
|
print(output.shape)
|
|
|
|
"""
|
|
model:
|
|
type: TimmModel
|
|
name: resnet34
|
|
return_layers: ['layer2', 'layer4']
|
|
"""
|