tfsnippet.scaffold

class tfsnippet.scaffold.EarlyStopping(param_vars, initial_metric=None, checkpoint_dir=None, smaller_is_better=True, restore_on_error=False, cleanup=True, name=None)

Bases: tfsnippet.utils.concepts.DisposableContext

Early-stopping context object.

This class provides a object for memorizing the parameters for best metric, in an early-stopping context. An example of using this context:

with EarlyStopping(param_vars) as es:
    ...
    es.update(loss, global_step)
    ...

Where es.update(loss, global_step) should cause the parameters to be saved on disk if loss is better than the current best metric. One may also get the current best metric via es.best_metric.

Notes

If no loss is given via es.update, then the variables would keep their latest values when closing an early-stopping object.

__init__(param_vars, initial_metric=None, checkpoint_dir=None, smaller_is_better=True, restore_on_error=False, cleanup=True, name=None)

Construct the EarlyStopping.

Parameters:
  • param_vars (list[tf.Variable] or dict[str, tf.Variable]) – List or dict of variables to be memorized. If a dict is specified, the keys of the dict would be used as the serializations keys via VariableSaver.
  • initial_metric (float or tf.Tensor or tf.Variable) – The initial best metric (for recovering from previous session).
  • checkpoint_dir (str) – The directory where to save the checkpoint files. If not specified, will use a temporary directory.
  • smaller_is_better (bool) – Whether or not it is better to have smaller metric values? (default True)
  • restore_on_error (bool) – Whether or not to restore the memorized parameters even on error? (default False)
  • cleanup (bool) – Whether or not to cleanup the checkpoint directory on exit? This argument will be ignored if checkpoint_dir is None, where the temporary directory will always be deleted on exit.
  • name (str) – Name scope of all TensorFlow operations. (default “early_stopping”).
best_metric

Get the current best loss.

ever_updated

Check whether or not update method has ever been called.

update(metric, global_step=None)

Update the best metric.

Parameters:
  • metric (float) – New metric value.
  • global_step (int) – Optional global step counter.
Returns:

Whether or not the best loss has been updated?

Return type:

bool

tfsnippet.scaffold.EarlyStoppingContext

alias of tfsnippet.scaffold.early_stopping_.EarlyStopping

tfsnippet.scaffold.early_stopping(*args, **kwargs)
class tfsnippet.scaffold.MetricFormatter

Bases: object

Base class for a training metrics formatter.

A training metric formatter determines the order of metrics, and the way to display the values of these metrics, in MetricLogger.

format_metric(name, value)

Format the value of specified metric.

Parameters:
  • name – Name of the metric.
  • value – Value of the metric.
Returns:

Human readable string representation of the metric value.

Return type:

str

sort_metrics(names)

Sort the names of metrics.

Parameters:names – Iterable metric names.
Returns:Sorted metric names.
Return type:list[str]
class tfsnippet.scaffold.DefaultMetricFormatter

Bases: tfsnippet.scaffold.logs.MetricFormatter

Default training metric formatter.

This class sorts the metrics as follows:

  1. The metrics are first divided into groups according to the suffices of their names as follows:
    1. Names ending with “time” or “timer” should come the first;
    2. Other metrics should come the second;
    3. Names ending with “loss” or “cost” should come the third;
    4. Names ending with “acc”, “accuracy”, “nll”, “lb” or “lower_bound” should come the fourth.
  2. The metrics are then sorted according to their names, within each group.

The values of the metrics would be formatted into 6-digit real numbers, except for metrics with “time” or “timer” as suffices in their names, which would be formatted using humanize_duration().

METRIC_ORDERS = ((-1, <_sre.SRE_Pattern object>), (998, <_sre.SRE_Pattern object>), (999, <_sre.SRE_Pattern object at 0x541eec0>))
format_metric(name, value)

Format the value of specified metric.

Parameters:
  • name – Name of the metric.
  • value – Value of the metric.
Returns:

Human readable string representation of the metric value.

Return type:

str

sort_metrics(names)

Sort the names of metrics.

Parameters:names – Iterable metric names.
Returns:Sorted metric names.
Return type:list[str]
class tfsnippet.scaffold.MetricLogger(summary_writer=None, summary_metric_prefix='', summary_skip_pattern=None, summary_commit_freqs=None, formatter=None)

Bases: object

Logger for the training metrics.

This class provides convenient methods for logging training metrics, and for writing metrics onto disk via TensorFlow summary writer. The statistics of the metrics could be formatted into human readable strings via format_logs().

An example of using this logger is:

logger = MetricLogger(tf.summary.FileWriter(log_dir))
global_step = 1

for epoch in range(1, max_epoch+1):
    for batch in DataFlow.arrays(...):
        loss, _ = session.run([loss, train_op], ...)
        logger.collect_metrics({'loss': loss}, global_step)
        global_step += 1

    valid_loss = session.run([loss], ...)
    logger.collect_metrics({'valid_loss': valid_loss}, global_step)
    print('Epoch {}, step {}: {}'.format(
        epoch, global_step, logger.format_logs()))
    logger.clear()
__init__(summary_writer=None, summary_metric_prefix='', summary_skip_pattern=None, summary_commit_freqs=None, formatter=None)

Construct the MetricLogger.

Parameters:
  • summary_writer – TensorFlow summary writer.
  • summary_metric_prefix (str) – The prefix for the metrics committed to summary_writer. This will not affect the summaries added via add_summary(). (default “”)
  • summary_skip_pattern (str or regex) – Metrics matching this pattern will be excluded from summary_writer. (default None)
  • summary_commit_freqs (dict[str, int] or None) – If specified, a metric will be committed to summary_writer no more frequent than summary_commit_freqs[metric]. (default None)
  • formatter (MetricFormatter) – Metric formatter for this logger. If not specified, will use an instance of DefaultMetricFormatter.
clear()

Clear all the metric statistics.

collect_metrics(metrics, global_step=None)

Collect the statistics of metrics.

Parameters:
  • metrics (dict[str, float or np.ndarray or DynamicValue]) – Dict from metrics names to their values. For format_logs(), there is no difference between calling collect_metrics() only once, with an array of metric values; or calling collect_metrics() multiple times, with one value at each time. However, for the TensorFlow summary writer, only the mean of the metric values would be recorded, if calling collect_metrics() with an array.
  • global_step (int or tf.Variable or tf.Tensor) – The global step counter. (optional)
format_logs()

Format the metric statistics as human readable strings.

Returns:The formatted metric statistics.
Return type:str
tfsnippet.scaffold.summarize_variables(variables, title='Variables Summary', other_variables_title='Other Variables', groups=None)

Get a formatted summary about the variables.

Parameters:
  • variables (list[tf.Variable] or dict[str, tf.Variable]) – List or dict of variables to be summarized.
  • title (str) – Title of this summary.
  • other_variables_title (str) – Title of the “Other Variables”.
  • groups (None or list[str]) – List of separated variable groups, each summarized in a table. (default None)
Returns:

Formatted summary about the variables.

Return type:

str

class tfsnippet.scaffold.TrainLoop(param_vars, var_groups=None, show_eta=True, print_func=<built-in function print>, summary_dir=None, summary_writer=None, summary_graph=None, summary_metric_prefix='metrics/', summary_skip_pattern=<_sre.SRE_Pattern object>, summary_commit_freqs=None, metric_formatter=<tfsnippet.scaffold.logs.DefaultMetricFormatter object>, valid_metric_name='valid_loss', initial_valid_metric=None, valid_metric_smaller_is_better=None, early_stopping=False, initial_epoch=0, initial_step=0, max_epoch=None, max_step=None)

Bases: tfsnippet.utils.concepts.DisposableContext

Training loop object.

This class provides a set of convenient methods for writing training loop. It is useful for maintaining epoch and step counters, logging training metrics, memorizing best parameters for early-stopping, etc. An example of using the TrainLoop:

from tfsnippet.dataflow import DataFlow
from tfsnippet.scaffold import TrainLoop

with TrainLoop(param_vars, max_epoch=10, early_stopping=True) as loop:
    loop.print_training_summary()
    train_flow = DataFlow.arrays([x, y], batch_size, shuffle=True)

    for epoch in loop.iter_epochs():
        for step, (x, y) in loop.iter_steps(train_flow):
            step_loss = session.run(
                [loss, train_op],
                feed_dict={input_x: x, input_y: y}
            )
            loop.collect_metrics(loss=step_loss)
        with loop.timeit('valid_time'):
            valid_loss = session.run(
                loss, feed_dict={input_x: test_x, input_y: test_y})
            loop.collect_metrics(valid_loss=valid_loss)
        loop.print_logs()
__init__(param_vars, var_groups=None, show_eta=True, print_func=<built-in function print>, summary_dir=None, summary_writer=None, summary_graph=None, summary_metric_prefix='metrics/', summary_skip_pattern=<_sre.SRE_Pattern object>, summary_commit_freqs=None, metric_formatter=<tfsnippet.scaffold.logs.DefaultMetricFormatter object>, valid_metric_name='valid_loss', initial_valid_metric=None, valid_metric_smaller_is_better=None, early_stopping=False, initial_epoch=0, initial_step=0, max_epoch=None, max_step=None)

Construct the TrainLoop.

Parameters:
  • param_vars (list[tf.Variable] or dict[str, tf.Variable]) – List or dict of variables, optimized during training.
  • var_groups (None or list[str]) – Variable groups, the prefixes of variable scopes. A hint for printing the variables summary. (default None)
  • show_eta (bool) – Whether or not to show ETA? (default True)
  • print_func ((str) -> None) – Function for printing log messages (calling print by default). An alternative of this argument may be getLogger(__name__).info, such that the log messages will be printed via logging facilities.
  • summary_dir (str) – Directory for writing TensorFlow summaries. Ignored if summary_writer is specified.
  • summary_writer – TensorFlow summary writer for writing metrics.
  • summary_metric_prefix (str) – The prefix for the metrics committed to summary_writer. This will not affect the summaries added via add_summary(). (default “”)
  • summary_graph – If specified, log the graph via summary_writer.
  • summary_skip_pattern (str or regex) – Metrics matching this pattern will be excluded from summary_writer. (default “.*(time|timer)$”)
  • summary_commit_freqs (dict[str, int] or None) – If specified, a metric will be committed to summary_writer no more frequent than summary_commit_freqs[metric]. (default None)
  • metric_formatter (MetricFormatter) – The training metrics formatter.
  • valid_metric_name (str) – Name of the validation metric.
  • initial_valid_metric (float or tf.Tensor or tf.Variable) – Initial value of the validation metric for early-stopping.
  • valid_metric_smaller_is_better (bool) – Whether or not the smaller value is better for validation metric? If not specified, it will be inferred according to valid_metric_name: metric names with acc or accuracy as suffix imply True, while other names imply False.
  • early_stopping (bool) – Whether or not to do early-stopping? (default False) If True, early-stopping will be applied on param_vars, according to the validation metric.
  • initial_epoch (int or tf.Tensor or tf.Variable) – The initial epoch (default 0). Should be one less than the actual first epoch.
  • initial_step (int or tf.Tensor or tf.Variable) – The initial step (default 0). Should be one less than the actual first step.
  • max_epoch (None or int or tf.Tensor or tf.Variable) – The maximum epoch to run. If None, will run for infinite epochs. If 1, the epoch counter will be discarded in the output logs. (default None)
  • max_step (None or int or tf.Tensor or tf.Variable) – The maximum step to run. If None, will run for infinite steps. Note this limit applies for the total step counter, rather than the epoch-wise step counter. (default None)
add_summary(summary)

Add a summary object, with self.step as global_step.

Parameters:summary (tf.summary.Summary or byes) – TensorFlow summary object, or serialized summary.
best_valid_metric

Get the best valid metric.

collect_metrics(metrics=None, **kwargs)

Add metric values.

This method must be called when there’s at least an active epoch loop. It will add metrics to the epoch metrics collector, and if there’s an active step loop, it will also add metrics to the step metrics collector.

If summary_writer is configured, it will also write the metrics as summaries onto disk. Furthermore, if valid_metric_name is configured, it will also perform early-stopping.

Parameters:
  • metrics (dict[str, float or np.ndarray]) – Metric values as dict.
  • **kwargs – Metric values, specified as named arguments.
epoch

Get the epoch counter (starting from 1).

get_progress()

Get the progress of training.

Returns:
The progress in range [0, 1], or None if
the progress cannot be estimated.
Return type:float or None
iter_epochs()

Iterate through the epochs.

This method can only be called when there’s no other epoch loop is being iterated. Furthermore, after exiting this loop, both the epoch metrics as well as the step metrics will be cleared.

If max_epoch is configured, it will stop at it.

Yields:int – The epoch counter (starting from 1).
iter_steps(data_generator=None)

Iterate through the steps.

This method can only be called when there’s no other step loop is being iterated, and an epoch loop is active.

Parameters:

data_generator – Optional iterable data to be yielded at every step. This is required if max_step is not configured, so as to prevent an infinite step loop.

Yields:

int or (int, any)

The global step counter (starting from 1), or

the tuple of (step counter, batch data) if data_generator is specified.

max_epoch

Get or set the max value for epoch counter.

max_step

Get or set the max value for global step counter.

metric_collector(**kwds)

Get a StatisticsCollector for metric.

The mean value of the collected metrics will be added to summary after exiting the context. Other statistics will be discarded.

Parameters:metric_name (str) – The name of this metric.
Yields:StatisticsCollector – The collector for metric values.
param_vars

Get the trainable parameter variables.

print_logs()

Print the training logs.

This method will print the collected metrics. If there’s an active step loop, it will print metrics from the step metrics collector. Otherwise if there’s only an epoch loop, it will print metrics from the epoch metrics accumulator.

Note it must be called at the end of an epoch or a step. This is because the metrics of corresponding loop context will be cleared after the logs are printed. Moreover, the epoch or step timer will be committed as metric immediately when this method is called, before printing the logs.

print_training_summary()

Print the training summary.

The training summary include the following content:

  1. Execution environment.
  2. Parameters to be optimized during training.
println(message, with_tag=False)

Print message via print_function.

Parameters:
  • message (str) – Message to be printed.
  • with_tag (bool) – Whether or not to add the epoch & step tag? (default False)
step

Get the global step counter (starting from 1).

summary_metric_prefix

Get the prefix for the metrics committed to summary_writer.

summary_writer

Get the summary writer instance.

timeit(**kwds)

Open a context for timing.

Parameters:metric_name (str) – Store the timing result in metric of this name. Note that metric_name must end with time or timer, otherwise by default the time values will not be formatted as human readable strings.
use_early_stopping

Whether or not to adopt early-stopping?

valid_metric_name

Get the name of the validation metric.

valid_metric_smaller_is_better

Whether or not the smaller value is better for validation metric?

var_groups

Get the variable groups.

tfsnippet.scaffold.TrainLoopContext

alias of tfsnippet.scaffold.train_loop_.TrainLoop

tfsnippet.scaffold.train_loop(*args, **kwargs)