Date: 2020/07/27
Coder: CW
Foreword:
相信诸位炼丹者在看paper时会注意到出现有 'FLOPs' 这个词,那么 FLOPs 究竟是什么意思呢?它是如何计算的?有哪些开源工具可供使用?
Outline
I. 什么是FLOPs?
II. 如何计算FLOPs?
III. 相关开源工具
1、torchstat
2、torchscan
3、其它
什么是FLOPs?
FLOPs(Floating Point Operations)即 浮点运算次数,常用于评估模型/算法的计算量(复杂度)。注意,此处s必须小写!因为大写S对应的是另一个概念——FLOPS(Floating Point Operations per Second),意为每秒浮点运算次数,代表的是一种运算速率,通常用于衡量硬件的性能指标。
如何计算FLOPs?
对于卷积层(不考虑激活函数的运算),令:
:表示卷积核大小;
:输入通道数;
:输出通道数;
:输出特征图的高和宽
那么生成输出特征图的一个单元(单通道)则需要:
-次乘法
-次加法
若算上偏置(bias),那么还需额外的一次加法,于是就有次加法,与乘法次数一致。
由于输出特征图尺寸为()且通道数为
,因此该卷积层的运算次数为:
这是一个卷积层的乘加运算次数(MAdd)。
有些时候会将一次乘法和一次加法合并对应一次FLOP,那么一个卷积层的FLOPs就为。在这种方式下,可将卷积核和输出特征图都分别看作立方体,那么FLOPs就恰好是两者体积的乘积。
同理,对于一个全连接层(不计bias),我们不难推出其FLOPs(MAdd)为:
其中分别对应输入、输出神经元个数。
相关开源工具
1、torchstat
官方介绍:
This is a lightweight neural network analyzer based on PyTorch.
This tools can show:
Total number of network parameters
Theoretical amount of floating point arithmetics (FLOPs)
Theoretical amount of multiply-adds (MAdd)
Memory usage
安装:
pip install torchstat
或者使用源码方式:
python3 setup.py install
使用:
from torchstat import stat
import torchvision.models as models
model = models.resnet18()
stat(model, (3, 224, 224))
效果(部分截图):
可看到这里区分了乘加运算(MAdd)与FLOPs。
2、torchscan
安装:
pip install torchscan
使用:
>>> import torchvision.models as models
>>> from torchscan import summary
>>> mod = models.resnet18()
>>> summary(mod, (3, 224, 224))
效果(部分截图):
注意,这里的FLOPs与torchstat的乘加运算次数(MAdd)对应,而MAC(Multiply-Accumulations)与torchstat的FLOPs对应。
3、其它
其它的还有 flops-counter、torchsummary等,感兴趣的可以自行摸索。