TrainLoop

class tfsnippet.TrainLoop(param_vars, var_groups=None, show_eta=True, print_func=<built-in function print>, summary_dir=None, summary_writer=None, summary_graph=None, summary_metric_prefix='metrics/', summary_skip_pattern=<_sre.SRE_Pattern object>, summary_commit_freqs=None, metric_formatter=<tfsnippet.scaffold.logs.DefaultMetricFormatter object>, valid_metric_name='valid_loss', initial_valid_metric=None, valid_metric_smaller_is_better=None, early_stopping=False, initial_epoch=0, initial_step=0, max_epoch=None, max_step=None)

Bases: tfsnippet.utils.concepts.DisposableContext

Training loop object.

This class provides a set of convenient methods for writing training loop. It is useful for maintaining epoch and step counters, logging training metrics, memorizing best parameters for early-stopping, etc. An example of using the TrainLoop:

from tfsnippet.dataflows import DataFlow
from tfsnippet.scaffold import TrainLoop

with TrainLoop(param_vars, max_epoch=10, early_stopping=True) as loop:
    loop.print_training_summary()
    train_flow = DataFlow.arrays([x, y], batch_size, shuffle=True)

    for epoch in loop.iter_epochs():
        for step, (x, y) in loop.iter_steps(train_flow):
            step_loss = session.run(
                [loss, train_op],
                feed_dict={input_x: x, input_y: y}
            )
            loop.collect_metrics(loss=step_loss)
        with loop.timeit('valid_time'):
            valid_loss = session.run(
                loss, feed_dict={input_x: test_x, input_y: test_y})
            loop.collect_metrics(valid_loss=valid_loss)
        loop.print_logs()

Attributes Summary

best_valid_metric Get the best valid metric.
epoch Get the epoch counter (starting from 1).
max_epoch Get or set the max value for epoch counter.
max_step Get or set the max value for global step counter.
param_vars Get the trainable parameter variables.
step Get the global step counter (starting from 1).
summary_metric_prefix Get the prefix for the metrics committed to summary_writer.
summary_writer Get the summary writer instance.
use_early_stopping Whether or not to adopt early-stopping?
valid_metric_name Get the name of the validation metric.
valid_metric_smaller_is_better Whether or not the smaller value is better for validation metric?
var_groups Get the variable groups.

Methods Summary

add_summary(summary) Add a summary object, with self.step as global_step.
collect_metrics([metrics]) Add metric values.
get_progress() Get the progress of training.
iter_epochs() Iterate through the epochs.
iter_steps([data_generator]) Iterate through the steps.
metric_collector(**kwds) Get a StatisticsCollector for metric.
print_logs() Print the training logs.
print_training_summary() Print the training summary.
println(message[, with_tag]) Print message via print_function.
timeit(**kwds) Open a context for timing.

Attributes Documentation

best_valid_metric

Get the best valid metric.

epoch

Get the epoch counter (starting from 1).

max_epoch

Get or set the max value for epoch counter.

max_step

Get or set the max value for global step counter.

param_vars

Get the trainable parameter variables.

step

Get the global step counter (starting from 1).

summary_metric_prefix

Get the prefix for the metrics committed to summary_writer.

summary_writer

Get the summary writer instance.

use_early_stopping

Whether or not to adopt early-stopping?

valid_metric_name

Get the name of the validation metric.

valid_metric_smaller_is_better

Whether or not the smaller value is better for validation metric?

var_groups

Get the variable groups.

Methods Documentation

add_summary(summary)

Add a summary object, with self.step as global_step.

Parameters:summary (tf.summary.Summary or byes) – TensorFlow summary object, or serialized summary.
collect_metrics(metrics=None, **kwargs)

Add metric values.

This method must be called when there’s at least an active epoch loop. It will add metrics to the epoch metrics collector, and if there’s an active step loop, it will also add metrics to the step metrics collector.

If summary_writer is configured, it will also write the metrics as summaries onto disk. Furthermore, if valid_metric_name is configured, it will also perform early-stopping.

Parameters:
  • metrics (dict[str, float or np.ndarray]) – Metric values as dict.
  • **kwargs – Metric values, specified as named arguments.
get_progress()

Get the progress of training.

Returns:
The progress in range [0, 1], or None if
the progress cannot be estimated.
Return type:float or None
iter_epochs()

Iterate through the epochs.

This method can only be called when there’s no other epoch loop is being iterated. Furthermore, after exiting this loop, both the epoch metrics as well as the step metrics will be cleared.

If max_epoch is configured, it will stop at it.

Yields:int – The epoch counter (starting from 1).
iter_steps(data_generator=None)

Iterate through the steps.

This method can only be called when there’s no other step loop is being iterated, and an epoch loop is active.

Parameters:

data_generator – Optional iterable data to be yielded at every step. This is required if max_step is not configured, so as to prevent an infinite step loop.

Yields:

int or (int, any)

The global step counter (starting from 1), or

the tuple of (step counter, batch data) if data_generator is specified.

metric_collector(**kwds)

Get a StatisticsCollector for metric.

The mean value of the collected metrics will be added to summary after exiting the context. Other statistics will be discarded.

Parameters:metric_name (str) – The name of this metric.
Yields:StatisticsCollector – The collector for metric values.
print_logs()

Print the training logs.

This method will print the collected metrics. If there’s an active step loop, it will print metrics from the step metrics collector. Otherwise if there’s only an epoch loop, it will print metrics from the epoch metrics accumulator.

Note it must be called at the end of an epoch or a step. This is because the metrics of corresponding loop context will be cleared after the logs are printed. Moreover, the epoch or step timer will be committed as metric immediately when this method is called, before printing the logs.

print_training_summary()

Print the training summary.

The training summary include the following content:

  1. Execution environment.
  2. Parameters to be optimized during training.
println(message, with_tag=False)

Print message via print_function.

Parameters:
  • message (str) – Message to be printed.
  • with_tag (bool) – Whether or not to add the epoch & step tag? (default False)
timeit(**kwds)

Open a context for timing.

Parameters:metric_name (str) – Store the timing result in metric of this name. Note that metric_name must end with time or timer, otherwise by default the time values will not be formatted as human readable strings.