36#include <tensorflow/core/platform/env.h>
43namespace tf = tensorflow;
55 tf::GraphDef graph_def;
56 tf::Status status = tf::ReadBinaryProto(tf::Env::Default(), file, &graph_def);
58 throw std::runtime_error(
"Failed to load frozen graph: " +
75 const tf::GraphDef& graph_def) {
77 tf::Status status = session->Create(graph_def);
79 throw std::runtime_error(
"Failed to load graph into session: " +
98 const std::string& file,
const bool allow_growth =
true,
99 const double per_process_gpu_memory_fraction = 0,
100 const std::string& visible_device_list =
"") {
104 allow_growth, per_process_gpu_memory_fraction, visible_device_list);
119 const tf::GraphDef& graph_def) {
121 std::vector<std::string> input_nodes;
122 for (
const tf::NodeDef& node : graph_def.node()) {
123 if (node.op() ==
"Placeholder") input_nodes.push_back(node.name());
138 const tf::GraphDef& graph_def) {
140 std::vector<std::string> output_nodes;
141 std::vector<std::string> nodes_with_outputs;
142 std::unordered_set<std::string> unlikely_output_ops = {
"Const",
"Assign",
143 "NoOp",
"Placeholder",
145 for (
const tf::NodeDef& node : graph_def.node()) {
146 for (
const std::string& input_name : node.input())
147 nodes_with_outputs.push_back(input_name);
149 for (
const tf::NodeDef& node : graph_def.node()) {
150 if (std::find(nodes_with_outputs.begin(), nodes_with_outputs.end(),
151 node.name()) == nodes_with_outputs.end() &&
152 unlikely_output_ops.count(node.op()) == 0)
153 output_nodes.push_back(node.name());
169 const std::string& node_name) {
171 std::vector<int> node_shape;
172 for (
const tf::NodeDef& node : graph_def.node()) {
173 if (node.name() == node_name) {
174 if (node.attr().count(
"shape") == 0)
return node_shape;
175 auto shape = node.attr().at(
"shape").shape();
176 for (
int d = 0; d < shape.dim_size(); d++)
177 node_shape.push_back(shape.dim(d).size());
195 const std::string& node_name) {
197 tf::DataType type = tf::DT_INVALID;
198 for (
const tf::NodeDef& node : graph_def.node()) {
199 if (node.name() == node_name) {
200 if (node.attr().count(
"dtype") == 0)
return type;
201 type = node.attr().at(
"dtype").type();
223 std::stringstream ss;
224 ss <<
"FrozenGraph Info:" << std::endl;
229 ss <<
"Inputs: " << inputs.size() << std::endl;
230 for (
const auto& name : inputs) {
233 ss <<
" " << name << std::endl;
235 for (
int d = 0; d < shape.size(); d++) {
236 ss << shape[d] <<
", ";
238 ss <<
"]" << std::endl;
239 ss <<
" DataType: " << tf::DataTypeString(dtype) << std::endl;
242 ss <<
"Outputs: " << outputs.size() << std::endl;
243 for (
const auto& name : outputs) {
246 ss <<
" " << name << std::endl;
248 for (
int d = 0; d < shape.size(); d++) {
249 ss << shape[d] <<
", ";
251 ss <<
"]" << std::endl;
252 ss <<
" DataType: " << tf::DataTypeString(dtype) << std::endl;
Namespace for tensorflow_cpp library.
std::vector< std::string > getGraphOutputNames(const tf::GraphDef &graph_def)
Determines the names of all graph output nodes.
std::string getGraphInfoString(const tf::GraphDef &graph_def)
tf::GraphDef loadFrozenGraph(const std::string &file)
Loads a TensorFlow graph from a frozen graph file.
std::vector< int > getGraphNodeShape(const tf::GraphDef &graph_def, const std::string &node_name)
Determines the shape of a given graph node.
tf::Session * loadFrozenGraphIntoNewSession(const std::string &file, const bool allow_growth=true, const double per_process_gpu_memory_fraction=0, const std::string &visible_device_list="")
Loads a TensorFlow graph from a frozen graph file into a new session.
std::vector< std::string > getGraphInputNames(const tf::GraphDef &graph_def)
Determines the names of all graph input nodes.
bool loadGraphIntoSession(tf::Session *session, const tf::GraphDef &graph_def)
Loads a TensorFlow graph into an existing session.
tf::DataType getGraphNodeType(const tf::GraphDef &graph_def, const std::string &node_name)
Determines the datatype of a given graph node.
tf::Session * createSession(const bool allow_growth=true, const double per_process_gpu_memory_fraction=0, const std::string &visible_device_list="")
Creates a new TensorFlow session.
Utility functions for TensorFlow backend.