tfsnippet.scaffold¶
-
class
tfsnippet.scaffold.
EarlyStopping
(param_vars, initial_metric=None, checkpoint_dir=None, smaller_is_better=True, restore_on_error=False, cleanup=True, name=None)¶ Bases:
tfsnippet.utils.concepts.DisposableContext
Early-stopping context object.
This class provides a object for memorizing the parameters for best metric, in an early-stopping context. An example of using this context:
with EarlyStopping(param_vars) as es: ... es.update(loss, global_step) ...
Where
es.update(loss, global_step)
should cause the parameters to be saved on disk if loss is better than the current best metric. One may also get the current best metric viaes.best_metric
.Notes
If no loss is given via
es.update
, then the variables would keep their latest values when closing an early-stopping object.-
__init__
(param_vars, initial_metric=None, checkpoint_dir=None, smaller_is_better=True, restore_on_error=False, cleanup=True, name=None)¶ Construct the
EarlyStopping
.Parameters: - param_vars (list[tf.Variable] or dict[str, tf.Variable]) – List or
dict of variables to be memorized. If a dict is specified, the
keys of the dict would be used as the serializations keys via
VariableSaver
. - initial_metric (float or tf.Tensor or tf.Variable) – The initial best metric (for recovering from previous session).
- checkpoint_dir (str) – The directory where to save the checkpoint files. If not specified, will use a temporary directory.
- smaller_is_better (bool) – Whether or not it is better to have
smaller metric values? (default
True
) - restore_on_error (bool) – Whether or not to restore the memorized
parameters even on error? (default
False
) - cleanup (bool) – Whether or not to cleanup the checkpoint directory
on exit? This argument will be ignored if checkpoint_dir is
None
, where the temporary directory will always be deleted on exit. - name (str) – Name scope of all TensorFlow operations. (default “early_stopping”).
- param_vars (list[tf.Variable] or dict[str, tf.Variable]) – List or
dict of variables to be memorized. If a dict is specified, the
keys of the dict would be used as the serializations keys via
-
best_metric
¶ Get the current best loss.
-
ever_updated
¶ Check whether or not update method has ever been called.
-
-
tfsnippet.scaffold.
EarlyStoppingContext
¶ alias of
tfsnippet.scaffold.early_stopping_.EarlyStopping
-
tfsnippet.scaffold.
early_stopping
(*args, **kwargs)¶
-
class
tfsnippet.scaffold.
MetricFormatter
¶ Bases:
object
Base class for a training metrics formatter.
A training metric formatter determines the order of metrics, and the way to display the values of these metrics, in
MetricLogger
.-
format_metric
(name, value)¶ Format the value of specified metric.
Parameters: - name – Name of the metric.
- value – Value of the metric.
Returns: Human readable string representation of the metric value.
Return type:
-
-
class
tfsnippet.scaffold.
DefaultMetricFormatter
¶ Bases:
tfsnippet.scaffold.logs.MetricFormatter
Default training metric formatter.
This class sorts the metrics as follows:
- The metrics are first divided into groups according to the suffices
of their names as follows:
- Names ending with “time” or “timer” should come the first;
- Other metrics should come the second;
- Names ending with “loss” or “cost” should come the third;
- Names ending with “acc”, “accuracy”, “nll”, “lb” or “lower_bound” should come the fourth.
- The metrics are then sorted according to their names, within each group.
The values of the metrics would be formatted into 6-digit real numbers, except for metrics with “time” or “timer” as suffices in their names, which would be formatted using
humanize_duration()
.-
METRIC_ORDERS
= ((-1, <_sre.SRE_Pattern object>), (998, <_sre.SRE_Pattern object>), (999, <_sre.SRE_Pattern object at 0x4c7b430>))¶
-
format_metric
(name, value)¶ Format the value of specified metric.
Parameters: - name – Name of the metric.
- value – Value of the metric.
Returns: Human readable string representation of the metric value.
Return type:
- The metrics are first divided into groups according to the suffices
of their names as follows:
-
class
tfsnippet.scaffold.
MetricLogger
(summary_writer=None, summary_metric_prefix='', summary_skip_pattern=None, summary_commit_freqs=None, formatter=None)¶ Bases:
object
Logger for the training metrics.
This class provides convenient methods for logging training metrics, and for writing metrics onto disk via TensorFlow summary writer. The statistics of the metrics could be formatted into human readable strings via
format_logs()
.An example of using this logger is:
logger = MetricLogger(tf.summary.FileWriter(log_dir)) global_step = 1 for epoch in range(1, max_epoch+1): for batch in DataFlow.arrays(...): loss, _ = session.run([loss, train_op], ...) logger.collect_metrics({'loss': loss}, global_step) global_step += 1 valid_loss = session.run([loss], ...) logger.collect_metrics({'valid_loss': valid_loss}, global_step) print('Epoch {}, step {}: {}'.format( epoch, global_step, logger.format_logs())) logger.clear()
-
__init__
(summary_writer=None, summary_metric_prefix='', summary_skip_pattern=None, summary_commit_freqs=None, formatter=None)¶ Construct the
MetricLogger
.Parameters: - summary_writer – TensorFlow summary writer.
- summary_metric_prefix (str) – The prefix for the metrics committed
to summary_writer. This will not affect the summaries
added via
add_summary()
. (default “”) - summary_skip_pattern (str or regex) – Metrics matching this pattern
will be excluded from summary_writer. (default
None
) - summary_commit_freqs (dict[str, int] or None) – If specified,
a metric will be committed to summary_writer no more frequent
than
summary_commit_freqs[metric]
. (defaultNone
) - formatter (MetricFormatter) – Metric formatter for this logger.
If not specified, will use an instance of
DefaultMetricFormatter
.
-
clear
()¶ Clear all the metric statistics.
-
collect_metrics
(metrics, global_step=None)¶ Collect the statistics of metrics.
Parameters: - metrics (dict[str, float or np.ndarray or DynamicValue]) – Dict from metrics names to their values.
For
format_logs()
, there is no difference between callingcollect_metrics()
only once, with an array of metric values; or callingcollect_metrics()
multiple times, with one value at each time. However, for the TensorFlow summary writer, only the mean of the metric values would be recorded, if callingcollect_metrics()
with an array. - global_step (int or tf.Variable or tf.Tensor) – The global step counter. (optional)
- metrics (dict[str, float or np.ndarray or DynamicValue]) – Dict from metrics names to their values.
For
-
-
tfsnippet.scaffold.
summarize_variables
(variables, title='Variables Summary', other_variables_title='Other Variables', groups=None)¶ Get a formatted summary about the variables.
Parameters: - variables (list[tf.Variable] or dict[str, tf.Variable]) – List or dict of variables to be summarized.
- title (str) – Title of this summary.
- other_variables_title (str) – Title of the “Other Variables”.
- groups (None or list[str]) – List of separated variable groups, each
summarized in a table. (default
None
)
Returns: Formatted summary about the variables.
Return type:
-
class
tfsnippet.scaffold.
TrainLoop
(param_vars, var_groups=None, show_eta=True, print_func=<built-in function print>, summary_dir=None, summary_writer=None, summary_graph=None, summary_metric_prefix='metrics/', summary_skip_pattern=<_sre.SRE_Pattern object>, summary_commit_freqs=None, metric_formatter=<tfsnippet.scaffold.logs.DefaultMetricFormatter object>, valid_metric_name='valid_loss', initial_valid_metric=None, valid_metric_smaller_is_better=None, early_stopping=False, initial_epoch=0, initial_step=0, max_epoch=None, max_step=None)¶ Bases:
tfsnippet.utils.concepts.DisposableContext
Training loop object.
This class provides a set of convenient methods for writing training loop. It is useful for maintaining epoch and step counters, logging training metrics, memorizing best parameters for early-stopping, etc. An example of using the
TrainLoop
:from tfsnippet.dataflow import DataFlow from tfsnippet.scaffold import TrainLoop with TrainLoop(param_vars, max_epoch=10, early_stopping=True) as loop: loop.print_training_summary() train_flow = DataFlow.arrays([x, y], batch_size, shuffle=True) for epoch in loop.iter_epochs(): for step, (x, y) in loop.iter_steps(train_flow): step_loss = session.run( [loss, train_op], feed_dict={input_x: x, input_y: y} ) loop.collect_metrics(loss=step_loss) with loop.timeit('valid_time'): valid_loss = session.run( loss, feed_dict={input_x: test_x, input_y: test_y}) loop.collect_metrics(valid_loss=valid_loss) loop.print_logs()
-
__init__
(param_vars, var_groups=None, show_eta=True, print_func=<built-in function print>, summary_dir=None, summary_writer=None, summary_graph=None, summary_metric_prefix='metrics/', summary_skip_pattern=<_sre.SRE_Pattern object>, summary_commit_freqs=None, metric_formatter=<tfsnippet.scaffold.logs.DefaultMetricFormatter object>, valid_metric_name='valid_loss', initial_valid_metric=None, valid_metric_smaller_is_better=None, early_stopping=False, initial_epoch=0, initial_step=0, max_epoch=None, max_step=None)¶ Construct the
TrainLoop
.Parameters: - param_vars (list[tf.Variable] or dict[str, tf.Variable]) – List or dict of variables, optimized during training.
- var_groups (None or list[str]) – Variable groups, the prefixes of
variable scopes. A hint for printing the variables summary.
(default
None
) - show_eta (bool) – Whether or not to show ETA? (default
True
) - print_func ((str) -> None) – Function for printing log messages
(calling
print
by default). An alternative of this argument may begetLogger(__name__).info
, such that the log messages will be printed via logging facilities. - summary_dir (str) – Directory for writing TensorFlow summaries. Ignored if summary_writer is specified.
- summary_writer – TensorFlow summary writer for writing metrics.
- summary_metric_prefix (str) – The prefix for the metrics committed
to summary_writer. This will not affect the summaries
added via
add_summary()
. (default “”) - summary_graph – If specified, log the graph via summary_writer.
- summary_skip_pattern (str or regex) – Metrics matching this pattern will be excluded from summary_writer. (default “.*(time|timer)$”)
- summary_commit_freqs (dict[str, int] or None) – If specified,
a metric will be committed to summary_writer no more frequent
than
summary_commit_freqs[metric]
. (defaultNone
) - metric_formatter (MetricFormatter) – The training metrics formatter.
- valid_metric_name (str) – Name of the validation metric.
- initial_valid_metric (float or tf.Tensor or tf.Variable) – Initial value of the validation metric for early-stopping.
- valid_metric_smaller_is_better (bool) – Whether or not the smaller
value is better for validation metric? If not specified, it
will be inferred according to valid_metric_name: metric names
with
acc
oraccuracy
as suffix implyTrue
, while other names implyFalse
. - early_stopping (bool) – Whether or not to do early-stopping?
(default
False
) IfTrue
, early-stopping will be applied on param_vars, according to the validation metric. - initial_epoch (int or tf.Tensor or tf.Variable) – The initial epoch (default 0). Should be one less than the actual first epoch.
- initial_step (int or tf.Tensor or tf.Variable) – The initial step (default 0). Should be one less than the actual first step.
- max_epoch (None or int or tf.Tensor or tf.Variable) – The maximum epoch to run. If
None
, will run for infinite epochs. If1
, the epoch counter will be discarded in the output logs. (defaultNone
) - max_step (None or int or tf.Tensor or tf.Variable) – The maximum step to run. If
None
, will run for infinite steps. Note this limit applies for the total step counter, rather than the epoch-wise step counter. (defaultNone
)
-
add_summary
(summary)¶ Add a summary object, with
self.step
as global_step.Parameters: summary (tf.summary.Summary or byes) – TensorFlow summary object, or serialized summary.
-
best_valid_metric
¶ Get the best valid metric.
-
collect_metrics
(metrics=None, **kwargs)¶ Add metric values.
This method must be called when there’s at least an active epoch loop. It will add metrics to the epoch metrics collector, and if there’s an active step loop, it will also add metrics to the step metrics collector.
If summary_writer is configured, it will also write the metrics as summaries onto disk. Furthermore, if valid_metric_name is configured, it will also perform early-stopping.
Parameters:
-
epoch
¶ Get the epoch counter (starting from 1).
-
get_progress
()¶ Get the progress of training.
Returns: - The progress in range
[0, 1]
, or None if - the progress cannot be estimated.
Return type: float or None - The progress in range
-
iter_epochs
()¶ Iterate through the epochs.
This method can only be called when there’s no other epoch loop is being iterated. Furthermore, after exiting this loop, both the epoch metrics as well as the step metrics will be cleared.
If max_epoch is configured, it will stop at it.
Yields: int – The epoch counter (starting from 1).
-
iter_steps
(data_generator=None)¶ Iterate through the steps.
This method can only be called when there’s no other step loop is being iterated, and an epoch loop is active.
Parameters: data_generator – Optional iterable data to be yielded at every step. This is required if max_step is not configured, so as to prevent an infinite step loop.
Yields: int or (int, any) –
- The global step counter (starting from 1), or
the tuple of
(step counter, batch data)
if data_generator is specified.
-
max_epoch
¶ Get or set the max value for epoch counter.
-
max_step
¶ Get or set the max value for global step counter.
-
metric_collector
(**kwds)¶ Get a
StatisticsCollector
for metric.The mean value of the collected metrics will be added to summary after exiting the context. Other statistics will be discarded.
Parameters: metric_name (str) – The name of this metric. Yields: StatisticsCollector – The collector for metric values.
-
param_vars
¶ Get the trainable parameter variables.
-
print_logs
()¶ Print the training logs.
This method will print the collected metrics. If there’s an active step loop, it will print metrics from the step metrics collector. Otherwise if there’s only an epoch loop, it will print metrics from the epoch metrics accumulator.
Note it must be called at the end of an epoch or a step. This is because the metrics of corresponding loop context will be cleared after the logs are printed. Moreover, the epoch or step timer will be committed as metric immediately when this method is called, before printing the logs.
-
print_training_summary
()¶ Print the training summary.
The training summary include the following content:
- Execution environment.
- Parameters to be optimized during training.
-
println
(message, with_tag=False)¶ Print message via print_function.
Parameters:
-
step
¶ Get the global step counter (starting from 1).
-
summary_metric_prefix
¶ Get the prefix for the metrics committed to summary_writer.
-
summary_writer
¶ Get the summary writer instance.
-
timeit
(**kwds)¶ Open a context for timing.
Parameters: metric_name (str) – Store the timing result in metric of this name. Note that metric_name must end with time
ortimer
, otherwise by default the time values will not be formatted as human readable strings.
-
use_early_stopping
¶ Whether or not to adopt early-stopping?
-
valid_metric_name
¶ Get the name of the validation metric.
-
valid_metric_smaller_is_better
¶ Whether or not the smaller value is better for validation metric?
-
var_groups
¶ Get the variable groups.
-
-
tfsnippet.scaffold.
TrainLoopContext
¶ alias of
tfsnippet.scaffold.train_loop_.TrainLoop
-
tfsnippet.scaffold.
train_loop
(*args, **kwargs)¶