Trainer¶
-
class
tfsnippet.Trainer(loop, train_op, inputs, data_flow, feed_dict=None, metrics=None, summaries=None, ensure_variables_initialized=True)¶ Bases:
tfsnippet.trainer.base_trainer.BaseTrainerA subclass of
BaseTrainer, executing a training operation per step. This might be the most commonly usedTrainer. Code example:import tfsnippet as spt # build the model input_x = tf.placeholder(...) input_y = tf.placeholder(...) learning_rate = tf.placeholder(...) # learning rate annealing # prepare for the data and train_data = spt.DataFlow.arrays( [train_x, train_y], batch_size=128, shuffle=True, skip_incomplete=True ) valid_data = spt.DataFlow.arrays( [valid_x, valid_y], batch_size=512) ... # derive the training operation optimizer = tf.train.AdamOptimizer(learning_rate) train_op = optimizer.minimize(loss) # run the trainer learning_rate = spt.AnnealingVariable('learning_rate', 0.001, 0.75) with spt.TrainLoop(param_vars, max_epoch=10, early_stopping=True) as loop: trainer = spt.Trainer( loop, train_op, [input_x, input_y], train_data, metrics={'loss': loss'} ) evaluator = spt.Evaluator( loop, {'loss': loss}, [input_x, input_y], valid_data) # validate after every epoch trainer.evaluate_after_epochs(evaluator, freq=1) # log after every epoch (and after validation, since # ``HookPriority.VALIDATION < HookPriority.LOGGING``) trainer.log_after_epochs(freq=1) # anneal the learning rate after every 10 epochs trainer.anneal_after_epochs(learning_rate, freq=10) # run the main training loop trainer.run()
See also
tfsnippet.trainer.BaseTrainerAttributes Summary
data_flowGet the training data flow. eventsGet the event source object. feed_dictGet the feed dict for training. inputsGet the input placeholders. loopGet the training loop object. metricsGet the metrics to be computed along with train_op. summariesGet the summaries to be computed along with train_op. train_opGet the training operation. Methods Summary
anneal_after(value[, epochs, steps])Add an annealing hook to run after every few epochs or steps. anneal_after_epochs(value, freq)Add an annealing hook to run after every few epochs. anneal_after_steps(value, freq)Add an annealing hook to run after every few steps. evaluate_after(evaluator[, epochs, steps])Add an evaluation hook to run after every few epochs or steps. evaluate_after_epochs(evaluator, freq)Add an evaluation hook to run after every few epochs. evaluate_after_steps(evaluator, freq)Add an evaluation hook to run after every few steps. log_after([epochs, steps])Add a logging hook to run after every few epochs or steps. log_after_epochs(freq)Add a logging hook to run after every few epochs. log_after_steps(freq)Add a logging hook to run after every few steps. remove_annealing_hooks()Remove annealing hooks from all lists. remove_evaluation_hooks()Remove evaluation hooks from all lists. remove_log_hooks()Remove logging hooks from all lists. remove_validation_hooks()Remove evaluation hooks from all lists. run()Run training loop. validate_after(evaluator[, epochs, steps])Add an evaluation hook to run after every few epochs or steps. validate_after_epochs(evaluator, freq)Add an evaluation hook to run after every few epochs. validate_after_steps(evaluator, freq)Add an evaluation hook to run after every few steps. Attributes Documentation
-
events¶ Get the event source object.
Returns: The event source object. Return type: EventSource
-
feed_dict¶ Get the feed dict for training.
Returns: The feed dict for training. Return type: dict[tf.Tensor, any]
-
metrics¶ Get the metrics to be computed along with train_op.
-
summaries¶ Get the summaries to be computed along with train_op.
-
train_op¶ Get the training operation.
Methods Documentation
-
anneal_after(value, epochs=None, steps=None)¶ Add an annealing hook to run after every few epochs or steps.
Parameters: - value (AnnealingVariable or () -> any) – An annealing variable
(which has
.anneal()), or any callable object. - epochs (None or int) – Run validation after every this few epochs.
- steps (None or int) – Run validation after every this few steps.
Raises: ValueError– If both epochs and steps are specified, or neither is specified.- value (AnnealingVariable or () -> any) – An annealing variable
(which has
-
anneal_after_epochs(value, freq)¶ Add an annealing hook to run after every few epochs.
Parameters: - value (AnnealingVariable or () -> any) – An annealing variable
(which has
.anneal()), or any callable object. - freq (int) – The frequency for this annealing hook to run.
- value (AnnealingVariable or () -> any) – An annealing variable
(which has
-
anneal_after_steps(value, freq)¶ Add an annealing hook to run after every few steps.
Parameters: - value (AnnealingVariable or () -> any) – An annealing variable
(which has
.anneal()), or any callable object. - freq (int) – The frequency for this annealing hook to run.
- value (AnnealingVariable or () -> any) – An annealing variable
(which has
-
evaluate_after(evaluator, epochs=None, steps=None)¶ Add an evaluation hook to run after every few epochs or steps.
Parameters: Raises: ValueError– If both epochs and steps are specified, or neither is specified.
-
evaluate_after_epochs(evaluator, freq)¶ Add an evaluation hook to run after every few epochs.
Parameters:
-
evaluate_after_steps(evaluator, freq)¶ Add an evaluation hook to run after every few steps.
Parameters:
-
log_after(epochs=None, steps=None)¶ Add a logging hook to run after every few epochs or steps.
Parameters: Raises: ValueError– If both epochs and steps are specified, or neither is specified.
-
log_after_epochs(freq)¶ Add a logging hook to run after every few epochs.
Parameters: freq (int) – The frequency for this logging hook to run.
-
log_after_steps(freq)¶ Add a logging hook to run after every few steps.
Parameters: freq (int) – The frequency for this logging hook to run.
-
remove_annealing_hooks()¶ Remove annealing hooks from all lists.
Returns: The number of removed hooks. Return type: int
-
remove_evaluation_hooks()¶ Remove evaluation hooks from all lists.
Returns: The number of removed hooks. Return type: int
-
remove_log_hooks()¶ Remove logging hooks from all lists.
Returns: The number of removed hooks. Return type: int
-
remove_validation_hooks()¶ Remove evaluation hooks from all lists.
Returns: The number of removed hooks. Return type: int
-
run()¶ Run training loop.
-
validate_after(evaluator, epochs=None, steps=None)¶ Add an evaluation hook to run after every few epochs or steps.
Parameters: Raises: ValueError– If both epochs and steps are specified, or neither is specified.
-
validate_after_epochs(evaluator, freq)¶ Add an evaluation hook to run after every few epochs.
Parameters:
-