tfsnippet.trainer

class tfsnippet.trainer.BaseTrainer(loop)

Bases: object

Base class for all trainers.

All the trainers provided in tfsnippet.trainer are not designed to take control of the training totally, which is often assumed in other libraries such as Keras. Instead, it just takes responsibility of assembling different steps of a training process together, and run the main training loop. So it is usually the caller’s responsibility to derive his training operation from a certain TensorFlow optimizer, and pass it to a proper trainer.

__init__(loop)

Initialize the internal states of BaseTrainer.

Parameters:loop (TrainLoop) – The training loop object.
_iter_steps()

Subclasses should override this to iterate through steps.

A common implementation of _iter_steps() might be:

def _iter_steps(self):
    return self.loop.iter_steps(training_data)
Yields:

int or (int, tuple[np.ndarray])

The step counter, or the step

counter as well as the step training data. Will be directly given to _fit_step() as the payload argument.

_run_step(session, payload)

Subclasses should override this to run a training step.

Parameters:
  • session – The TensorFlow session.
  • payload – The step payload generated by _iter_steps().
after_epochs

Get the hooks run after epochs.

Returns:The hook list.
Return type:HookList
after_steps

Get the hooks run after steps.

Returns:The hook list.
Return type:HookList
anneal_after(value, epochs=None, steps=None)

Add an annealing hook to run after every few epochs or steps.

Parameters:
  • value (AnnealingDynamicValue or () -> any) – An annealing dynamic value (which has .anneal()), or any callable object.
  • epochs (None or int) – Run validation after every this few epochs.
  • steps (None or int) – Run validation after every this few steps.
Raises:

ValueError – If both epochs and steps are specified, or neither is specified.

anneal_after_epochs(value, freq)

Add an annealing hook to run after every few epochs.

Parameters:
  • value (AnnealingDynamicValue or () -> any) – An annealing dynamic value (which has .anneal()), or any callable object.
  • freq (int) – The frequency for this annealing hook to run.
anneal_after_steps(value, freq)

Add an annealing hook to run after every few steps.

Parameters:
  • value (AnnealingDynamicValue or () -> any) – An annealing dynamic value (which has .anneal()), or any callable object.
  • freq (int) – The frequency for this annealing hook to run.
before_epochs

Get the hooks run before epochs.

Returns:The hook list.
Return type:HookList
before_steps

Get the hooks run before steps.

Returns:The hook list.
Return type:HookList
evaluate_after(evaluator, epochs=None, steps=None)

Add an evaluation hook to run after every few epochs or steps.

Parameters:
  • evaluator (Evaluator or () -> any) – A evaluator object (which has .run()), or any callable object.
  • epochs (None or int) – Run validation after every this few epochs.
  • steps (None or int) – Run validation after every this few steps.
Raises:

ValueError – If both epochs and steps are specified, or neither is specified.

evaluate_after_epochs(evaluator, freq)

Add an evaluation hook to run after every few epochs.

Parameters:
  • evaluator (Evaluator or () -> any) – A evaluator object (which has .run()), or any callable object.
  • freq (int) – The frequency for this evaluation hook to run.
evaluate_after_steps(evaluator, freq)

Add an evaluation hook to run after every few steps.

Parameters:
  • evaluator (Evaluator or () -> any) – A evaluator object (which has .run()), or any callable object.
  • freq (int) – The frequency for this evaluation hook to run.
hook_lists

Get all the hook lists.

Returns:
The tuple (self.before_epochs, self.before_steps,
self.after_steps, self.after_epochs).
Return type:tuple[HookList]
log_after(epochs=None, steps=None)

Add a logging hook to run after every few epochs or steps.

Parameters:
  • epochs (None or int) – Run validation after every this few epochs.
  • steps (None or int) – Run validation after every this few steps.
Raises:

ValueError – If both epochs and steps are specified, or neither is specified.

log_after_epochs(freq)

Add a logging hook to run after every few epochs.

Parameters:freq (int) – The frequency for this logging hook to run.
log_after_steps(freq)

Add a logging hook to run after every few steps.

Parameters:freq (int) – The frequency for this logging hook to run.
loop

Get the training loop object.

Returns:The training loop object.
Return type:TrainLoop
remove_annealing_hooks()

Remove annealing hooks from all lists.

Returns:The number of removed hooks.
Return type:int
remove_by_priority(priority)

Remove hooks having the specified priority from all lists.

Parameters:priority – The priority of the hooks to be removed.
Returns:The number of removed hooks.
Return type:int
remove_evaluation_hooks()

Remove evaluation hooks from all lists.

Returns:The number of removed hooks.
Return type:int
remove_log_hooks()

Remove logging hooks from all lists.

Returns:The number of removed hooks.
Return type:int
remove_validation_hooks()

Remove evaluation hooks from all lists.

Returns:The number of removed hooks.
Return type:int
run()

Run training loop.

validate_after(evaluator, epochs=None, steps=None)

Add an evaluation hook to run after every few epochs or steps.

Parameters:
  • evaluator (Evaluator or () -> any) – A evaluator object (which has .run()), or any callable object.
  • epochs (None or int) – Run validation after every this few epochs.
  • steps (None or int) – Run validation after every this few steps.
Raises:

ValueError – If both epochs and steps are specified, or neither is specified.

validate_after_epochs(evaluator, freq)

Add an evaluation hook to run after every few epochs.

Parameters:
  • evaluator (Evaluator or () -> any) – A evaluator object (which has .run()), or any callable object.
  • freq (int) – The frequency for this evaluation hook to run.
validate_after_steps(evaluator, freq)

Add an evaluation hook to run after every few steps.

Parameters:
  • evaluator (Evaluator or () -> any) – A evaluator object (which has .run()), or any callable object.
  • freq (int) – The frequency for this evaluation hook to run.
class tfsnippet.trainer.DynamicValue

Bases: object

Dynamic values fed into trainers.

It is sometimes necessary to feed a dynamic value into a trainer, e.g., an annealing learning rate. This class provides such a base class for all dynamic values.

get()

Get the current value of this DynamicValue object.

class tfsnippet.trainer.SimpleDynamicValue(value)

Bases: tfsnippet.trainer.dynamic_values.DynamicValue

A simple DynamicValue, which stores the value in its internal attribute, and can be changed by set().

__init__(value)

Construct a new SimpleDynamicValue.

Parameters:value – Any value to be set. It can even be another instance of DynamicValue.
get()

Get the current value of this DynamicValue object.

set(value)

Set the value of this SimpleDynamicValue instance.

Parameters:value – Any value to be set. It can even be another instance of DynamicValue.
class tfsnippet.trainer.AnnealingDynamicValue(initial_value, ratio, min_value=None)

Bases: tfsnippet.trainer.dynamic_values.SimpleDynamicValue

A DynamicValue whose value is annealed (scaled) each time anneal() is called.

__init__(initial_value, ratio, min_value=None)

Construct a new AnnealingDynamicValue.

Parameters:
  • initial_value – A number, the initial value.
  • ratio – A number, the ratio of annealing at each time.
  • min_value – Optional, a number, the minimum value.
anneal()

Anneal the value.

tfsnippet.trainer.auto_batch_weight(*batch_arrays)

Automatically inspect the metric weight for an evaluation mini-batch.

Parameters:*batch_arrays – Mini-batch arrays. The .size of the first array will be used as the weight.
Returns:The inspected weight, or 1. if any error occurs during inspection.
class tfsnippet.trainer.Evaluator(loop, metrics, inputs, data_flow, feed_dict=None, time_metric_name='eval_time', batch_weight_func=<function auto_batch_weight>)

Bases: object

Class to compute evaluation metrics.

It is a common practice to compute one or more metrics for evaluation and validation during the training process. This class provides a convenient interface for computing metrics by mini-batches.

__init__(loop, metrics, inputs, data_flow, feed_dict=None, time_metric_name='eval_time', batch_weight_func=<function auto_batch_weight>)

Construct a new Evaluator.

Parameters:
  • loop (TrainLoop) – The training loop object.
  • metrics (Tensor or dict[str, Tensor]) –

    The validation loss metric, or a dict of metrics. All the metrics must be 0-d tensors.

    If only a loss is specified, the default validation loss name loop.valid_metric_name will be used as its name. Otherwise if a dict is specified, the keys will be used as the names of each metric.

  • inputs (list[tf.Tensor]) – The input placeholders. The number of tensors, and the order of tensors, should both match the arrays of each mini-batch data, provided by data_flow.
  • data_flow (DataFlow) – The validation data flow.
  • feed_dict (dict[tf.Tensor, any]) – The fixed feed dict for validation. It will be merged with inputs and the argument of run(feed_dict). (default None)
  • time_metric_name (None or str) – The metric name for collecting evaluation time usage. Specify None to suppress the time usage metric. (default “eval_time”)
  • batch_weight_func ((*arrays) -> float or None) – Specify how to compute the metric weight for each mini-batch. If None, will use 1. as the metric weight. (default auto_batch_weight())
after_run

Get the hooks run after evaluation.

Returns:The hook list.
Return type:HookList
batch_weight_func

Get the function to compute the metric weight for each mini-batch.

before_run

Get the hooks run before evaluation.

Returns:The hook list.
Return type:HookList
data_flow

Get the validation data flow.

Returns:The validation data flow.
Return type:DataFlow
feed_dict

Get the fixed feed dict.

Returns:The fixed feed dict.
Return type:dict[tf.Tensor, any]
inputs

Get the input placeholders.

Returns:The input placeholders.
Return type:list[tf.Tensor]
last_metrics_dict

Get the metric values from last evaluation.

Returns:The metric values dict.
Return type:dict[str, any]
loop

Get the training loop object.

Returns:The training loop object.
Return type:TrainLoop
metrics

Get the metrics to compute.

Returns:The metrics to compute.
Return type:OrderedDict[str, tf.Tensor]
run(feed_dict=None)

Run evaluation.

Parameters:feed_dict – The extra feed dict to be merged with the already configured dict. (default None)
time_metric_name

Get the metric name for collecting evaluation time usage.

tfsnippet.trainer.resolve_feed_dict(feed_dict, inplace=False)

Resolve all dynamic values in feed_dict into fixed values.

The supported dynamic value types and corresponding resolving method is listed as follows:

  1. DynamicValue: get() will be called.
  2. callable object: Will be called to get the value.
Parameters:
  • feed_dict (dict[tf.Tensor, any]) – The feed dict to be resolved.
  • inplace (bool) – Whether or not to fill resolved values in the input feed_dict directly, instead of copying a new one? (default False)
Returns:

The resolved feed dict.

tfsnippet.trainer.merge_feed_dict(*feed_dicts)

Merge all feed dicts into one.

Parameters:**feed_dicts – List of feed dicts. The later ones will override values specified in the previous ones. If a None is specified, it will be simply ignored.
Returns:The merged feed dict.
class tfsnippet.trainer.HookPriority

Bases: object

Pre-defined hook priorities for BaseTrainer and Evaluator.

Smaller values take higher priorities.

ANNEALING = 1500
DEFAULT = 1000
EVALUATION = 500
LOGGING = 10000
VALIDATION = 500
class tfsnippet.trainer.HookEntry(callback, freq, priority, birth)

Bases: object

Configurations of a hook entry in HookList.

__init__(callback, freq, priority, birth)

Construct a new HookEntry.

Parameters:
  • callback (() -> any) – The callable object, as the hook callback.
  • freq (int) – The frequency for this callback to be called.
  • priority (int) – The hook priority. Smaller number has higher priority when the hooks are called.
  • birth (int) – The counter of birth, as an additional key for sorting the hook entries, such that old hooks will be placed in front of newly added hooks, if they have the same priority.
maybe_call()

Decrease the counter, and call the callback if counter is less than 1. The counter will be reset to freq after then.

reset_counter()

Reset the counter to freq, its initial value.

sort_key()

Get the key for sorting this hook entry.

class tfsnippet.trainer.HookList

Bases: object

Class for managing hooks in BaseTrainer and Evaluator.

A hook is a registered callback that the trainers will call at certain time, during the training process. Apart from the callback method, each hook has a freq and a priority.

  • The freq controls how often the particular hook should be called, e.g., every 2 epochs.
  • The priority determines the priority (order) of calling the hooks. Smaller number corresponds to higher priority.
__init__()

Construct a new HookList.

add_hook(callback, freq=1, priority=1000)

Add a hook into the list.

Parameters:
  • callback (() -> any) – The callable object, as the hook callback.
  • freq (int) – The frequency for this callback to be called.
  • priority (int) – The hook priority. Smaller number has higher priority when the hooks are called.
call_hooks()

Call all the registered hooks.

If any of the hook raises an error, it will stop the calling chain, and propagate the error to upper caller.

remove(callback)

Remove all hooks having the specified callback.

Parameters:callback – The callback of the hooks to be removed.
Returns:The number of removed hooks.
Return type:int
remove_all()

Remove all hooks.

Returns:The number of removed hooks.
Return type:int
remove_by_priority(priority)

Remove all hooks having the specified priority.

Parameters:priority (int) – The priority of the hooks to be removed.
Returns:The number of removed hooks.
Return type:int
remove_if(condition)

Remove all hooks matching the specified condition.

Parameters:condition ((callback, freq, priority) -> bool) – A callable object to tell whether or not a hook should be removed.
Returns:The number of removed hooks.
Return type:int
reset()

Reset the frequency counter of all hooks.

class tfsnippet.trainer.LossTrainer(loop, loss, train_op, inputs, data_flow, feed_dict=None, metric_name='loss')

Bases: tfsnippet.trainer.trainer.Trainer

A subclass of BaseTrainer, which optimizes a single loss. This class is deprecated, use Trainer instead.

__init__(loop, loss, train_op, inputs, data_flow, feed_dict=None, metric_name='loss')

Construct a new LossTrainer.

Parameters:
  • loop (TrainLoop) – The training loop object.
  • loss (tf.Tensor) – The training loss.
  • train_op (tf.Operation) – The training operation.
  • inputs (list[tf.Tensor]) – The input placeholders. The number of tensors, and the order of tensors, should both match the arrays of each mini-batch data, provided by data_flow.
  • data_flow (DataFlow) – The training data flow. Each mini-batch must contain one array for each placeholder in inputs.
  • feed_dict – The feed dict for training. It will be merged with the arrays provided by data_flow in each step. (default None)
  • metric_name (str) – The metric name for collecting training loss.
loss

Get the training loss.

metric_name

Get the metric name for collecting training loss.

run(feed_dict=None)

Run training loop.

Parameters:feed_dict – DEPRECATED. The extra feed dict to be merged with the already configured dict. (default None)
class tfsnippet.trainer.Trainer(loop, train_op, inputs, data_flow, feed_dict=None, metrics=None)

Bases: tfsnippet.trainer.base_trainer.BaseTrainer

A subclass of BaseTrainer, executing a training operation per step. This might be the most commonly used Trainer. Code example:

from tfsnippet.scaffold import TrainLoop
from tfsnippet.trainer import (LossTrainer,
                               Evaluator,
                               AnnealingDynamicValue)

# build the model
input_x = tf.placeholder(...)
input_y = tf.placeholder(...)
learning_rate = tf.placeholder(...)  # learning rate annealing

# prepare for the data and
train_data = DataFlow.arrays(
    [train_x, train_y], batch_size=128, shuffle=True,
    skip_incomplete=True
)
valid_data = DataFlow.arrays(
    [valid_x, valid_y], batch_size=512)
...

# derive the training operation
optimizer = tf.train.AdamOptimizer(learning_rate)
train_op = optimizer.minimize(loss)

# run the trainer
learning_rate_var = AnnealingDynamicValue(0.001, ratio=0.75)

with TrainLoop(param_vars,
               max_epoch=10,
               early_stopping=True) as loop:
    trainer = Trainer(
        loop, train_op, [input_x, input_y], train_data,
        feed_dict={learning_rate: learning_rate_var},
        metrics={'loss': loss'}
    )
    evaluator = Evaluator(
        loop, {'loss': loss}, [input_x, input_y], valid_data)

    # validate after every epoch
    trainer.evaluate_after_epochs(evaluator, freq=1)

    # log after every epoch (and after validation, since
    # ``HookPriority.VALIDATION < HookPriority.LOGGING``)
    trainer.log_after_epochs(freq=1)

    # anneal the learning rate after every 10 epochs
    trainer.anneal_after_epochs(learning_rate_var, freq=10)

    # run the main training loop
    trainer.run()
__init__(loop, train_op, inputs, data_flow, feed_dict=None, metrics=None)
Parameters:
  • loop (TrainLoop) – The training loop object.
  • train_op (tf.Operation) – The training operation.
  • inputs (list[tf.Tensor]) – The input placeholders. The number of tensors, and the order of tensors, should both match the arrays of each mini-batch data, provided by data_flow.
  • data_flow (DataFlow) – The training data flow. Each mini-batch must contain one array for each placeholder in inputs.
  • feed_dict – The feed dict for training. It will be merged with the arrays provided by data_flow in each step. (default None)
  • metrics (dict[str, tf.Tensor]) – Metrics to be computed along with train_op. The keys are the names of metrics.
_iter_steps()

Subclasses should override this to iterate through steps.

A common implementation of _iter_steps() might be:

def _iter_steps(self):
    return self.loop.iter_steps(training_data)
Yields:

int or (int, tuple[np.ndarray])

The step counter, or the step

counter as well as the step training data. Will be directly given to _fit_step() as the payload argument.

_run_step(session, payload)

Subclasses should override this to run a training step.

Parameters:
  • session – The TensorFlow session.
  • payload – The step payload generated by _iter_steps().
data_flow

Get the training data flow.

Returns:The training data flow.
Return type:DataFlow
feed_dict

Get the feed dict for training.

Returns:The feed dict for training.
Return type:dict[tf.Tensor, any]
inputs

Get the input placeholders.

Returns:The input placeholders.
Return type:list[tf.Tensor]
metrics

Get the metrics to be computed along with train_op.

train_op

Get the training operation.

class tfsnippet.trainer.Validator(loop, metrics, inputs, data_flow, feed_dict=None, time_metric_name='valid_time', batch_weight_func=<function auto_batch_weight>)

Bases: tfsnippet.trainer.evaluator.Evaluator

Class to compute validation loss and other metrics.

This class is a legacy class, which inherits Evaluator. Use Evaluator instead if you’re writing new code.