tensorflow_cpp 1.0.6
Loading...
Searching...
No Matches
Classes | Functions
tensorflow_cpp Namespace Reference

Namespace for tensorflow_cpp library. More...

Classes

class  Model
 Wrapper class for running TensorFlow SavedModels or FrozenGraphs. More...
 

Functions

tf::GraphDef loadFrozenGraph (const std::string &file)
 Loads a TensorFlow graph from a frozen graph file.
 
bool loadGraphIntoSession (tf::Session *session, const tf::GraphDef &graph_def)
 Loads a TensorFlow graph into an existing session.
 
tf::Session * loadFrozenGraphIntoNewSession (const std::string &file, const bool allow_growth=true, const double per_process_gpu_memory_fraction=0, const std::string &visible_device_list="")
 Loads a TensorFlow graph from a frozen graph file into a new session.
 
std::vector< std::string > getGraphInputNames (const tf::GraphDef &graph_def)
 Determines the names of all graph input nodes.
 
std::vector< std::string > getGraphOutputNames (const tf::GraphDef &graph_def)
 Determines the names of all graph output nodes.
 
std::vector< int > getGraphNodeShape (const tf::GraphDef &graph_def, const std::string &node_name)
 Determines the shape of a given graph node.
 
tf::DataType getGraphNodeType (const tf::GraphDef &graph_def, const std::string &node_name)
 Determines the datatype of a given graph node.
 
std::string getGraphInfoString (const tf::GraphDef &graph_def)
 
tf::SavedModelBundleLite loadSavedModel (const std::string &dir, const bool allow_growth=true, const double per_process_gpu_memory_fraction=0, const std::string &visible_device_list="")
 Loads a TensorFlow SavedModel from a directory into a new session.
 
tf::Session * loadSavedModelIntoNewSession (const std::string &dir, const bool allow_growth=true, const double per_process_gpu_memory_fraction=0, const std::string &visible_device_list="")
 Loads a TensorFlow SavedModel from a directory into a new session.
 
tf::Session * getSessionFromSavedModel (const tf::SavedModelBundleLite &saved_model)
 Returns the session that a SavedModel is loaded in.
 
std::string getSavedModelNodeByLayerName (const tf::SavedModelBundleLite &saved_model, const std::string &layer_name, const std::string &signature="serving_default")
 Determines the node name from a SavedModel layer name.
 
std::string getSavedModelLayerByNodeName (const tf::SavedModelBundleLite &saved_model, const std::string &node_name, const std::string &signature="serving_default")
 Determines the layer name from a SavedModel node name.
 
std::vector< std::string > getSavedModelInputNames (const tf::SavedModelBundleLite &saved_model, const bool layer_names=false, const std::string &signature="serving_default")
 Determines the names of the SavedModel input nodes.
 
std::vector< std::string > getSavedModelOutputNames (const tf::SavedModelBundleLite &saved_model, const bool layer_names=false, const std::string &signature="serving_default")
 Determines the names of the SavedModel output nodes.
 
std::vector< int > getSavedModelNodeShape (const tf::SavedModelBundleLite &saved_model, const std::string &node_name, const std::string &signature="serving_default")
 Determines the shape of a given SavedModel node.
 
tf::DataType getSavedModelNodeType (const tf::SavedModelBundleLite &saved_model, const std::string &node_name, const std::string &signature="serving_default")
 Determines the datatype of a given SavedModel node.
 
std::string getSavedModelInfoString (const tf::SavedModelBundleLite &saved_model)
 
tf::SessionOptions makeSessionOptions (const bool allow_growth=true, const double per_process_gpu_memory_fraction=0, const std::string &visible_device_list="")
 Helps to quickly create SessionOptions.
 
tf::Session * createSession (const bool allow_growth=true, const double per_process_gpu_memory_fraction=0, const std::string &visible_device_list="")
 Creates a new TensorFlow session.
 

Detailed Description

Namespace for tensorflow_cpp library.

Function Documentation

◆ createSession()

tf::Session * tensorflow_cpp::createSession ( const bool allow_growth = true,
const double per_process_gpu_memory_fraction = 0,
const std::string & visible_device_list = "" )
inline

Creates a new TensorFlow session.

Parameters
[in]allow_growthdynamically grow GPU usage
[in]per_process_gpu_memory_fractionmaximum GPU memory fraction
[in]visible_device_listlist of GPUs to use, e.g. "0,1"
Returns
tf::Session* session

Definition at line 78 of file utils.h.

81 {
82
83 tf::Session* session;
84 tf::SessionOptions options = makeSessionOptions(
85 allow_growth, per_process_gpu_memory_fraction, visible_device_list);
86 tf::Status status = tf::NewSession(options, &session);
87 if (!status.ok())
88 throw std::runtime_error("Failed to create new session: " +
89 status.ToString());
90
91 return session;
92}
tf::SessionOptions makeSessionOptions(const bool allow_growth=true, const double per_process_gpu_memory_fraction=0, const std::string &visible_device_list="")
Helps to quickly create SessionOptions.
Definition utils.h:52

◆ getGraphInfoString()

std::string tensorflow_cpp::getGraphInfoString ( const tf::GraphDef & graph_def)
inline

Returns information about a FrozenGraph model.

Returns a formatted message containing information about the shape and type of all inputs/outputs of a FrozenGraph.

Currently limited to single-output graphs.

Parameters
[in]graph_defgraph
Returns
std::string formatted info message

Definition at line 221 of file graph_utils.h.

221 {
222
223 std::stringstream ss;
224 ss << "FrozenGraph Info:" << std::endl;
225
226 const std::vector<std::string> inputs = getGraphInputNames(graph_def);
227 const std::vector<std::string> outputs = getGraphOutputNames(graph_def);
228
229 ss << "Inputs: " << inputs.size() << std::endl;
230 for (const auto& name : inputs) {
231 const auto& shape = getGraphNodeShape(graph_def, name);
232 const auto& dtype = getGraphNodeType(graph_def, name);
233 ss << " " << name << std::endl;
234 ss << " Shape: [ ";
235 for (int d = 0; d < shape.size(); d++) {
236 ss << shape[d] << ", ";
237 }
238 ss << "]" << std::endl;
239 ss << " DataType: " << tf::DataTypeString(dtype) << std::endl;
240 }
241
242 ss << "Outputs: " << outputs.size() << std::endl;
243 for (const auto& name : outputs) {
244 const auto& shape = getGraphNodeShape(graph_def, name);
245 const auto& dtype = getGraphNodeType(graph_def, name);
246 ss << " " << name << std::endl;
247 ss << " Shape: [ ";
248 for (int d = 0; d < shape.size(); d++) {
249 ss << shape[d] << ", ";
250 }
251 ss << "]" << std::endl;
252 ss << " DataType: " << tf::DataTypeString(dtype) << std::endl;
253 }
254
255 return ss.str();
256}
std::vector< std::string > getGraphOutputNames(const tf::GraphDef &graph_def)
Determines the names of all graph output nodes.
std::vector< int > getGraphNodeShape(const tf::GraphDef &graph_def, const std::string &node_name)
Determines the shape of a given graph node.
std::vector< std::string > getGraphInputNames(const tf::GraphDef &graph_def)
Determines the names of all graph input nodes.
tf::DataType getGraphNodeType(const tf::GraphDef &graph_def, const std::string &node_name)
Determines the datatype of a given graph node.

◆ getGraphInputNames()

std::vector< std::string > tensorflow_cpp::getGraphInputNames ( const tf::GraphDef & graph_def)
inline

Determines the names of all graph input nodes.

Parameters
[in]graph_defgraph
Returns
std::vector<std::string> list of input node names

Definition at line 118 of file graph_utils.h.

119 {
120
121 std::vector<std::string> input_nodes;
122 for (const tf::NodeDef& node : graph_def.node()) {
123 if (node.op() == "Placeholder") input_nodes.push_back(node.name());
124 }
125
126 return input_nodes;
127}

◆ getGraphNodeShape()

std::vector< int > tensorflow_cpp::getGraphNodeShape ( const tf::GraphDef & graph_def,
const std::string & node_name )
inline

Determines the shape of a given graph node.

Parameters
[in]graph_defgraph
[in]node_namenode name
Returns
std::vector<int> node shape

Definition at line 168 of file graph_utils.h.

169 {
170
171 std::vector<int> node_shape;
172 for (const tf::NodeDef& node : graph_def.node()) {
173 if (node.name() == node_name) {
174 if (node.attr().count("shape") == 0) return node_shape;
175 auto shape = node.attr().at("shape").shape();
176 for (int d = 0; d < shape.dim_size(); d++)
177 node_shape.push_back(shape.dim(d).size());
178 break;
179 }
180 }
181
182 return node_shape;
183}

◆ getGraphNodeType()

tf::DataType tensorflow_cpp::getGraphNodeType ( const tf::GraphDef & graph_def,
const std::string & node_name )
inline

Determines the datatype of a given graph node.

Parameters
[in]graph_defgraph
[in]node_namenode name
Returns
tf::DataType node datatype

Definition at line 194 of file graph_utils.h.

195 {
196
197 tf::DataType type = tf::DT_INVALID;
198 for (const tf::NodeDef& node : graph_def.node()) {
199 if (node.name() == node_name) {
200 if (node.attr().count("dtype") == 0) return type;
201 type = node.attr().at("dtype").type();
202 break;
203 }
204 }
205 return type;
206}

◆ getGraphOutputNames()

std::vector< std::string > tensorflow_cpp::getGraphOutputNames ( const tf::GraphDef & graph_def)
inline

Determines the names of all graph output nodes.

Parameters
[in]graph_defgraph
Returns
std::vector<std::string> list of output node names

Definition at line 137 of file graph_utils.h.

138 {
139
140 std::vector<std::string> output_nodes;
141 std::vector<std::string> nodes_with_outputs;
142 std::unordered_set<std::string> unlikely_output_ops = {"Const", "Assign",
143 "NoOp", "Placeholder",
144 "Assert"};
145 for (const tf::NodeDef& node : graph_def.node()) {
146 for (const std::string& input_name : node.input())
147 nodes_with_outputs.push_back(input_name);
148 }
149 for (const tf::NodeDef& node : graph_def.node()) {
150 if (std::find(nodes_with_outputs.begin(), nodes_with_outputs.end(),
151 node.name()) == nodes_with_outputs.end() &&
152 unlikely_output_ops.count(node.op()) == 0)
153 output_nodes.push_back(node.name());
154 }
155
156 return output_nodes;
157}

◆ getSavedModelInfoString()

std::string tensorflow_cpp::getSavedModelInfoString ( const tf::SavedModelBundleLite & saved_model)
inline

Returns information about a SavedModel model.

Returns a formatted message containing information about the shape and type of all inputs/outputs of all SavedModel signatures.

Parameters
[in]saved_modelSavedModel
Returns
std::string formatted info message

Definition at line 340 of file saved_model_utils.h.

341 {
342
343 std::stringstream ss;
344 ss << "SavedModel Info:" << std::endl;
345
346 ss << "Signatures:" << std::endl;
347 const auto& signatures = saved_model.GetSignatures();
348 for (const auto& sig : signatures) {
349
350 ss << " " << sig.first << std::endl;
351 const auto& def = sig.second;
352
353 ss << " Inputs: " << def.inputs_size() << std::endl;
354 for (const auto& node : def.inputs()) {
355 ss << " " << node.first << ": " << node.second.name() << std::endl;
356 ss << " Shape: [ ";
357 for (int d = 0; d < node.second.tensor_shape().dim_size(); d++) {
358 ss << node.second.tensor_shape().dim(d).size() << ", ";
359 }
360 ss << "]" << std::endl;
361 ss << " DataType: " << tf::DataTypeString(node.second.dtype())
362 << std::endl;
363 }
364
365 ss << " Outputs: " << def.outputs_size() << std::endl;
366 for (const auto& node : def.outputs()) {
367 ss << " " << node.first << ": " << node.second.name() << std::endl;
368 ss << " Shape: [ ";
369 for (int d = 0; d < node.second.tensor_shape().dim_size(); d++) {
370 ss << node.second.tensor_shape().dim(d).size() << ", ";
371 }
372 ss << "]" << std::endl;
373 ss << " DataType: " << tf::DataTypeString(node.second.dtype())
374 << std::endl;
375 }
376 }
377
378 return ss.str();
379}

◆ getSavedModelInputNames()

std::vector< std::string > tensorflow_cpp::getSavedModelInputNames ( const tf::SavedModelBundleLite & saved_model,
const bool layer_names = false,
const std::string & signature = "serving_default" )
inline

Determines the names of the SavedModel input nodes.

These are the names that need to be passed to session->Run. Alternatively, using layer_names, the layer names can be returned.

Returned names are sorted alphabetically, since their order is not deterministic in general. The sorting is always based on the actual node names, even when returning layer names.

Parameters
[in]saved_modelSavedModel
[in]layer_nameswhether to return layer names
[in]signatureSavedModel signature to query
Returns
std::vector<std::string> input names

Definition at line 198 of file saved_model_utils.h.

200 {
201
202 std::vector<std::string> names;
203 const tf::SignatureDef& model_def = saved_model.GetSignatures().at(signature);
204 for (const auto& node : model_def.inputs()) {
205 const std::string& key = node.first;
206 const tf::TensorInfo& info = node.second;
207 names.push_back(info.name());
208 }
209 std::sort(names.begin(), names.end());
210
211 if (layer_names) {
212 std::vector<std::string> node_names = names;
213 names = {};
214 for (const auto& node_name : node_names)
215 names.push_back(
216 getSavedModelLayerByNodeName(saved_model, node_name, signature));
217 }
218
219 return names;
220}
std::string getSavedModelLayerByNodeName(const tf::SavedModelBundleLite &saved_model, const std::string &node_name, const std::string &signature="serving_default")
Determines the layer name from a SavedModel node name.

◆ getSavedModelLayerByNodeName()

std::string tensorflow_cpp::getSavedModelLayerByNodeName ( const tf::SavedModelBundleLite & saved_model,
const std::string & node_name,
const std::string & signature = "serving_default" )
inline

Determines the layer name from a SavedModel node name.

Layer names are specified during model construction, node names must be passed to session->Run.

Parameters
[in]saved_modelSavedModel
[in]node_namenode name
[in]signatureSavedModel signature to query
Returns
std::string layer name

Definition at line 159 of file saved_model_utils.h.

161 {
162
163 std::string layer_name;
164 const tf::SignatureDef& model_def = saved_model.GetSignatures().at(signature);
165 auto inputs = model_def.inputs();
166 auto outputs = model_def.outputs();
167 auto& nodes = inputs;
168 nodes.insert(outputs.begin(), outputs.end());
169 for (const auto& node : nodes) {
170 const std::string& key = node.first;
171 const tf::TensorInfo& info = node.second;
172 if (info.name() == node_name) {
173 layer_name = key;
174 break;
175 }
176 }
177
178 return layer_name;
179}

◆ getSavedModelNodeByLayerName()

std::string tensorflow_cpp::getSavedModelNodeByLayerName ( const tf::SavedModelBundleLite & saved_model,
const std::string & layer_name,
const std::string & signature = "serving_default" )
inline

Determines the node name from a SavedModel layer name.

Layer names are specified during model construction, node names must be passed to session->Run.

Parameters
[in]saved_modelSavedModel
[in]layer_namelayer name
[in]signatureSavedModel signature to query
Returns
std::string node name

Definition at line 124 of file saved_model_utils.h.

126 {
127
128 std::string node_name;
129 const tf::SignatureDef& model_def = saved_model.GetSignatures().at(signature);
130 auto inputs = model_def.inputs();
131 auto outputs = model_def.outputs();
132 auto& nodes = inputs;
133 nodes.insert(outputs.begin(), outputs.end());
134 for (const auto& node : nodes) {
135 const std::string& key = node.first;
136 const tf::TensorInfo& info = node.second;
137 if (key == layer_name) {
138 node_name = info.name();
139 break;
140 }
141 }
142
143 return node_name;
144}

◆ getSavedModelNodeShape()

std::vector< int > tensorflow_cpp::getSavedModelNodeShape ( const tf::SavedModelBundleLite & saved_model,
const std::string & node_name,
const std::string & signature = "serving_default" )
inline

Determines the shape of a given SavedModel node.

Parameters
[in]saved_modelSavedModel
[in]node_namenode name
[in]signatureSavedModel signature to query
Returns
std::vector<int> node shape

Definition at line 273 of file saved_model_utils.h.

275 {
276
277 std::vector<int> node_shape;
278 const tf::SignatureDef& model_def = saved_model.GetSignatures().at(signature);
279 auto inputs = model_def.inputs();
280 auto outputs = model_def.outputs();
281 auto& nodes = inputs;
282 nodes.insert(outputs.begin(), outputs.end());
283 for (const auto& node : nodes) {
284 const std::string& key = node.first;
285 const tf::TensorInfo& info = node.second;
286 if (info.name() == node_name) {
287 const auto& shape = info.tensor_shape();
288 for (int d = 0; d < shape.dim_size(); d++)
289 node_shape.push_back(shape.dim(d).size());
290 break;
291 }
292 }
293
294 return node_shape;
295}

◆ getSavedModelNodeType()

tf::DataType tensorflow_cpp::getSavedModelNodeType ( const tf::SavedModelBundleLite & saved_model,
const std::string & node_name,
const std::string & signature = "serving_default" )
inline

Determines the datatype of a given SavedModel node.

Parameters
[in]saved_modelSavedModel
[in]node_namenode name
[in]signatureSavedModel signature to query
Returns
tf::DataType node datatype

Definition at line 307 of file saved_model_utils.h.

309 {
310
311 tf::DataType type = tf::DT_INVALID;
312 const tf::SignatureDef& model_def = saved_model.GetSignatures().at(signature);
313 auto inputs = model_def.inputs();
314 auto outputs = model_def.outputs();
315 auto& nodes = inputs;
316 nodes.insert(outputs.begin(), outputs.end());
317 for (const auto& node : nodes) {
318 const std::string& key = node.first;
319 const tf::TensorInfo& info = node.second;
320 if (info.name() == node_name) {
321 type = info.dtype();
322 break;
323 }
324 }
325
326 return type;
327}

◆ getSavedModelOutputNames()

std::vector< std::string > tensorflow_cpp::getSavedModelOutputNames ( const tf::SavedModelBundleLite & saved_model,
const bool layer_names = false,
const std::string & signature = "serving_default" )
inline

Determines the names of the SavedModel output nodes.

These are the names that need to be passed to session->Run. Alternatively, using layer_names, the layer names can be returned.

Returned names are sorted alphabetically, since their order is not deterministic in general. The sorting is always based on the actual node names, even when returning layer names.

Parameters
[in]saved_modelSavedModel
[in]layer_nameswhether to return layer names
[in]signatureSavedModel signature to query
Returns
std::vector<std::string> output names

Definition at line 239 of file saved_model_utils.h.

241 {
242
243 std::vector<std::string> names;
244 const tf::SignatureDef& model_def = saved_model.GetSignatures().at(signature);
245 for (const auto& node : model_def.outputs()) {
246 const std::string& key = node.first;
247 const tf::TensorInfo& info = node.second;
248 names.push_back(info.name());
249 }
250 std::sort(names.begin(), names.end());
251
252 if (layer_names) {
253 std::vector<std::string> node_names = names;
254 names = {};
255 for (const auto& node_name : node_names)
256 names.push_back(
257 getSavedModelLayerByNodeName(saved_model, node_name, signature));
258 }
259
260 return names;
261}

◆ getSessionFromSavedModel()

tf::Session * tensorflow_cpp::getSessionFromSavedModel ( const tf::SavedModelBundleLite & saved_model)
inline

Returns the session that a SavedModel is loaded in.

Parameters
[in]saved_modelSavedModel
Returns
tf::Session* session

Definition at line 105 of file saved_model_utils.h.

106 {
107
108 return saved_model.GetSession();
109}

◆ loadFrozenGraph()

tf::GraphDef tensorflow_cpp::loadFrozenGraph ( const std::string & file)
inline

Loads a TensorFlow graph from a frozen graph file.

Parameters
[in]filefrozen graph file
Returns
tf::GraphDef graph

Definition at line 53 of file graph_utils.h.

53 {
54
55 tf::GraphDef graph_def;
56 tf::Status status = tf::ReadBinaryProto(tf::Env::Default(), file, &graph_def);
57 if (!status.ok())
58 throw std::runtime_error("Failed to load frozen graph: " +
59 status.ToString());
60
61 return graph_def;
62}

◆ loadFrozenGraphIntoNewSession()

tf::Session * tensorflow_cpp::loadFrozenGraphIntoNewSession ( const std::string & file,
const bool allow_growth = true,
const double per_process_gpu_memory_fraction = 0,
const std::string & visible_device_list = "" )
inline

Loads a TensorFlow graph from a frozen graph file into a new session.

Parameters
[in]filefrozen graph file
[in]allow_growthdynamically grow GPU usage
[in]per_process_gpu_memory_fractionmaximum GPU memory fraction
[in]visible_device_listlist of GPUs to use, e.g. "0,1"
Returns
tf::Session* session

Definition at line 97 of file graph_utils.h.

100 {
101
102 tf::GraphDef graph_def = loadFrozenGraph(file);
103 tf::Session* session = createSession(
104 allow_growth, per_process_gpu_memory_fraction, visible_device_list);
105 if (!loadGraphIntoSession(session, graph_def)) return nullptr;
106
107 return session;
108}
tf::GraphDef loadFrozenGraph(const std::string &file)
Loads a TensorFlow graph from a frozen graph file.
Definition graph_utils.h:53
bool loadGraphIntoSession(tf::Session *session, const tf::GraphDef &graph_def)
Loads a TensorFlow graph into an existing session.
Definition graph_utils.h:74
tf::Session * createSession(const bool allow_growth=true, const double per_process_gpu_memory_fraction=0, const std::string &visible_device_list="")
Creates a new TensorFlow session.
Definition utils.h:78

◆ loadGraphIntoSession()

bool tensorflow_cpp::loadGraphIntoSession ( tf::Session * session,
const tf::GraphDef & graph_def )
inline

Loads a TensorFlow graph into an existing session.

Parameters
[in]sessionsession
[in]graph_defgraph
Returns
true if operation succeeded
false if operation failed

Definition at line 74 of file graph_utils.h.

75 {
76
77 tf::Status status = session->Create(graph_def);
78 if (!status.ok())
79 throw std::runtime_error("Failed to load graph into session: " +
80 status.ToString());
81
82 return true;
83}

◆ loadSavedModel()

tf::SavedModelBundleLite tensorflow_cpp::loadSavedModel ( const std::string & dir,
const bool allow_growth = true,
const double per_process_gpu_memory_fraction = 0,
const std::string & visible_device_list = "" )
inline

Loads a TensorFlow SavedModel from a directory into a new session.

Parameters
[in]dirSavedModel directory
[in]allow_growthdynamically grow GPU usage
[in]per_process_gpu_memory_fractionmaximum GPU memory fraction
[in]visible_device_listlist of GPUs to use, e.g. "0,1"
Returns
tf::SavedModelBundleLite SavedModel

Definition at line 57 of file saved_model_utils.h.

60 {
61
62 tf::SavedModelBundleLite saved_model;
63 tf::SessionOptions session_options = makeSessionOptions(
64 allow_growth, per_process_gpu_memory_fraction, visible_device_list);
65 tf::Status status =
66 tf::LoadSavedModel(session_options, tf::RunOptions(), dir,
67 {tf::kSavedModelTagServe}, &saved_model);
68 if (!status.ok())
69 throw std::runtime_error("Failed to load SavedModel: " + status.ToString());
70
71 return saved_model;
72}

◆ loadSavedModelIntoNewSession()

tf::Session * tensorflow_cpp::loadSavedModelIntoNewSession ( const std::string & dir,
const bool allow_growth = true,
const double per_process_gpu_memory_fraction = 0,
const std::string & visible_device_list = "" )
inline

Loads a TensorFlow SavedModel from a directory into a new session.

Parameters
[in]dirSavedModel directory
[in]allow_growthdynamically grow GPU usage
[in]per_process_gpu_memory_fractionmaximum GPU memory fraction
[in]visible_device_listlist of GPUs to use, e.g. "0,1"
Returns
tf::Session* session

Definition at line 85 of file saved_model_utils.h.

88 {
89
90 tf::SavedModelBundleLite saved_model = loadSavedModel(
91 dir, allow_growth, per_process_gpu_memory_fraction, visible_device_list);
92 tf::Session* session = saved_model.GetSession();
93
94 return session;
95}
tf::SavedModelBundleLite loadSavedModel(const std::string &dir, const bool allow_growth=true, const double per_process_gpu_memory_fraction=0, const std::string &visible_device_list="")
Loads a TensorFlow SavedModel from a directory into a new session.

◆ makeSessionOptions()

tf::SessionOptions tensorflow_cpp::makeSessionOptions ( const bool allow_growth = true,
const double per_process_gpu_memory_fraction = 0,
const std::string & visible_device_list = "" )
inline

Helps to quickly create SessionOptions.

Parameters
[in]allow_growthdynamically grow GPU usage
[in]per_process_gpu_memory_fractionmaximum GPU memory fraction
[in]visible_device_listlist of GPUs to use, e.g. "0,1"
Returns
tf::SessionOptions session options

Definition at line 52 of file utils.h.

55 {
56
57 tf::SessionOptions options = tf::SessionOptions();
58 tf::ConfigProto* config = &options.config;
59 tf::GPUOptions* gpu_options = config->mutable_gpu_options();
60 gpu_options->set_allow_growth(allow_growth);
61 gpu_options->set_per_process_gpu_memory_fraction(
62 per_process_gpu_memory_fraction);
63 gpu_options->set_visible_device_list(visible_device_list);
64
65 return options;
66}