tensorflow_cpp 1.0.6
Loading...
Searching...
No Matches
graph_utils.h
Go to the documentation of this file.
1/*
2==============================================================================
3MIT License
4Copyright 2022 Institute for Automotive Engineering of RWTH Aachen University.
5Permission is hereby granted, free of charge, to any person obtaining a copy
6of this software and associated documentation files (the "Software"), to deal
7in the Software without restriction, including without limitation the rights
8to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9copies of the Software, and to permit persons to whom the Software is
10furnished to do so, subject to the following conditions:
11The above copyright notice and this permission notice shall be included in all
12copies or substantial portions of the Software.
13THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19SOFTWARE.
20==============================================================================
21*/
22
28#pragma once
29
30#include <algorithm>
31#include <sstream>
32#include <stdexcept>
33#include <string>
34#include <vector>
35
36#include <tensorflow/core/platform/env.h>
38
39
40namespace tensorflow_cpp {
41
42
43namespace tf = tensorflow;
44
45
53inline tf::GraphDef loadFrozenGraph(const std::string& file) {
54
55 tf::GraphDef graph_def;
56 tf::Status status = tf::ReadBinaryProto(tf::Env::Default(), file, &graph_def);
57 if (!status.ok())
58 throw std::runtime_error("Failed to load frozen graph: " +
59 status.ToString());
60
61 return graph_def;
62}
63
64
74inline bool loadGraphIntoSession(tf::Session* session,
75 const tf::GraphDef& graph_def) {
76
77 tf::Status status = session->Create(graph_def);
78 if (!status.ok())
79 throw std::runtime_error("Failed to load graph into session: " +
80 status.ToString());
81
82 return true;
83}
84
98 const std::string& file, const bool allow_growth = true,
99 const double per_process_gpu_memory_fraction = 0,
100 const std::string& visible_device_list = "") {
101
102 tf::GraphDef graph_def = loadFrozenGraph(file);
103 tf::Session* session = createSession(
104 allow_growth, per_process_gpu_memory_fraction, visible_device_list);
105 if (!loadGraphIntoSession(session, graph_def)) return nullptr;
106
107 return session;
108}
109
110
118inline std::vector<std::string> getGraphInputNames(
119 const tf::GraphDef& graph_def) {
120
121 std::vector<std::string> input_nodes;
122 for (const tf::NodeDef& node : graph_def.node()) {
123 if (node.op() == "Placeholder") input_nodes.push_back(node.name());
124 }
125
126 return input_nodes;
127}
128
129
137inline std::vector<std::string> getGraphOutputNames(
138 const tf::GraphDef& graph_def) {
139
140 std::vector<std::string> output_nodes;
141 std::vector<std::string> nodes_with_outputs;
142 std::unordered_set<std::string> unlikely_output_ops = {"Const", "Assign",
143 "NoOp", "Placeholder",
144 "Assert"};
145 for (const tf::NodeDef& node : graph_def.node()) {
146 for (const std::string& input_name : node.input())
147 nodes_with_outputs.push_back(input_name);
148 }
149 for (const tf::NodeDef& node : graph_def.node()) {
150 if (std::find(nodes_with_outputs.begin(), nodes_with_outputs.end(),
151 node.name()) == nodes_with_outputs.end() &&
152 unlikely_output_ops.count(node.op()) == 0)
153 output_nodes.push_back(node.name());
154 }
155
156 return output_nodes;
157}
158
159
168inline std::vector<int> getGraphNodeShape(const tf::GraphDef& graph_def,
169 const std::string& node_name) {
170
171 std::vector<int> node_shape;
172 for (const tf::NodeDef& node : graph_def.node()) {
173 if (node.name() == node_name) {
174 if (node.attr().count("shape") == 0) return node_shape;
175 auto shape = node.attr().at("shape").shape();
176 for (int d = 0; d < shape.dim_size(); d++)
177 node_shape.push_back(shape.dim(d).size());
178 break;
179 }
180 }
181
182 return node_shape;
183}
184
185
194inline tf::DataType getGraphNodeType(const tf::GraphDef& graph_def,
195 const std::string& node_name) {
196
197 tf::DataType type = tf::DT_INVALID;
198 for (const tf::NodeDef& node : graph_def.node()) {
199 if (node.name() == node_name) {
200 if (node.attr().count("dtype") == 0) return type;
201 type = node.attr().at("dtype").type();
202 break;
203 }
204 }
205 return type;
206}
207
208
221inline std::string getGraphInfoString(const tf::GraphDef& graph_def) {
222
223 std::stringstream ss;
224 ss << "FrozenGraph Info:" << std::endl;
225
226 const std::vector<std::string> inputs = getGraphInputNames(graph_def);
227 const std::vector<std::string> outputs = getGraphOutputNames(graph_def);
228
229 ss << "Inputs: " << inputs.size() << std::endl;
230 for (const auto& name : inputs) {
231 const auto& shape = getGraphNodeShape(graph_def, name);
232 const auto& dtype = getGraphNodeType(graph_def, name);
233 ss << " " << name << std::endl;
234 ss << " Shape: [ ";
235 for (int d = 0; d < shape.size(); d++) {
236 ss << shape[d] << ", ";
237 }
238 ss << "]" << std::endl;
239 ss << " DataType: " << tf::DataTypeString(dtype) << std::endl;
240 }
241
242 ss << "Outputs: " << outputs.size() << std::endl;
243 for (const auto& name : outputs) {
244 const auto& shape = getGraphNodeShape(graph_def, name);
245 const auto& dtype = getGraphNodeType(graph_def, name);
246 ss << " " << name << std::endl;
247 ss << " Shape: [ ";
248 for (int d = 0; d < shape.size(); d++) {
249 ss << shape[d] << ", ";
250 }
251 ss << "]" << std::endl;
252 ss << " DataType: " << tf::DataTypeString(dtype) << std::endl;
253 }
254
255 return ss.str();
256}
257
258
259} // namespace tensorflow_cpp
Namespace for tensorflow_cpp library.
Definition graph_utils.h:40
std::vector< std::string > getGraphOutputNames(const tf::GraphDef &graph_def)
Determines the names of all graph output nodes.
std::string getGraphInfoString(const tf::GraphDef &graph_def)
tf::GraphDef loadFrozenGraph(const std::string &file)
Loads a TensorFlow graph from a frozen graph file.
Definition graph_utils.h:53
std::vector< int > getGraphNodeShape(const tf::GraphDef &graph_def, const std::string &node_name)
Determines the shape of a given graph node.
tf::Session * loadFrozenGraphIntoNewSession(const std::string &file, const bool allow_growth=true, const double per_process_gpu_memory_fraction=0, const std::string &visible_device_list="")
Loads a TensorFlow graph from a frozen graph file into a new session.
Definition graph_utils.h:97
std::vector< std::string > getGraphInputNames(const tf::GraphDef &graph_def)
Determines the names of all graph input nodes.
bool loadGraphIntoSession(tf::Session *session, const tf::GraphDef &graph_def)
Loads a TensorFlow graph into an existing session.
Definition graph_utils.h:74
tf::DataType getGraphNodeType(const tf::GraphDef &graph_def, const std::string &node_name)
Determines the datatype of a given graph node.
tf::Session * createSession(const bool allow_growth=true, const double per_process_gpu_memory_fraction=0, const std::string &visible_device_list="")
Creates a new TensorFlow session.
Definition utils.h:78
Utility functions for TensorFlow backend.