简易代码
def print_model_parm_nums():
model = models.alexnet()
total = sum([param.nelement() for param in model.parameters()])
print(' + Number of params: %.2fM' % (total / 1e6))
简易代码
def print_model_parm_nums():
model = models.alexnet()
total = sum([param.nelement() for param in model.parameters()])
print(' + Number of params: %.2fM' % (total / 1e6))