Files
RT-DETR/rtdetrv2_pytorch/src/misc/profiler_utils.py
2026-06-03 12:42:47 +08:00

66 lines
1.7 KiB
Python

"""Copyright(c) 2023 lyuwenyu. All Rights Reserved.
"""
import re
import torch
import torch.nn as nn
from torch import Tensor
from typing import List
def stats(
model: nn.Module,
data: Tensor=None,
input_shape: List=[1, 3, 640, 640],
device: str='cpu',
verbose=False) -> str:
is_training = model.training
model.train()
num_params = sum([p.numel() for p in model.parameters() if p.requires_grad])
model.eval()
model = model.to(device)
if data is None:
data = torch.rand(*input_shape, device=device)
def trace_handler(prof):
print(prof.key_averages().table(
sort_by="self_cuda_time_total", row_limit=-1))
num_active = 2
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
schedule=torch.profiler.schedule(
wait=1,
warmup=1,
active=num_active,
repeat=1
),
# on_trace_ready=trace_handler,
# on_trace_ready=torch.profiler.tensorboard_trace_handler('./log')
# with_modules=True,
with_flops=True,
) as p:
for _ in range(5):
_ = model(data)
p.step()
if is_training:
model.train()
info = p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)
num_flops = sum([float(v.strip()) for v in re.findall('(\d+.?\d+ *\n)', info)]) / num_active
if verbose:
# print(info)
print(f'Total number of trainable parameters: {num_params}')
print(f'Total number of flops: {int(num_flops)}M with {input_shape}')
return {'n_parameters': num_params, 'n_flops': num_flops, 'info': info}