first commit
This commit is contained in:
70
rtdetrv2_pytorch/src/nn/backbone/timm_model.py
Normal file
70
rtdetrv2_pytorch/src/nn/backbone/timm_model.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""Copyright(c) 2023 lyuwenyu. All Rights Reserved.
|
||||
|
||||
https://towardsdatascience.com/getting-started-with-pytorch-image-models-timm-a-practitioners-guide-4e77b4bf9055#0583
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torchvision.models.feature_extraction import get_graph_node_names, create_feature_extractor
|
||||
|
||||
from .utils import IntermediateLayerGetter
|
||||
from ...core import register
|
||||
|
||||
|
||||
@register()
|
||||
class TimmModel(torch.nn.Module):
|
||||
def __init__(self, \
|
||||
name,
|
||||
return_layers,
|
||||
pretrained=False,
|
||||
exportable=True,
|
||||
features_only=True,
|
||||
**kwargs) -> None:
|
||||
|
||||
super().__init__()
|
||||
|
||||
import timm
|
||||
model = timm.create_model(
|
||||
name,
|
||||
pretrained=pretrained,
|
||||
exportable=exportable,
|
||||
features_only=features_only,
|
||||
**kwargs
|
||||
)
|
||||
# nodes, _ = get_graph_node_names(model)
|
||||
# print(nodes)
|
||||
# features = {'': ''}
|
||||
# model = create_feature_extractor(model, return_nodes=features)
|
||||
|
||||
assert set(return_layers).issubset(model.feature_info.module_name()), \
|
||||
f'return_layers should be a subset of {model.feature_info.module_name()}'
|
||||
|
||||
# self.model = model
|
||||
self.model = IntermediateLayerGetter(model, return_layers)
|
||||
|
||||
return_idx = [model.feature_info.module_name().index(name) for name in return_layers]
|
||||
self.strides = [model.feature_info.reduction()[i] for i in return_idx]
|
||||
self.channels = [model.feature_info.channels()[i] for i in return_idx]
|
||||
self.return_idx = return_idx
|
||||
self.return_layers = return_layers
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
outputs = self.model(x)
|
||||
# outputs = [outputs[i] for i in self.return_idx]
|
||||
return outputs
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
model = TimmModel(name='resnet34', return_layers=['layer2', 'layer3'])
|
||||
data = torch.rand(1, 3, 640, 640)
|
||||
outputs = model(data)
|
||||
|
||||
for output in outputs:
|
||||
print(output.shape)
|
||||
|
||||
"""
|
||||
model:
|
||||
type: TimmModel
|
||||
name: resnet34
|
||||
return_layers: ['layer2', 'layer4']
|
||||
"""
|
||||
Reference in New Issue
Block a user