-
背景:一个自研模型policy_model转om报错,定位到原因是其模型有控制流算子,需要配合使用summarize_graph与xlacompile工具转换成函数类算子的V2网络模型目的:编译summarize_graph和xlacompile工具,并对客户模型进行V2转换,最后通过ATC转成om模型注:1、本文前四个章节主要总结记录summarize_graph与xlacompile工具的编译,若已有可用环境请忽略2、本文“灰底色”字体为关键执行命令一、 环境要求安装bazel编译工具,详情可参考bazel官方指导链接安装python3.7.5,详情可参考CANN安装指南安装future、patch、numpy:pip3.7.5 install futurepip3.7.5 install patchpip3.7.5 install numpy二、TensorFlow源码准备wget https://github.com/tensorflow/tensorflow/archive/v1.15.0.tar.gztar -zxvf tensorflow-1.15.0.tar.gzcd tensorflow-1.15.0三、编译xlacompile步骤记录1、构造xlacompile.patch补丁vim xlacompile.patch文件内容过长,请参考本文底部附件文件 xlacompile.rar(可以直接拿去用)2、安装补丁把写好的xlacompile.patch文件放到tensorflow-1.15.0目录下,并执行如下命令:patch -p1 < xlacompile.patch3、编译xlacompile工具在tensorflow-1.15.0目录下,执行如下命令:cd tensorflow-1.15.0bazel build --config=monolithic //tensorflow/compiler/aot:xlacompile编译时间大概需要5-10分钟,如果编译失败请参考文档最后的FAQ,或根据具体报错进行分析成功编译截图记录如下:…中间过程很长,省略……中间过程很长,省略…4、找到xlacompile执行文件编译完成后,xlacompile工具会存放在tensorflow-1.15.0/bazel-out/目录下ll bazel-out/k8-opt/bin/tensorflow/compiler/aot/也可以在$HOME/.cache目录下寻找find $HOME/.cache -name xlacompile至此,xlacompile已成功编译生成!四、编译summarize_graph步骤记录1、编译summarize_graph工具切到tensorflow-1.15.0目录,执行如下命令:bazel build --config=monolithic -c opt //tensorflow/tools/graph_transforms:summarize_graph注:一般而言,xlacompile成功编译出来后,summarize_graph能很顺利的编译出来2、找到summarize_graph执行文件编译完成后,summarize_graph工具会存放在tensorflow-1.15.0/bazel-out/目录下:ll bazel-out/k8-opt/bin/tensorflow/tools/graph_transforms/也可以在$HOME/.cache目录下寻找find $HOME/.cache -name summarize_graph至此,summarize_graph已成功编译生成!五、V1 -> V2 -> OM转换步骤记录将上述编译好的两个工具以及需要转换的V1网络模型拷贝到服务器任意目录,如/home/tensorflow_tool/另外本案例网络模型不方便提供,可以参考TensorFlow官网自行写一个控制流单算子模型进行学习使用1、获取V1网络模型的输出算子名./summarize_graph --in_graph=policy_model.pb可以看到policy_model模型的输出算子名为pred2、构造config.pbtxt输出配置文件根据上一步得到的输出算子名进行编辑,格式内容如下:vim config.pbtxt文本内容:(供复制粘贴) -------------------------------------------------- fetch { id { node_name: "实际模型输出算子名xxx" } } -------------------------------------------------- 修改完后记得wq保存退出!3、通过xlacompile转换V2网络模型开启TF调试日志,设置以下环境变量:export TF_CPP_MIN_LOG_LEVEL=0export TF_CPP_MIN_VLOG_LEVEL=1./xlacompile --graph=policy_model.pb --config=config.pbtxt --output=policy_model_V2可以看到最后生成了policy_model_V2.pb和policy_model_V2.pbtxt文件! 4、获取V2网络模型的graph子图ATC将V2网络转成om模型时需要graph子图,使用atc目录下的func2graph.py脚本获取子图:python3.7.5 /usr/local/Ascend/ascend-toolkit/3.3.0.alpha002/x86_64-linux/atc/python/func2graph/func2graph.py -m policy_model_V2.pb可以看到最后生成了graph_def_library.pbtxt文件! 5、通过ATC将V2网络模型转成OMatc --model=policy_model_V2.pb --framework=3 --output=./atc_out/policy_model_V2 --soc_version=Ascend310 --input_shape="input_ids_1:1,510;input_mask_1:1,510;segment_ids_1:1,510"至此,TensorFlow V2网络模型已成功转成适配昇腾AI处理器的离线模型!FAQ问题1:bazel编译xlacompile失败报错1:ERROR: An error occurred during the fetch of repository 'io_bazel_rules_docker':解决办法1:检查服务器能否正常访问外网,若网络畅通,仍报上图错误信息,请参考解决办法2解决办法2:1)在PC端或其他服务器下载下列包:https://github.com/bazelbuild/rules_docker/releases/download/v0.14.3/rules_docker-v0.14.3.tar.gzhttps://github.com/bazelbuild/bazel-skylib/releases/download/0.8.0/bazel-skylib.0.8.0.tar.gzhttps://github.com/bazelbuild/rules_swift/releases/download/0.11.1/rules_swift.0.11.1.tar.gzhttps://github.com/llvm-mirror/llvm/archive/7a7e03f906aada0cf4b749b51213fe5784eeff84.tar.gz上传到服务器任意目录,例如/home/bazel_build/2)切到tensorflow-1.15.0目录,修改WORKSPACE文件cd tensorflow-1.15.0在WORKSPACE文件里添加修改下图内容:vim WORKSPACE文本内容:(供复制粘贴) ------------------------------------------------------------------------------------------ http_archive( name = "io_bazel_rules_docker", sha256 = "6287241e033d247e9da5ff705dd6ef526bac39ae82f3d17de1b69f8cb313f9cd", strip_prefix = "rules_docker-0.14.3", urls = ["file:///home/bazel_build/rules_docker-v0.14.3.tar.gz"], ) http_archive( name = "build_bazel_rules_swift", sha256 = "96a86afcbdab215f8363e65a10cf023b752e90b23abf02272c4fc668fcb70311", urls = ["file:///home/bazel_build/rules_swift.0.11.1.tar.gz"], ) # https://github.com/bazelbuild/rules_swift/releases ------------------------------------------------------------------------------------------ 修改完后记得wq保存退出!文件示例参考附件WORKSPACE.rar(不一定能照搬着用,仅供参考)3)修改tensorflow-1.15.0/tensorflow路径下workspace.bzl文件:vim tensorflow/workspace.bzl文本内容:(供复制粘贴) ------------------------------------------------------------------------------------------ # TODO(phawkins): currently, this rule uses an unofficial LLVM mirror. # Switch to an official source of snapshots if/when possible. tf_http_archive( name = "llvm", build_file = clean_dep("//third_party/llvm:llvm.autogenerated.BUILD"), sha256 = "599b89411df88b9e2be40b019e7ab0f7c9c10dd5ab1c948cd22e678cc8f8f352", strip_prefix = "llvm-7a7e03f906aada0cf4b749b51213fe5784eeff84", urls = [ "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/7a7e03f906aada0cf4b749b51213fe5784eeff84.tar.gz", "file:///home/bazel_build/llvm-7a7e03f906aada0cf4b749b51213fe5784eeff84.tar.gz", ], ) ------------------------------------------------------------------------------------------ 修改完后记得wq保存退出!文件示例参考附件workspace_bzl.rar(不一定能照搬着用,仅供参考)4)再切到tensorflow-1.15.0目录,编译xlacompile工具:bazel build --config=monolithic //tensorflow/compiler/aot:xlacompile报错2:ERROR: An error occurred during the fetch of repository 'bazel_skylib'原因:bazel-skylib相关包下载失败导致解决办法:需要到$HOME/.cache目录下查看哪个文件引用了bazel-skylib.0.8.0.tar.gz,修改完成后重新执行编译命令。使用grep命令查找哪个文件引用了bazel-skylib.0.8.0:grep -r bazel-skylib.0.8.0 $HOME/.cache/根据自己的实际情况修改external/io_bazel_rules_closure/closure/repositories.bzl文件:vim /root/.cache/bazel/_bazel_root/cd6bc27fdb30c52d70cab0c885928e4b/external/io_bazel_rules_closure/closure/repositories.bzl文本内容:(供复制粘贴) ------------------------------------------------------------------------------------------ def bazel_skylib(): http_archive( name = "bazel_skylib", sha256 = "2ef429f5d7ce7111263289644d233707dba35e39696377ebab8b0bc701f7818e", urls = ["file:///home/bazel_build/bazel-skylib.0.8.0.tar.gz"], ) ------------------------------------------------------------------------------------------ 修改完后记得wq保存退出!报错3:in cc_binary: cc_binary rule 'xlacompile' in package 'tensorflow/compiler/aot' conflicts with existing cc_binary rule原因:重复执行patch -p1 < xlacompile.patch导致解决办法1:修改tensorflow/compiler/aot/BUILD文件,把xlacompile相关重复内容删除解决办法2:重新解压tensorflow源码,在新源码目录里只执行一次patch -p1 < xlacompile.patch报错4:ImportError: No module named builtins原因:future没装或找不到解决方法:pip install future注:认真检查自己环境是否存在多个python,bazel编译时默认使用"python"命令进行编译,而不是"python2"或"python3"命令问题2:xlacompile转V2失败报错1:ERROR: This command does not take any arguments other than flags原因:命令错误,“--output=”后跟多了一个空格解决方法:仔细检查自己的命令,不要有错漏或多余字符,--graph、--config、--output三个参数必选问题3:atc转模型失败报错1:E19000: Path[graph_def_library.pbtxt]'s realpath is empty, errmsg[No such file or directory]原因:V2网络模型转OM需要graph子图,atc找不到graph_def_library.pbtxt子图文件解决方法:通过atc自带的func2graph.py脚本获取子图,此脚本一般位于atc安装目录下:find / -name func2graph.pypython3.7.5 /usr/local/Ascend/ascend-toolkit/3.3.0.alpha002/x86_64-linux/atc/python/func2graph/func2graph.py -m policy_model_V2.pb获取子图后再进行ATC转换即可!附录一TF_CPP_MIN_LOG_LEVEL日志级别信息:TF_CPP_MIN_LOG_LEVELbase_loging屏蔽信息输出信息0INFO无INFO + WARNING + ERROR + FATAL1WARNINGINFOWARNING + ERROR + FATAL2ERRORINFO+WARNINGERROR + FATAL3FATALINFO+WARNING+ERRORFATAL
-
【功能模块】创建新项目【操作步骤&问题现象】1、从tensorflow创建新算子2、加载进入工程界面就报下图的错误【截图信息】
-
tensorflow模型pb或者ckpt可以直接运行吗?
-
1.对于aarch64架构,由于tensorflow依赖h5py,而h5py依赖HDF5,需要先编译安装HDF5,否则使用pip安装h5py会报错,以下步骤以root用户操作。 ### 编译安装HDF5 ``` wget https://support.hdfgroup.org/ftp/HDF5/releases/hdf5-1.10/hdf5-1.10.5/src/hdf5-1.10.5.tar.gz --no-check-certificate tar -zxvf hdf5-1.10.5.tar.gz cd hdf5-1.10.5/ ./configure --prefix=/usr/include/hdf5 make make install ```   配置环境变量并建立动态链接库软连接 ``` vi ~/.bashrc export CPATH="/usr/include/hdf5/include/:/usr/include/hdf5/lib/" source ~/.bashrc ```  root用户建立动态链接库软连接命令如下,非root用户需要在以下命令前添加sudo ``` ln -s /usr/include/hdf5/lib/libhdf5.so /usr/lib/libhdf5.so ln -s /usr/include/hdf5/lib/libhdf5_hl.so /usr/lib/libhdf5_hl.so ``` ### 安装h5py ``` pip3.7 install Cython pip3.7 install h5py==2.8.0 ```  ### TensorFlow软件包下载链接:https://bbs.huaweicloud.com/forum/thread-81557-1-1.html 下载附件的所有whl包  然后放在同一级目录解压任一个,即可得到文件tensorflow-1.15.0-cp37-cp37m-linux_aarch64.whl 上传到环境中pip安装 ``` pip3 install tensorflow-1.15.0-cp37-cp37m-linux_aarch64.whl 安装过程很多依赖,缺什么装什么,然后再装tensorflow 最后终于装上了 ```  ### 安装TensorFlow Adapter 下载和cann包配套的tfplugin包 https://www.hiascend.com/software/ai-frameworks/commercial-tf 然后./Ascend-cann-tfplugin_5.0.2_linux-aarch64.run --install  安装后,设置环境变量添加到PYTHONPATH中。参考如下图  source ~/.bashrc 至此Ascend910的tf训练环境搭建完成。后面即可开始tf训练。
-
【功能模块】安装tensorflow【操作步骤&问题现象】1、创建了MindSpore-1.1.1,TensorFlow1.15.0的公共镜像2、升级了pip的版本,根据提示安装相关依赖3、在安装tensorflow的时候报错Could not find a version that satisfies the requirement tensorflow4、尝试换其他的pip源仍然报相同的错【截图信息】【日志信息】(可选,上传日志内容或者附件)
-
【功能模块】【操作步骤&问题现象1、按照文档MindStudio 版本:2.0.0(beta3)里面的SDK样例工程使用指导,下载ResNet50,要求是TensorFlow框架的,然后按照链接去找发现官网提供的是MindSpore(1.2.0)框架的,不知道该下载哪一个2、【截图信息】【日志信息】(可选,上传日志内容或者附件)
-
先将.weights通过https://github.com/hunglc007/tensorflow-yolov4-tflite转为了tensorflow2.x的checkpoint文件夹,通过该链接的https://bbs.huaweicloud.com/forum/thread-92742-1-1.html 的2.2将checkpoint文件夹转为pb文件,再通过这个链接下的教程将pb转换为om,https://gitee.com/ascend/samples/tree/master/cplusplus/level2_simple_inference/2_object_detection/YOLOV4_coco_detection_picture输入命令atc --input_shape="Input:1,416,416,3" --output=./yolov4 --insert_op_conf=./insert_op.cfg --framework=3 --model=./WRN.pb --soc_version=Ascend310报错:ATC start working now, please wait for a moment.[WARNING] TBE(4561,python3):2021-06-26-17:51:44.079.461 /home/ascend/Ascend/ascend-toolkit/5.0.2.alpha003/x86_64-linux/atc/python/site-packages/te/tvm/contrib/ccec.py:581: DeprecationWarning: invalid escape sequence \L[WARNING] TBE(4556,python3):2021-06-26-17:51:44.088.932 /home/ascend/Ascend/ascend-toolkit/5.0.2.alpha003/x86_64-linux/atc/python/site-packages/te/tvm/contrib/ccec.py:581: DeprecationWarning: invalid escape sequence \L[WARNING] TBE(4559,python3):2021-06-26-17:51:44.094.376 /home/ascend/Ascend/ascend-toolkit/5.0.2.alpha003/x86_64-linux/atc/python/site-packages/te/tvm/contrib/ccec.py:581: DeprecationWarning: invalid escape sequence \L[WARNING] TBE(4557,python3):2021-06-26-17:51:44.126.388 /home/ascend/Ascend/ascend-toolkit/5.0.2.alpha003/x86_64-linux/atc/python/site-packages/te/tvm/contrib/ccec.py:581: DeprecationWarning: invalid escape sequence \L[WARNING] TBE(4554,python3):2021-06-26-17:51:44.141.682 /home/ascend/Ascend/ascend-toolkit/5.0.2.alpha003/x86_64-linux/atc/python/site-packages/te/tvm/contrib/ccec.py:581: DeprecationWarning: invalid escape sequence \L[WARNING] TBE(4555,python3):2021-06-26-17:51:44.165.966 /home/ascend/Ascend/ascend-toolkit/5.0.2.alpha003/x86_64-linux/atc/python/site-packages/te/tvm/contrib/ccec.py:581: DeprecationWarning: invalid escape sequence \L[WARNING] TBE(4558,python3):2021-06-26-17:51:44.175.854 /home/ascend/Ascend/ascend-toolkit/5.0.2.alpha003/x86_64-linux/atc/python/site-packages/te/tvm/contrib/ccec.py:581: DeprecationWarning: invalid escape sequence \L[WARNING] TBE(4560,python3):2021-06-26-17:51:44.248.715 /home/ascend/Ascend/ascend-toolkit/5.0.2.alpha003/x86_64-linux/atc/python/site-packages/te/tvm/contrib/ccec.py:581: DeprecationWarning: invalid escape sequence \L[WARNING] TBE(4561,python3):2021-06-26-17:51:44.332.412 /home/ascend/Ascend/ascend-toolkit/latest/atc/python/site-packages/te/platform/cce_policy.py:42: DeprecationWarning: te.platform.cce_policy.OpImplPolicy is deprecated, use te_fusion.fusion_util.OpImplPolicy instead.[WARNING] TBE(4561,python3):2021-06-26-17:51:44.343.208 /home/ascend/Ascend/ascend-toolkit/latest/atc/python/site-packages/te/platform/cce_conf.py:454: DeprecationWarning: 'from te.platform import te_set_version' is expired, please replace it with 'from tbe.common.platform import set_current_compile_soc_info'.[WARNING] TBE(4556,python3):2021-06-26-17:51:44.385.104 /home/ascend/Ascend/ascend-toolkit/latest/atc/python/site-packages/te/platform/cce_policy.py:42: DeprecationWarning: te.platform.cce_policy.OpImplPolicy is deprecated, use te_fusion.fusion_util.OpImplPolicy instead.[WARNING] TBE(4557,python3):2021-06-26-17:51:44.398.470 /home/ascend/Ascend/ascend-toolkit/latest/atc/python/site-packages/te/platform/cce_policy.py:42: DeprecationWarning: te.platform.cce_policy.OpImplPolicy is deprecated, use te_fusion.fusion_util.OpImplPolicy instead.[WARNING] TBE(4556,python3):2021-06-26-17:51:44.399.924 /home/ascend/Ascend/ascend-toolkit/latest/atc/python/site-packages/te/platform/cce_conf.py:454: DeprecationWarning: 'from te.platform import te_set_version' is expired, please replace it with 'from tbe.common.platform import set_current_compile_soc_info'.[WARNING] TBE(4559,python3):2021-06-26-17:51:44.407.418 /home/ascend/Ascend/ascend-toolkit/latest/atc/python/site-packages/te/platform/cce_policy.py:42: DeprecationWarning: te.platform.cce_policy.OpImplPolicy is deprecated, use te_fusion.fusion_util.OpImplPolicy instead.[WARNING] TBE(4559,python3):2021-06-26-17:51:44.416.968 /home/ascend/Ascend/ascend-toolkit/latest/atc/python/site-packages/te/platform/cce_conf.py:454: DeprecationWarning: 'from te.platform import te_set_version' is expired, please replace it with 'from tbe.common.platform import set_current_compile_soc_info'.[WARNING] TBE(4557,python3):2021-06-26-17:51:44.420.239 /home/ascend/Ascend/ascend-toolkit/latest/atc/python/site-packages/te/platform/cce_conf.py:454: DeprecationWarning: 'from te.platform import te_set_version' is expired, please replace it with 'from tbe.common.platform import set_current_compile_soc_info'.[WARNING] TBE(4558,python3):2021-06-26-17:51:44.454.457 /home/ascend/Ascend/ascend-toolkit/latest/atc/python/site-packages/te/platform/cce_policy.py:42: DeprecationWarning: te.platform.cce_policy.OpImplPolicy is deprecated, use te_fusion.fusion_util.OpImplPolicy instead.[WARNING] TBE(4554,python3):2021-06-26-17:51:44.468.369 /home/ascend/Ascend/ascend-toolkit/latest/atc/python/site-packages/te/platform/cce_policy.py:42: DeprecationWarning: te.platform.cce_policy.OpImplPolicy is deprecated, use te_fusion.fusion_util.OpImplPolicy instead.[WARNING] TBE(4555,python3):2021-06-26-17:51:44.470.209 /home/ascend/Ascend/ascend-toolkit/latest/atc/python/site-packages/te/platform/cce_policy.py:42: DeprecationWarning: te.platform.cce_policy.OpImplPolicy is deprecated, use te_fusion.fusion_util.OpImplPolicy instead.[WARNING] TBE(4555,python3):2021-06-26-17:51:44.484.863 /home/ascend/Ascend/ascend-toolkit/latest/atc/python/site-packages/te/platform/cce_conf.py:454: DeprecationWarning: 'from te.platform import te_set_version' is expired, please replace it with 'from tbe.common.platform import set_current_compile_soc_info'.[WARNING] TBE(4554,python3):2021-06-26-17:51:44.486.700 /home/ascend/Ascend/ascend-toolkit/latest/atc/python/site-packages/te/platform/cce_conf.py:454: DeprecationWarning: 'from te.platform import te_set_version' is expired, please replace it with 'from tbe.common.platform import set_current_compile_soc_info'.[WARNING] TBE(4558,python3):2021-06-26-17:51:44.489.757 /home/ascend/Ascend/ascend-toolkit/latest/atc/python/site-packages/te/platform/cce_conf.py:454: DeprecationWarning: 'from te.platform import te_set_version' is expired, please replace it with 'from tbe.common.platform import set_current_compile_soc_info'.[WARNING] TBE(4560,python3):2021-06-26-17:51:44.582.433 /home/ascend/Ascend/ascend-toolkit/latest/atc/python/site-packages/te/platform/cce_policy.py:42: DeprecationWarning: te.platform.cce_policy.OpImplPolicy is deprecated, use te_fusion.fusion_util.OpImplPolicy instead.[WARNING] TBE(4560,python3):2021-06-26-17:51:44.603.451 /home/ascend/Ascend/ascend-toolkit/latest/atc/python/site-packages/te/platform/cce_conf.py:454: DeprecationWarning: 'from te.platform import te_set_version' is expired, please replace it with 'from tbe.common.platform import set_current_compile_soc_info'.ATC run failed, Please check the detail log, Try 'atc --help' for more informationE12012: Input op[Input] not found in graph.使用wget https://nkxiaolei88.obs.cn-north-1.myhuaweicloud.com/ATC%20Model/YoloV4/yolov4_no_postprocess.pb下载的pb文件可以顺利转化为om
-
【功能模块】InvertPermutation算子:计算张量的逆排列逻辑为:y[x[i]] = i for i in [0, 1, ..., len(x) - 1]比如:输入[3, 4, 0, 2, 1]对应的输出为[3, 4, 0, 2, 1]【操作步骤&问题现象】1、在性能测试过程中:程序中使用了多线程后,处理int32类型的数据时,性能与tensorflow对应的算子相持平;但对于int64类型的数据,当数据量增大时(64k, 1M, 2M, 8M, 128M),编写的算子运行时间要比tensorflow中的算子高3%到6%。2、请问该怎么提升一下算子的性能。【截图信息】附件中包含算子实现文件,算子性能数据以及我编写的代码与TF中代码的对比描述。
-
【功能模块】Mindstudio中创建新的算子工程,选择是“Tensorflow”框架,里面“operator type”不知道怎么选择,算子是我自定义要写的,怎么选择算子类型?Mindstudio使用指南参考了,里面讲解的没看懂。【截图信息】【日志信息】(可选,上传日志内容或者附件)
-
【功能模块】尽管tensorflow有该算子,但是在开发时我找不到这个算子的功能描述或参考代码,因此无法模仿Tensorflow的代码来开发,能否提供所需开发的算子的功能描述和对标代码【操作步骤&问题现象】1、2、【截图信息】【日志信息】(可选,上传日志内容或者附件)
-
【功能模块】 contrib【操作步骤&问题现象】报错内容root@ubuntu:~/cby/sj# python3.7.5 resnet.py /usr/local/python3.7.5/lib/python3.7/site-packages/pandas/compat/__init__.py:97: UserWarning: Could not import the lzma module. Your installed Python is incomplete. Attempting to use lzma compression will result in a RuntimeError. warnings.warn(msg)WARNING:tensorflow:The TensorFlow contrib module will not be included in TensorFlow 2.0.For more information, please see: * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md * https://github.com/tensorflow/addons * https://github.com/tensorflow/io (for I/O related ops)If you depend on functionality not listed there, please file an issue.Traceback (most recent call last): File "resnet.py", line 152, in <module> main() File "resnet.py", line 65, in main processed_data = np.load(INPUT_DATA) File "/usr/local/python3.7.5/lib/python3.7/site-packages/numpy/lib/npyio.py", line 453, in load pickle_kwargs=pickle_kwargs) File "/usr/local/python3.7.5/lib/python3.7/site-packages/numpy/lib/format.py", line 722, in read_array raise ValueError("Object arrays cannot be loaded when "ValueError: Object arrays cannot be loaded when allow_pickle=Falseroot@ubuntu:~/cby/sj# pip3.7.5 list | grep tensorflowtensorflow 1.15.0tensorflow-estimator 1.15.1root@ubuntu:~/cby/sj# 代码内容import numpy as npimport tensorflow as tfimport tensorflow.contrib.slim as slim# 加载通过slim定义好的resnet_v1模型import tensorflow.contrib.slim.python.slim.nets.resnet_v1 as resnet_v1# 数据文件INPUT_DATA = "./flower_processed_data.npy"# 保存训练好的模型TRAIN_FILE = "./save_model/my_model"# 提供的已经训练好的模型CKPT_FILE = "./resnet_v1_50.ckpt"# 定义训练所用参数LEARNING_RATE = 0.0001STEPS = 500BATCH = 32N_CLASSES = 5# 这里指出了不需要从训练好的模型中加载的参数,就是最后的自定义的全连接层CHECKPOINT_EXCLUDE_SCOPES = 'Logits'# 指定最后的全连接层为可训练的参数TRAINABLE_SCOPES = 'Logits'# 加载所有需要从训练好的模型加载的参数def get_tuned_variables(): ##不需要加载的范围 exclusions = [scope.strip() for scope in CHECKPOINT_EXCLUDE_SCOPES.split(",")] # 初始化需要加载的参数 variables_to_restore = [] # 遍历模型中的所有参数 for var in slim.get_model_variables(): # 先指定为不需要移除 excluded = False # 遍历exclusions,如果在exclusions中,就指定为需要移除 for exclusion in exclusions: if var.op.name.startswith(exclusion): excluded = True break # 如果遍历完后还是不需要移除,就把参数加到列表里 if not excluded: variables_to_restore.append(var) return variables_to_restore# 获取所有需要训练的参数def get_trainable_variables(): # 同上 scopes = [scope.strip() for scope in TRAINABLE_SCOPES.split(",")] variables_to_train = [] # 枚举所有需要训练的参数的前缀,并找到这些前缀的所有参数 for scope in scopes: variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope) variables_to_train.extend(variables) return variables_to_traindef main(): # 加载数据 processed_data = np.load(INPUT_DATA) training_images = processed_data[0] n_training_example = len(training_images) training_labels = processed_data[1] validation_images = processed_data[2] validation_labels = processed_data[3] testing_images = processed_data[4] testing_labels = processed_data[5] print("there is %d training examples, %d validation examples, %d testing examples" % (n_training_example, len(validation_labels), len(testing_labels))) # 定义数据格式 images = tf.placeholder(tf.float32, [None, 300, 300, 3], name='input_images') labels = tf.placeholder(tf.int64, [None], name='labels') # 定义模型,因为给出的只有参数,并没有模型,这里需要指定模型的具体结构 with slim.arg_scope(resnet_v1.resnet_arg_scope()): # logits就是最后预测值,images就是输入数据,指定num_classes=None是为了使resnet模型最后的输出层禁用 logits, _ = resnet_v1.resnet_v1_50(images, num_classes=None) #自定义的输出层 with tf.variable_scope("Logits"): #将原始模型的输出数据去掉维度为2和3的维度,最后只剩维度1的batch数和维度4的300*300*3 #也就是将原来的二三四维度全部压缩到第四维度 net = tf.squeeze(logits, axis=[1,2]) #加入一层dropout层 net = slim.dropout(net, keep_prob=0.5,scope='dropout_scope') #加入一层全连接层,指定最后输出大小 logits = slim.fully_connected(net, num_outputs=N_CLASSES, scope='fc') # 获取需要训练的变量 trainable_variables = get_trainable_variables() # 定义损失,模型定义的时候已经考虑了正则化了 tf.losses.softmax_cross_entropy(tf.one_hot(labels, N_CLASSES), logits, weights=1.0) # 定义训练过程 train_step = tf.train.RMSPropOptimizer(LEARNING_RATE).minimize(tf.losses.get_total_loss()) # 定义测试和验证过程 with tf.name_scope('evaluation'): correct_prediction = tf.equal(tf.argmax(logits, 1), labels) evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) # 定义加载模型的函数,就是重新定义load_fn函数,从文件中获取参数,获取指定的变量,忽略缺省值 load_fn = slim.assign_from_checkpoint_fn(CKPT_FILE, get_tuned_variables(), ignore_missing_vars=True) # 定义保存新的训练好的模型的函数 saver = tf.train.Saver() with tf.Session() as sess: # 初始化没有加载进来的变量,一定要在模型加载之前,否则会将训练好的参数重新赋值 init = tf.global_variables_initializer() sess.run(init) # 加载训练好的模型 print("加载谷歌训练好的模型...") load_fn(sess) start = 0 end = BATCH for i in range(STEPS): # 训练... sess.run(train_step, feed_dict={images: training_images[start:end], labels: training_labels[start:end]}) # 间断地保存模型,并在验证集上验证 if i % 50 == 0 or i + 1 == STEPS: saver.save(sess, TRAIN_FILE, global_step=i) validation_accuracy = sess.run(evaluation_step, feed_dict={images: validation_images, labels: validation_labels}) print("经过%d次训练后,在验证集上的正确率为%.3f" % (i, validation_accuracy)) # 更新起始和末尾 start = end if start == n_training_example: start = 0 end = start + BATCH if end > n_training_example: end = n_training_example # 训练完了在测试集上测试正确率 testing_accuracy = sess.run(evaluation_step, feed_dict={images: testing_images, labels: testing_labels}) print("最后在测试集上的正确率为%.3f" % testing_accuracy)if __name__ == '__main__': main()是需要升级代码吗?具体怎么升级?tensorflow降级的话需要降级到哪个版本?在哪里下载?arm架构cpu
-
【功能模块】st测试【操作步骤&问题现象】1、执行ST测试显示如下提示:2、然后查看算子运行时间时,发现执行的是tensorflow的算子,请问该如何解决。【日志信息】(可选,上传日志内容或者附件)附件中是日志信息
-
【功能模块】A310推理卡【操作步骤&问题现象】1、A310推理卡现在支持的tensorflow版本是哪个版本?【截图信息】【日志信息】(可选,上传日志内容或者附件)
-
之前在Linux X86上利用keras训练的YOLOV3生成的.h5文件,先将h5文件转换成了适合TensorFlow的.pd文件,然后通过Mindstudio将pd文件转换为om文件。在Mindstudio转换的时候出现如下错误,还请专家帮忙指导一下!
-
【功能模块】 请问哪里可以找到反向算子的Tensorflow/PyTorch实现代码,pytorch上只有正向的算子实现 【截图信息】 这是参考文档里的反向函数实现代码 
推荐直播
-
HDC深度解读系列 - Serverless与MCP融合创新,构建AI应用全新智能中枢2025/08/20 周三 16:30-18:00
张昆鹏 HCDG北京核心组代表
HDC2025期间,华为云展示了Serverless与MCP融合创新的解决方案,本期访谈直播,由华为云开发者专家(HCDE)兼华为云开发者社区组织HCDG北京核心组代表张鹏先生主持,华为云PaaS服务产品部 Serverless总监Ewen为大家深度解读华为云Serverless与MCP如何融合构建AI应用全新智能中枢
回顾中 -
关于RISC-V生态发展的思考2025/09/02 周二 17:00-18:00
中国科学院计算技术研究所副所长包云岗教授
中科院包云岗老师将在本次直播中,探讨处理器生态的关键要素及其联系,分享过去几年推动RISC-V生态建设实践过程中的经验与教训。
回顾中 -
一键搞定华为云万级资源,3步轻松管理企业成本2025/09/09 周二 15:00-16:00
阿言 华为云交易产品经理
本直播重点介绍如何一键续费万级资源,3步轻松管理成本,帮助提升日常管理效率!
回顾中
热门标签