TrainLoop¶
-
class
tfsnippet.TrainLoop(param_vars, var_groups=None, show_eta=True, print_func=<built-in function print>, summary_dir=None, summary_writer=None, summary_graph=None, summary_metric_prefix='metrics/', summary_skip_pattern=<_sre.SRE_Pattern object>, summary_commit_freqs=None, metric_formatter=<tfsnippet.scaffold.logs.DefaultMetricFormatter object>, valid_metric_name='valid_loss', initial_valid_metric=None, valid_metric_smaller_is_better=None, early_stopping=False, initial_epoch=0, initial_step=0, max_epoch=None, max_step=None)¶ Bases:
tfsnippet.utils.concepts.DisposableContextTraining loop object.
This class provides a set of convenient methods for writing training loop. It is useful for maintaining epoch and step counters, logging training metrics, memorizing best parameters for early-stopping, etc. An example of using the
TrainLoop:from tfsnippet.dataflows import DataFlow from tfsnippet.scaffold import TrainLoop with TrainLoop(param_vars, max_epoch=10, early_stopping=True) as loop: loop.print_training_summary() train_flow = DataFlow.arrays([x, y], batch_size, shuffle=True) for epoch in loop.iter_epochs(): for step, (x, y) in loop.iter_steps(train_flow): step_loss = session.run( [loss, train_op], feed_dict={input_x: x, input_y: y} ) loop.collect_metrics(loss=step_loss) with loop.timeit('valid_time'): valid_loss = session.run( loss, feed_dict={input_x: test_x, input_y: test_y}) loop.collect_metrics(valid_loss=valid_loss) loop.print_logs()
Attributes Summary
best_valid_metricGet the best valid metric. epochGet the epoch counter (starting from 1). max_epochGet or set the max value for epoch counter. max_stepGet or set the max value for global step counter. param_varsGet the trainable parameter variables. stepGet the global step counter (starting from 1). summary_metric_prefixGet the prefix for the metrics committed to summary_writer. summary_writerGet the summary writer instance. use_early_stoppingWhether or not to adopt early-stopping? valid_metric_nameGet the name of the validation metric. valid_metric_smaller_is_betterWhether or not the smaller value is better for validation metric? var_groupsGet the variable groups. Methods Summary
add_summary(summary)Add a summary object, with self.stepas global_step.collect_metrics([metrics])Add metric values. get_progress()Get the progress of training. iter_epochs()Iterate through the epochs. iter_steps([data_generator])Iterate through the steps. metric_collector(**kwds)Get a StatisticsCollectorfor metric.print_logs()Print the training logs. print_training_summary()Print the training summary. println(message[, with_tag])Print message via print_function. timeit(**kwds)Open a context for timing. Attributes Documentation
-
best_valid_metric¶ Get the best valid metric.
-
epoch¶ Get the epoch counter (starting from 1).
-
max_epoch¶ Get or set the max value for epoch counter.
-
max_step¶ Get or set the max value for global step counter.
-
param_vars¶ Get the trainable parameter variables.
-
step¶ Get the global step counter (starting from 1).
-
summary_metric_prefix¶ Get the prefix for the metrics committed to summary_writer.
-
summary_writer¶ Get the summary writer instance.
-
use_early_stopping¶ Whether or not to adopt early-stopping?
-
valid_metric_name¶ Get the name of the validation metric.
-
valid_metric_smaller_is_better¶ Whether or not the smaller value is better for validation metric?
-
var_groups¶ Get the variable groups.
Methods Documentation
-
add_summary(summary)¶ Add a summary object, with
self.stepas global_step.Parameters: summary (tf.summary.Summary or byes) – TensorFlow summary object, or serialized summary.
-
collect_metrics(metrics=None, **kwargs)¶ Add metric values.
This method must be called when there’s at least an active epoch loop. It will add metrics to the epoch metrics collector, and if there’s an active step loop, it will also add metrics to the step metrics collector.
If summary_writer is configured, it will also write the metrics as summaries onto disk. Furthermore, if valid_metric_name is configured, it will also perform early-stopping.
Parameters:
-
get_progress()¶ Get the progress of training.
Returns: - The progress in range
[0, 1], or None if - the progress cannot be estimated.
Return type: float or None - The progress in range
-
iter_epochs()¶ Iterate through the epochs.
This method can only be called when there’s no other epoch loop is being iterated. Furthermore, after exiting this loop, both the epoch metrics as well as the step metrics will be cleared.
If max_epoch is configured, it will stop at it.
Yields: int – The epoch counter (starting from 1).
-
iter_steps(data_generator=None)¶ Iterate through the steps.
This method can only be called when there’s no other step loop is being iterated, and an epoch loop is active.
Parameters: data_generator – Optional iterable data to be yielded at every step. This is required if max_step is not configured, so as to prevent an infinite step loop.
Yields: int or (int, any) –
- The global step counter (starting from 1), or
the tuple of
(step counter, batch data)if data_generator is specified.
-
metric_collector(**kwds)¶ Get a
StatisticsCollectorfor metric.The mean value of the collected metrics will be added to summary after exiting the context. Other statistics will be discarded.
Parameters: metric_name (str) – The name of this metric. Yields: StatisticsCollector – The collector for metric values.
-
print_logs()¶ Print the training logs.
This method will print the collected metrics. If there’s an active step loop, it will print metrics from the step metrics collector. Otherwise if there’s only an epoch loop, it will print metrics from the epoch metrics accumulator.
Note it must be called at the end of an epoch or a step. This is because the metrics of corresponding loop context will be cleared after the logs are printed. Moreover, the epoch or step timer will be committed as metric immediately when this method is called, before printing the logs.
-
print_training_summary()¶ Print the training summary.
The training summary include the following content:
- Execution environment.
- Parameters to be optimized during training.
-
println(message, with_tag=False)¶ Print message via print_function.
Parameters:
-