Trainer

class tfsnippet.Trainer(loop, train_op, inputs, data_flow, feed_dict=None, metrics=None, summaries=None, ensure_variables_initialized=True)

Bases: tfsnippet.trainer.base_trainer.BaseTrainer

A subclass of BaseTrainer, executing a training operation per step. This might be the most commonly used Trainer. Code example:

import tfsnippet as spt

# build the model
input_x = tf.placeholder(...)
input_y = tf.placeholder(...)
learning_rate = tf.placeholder(...)  # learning rate annealing

# prepare for the data and
train_data = spt.DataFlow.arrays(
    [train_x, train_y], batch_size=128, shuffle=True,
    skip_incomplete=True
)
valid_data = spt.DataFlow.arrays(
    [valid_x, valid_y], batch_size=512)
...

# derive the training operation
optimizer = tf.train.AdamOptimizer(learning_rate)
train_op = optimizer.minimize(loss)

# run the trainer
learning_rate = spt.AnnealingVariable('learning_rate', 0.001, 0.75)

with spt.TrainLoop(param_vars,
                   max_epoch=10,
                   early_stopping=True) as loop:
    trainer = spt.Trainer(
        loop, train_op, [input_x, input_y], train_data,
        metrics={'loss': loss'}
    )
    evaluator = spt.Evaluator(
        loop, {'loss': loss}, [input_x, input_y], valid_data)

    # validate after every epoch
    trainer.evaluate_after_epochs(evaluator, freq=1)

    # log after every epoch (and after validation, since
    # ``HookPriority.VALIDATION < HookPriority.LOGGING``)
    trainer.log_after_epochs(freq=1)

    # anneal the learning rate after every 10 epochs
    trainer.anneal_after_epochs(learning_rate, freq=10)

    # run the main training loop
    trainer.run()

See also

tfsnippet.trainer.BaseTrainer

Attributes Summary

data_flow Get the training data flow.
events Get the event source object.
feed_dict Get the feed dict for training.
inputs Get the input placeholders.
loop Get the training loop object.
metrics Get the metrics to be computed along with train_op.
summaries Get the summaries to be computed along with train_op.
train_op Get the training operation.

Methods Summary

anneal_after(value[, epochs, steps]) Add an annealing hook to run after every few epochs or steps.
anneal_after_epochs(value, freq) Add an annealing hook to run after every few epochs.
anneal_after_steps(value, freq) Add an annealing hook to run after every few steps.
evaluate_after(evaluator[, epochs, steps]) Add an evaluation hook to run after every few epochs or steps.
evaluate_after_epochs(evaluator, freq) Add an evaluation hook to run after every few epochs.
evaluate_after_steps(evaluator, freq) Add an evaluation hook to run after every few steps.
log_after([epochs, steps]) Add a logging hook to run after every few epochs or steps.
log_after_epochs(freq) Add a logging hook to run after every few epochs.
log_after_steps(freq) Add a logging hook to run after every few steps.
remove_annealing_hooks() Remove annealing hooks from all lists.
remove_evaluation_hooks() Remove evaluation hooks from all lists.
remove_log_hooks() Remove logging hooks from all lists.
remove_validation_hooks() Remove evaluation hooks from all lists.
run() Run training loop.
validate_after(evaluator[, epochs, steps]) Add an evaluation hook to run after every few epochs or steps.
validate_after_epochs(evaluator, freq) Add an evaluation hook to run after every few epochs.
validate_after_steps(evaluator, freq) Add an evaluation hook to run after every few steps.

Attributes Documentation

data_flow

Get the training data flow.

Returns:The training data flow.
Return type:DataFlow
events

Get the event source object.

Returns:The event source object.
Return type:EventSource
feed_dict

Get the feed dict for training.

Returns:The feed dict for training.
Return type:dict[tf.Tensor, any]
inputs

Get the input placeholders.

Returns:The input placeholders.
Return type:list[tf.Tensor]
loop

Get the training loop object.

Returns:The training loop object.
Return type:TrainLoop
metrics

Get the metrics to be computed along with train_op.

summaries

Get the summaries to be computed along with train_op.

train_op

Get the training operation.

Methods Documentation

anneal_after(value, epochs=None, steps=None)

Add an annealing hook to run after every few epochs or steps.

Parameters:
  • value (AnnealingVariable or () -> any) – An annealing variable (which has .anneal()), or any callable object.
  • epochs (None or int) – Run validation after every this few epochs.
  • steps (None or int) – Run validation after every this few steps.
Raises:

ValueError – If both epochs and steps are specified, or neither is specified.

anneal_after_epochs(value, freq)

Add an annealing hook to run after every few epochs.

Parameters:
  • value (AnnealingVariable or () -> any) – An annealing variable (which has .anneal()), or any callable object.
  • freq (int) – The frequency for this annealing hook to run.
anneal_after_steps(value, freq)

Add an annealing hook to run after every few steps.

Parameters:
  • value (AnnealingVariable or () -> any) – An annealing variable (which has .anneal()), or any callable object.
  • freq (int) – The frequency for this annealing hook to run.
evaluate_after(evaluator, epochs=None, steps=None)

Add an evaluation hook to run after every few epochs or steps.

Parameters:
  • evaluator (Evaluator or () -> any) – A evaluator object (which has .run()), or any callable object.
  • epochs (None or int) – Run validation after every this few epochs.
  • steps (None or int) – Run validation after every this few steps.
Raises:

ValueError – If both epochs and steps are specified, or neither is specified.

evaluate_after_epochs(evaluator, freq)

Add an evaluation hook to run after every few epochs.

Parameters:
  • evaluator (Evaluator or () -> any) – A evaluator object (which has .run()), or any callable object.
  • freq (int) – The frequency for this evaluation hook to run.
evaluate_after_steps(evaluator, freq)

Add an evaluation hook to run after every few steps.

Parameters:
  • evaluator (Evaluator or () -> any) – A evaluator object (which has .run()), or any callable object.
  • freq (int) – The frequency for this evaluation hook to run.
log_after(epochs=None, steps=None)

Add a logging hook to run after every few epochs or steps.

Parameters:
  • epochs (None or int) – Run validation after every this few epochs.
  • steps (None or int) – Run validation after every this few steps.
Raises:

ValueError – If both epochs and steps are specified, or neither is specified.

log_after_epochs(freq)

Add a logging hook to run after every few epochs.

Parameters:freq (int) – The frequency for this logging hook to run.
log_after_steps(freq)

Add a logging hook to run after every few steps.

Parameters:freq (int) – The frequency for this logging hook to run.
remove_annealing_hooks()

Remove annealing hooks from all lists.

Returns:The number of removed hooks.
Return type:int
remove_evaluation_hooks()

Remove evaluation hooks from all lists.

Returns:The number of removed hooks.
Return type:int
remove_log_hooks()

Remove logging hooks from all lists.

Returns:The number of removed hooks.
Return type:int
remove_validation_hooks()

Remove evaluation hooks from all lists.

Returns:The number of removed hooks.
Return type:int
run()

Run training loop.

validate_after(evaluator, epochs=None, steps=None)

Add an evaluation hook to run after every few epochs or steps.

Parameters:
  • evaluator (Evaluator or () -> any) – A evaluator object (which has .run()), or any callable object.
  • epochs (None or int) – Run validation after every this few epochs.
  • steps (None or int) – Run validation after every this few steps.
Raises:

ValueError – If both epochs and steps are specified, or neither is specified.

validate_after_epochs(evaluator, freq)

Add an evaluation hook to run after every few epochs.

Parameters:
  • evaluator (Evaluator or () -> any) – A evaluator object (which has .run()), or any callable object.
  • freq (int) – The frequency for this evaluation hook to run.
validate_after_steps(evaluator, freq)

Add an evaluation hook to run after every few steps.

Parameters:
  • evaluator (Evaluator or () -> any) – A evaluator object (which has .run()), or any callable object.
  • freq (int) – The frequency for this evaluation hook to run.