TrainLoop

class tfsnippet.TrainLoop(param_vars, var_groups=None, show_eta=True, print_func=<built-in function print>, max_epoch=None, max_step=None, metric_formatter=<tfsnippet.scaffold.logging_.DefaultMetricFormatter object>, checkpoint_dir=None, checkpoint_epoch_freq=None, checkpoint_max_to_keep=None, checkpoint_save_objects=None, restore_checkpoint=True, summary_dir=None, summary_writer=None, summary_graph=None, summary_metric_prefix='metrics/', summary_skip_pattern=<_sre.SRE_Pattern object>, summary_commit_freqs=None, valid_metric_name='valid_loss', valid_metric_smaller_is_better=None, early_stopping=False)

Bases: tfsnippet.utils.concepts.DisposableContext

Training loop object.

This class provides a set of convenient methods for writing training loop. It is useful for maintaining epoch and step counters, logging training metrics, memorizing best parameters for early-stopping, etc. An example of using the TrainLoop:

import tfsnippet as spt

with spt.TrainLoop(param_vars,
                   max_epoch=10,
                   early_stopping=True) as loop:
    loop.print_training_summary()
    train_flow = spt.DataFlow.arrays([x, y], batch_size, shuffle=True)

    for epoch in loop.iter_epochs():
        for step, (x, y) in loop.iter_steps(train_flow):
            step_loss = session.run(
                [loss, train_op],
                feed_dict={input_x: x, input_y: y}
            )
            loop.collect_metrics(loss=step_loss)
        with loop.timeit('valid_time'):
            valid_loss = session.run(
                loss, feed_dict={input_x: test_x, input_y: test_y})
            loop.collect_metrics(valid_loss=valid_loss)
        loop.print_logs()

The event schedule of a TrainLoop can be briefly described as:

# the main training loop
events.fire(EventKeys.ENTER_LOOP, self)

for epoch in self.iter_epochs():
    events.fire(EventKeys.BEFORE_EPOCH, self)

    for step in self.iter_steps(...):
        events.fire(EventKeys.BEFORE_STEP, self)

        ...  # execute the step

        events.reverse_fire(EventKeys.AFTER_STEP, self)

    events.reverse_fire(EventKeys.AFTER_EPOCH, self)

events.fire(EventKeys.EXIT_LOOP, self)

# when metrics are fed into the loop by :meth:`collect_metrics`
def collect_metrics(self, metrics_dict=None, **kwargs):
    metrics_dict = merge(metrics_dict, kwargs)
    events.fire(EventKeys.METRICS_COLLECTED, self, metrics_dict)

# when summaries are fed into the loop by :meth:`add_summary`
def add_summary(self, summary):
    events.fire(EventKeys.SUMMARY_ADDED, self, summary)

# when metric statistics have been printed as log
def print_logs(self):
    ...
    events.fire(EventKeys.METRIC_STATS_PRINTED, self, metric_stats)
    events.fire(EventKeys.TIME_METRIC_STATS_PRINTED, self,
                time_metric_stats)

Warning

If you use early-stopping along with checkpoint, there is one case which is very dangerous: you’ve already successfully done a training loop, and the early-stopping variables have been restored. But you then recover from the latest checkpoint and continue to train. In this case, the param_vars (which is covered by early-stopping) are restored to the best validation step, but the other variables and the internal states of TrainLoop are recovered to the last step. Then you obtain a state mismatch, and the behaviour will be un-predictable after this recovery.

Attributes Summary

best_valid_metric Get the best valid metric.
epoch Get the epoch counter (starting from 1).
events Get the event source.
max_epoch Get or set the max value for epoch counter.
max_step Get or set the max value for global step counter.
param_vars Get the trainable parameter variables.
step Get the global step counter (starting from 1).
step_data Get the data of current step.
summary_writer Get the summary writer instance.
use_early_stopping Whether or not to adopt early-stopping?
valid_metric_name Get the name of the validation metric.
valid_metric_smaller_is_better Whether or not the smaller value is better for validation metric?
var_groups Get the variable groups.
within_epoch Whether or not an epoch is open?
within_step Whether or not a step is open?

Methods Summary

add_summary(summary) Add a summary object, with self.step as global_step.
collect_metrics([metrics]) Add metric values.
get_eta() Get the estimated time ahead (ETA).
get_progress() Get the progress of training.
iter_epochs() Iterate through the epochs.
iter_steps([data_generator]) Iterate through the steps.
make_checkpoint() Make a checkpoint.
metric_collector(**kwds) Get a StatisticsCollector for metric.
print_logs() Print the training logs.
print_training_summary() Print the training summary.
println(message[, with_tag]) Print message via print_function.
timeit(**kwds) Open a context for timing.

Attributes Documentation

best_valid_metric

Get the best valid metric.

epoch

Get the epoch counter (starting from 1).

events

Get the event source.

Returns:The event source.
Return type:EventSource
max_epoch

Get or set the max value for epoch counter.

max_step

Get or set the max value for global step counter.

param_vars

Get the trainable parameter variables.

step

Get the global step counter (starting from 1).

step_data

Get the data of current step.

summary_writer

Get the summary writer instance.

use_early_stopping

Whether or not to adopt early-stopping?

valid_metric_name

Get the name of the validation metric.

valid_metric_smaller_is_better

Whether or not the smaller value is better for validation metric?

var_groups

Get the variable groups.

within_epoch

Whether or not an epoch is open?

within_step

Whether or not a step is open?

Methods Documentation

add_summary(summary)

Add a summary object, with self.step as global_step.

Parameters:summary (tf.summary.Summary or bytes) – TensorFlow summary object, or serialized summary.
collect_metrics(metrics=None, **kwargs)

Add metric values.

This method must be called when there’s at least an active epoch loop. It will add metrics to the epoch metrics collector, and if there’s an active step loop, it will also add metrics to the step metrics collector.

If summary_writer is configured, it will also write the metrics as summaries onto disk. Furthermore, if valid_metric_name is configured, it will also perform early-stopping.

Parameters:
  • metrics (dict[str, float or np.ndarray]) – Metric values as dict.
  • **kwargs – Metric values, specified as named arguments.
get_eta()

Get the estimated time ahead (ETA).

Returns:
The estimated time ahead in seconds, or None
if not available.
Return type:float or None
get_progress()

Get the progress of training.

Returns:
The progress in range [0, 1], or None if
the progress cannot be estimated.
Return type:float or None
iter_epochs()

Iterate through the epochs.

This method can only be called when there’s no other epoch loop is being iterated. Furthermore, after exiting this loop, both the epoch metrics as well as the step metrics will be cleared.

If max_epoch is configured, it will stop at it.

Yields:int – The epoch counter (starting from 1).
iter_steps(data_generator=None)

Iterate through the steps.

This method can only be called when there’s no other step loop is being iterated, and an epoch loop is active.

Parameters:

data_generator – Optional iterable data to be yielded at every step. This is required if max_step is not configured, so as to prevent an infinite step loop.

Yields:

int or (int, any)

The global step counter (starting from 1), or

the tuple of (step counter, batch data) if data_generator is specified.

make_checkpoint()

Make a checkpoint.

This method must be called within an eopch or a step context. For example:

for epoch in loop.iter_epochs():
    for [x] in loop.iter_steps(train_data):
        ...

    if epoch % 100 == 0:
        loop.make_checkpoint()
metric_collector(**kwds)

Get a StatisticsCollector for metric.

The mean value of the collected metrics will be added to summary after exiting the context. Other statistics will be discarded.

Parameters:metric_name (str) – The name of this metric.
Yields:StatisticsCollector – The collector for metric values.
print_logs()

Print the training logs.

This method will print the collected metrics. If there’s an active step loop, it will print metrics from the step metrics collector. Otherwise if there’s only an epoch loop, it will print metrics from the epoch metrics accumulator.

Note it must be called at the end of an epoch or a step. This is because the metrics of corresponding loop context will be cleared after the logs are printed. Moreover, the epoch or step timer will be committed as metric immediately when this method is called, before printing the logs.

print_training_summary()

Print the training summary.

The training summary include the following content:

  1. Execution environment.
  2. Parameters to be optimized during training.
println(message, with_tag=False)

Print message via print_function.

Parameters:
  • message (str) – Message to be printed.
  • with_tag (bool) – Whether or not to add the epoch & step tag? (default False)
timeit(**kwds)

Open a context for timing.

Parameters:metric_name (str) – Store the timing result in metric of this name. Note that metric_name must end with time or timer, otherwise by default the time values will not be formatted as human readable strings.