将训练好的keras模型转换为tensorflow模型的通用代码

・3 分钟阅读

  • 源代码名称: keras_to_tensorflow
  • 源代码网址: https://www.github.com/amir-abdi/keras_to_tensorflow
  • keras_to_tensorflow的文档
  • keras_to_tensorflow的源代码下载
  • Git URL:
    git://www.github.com/amir-abdi/keras_to_tensorflow.git
  • Git Clone代码到本地:
    git clone https://www.github.com/amir-abdi/keras_to_tensorflow
  • Subversion代码到本地:
    $ svn co --depth empty https://www.github.com/amir-abdi/keras_to_tensorflow
                              Checked out revision 1.
                              $ cd repo
                              $ svn up trunk
              
  • keras_to_tensorflow

    将训练的keras模型转换为tensorflow模型的通用代码

    keras_to_tensorflow是一个例子代码,它加载训练的keras模型,冻结节点(将所有TensorFlow变量转换为TensorFlow常量),并将inference graph和权重保存到protobuf文件(.pb )中,然后,可以使用此文件来部署训练模型。在冻结期间,网络的其他节点(包含输出预测的张量)会被修剪,这导致模型更小,更优化的网络。

    那时候,旧版本freeze_graph工具(/tensorflow/python/tools/freeze_graph.py )用于将变量转换为常量。 此功能新版本 graph_util.convert_variables_to_constants

    如何使用

    可以使用model.save('file_name.h5') (有关详细信息,请参阅Keras API文档)保存keras模型。

    你可以使用IPython notebook (kears_to_tensorflow.ipnyb ),或者在你的keras模型文件夹中运行如下的python脚本:

    
    python3 keras_to_tensorflow.py -input_model_file model.h5
    
    
    python keras_to_tensorflow.py -input_model_file model.h5 
    
    
    
    

    尝试python3 keras_to_tensorflow.py --help

    输入参数

    • num_output:此值与classes数量,batch_size等无关,如果网络是multi-stream network(具有多个输出的分叉网络),则它大致等于1,将值设置为输出数量。

    • quantize:如果设置为True,则使用Tensorflow()的量化功能[默认值:false ]

    • use_theano:Thaeno和Tensorflow以不同的方式实现卷积。当使用Keras与Theano后端时,次序设置为'channels_first '。此功能未经过完全测试,并且不适用于quantizization [默认值: false ]

    • input_fld:保存keras权重文件的目录[默认:]

    • output_fld:保存tensorflow文件的目标目录[默认:]

    • input_model_file:输入权重文件的名称[默认值:''model.h5']

    • output_model_file:输出权重文件的名称[默认值:args.input_model_file +'pb']

    • graph_def:如果设置为True,则将图形定义写为ascii文件[默认值:false ]

    • output_graphdef_file:如果graph_def设置为True,则图形定义的文件名为[default:model.ascii ]

    • output_node_prefix:用于输出节点的前缀。[默认:output_node]

    没有测试过的功能:

    Theano支持尚未经过全面测试。

    依赖项

    • Keras
    • Tensorflow
    • argparse
    • pathlib
    讨论
    Fansisi profile image