Trainer¶
-
class
tfsnippet.
Trainer
(loop, train_op, inputs, data_flow, feed_dict=None, metrics=None, summaries=None, ensure_variables_initialized=True)¶ Bases:
tfsnippet.trainer.base_trainer.BaseTrainer
A subclass of
BaseTrainer
, executing a training operation per step. This might be the most commonly usedTrainer
. Code example:import tfsnippet as spt # build the model input_x = tf.placeholder(...) input_y = tf.placeholder(...) learning_rate = tf.placeholder(...) # learning rate annealing # prepare for the data and train_data = spt.DataFlow.arrays( [train_x, train_y], batch_size=128, shuffle=True, skip_incomplete=True ) valid_data = spt.DataFlow.arrays( [valid_x, valid_y], batch_size=512) ... # derive the training operation optimizer = tf.train.AdamOptimizer(learning_rate) train_op = optimizer.minimize(loss) # run the trainer learning_rate = spt.AnnealingVariable('learning_rate', 0.001, 0.75) with spt.TrainLoop(param_vars, max_epoch=10, early_stopping=True) as loop: trainer = spt.Trainer( loop, train_op, [input_x, input_y], train_data, metrics={'loss': loss'} ) evaluator = spt.Evaluator( loop, {'loss': loss}, [input_x, input_y], valid_data) # validate after every epoch trainer.evaluate_after_epochs(evaluator, freq=1) # log after every epoch (and after validation, since # ``HookPriority.VALIDATION < HookPriority.LOGGING``) trainer.log_after_epochs(freq=1) # anneal the learning rate after every 10 epochs trainer.anneal_after_epochs(learning_rate, freq=10) # run the main training loop trainer.run()
See also
tfsnippet.trainer.BaseTrainer
Attributes Summary
data_flow
Get the training data flow. events
Get the event source object. feed_dict
Get the feed dict for training. inputs
Get the input placeholders. loop
Get the training loop object. metrics
Get the metrics to be computed along with train_op. summaries
Get the summaries to be computed along with train_op. train_op
Get the training operation. Methods Summary
anneal_after
(value[, epochs, steps])Add an annealing hook to run after every few epochs or steps. anneal_after_epochs
(value, freq)Add an annealing hook to run after every few epochs. anneal_after_steps
(value, freq)Add an annealing hook to run after every few steps. evaluate_after
(evaluator[, epochs, steps])Add an evaluation hook to run after every few epochs or steps. evaluate_after_epochs
(evaluator, freq)Add an evaluation hook to run after every few epochs. evaluate_after_steps
(evaluator, freq)Add an evaluation hook to run after every few steps. log_after
([epochs, steps])Add a logging hook to run after every few epochs or steps. log_after_epochs
(freq)Add a logging hook to run after every few epochs. log_after_steps
(freq)Add a logging hook to run after every few steps. remove_annealing_hooks
()Remove annealing hooks from all lists. remove_evaluation_hooks
()Remove evaluation hooks from all lists. remove_log_hooks
()Remove logging hooks from all lists. remove_validation_hooks
()Remove evaluation hooks from all lists. run
()Run training loop. validate_after
(evaluator[, epochs, steps])Add an evaluation hook to run after every few epochs or steps. validate_after_epochs
(evaluator, freq)Add an evaluation hook to run after every few epochs. validate_after_steps
(evaluator, freq)Add an evaluation hook to run after every few steps. Attributes Documentation
-
events
¶ Get the event source object.
Returns: The event source object. Return type: EventSource
-
feed_dict
¶ Get the feed dict for training.
Returns: The feed dict for training. Return type: dict[tf.Tensor, any]
-
metrics
¶ Get the metrics to be computed along with train_op.
-
summaries
¶ Get the summaries to be computed along with train_op.
-
train_op
¶ Get the training operation.
Methods Documentation
-
anneal_after
(value, epochs=None, steps=None)¶ Add an annealing hook to run after every few epochs or steps.
Parameters: - value (AnnealingVariable or () -> any) – An annealing variable
(which has
.anneal()
), or any callable object. - epochs (None or int) – Run validation after every this few epochs.
- steps (None or int) – Run validation after every this few steps.
Raises: ValueError
– If both epochs and steps are specified, or neither is specified.- value (AnnealingVariable or () -> any) – An annealing variable
(which has
-
anneal_after_epochs
(value, freq)¶ Add an annealing hook to run after every few epochs.
Parameters: - value (AnnealingVariable or () -> any) – An annealing variable
(which has
.anneal()
), or any callable object. - freq (int) – The frequency for this annealing hook to run.
- value (AnnealingVariable or () -> any) – An annealing variable
(which has
-
anneal_after_steps
(value, freq)¶ Add an annealing hook to run after every few steps.
Parameters: - value (AnnealingVariable or () -> any) – An annealing variable
(which has
.anneal()
), or any callable object. - freq (int) – The frequency for this annealing hook to run.
- value (AnnealingVariable or () -> any) – An annealing variable
(which has
-
evaluate_after
(evaluator, epochs=None, steps=None)¶ Add an evaluation hook to run after every few epochs or steps.
Parameters: Raises: ValueError
– If both epochs and steps are specified, or neither is specified.
-
evaluate_after_epochs
(evaluator, freq)¶ Add an evaluation hook to run after every few epochs.
Parameters:
-
evaluate_after_steps
(evaluator, freq)¶ Add an evaluation hook to run after every few steps.
Parameters:
-
log_after
(epochs=None, steps=None)¶ Add a logging hook to run after every few epochs or steps.
Parameters: Raises: ValueError
– If both epochs and steps are specified, or neither is specified.
-
log_after_epochs
(freq)¶ Add a logging hook to run after every few epochs.
Parameters: freq (int) – The frequency for this logging hook to run.
-
log_after_steps
(freq)¶ Add a logging hook to run after every few steps.
Parameters: freq (int) – The frequency for this logging hook to run.
-
remove_annealing_hooks
()¶ Remove annealing hooks from all lists.
Returns: The number of removed hooks. Return type: int
-
remove_evaluation_hooks
()¶ Remove evaluation hooks from all lists.
Returns: The number of removed hooks. Return type: int
-
remove_log_hooks
()¶ Remove logging hooks from all lists.
Returns: The number of removed hooks. Return type: int
-
remove_validation_hooks
()¶ Remove evaluation hooks from all lists.
Returns: The number of removed hooks. Return type: int
-
run
()¶ Run training loop.
-
validate_after
(evaluator, epochs=None, steps=None)¶ Add an evaluation hook to run after every few epochs or steps.
Parameters: Raises: ValueError
– If both epochs and steps are specified, or neither is specified.
-
validate_after_epochs
(evaluator, freq)¶ Add an evaluation hook to run after every few epochs.
Parameters:
-