23 lines
533 B
Python
23 lines
533 B
Python
import torch
|
|
import torch.nn as nn
|
|
from transformers import RegNetModel
|
|
|
|
|
|
from src.core import register
|
|
|
|
__all__ = ['RegNet']
|
|
|
|
@register
|
|
class RegNet(nn.Module):
|
|
def __init__(self, configuration, return_idx=[0, 1, 2, 3]):
|
|
super(RegNet, self).__init__()
|
|
self.model = RegNetModel.from_pretrained("facebook/regnet-y-040")
|
|
self.return_idx = return_idx
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
outputs = self.model(x, output_hidden_states = True)
|
|
x = outputs.hidden_states[2:5]
|
|
|
|
return x |