BaseTrainer

class tfsnippet.BaseTrainer(loop, ensure_variables_initialized=True)

Bases: object

Base class for all trainers.

All the trainers provided in tfsnippet.trainer are not designed to take control of the training totally, which is often assumed in other libraries such as Keras. Instead, it just takes responsibility of assembling different steps of a training process together, and run the main training loop. So it is usually the caller’s responsibility to derive his training operation from a certain TensorFlow optimizer, and pass it to a proper trainer.

The event schedule of a BaseTrainer can be briefly described as:

events.fire(EventKeys.BEFORE_EXECUTION, self)

for epoch in epochs:
    events.fire(EventKeys.BEFORE_EPOCH, self)

    for step in steps:
        events.fire(EventKeys.BEFORE_STEP, self)

        ...  # actually train for a step

        events.fire(EventKeys.STEP_EVALUATION, self)
        events.fire(EventKeys.STEP_ANNEALING, self)
        events.fire(EventKeys.STEP_LOGGING, self)
        events.reverse_fire(EventKeys.AFTER_STEP, self)

    events.fire(EventKeys.EPOCH_EVALUATION, self)
    events.fire(EventKeys.EPOCH_ANNEALING, self)
    events.fire(EventKeys.EPOCH_LOGGING, self)
    events.reverse_fire(EventKeys.AFTER_EPOCH, self)

events.reverse_fire(EventKeys.AFTER_EXECUTION, self)

Using trainer.events.on(EventKeys.AFTER_EPOCH, lambda trainer: …) can register an after-epoch event handler. Handlers for other events can be registered in a similar way.

To make things even simpler, we provide several methods to register callbacks that will run every few epochs/steps, e.g.:

trainer.evaluate_after_epochs(
    lambda: print('after epoch callback'), 10)  # run every 10 epochs
trainer.log_after_steps(1000)  # call `loop.print_logs` every 1000 steps

Attributes Summary

events Get the event source object.
loop Get the training loop object.

Methods Summary

anneal_after(value[, epochs, steps]) Add an annealing hook to run after every few epochs or steps.
anneal_after_epochs(value, freq) Add an annealing hook to run after every few epochs.
anneal_after_steps(value, freq) Add an annealing hook to run after every few steps.
evaluate_after(evaluator[, epochs, steps]) Add an evaluation hook to run after every few epochs or steps.
evaluate_after_epochs(evaluator, freq) Add an evaluation hook to run after every few epochs.
evaluate_after_steps(evaluator, freq) Add an evaluation hook to run after every few steps.
log_after([epochs, steps]) Add a logging hook to run after every few epochs or steps.
log_after_epochs(freq) Add a logging hook to run after every few epochs.
log_after_steps(freq) Add a logging hook to run after every few steps.
remove_annealing_hooks() Remove annealing hooks from all lists.
remove_evaluation_hooks() Remove evaluation hooks from all lists.
remove_log_hooks() Remove logging hooks from all lists.
remove_validation_hooks() Remove evaluation hooks from all lists.
run() Run training loop.
validate_after(evaluator[, epochs, steps]) Add an evaluation hook to run after every few epochs or steps.
validate_after_epochs(evaluator, freq) Add an evaluation hook to run after every few epochs.
validate_after_steps(evaluator, freq) Add an evaluation hook to run after every few steps.

Attributes Documentation

events

Get the event source object.

Returns:The event source object.
Return type:EventSource
loop

Get the training loop object.

Returns:The training loop object.
Return type:TrainLoop

Methods Documentation

anneal_after(value, epochs=None, steps=None)

Add an annealing hook to run after every few epochs or steps.

Parameters:
  • value (AnnealingVariable or () -> any) – An annealing variable (which has .anneal()), or any callable object.
  • epochs (None or int) – Run validation after every this few epochs.
  • steps (None or int) – Run validation after every this few steps.
Raises:

ValueError – If both epochs and steps are specified, or neither is specified.

anneal_after_epochs(value, freq)

Add an annealing hook to run after every few epochs.

Parameters:
  • value (AnnealingVariable or () -> any) – An annealing variable (which has .anneal()), or any callable object.
  • freq (int) – The frequency for this annealing hook to run.
anneal_after_steps(value, freq)

Add an annealing hook to run after every few steps.

Parameters:
  • value (AnnealingVariable or () -> any) – An annealing variable (which has .anneal()), or any callable object.
  • freq (int) – The frequency for this annealing hook to run.
evaluate_after(evaluator, epochs=None, steps=None)

Add an evaluation hook to run after every few epochs or steps.

Parameters:
  • evaluator (Evaluator or () -> any) – A evaluator object (which has .run()), or any callable object.
  • epochs (None or int) – Run validation after every this few epochs.
  • steps (None or int) – Run validation after every this few steps.
Raises:

ValueError – If both epochs and steps are specified, or neither is specified.

evaluate_after_epochs(evaluator, freq)

Add an evaluation hook to run after every few epochs.

Parameters:
  • evaluator (Evaluator or () -> any) – A evaluator object (which has .run()), or any callable object.
  • freq (int) – The frequency for this evaluation hook to run.
evaluate_after_steps(evaluator, freq)

Add an evaluation hook to run after every few steps.

Parameters:
  • evaluator (Evaluator or () -> any) – A evaluator object (which has .run()), or any callable object.
  • freq (int) – The frequency for this evaluation hook to run.
log_after(epochs=None, steps=None)

Add a logging hook to run after every few epochs or steps.

Parameters:
  • epochs (None or int) – Run validation after every this few epochs.
  • steps (None or int) – Run validation after every this few steps.
Raises:

ValueError – If both epochs and steps are specified, or neither is specified.

log_after_epochs(freq)

Add a logging hook to run after every few epochs.

Parameters:freq (int) – The frequency for this logging hook to run.
log_after_steps(freq)

Add a logging hook to run after every few steps.

Parameters:freq (int) – The frequency for this logging hook to run.
remove_annealing_hooks()

Remove annealing hooks from all lists.

Returns:The number of removed hooks.
Return type:int
remove_evaluation_hooks()

Remove evaluation hooks from all lists.

Returns:The number of removed hooks.
Return type:int
remove_log_hooks()

Remove logging hooks from all lists.

Returns:The number of removed hooks.
Return type:int
remove_validation_hooks()

Remove evaluation hooks from all lists.

Returns:The number of removed hooks.
Return type:int
run()

Run training loop.

validate_after(evaluator, epochs=None, steps=None)

Add an evaluation hook to run after every few epochs or steps.

Parameters:
  • evaluator (Evaluator or () -> any) – A evaluator object (which has .run()), or any callable object.
  • epochs (None or int) – Run validation after every this few epochs.
  • steps (None or int) – Run validation after every this few steps.
Raises:

ValueError – If both epochs and steps are specified, or neither is specified.

validate_after_epochs(evaluator, freq)

Add an evaluation hook to run after every few epochs.

Parameters:
  • evaluator (Evaluator or () -> any) – A evaluator object (which has .run()), or any callable object.
  • freq (int) – The frequency for this evaluation hook to run.
validate_after_steps(evaluator, freq)

Add an evaluation hook to run after every few steps.

Parameters:
  • evaluator (Evaluator or () -> any) – A evaluator object (which has .run()), or any callable object.
  • freq (int) – The frequency for this evaluation hook to run.