TrainLoop¶
-
class
tfsnippet.
TrainLoop
(param_vars, var_groups=None, show_eta=True, print_func=<built-in function print>, max_epoch=None, max_step=None, metric_formatter=<tfsnippet.scaffold.logging_.DefaultMetricFormatter object>, checkpoint_dir=None, checkpoint_epoch_freq=None, checkpoint_max_to_keep=None, checkpoint_save_objects=None, restore_checkpoint=True, summary_dir=None, summary_writer=None, summary_graph=None, summary_metric_prefix='metrics/', summary_skip_pattern=<_sre.SRE_Pattern object>, summary_commit_freqs=None, valid_metric_name='valid_loss', valid_metric_smaller_is_better=None, early_stopping=False)¶ Bases:
tfsnippet.utils.concepts.DisposableContext
Training loop object.
This class provides a set of convenient methods for writing training loop. It is useful for maintaining epoch and step counters, logging training metrics, memorizing best parameters for early-stopping, etc. An example of using the
TrainLoop
:import tfsnippet as spt with spt.TrainLoop(param_vars, max_epoch=10, early_stopping=True) as loop: loop.print_training_summary() train_flow = spt.DataFlow.arrays([x, y], batch_size, shuffle=True) for epoch in loop.iter_epochs(): for step, (x, y) in loop.iter_steps(train_flow): step_loss = session.run( [loss, train_op], feed_dict={input_x: x, input_y: y} ) loop.collect_metrics(loss=step_loss) with loop.timeit('valid_time'): valid_loss = session.run( loss, feed_dict={input_x: test_x, input_y: test_y}) loop.collect_metrics(valid_loss=valid_loss) loop.print_logs()
The event schedule of a
TrainLoop
can be briefly described as:# the main training loop events.fire(EventKeys.ENTER_LOOP, self) for epoch in self.iter_epochs(): events.fire(EventKeys.BEFORE_EPOCH, self) for step in self.iter_steps(...): events.fire(EventKeys.BEFORE_STEP, self) ... # execute the step events.reverse_fire(EventKeys.AFTER_STEP, self) events.reverse_fire(EventKeys.AFTER_EPOCH, self) events.fire(EventKeys.EXIT_LOOP, self) # when metrics are fed into the loop by :meth:`collect_metrics` def collect_metrics(self, metrics_dict=None, **kwargs): metrics_dict = merge(metrics_dict, kwargs) events.fire(EventKeys.METRICS_COLLECTED, self, metrics_dict) # when summaries are fed into the loop by :meth:`add_summary` def add_summary(self, summary): events.fire(EventKeys.SUMMARY_ADDED, self, summary) # when metric statistics have been printed as log def print_logs(self): ... events.fire(EventKeys.METRIC_STATS_PRINTED, self, metric_stats) events.fire(EventKeys.TIME_METRIC_STATS_PRINTED, self, time_metric_stats)
Warning
If you use early-stopping along with checkpoint, there is one case which is very dangerous: you’ve already successfully done a training loop, and the early-stopping variables have been restored. But you then recover from the latest checkpoint and continue to train. In this case, the param_vars (which is covered by early-stopping) are restored to the best validation step, but the other variables and the internal states of
TrainLoop
are recovered to the last step. Then you obtain a state mismatch, and the behaviour will be un-predictable after this recovery.Attributes Summary
best_valid_metric
Get the best valid metric. epoch
Get the epoch counter (starting from 1). events
Get the event source. max_epoch
Get or set the max value for epoch counter. max_step
Get or set the max value for global step counter. param_vars
Get the trainable parameter variables. step
Get the global step counter (starting from 1). step_data
Get the data of current step. summary_writer
Get the summary writer instance. use_early_stopping
Whether or not to adopt early-stopping? valid_metric_name
Get the name of the validation metric. valid_metric_smaller_is_better
Whether or not the smaller value is better for validation metric? var_groups
Get the variable groups. within_epoch
Whether or not an epoch is open? within_step
Whether or not a step is open? Methods Summary
add_summary
(summary)Add a summary object, with self.step
as global_step.collect_metrics
([metrics])Add metric values. get_eta
()Get the estimated time ahead (ETA). get_progress
()Get the progress of training. iter_epochs
()Iterate through the epochs. iter_steps
([data_generator])Iterate through the steps. make_checkpoint
()Make a checkpoint. metric_collector
(**kwds)Get a StatisticsCollector
for metric.print_logs
()Print the training logs. print_training_summary
()Print the training summary. println
(message[, with_tag])Print message via print_function. timeit
(**kwds)Open a context for timing. Attributes Documentation
-
best_valid_metric
¶ Get the best valid metric.
-
epoch
¶ Get the epoch counter (starting from 1).
-
events
¶ Get the event source.
Returns: The event source. Return type: EventSource
-
max_epoch
¶ Get or set the max value for epoch counter.
-
max_step
¶ Get or set the max value for global step counter.
-
param_vars
¶ Get the trainable parameter variables.
-
step
¶ Get the global step counter (starting from 1).
-
step_data
¶ Get the data of current step.
-
summary_writer
¶ Get the summary writer instance.
-
use_early_stopping
¶ Whether or not to adopt early-stopping?
-
valid_metric_name
¶ Get the name of the validation metric.
-
valid_metric_smaller_is_better
¶ Whether or not the smaller value is better for validation metric?
-
var_groups
¶ Get the variable groups.
-
within_epoch
¶ Whether or not an epoch is open?
-
within_step
¶ Whether or not a step is open?
Methods Documentation
-
add_summary
(summary)¶ Add a summary object, with
self.step
as global_step.Parameters: summary (tf.summary.Summary or bytes) – TensorFlow summary object, or serialized summary.
-
collect_metrics
(metrics=None, **kwargs)¶ Add metric values.
This method must be called when there’s at least an active epoch loop. It will add metrics to the epoch metrics collector, and if there’s an active step loop, it will also add metrics to the step metrics collector.
If summary_writer is configured, it will also write the metrics as summaries onto disk. Furthermore, if valid_metric_name is configured, it will also perform early-stopping.
Parameters:
-
get_eta
()¶ Get the estimated time ahead (ETA).
Returns: - The estimated time ahead in seconds, or None
- if not available.
Return type: float or None
-
get_progress
()¶ Get the progress of training.
Returns: - The progress in range
[0, 1]
, or None if - the progress cannot be estimated.
Return type: float or None - The progress in range
-
iter_epochs
()¶ Iterate through the epochs.
This method can only be called when there’s no other epoch loop is being iterated. Furthermore, after exiting this loop, both the epoch metrics as well as the step metrics will be cleared.
If max_epoch is configured, it will stop at it.
Yields: int – The epoch counter (starting from 1).
-
iter_steps
(data_generator=None)¶ Iterate through the steps.
This method can only be called when there’s no other step loop is being iterated, and an epoch loop is active.
Parameters: data_generator – Optional iterable data to be yielded at every step. This is required if max_step is not configured, so as to prevent an infinite step loop.
Yields: int or (int, any) –
- The global step counter (starting from 1), or
the tuple of
(step counter, batch data)
if data_generator is specified.
-
make_checkpoint
()¶ Make a checkpoint.
This method must be called within an eopch or a step context. For example:
for epoch in loop.iter_epochs(): for [x] in loop.iter_steps(train_data): ... if epoch % 100 == 0: loop.make_checkpoint()
-
metric_collector
(**kwds)¶ Get a
StatisticsCollector
for metric.The mean value of the collected metrics will be added to summary after exiting the context. Other statistics will be discarded.
Parameters: metric_name (str) – The name of this metric. Yields: StatisticsCollector – The collector for metric values.
-
print_logs
()¶ Print the training logs.
This method will print the collected metrics. If there’s an active step loop, it will print metrics from the step metrics collector. Otherwise if there’s only an epoch loop, it will print metrics from the epoch metrics accumulator.
Note it must be called at the end of an epoch or a step. This is because the metrics of corresponding loop context will be cleared after the logs are printed. Moreover, the epoch or step timer will be committed as metric immediately when this method is called, before printing the logs.
-
print_training_summary
()¶ Print the training summary.
The training summary include the following content:
- Execution environment.
- Parameters to be optimized during training.
-
println
(message, with_tag=False)¶ Print message via print_function.
Parameters:
-