TrainLoop¶
-
class
tfsnippet.TrainLoop(param_vars, var_groups=None, show_eta=True, print_func=<built-in function print>, max_epoch=None, max_step=None, metric_formatter=<tfsnippet.scaffold.logging_.DefaultMetricFormatter object>, checkpoint_dir=None, checkpoint_epoch_freq=None, checkpoint_max_to_keep=None, checkpoint_save_objects=None, restore_checkpoint=True, summary_dir=None, summary_writer=None, summary_graph=None, summary_metric_prefix='metrics/', summary_skip_pattern=<_sre.SRE_Pattern object>, summary_commit_freqs=None, valid_metric_name='valid_loss', valid_metric_smaller_is_better=None, early_stopping=False)¶ Bases:
tfsnippet.utils.concepts.DisposableContextTraining loop object.
This class provides a set of convenient methods for writing training loop. It is useful for maintaining epoch and step counters, logging training metrics, memorizing best parameters for early-stopping, etc. An example of using the
TrainLoop:import tfsnippet as spt with spt.TrainLoop(param_vars, max_epoch=10, early_stopping=True) as loop: loop.print_training_summary() train_flow = spt.DataFlow.arrays([x, y], batch_size, shuffle=True) for epoch in loop.iter_epochs(): for step, (x, y) in loop.iter_steps(train_flow): step_loss = session.run( [loss, train_op], feed_dict={input_x: x, input_y: y} ) loop.collect_metrics(loss=step_loss) with loop.timeit('valid_time'): valid_loss = session.run( loss, feed_dict={input_x: test_x, input_y: test_y}) loop.collect_metrics(valid_loss=valid_loss) loop.print_logs()
The event schedule of a
TrainLoopcan be briefly described as:# the main training loop events.fire(EventKeys.ENTER_LOOP, self) for epoch in self.iter_epochs(): events.fire(EventKeys.BEFORE_EPOCH, self) for step in self.iter_steps(...): events.fire(EventKeys.BEFORE_STEP, self) ... # execute the step events.reverse_fire(EventKeys.AFTER_STEP, self) events.reverse_fire(EventKeys.AFTER_EPOCH, self) events.fire(EventKeys.EXIT_LOOP, self) # when metrics are fed into the loop by :meth:`collect_metrics` def collect_metrics(self, metrics_dict=None, **kwargs): metrics_dict = merge(metrics_dict, kwargs) events.fire(EventKeys.METRICS_COLLECTED, self, metrics_dict) # when summaries are fed into the loop by :meth:`add_summary` def add_summary(self, summary): events.fire(EventKeys.SUMMARY_ADDED, self, summary) # when metric statistics have been printed as log def print_logs(self): ... events.fire(EventKeys.METRIC_STATS_PRINTED, self, metric_stats) events.fire(EventKeys.TIME_METRIC_STATS_PRINTED, self, time_metric_stats)
Warning
If you use early-stopping along with checkpoint, there is one case which is very dangerous: you’ve already successfully done a training loop, and the early-stopping variables have been restored. But you then recover from the latest checkpoint and continue to train. In this case, the param_vars (which is covered by early-stopping) are restored to the best validation step, but the other variables and the internal states of
TrainLoopare recovered to the last step. Then you obtain a state mismatch, and the behaviour will be un-predictable after this recovery.Attributes Summary
best_valid_metricGet the best valid metric. epochGet the epoch counter (starting from 1). eventsGet the event source. max_epochGet or set the max value for epoch counter. max_stepGet or set the max value for global step counter. param_varsGet the trainable parameter variables. stepGet the global step counter (starting from 1). step_dataGet the data of current step. summary_writerGet the summary writer instance. use_early_stoppingWhether or not to adopt early-stopping? valid_metric_nameGet the name of the validation metric. valid_metric_smaller_is_betterWhether or not the smaller value is better for validation metric? var_groupsGet the variable groups. within_epochWhether or not an epoch is open? within_stepWhether or not a step is open? Methods Summary
add_summary(summary)Add a summary object, with self.stepas global_step.collect_metrics([metrics])Add metric values. get_eta()Get the estimated time ahead (ETA). get_progress()Get the progress of training. iter_epochs()Iterate through the epochs. iter_steps([data_generator])Iterate through the steps. make_checkpoint()Make a checkpoint. metric_collector(**kwds)Get a StatisticsCollectorfor metric.print_logs()Print the training logs. print_training_summary()Print the training summary. println(message[, with_tag])Print message via print_function. timeit(**kwds)Open a context for timing. Attributes Documentation
-
best_valid_metric¶ Get the best valid metric.
-
epoch¶ Get the epoch counter (starting from 1).
-
events¶ Get the event source.
Returns: The event source. Return type: EventSource
-
max_epoch¶ Get or set the max value for epoch counter.
-
max_step¶ Get or set the max value for global step counter.
-
param_vars¶ Get the trainable parameter variables.
-
step¶ Get the global step counter (starting from 1).
-
step_data¶ Get the data of current step.
-
summary_writer¶ Get the summary writer instance.
-
use_early_stopping¶ Whether or not to adopt early-stopping?
-
valid_metric_name¶ Get the name of the validation metric.
-
valid_metric_smaller_is_better¶ Whether or not the smaller value is better for validation metric?
-
var_groups¶ Get the variable groups.
-
within_epoch¶ Whether or not an epoch is open?
-
within_step¶ Whether or not a step is open?
Methods Documentation
-
add_summary(summary)¶ Add a summary object, with
self.stepas global_step.Parameters: summary (tf.summary.Summary or bytes) – TensorFlow summary object, or serialized summary.
-
collect_metrics(metrics=None, **kwargs)¶ Add metric values.
This method must be called when there’s at least an active epoch loop. It will add metrics to the epoch metrics collector, and if there’s an active step loop, it will also add metrics to the step metrics collector.
If summary_writer is configured, it will also write the metrics as summaries onto disk. Furthermore, if valid_metric_name is configured, it will also perform early-stopping.
Parameters:
-
get_eta()¶ Get the estimated time ahead (ETA).
Returns: - The estimated time ahead in seconds, or None
- if not available.
Return type: float or None
-
get_progress()¶ Get the progress of training.
Returns: - The progress in range
[0, 1], or None if - the progress cannot be estimated.
Return type: float or None - The progress in range
-
iter_epochs()¶ Iterate through the epochs.
This method can only be called when there’s no other epoch loop is being iterated. Furthermore, after exiting this loop, both the epoch metrics as well as the step metrics will be cleared.
If max_epoch is configured, it will stop at it.
Yields: int – The epoch counter (starting from 1).
-
iter_steps(data_generator=None)¶ Iterate through the steps.
This method can only be called when there’s no other step loop is being iterated, and an epoch loop is active.
Parameters: data_generator – Optional iterable data to be yielded at every step. This is required if max_step is not configured, so as to prevent an infinite step loop.
Yields: int or (int, any) –
- The global step counter (starting from 1), or
the tuple of
(step counter, batch data)if data_generator is specified.
-
make_checkpoint()¶ Make a checkpoint.
This method must be called within an eopch or a step context. For example:
for epoch in loop.iter_epochs(): for [x] in loop.iter_steps(train_data): ... if epoch % 100 == 0: loop.make_checkpoint()
-
metric_collector(**kwds)¶ Get a
StatisticsCollectorfor metric.The mean value of the collected metrics will be added to summary after exiting the context. Other statistics will be discarded.
Parameters: metric_name (str) – The name of this metric. Yields: StatisticsCollector – The collector for metric values.
-
print_logs()¶ Print the training logs.
This method will print the collected metrics. If there’s an active step loop, it will print metrics from the step metrics collector. Otherwise if there’s only an epoch loop, it will print metrics from the epoch metrics accumulator.
Note it must be called at the end of an epoch or a step. This is because the metrics of corresponding loop context will be cleared after the logs are printed. Moreover, the epoch or step timer will be committed as metric immediately when this method is called, before printing the logs.
-
print_training_summary()¶ Print the training summary.
The training summary include the following content:
- Execution environment.
- Parameters to be optimized during training.
-
println(message, with_tag=False)¶ Print message via print_function.
Parameters:
-