PaddlePaddle 极简入门实践三:模型的保存与读取

每次预测前还要先训练?
把模型保存下来不就可以直接拿来预测了嘛

1、准备工作

打开官方文档 http://paddlepaddle.org/documentation/docs/zh/1.3/user_guides/howto/training/save_load_variables.html

2、开始实战

#参数初始化
cpu = fluid.CPUPlace()
exe = fluid.Executor(cpu)


#定义数据
datatype = "float32"
train_data = numpy.array([[0],[1],[2],[3],[4],[5],[10]]).astype(datatype)#10倍缩放 此处数据类型尽可能与网格类型相似
y_true = numpy.array([[3],[13],[23],[33],[43],[53],[103]]).astype(datatype)


#定义网络
x = fluid.layers.data(name="x",shape=[1],dtype=datatype)
y = fluid.layers.data(name="y",shape=[1],dtype=datatype)
y_predict = fluid.layers.fc(input=x,size=1,act=None)#定义x与其有关系
#定义损失函数
cost = fluid.layers.square_error_cost(input=y_predict,label=y)
avg_cost = fluid.layers.mean(cost)
#定义优化方法
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.01)
sgd_optimizer.minimize(avg_cost)

##开始训练,迭代100次
prog =f luid.default_startup_program()
exe.run(prog)
for i in range(500):
    outs = exe.run(
        feed={'x':train_data,'y':y_true},
        fetch_list=[y_predict.name,avg_cost.name])#feed为数据表 输入数据和标签数据
    print("正在训练第"+str(i+1)+"次")
#观察结果
    print(outs)
#保存预测模型
fluid.io.save_inference_model(params_dirname, ['x'],[y_predict], exe)

此处对照上一次笔记添加了fluid.io.save_inference_model(params_dirname, ['x'],[y_predict], exe)
上次笔记传送门 //www.greatytc.com/p/789acc391246
这里的params_dirname是保存模型的路径,注意要用"/"分隔开文件夹层次~如果是"\"就会被当成转义符了,接下来就是报错 哈哈哈

['x']是刚刚网络里面的name='x'
[y_predict]则为最后导出的层(因为该模型只有一层网络,此处参数为[y_predict]即可)
exe则为执行器

3、读取模型

datatype = "float32"
test_data = numpy.array([[input("请输入数值")]]).astype(datatype)#测试数为60
#初始化训练环境
cpu = fluid.CPUPlace()
exe = fluid.Executor(cpu)
# 加载模型
[program, feed_target,fetch] = fluid.io.load_inference_model(params_dirname, exe)
#执行预测
results = exe.run(program,
                  feed={feed_target[0]: test_data},
                  fetch_list=fetch)

params_dirname是读取模型的文件夹路径
exe是执行器
[program, feed_target,fetch]分别代表项目、输入数据表、输出数据表,在执行预测时可以使用到

4、测试一下

Pass!

OK,模型正常载入!

未完待续

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容