tfsnippet.trainer¶
-
class
tfsnippet.trainer.
BaseTrainer
(loop)¶ Bases:
object
Base class for all trainers.
All the trainers provided in
tfsnippet.trainer
are not designed to take control of the training totally, which is often assumed in other libraries such as Keras. Instead, it just takes responsibility of assembling different steps of a training process together, and run the main training loop. So it is usually the caller’s responsibility to derive his training operation from a certain TensorFlow optimizer, and pass it to a proper trainer.See also
-
__init__
(loop)¶ Initialize the internal states of
BaseTrainer
.Parameters: loop (TrainLoop) – The training loop object.
-
_iter_steps
()¶ Subclasses should override this to iterate through steps.
A common implementation of
_iter_steps()
might be:def _iter_steps(self): return self.loop.iter_steps(training_data)
Yields: int or (int, tuple[np.ndarray]) –
- The step counter, or the step
counter as well as the step training data. Will be directly given to
_fit_step()
as the payload argument.
-
_run_step
(session, payload)¶ Subclasses should override this to run a training step.
Parameters: - session – The TensorFlow session.
- payload – The step payload generated by
_iter_steps()
.
-
anneal_after
(value, epochs=None, steps=None)¶ Add an annealing hook to run after every few epochs or steps.
Parameters: - value (AnnealingDynamicValue or () -> any) – An annealing dynamic
value (which has
.anneal()
), or any callable object. - epochs (None or int) – Run validation after every this few epochs.
- steps (None or int) – Run validation after every this few steps.
Raises: ValueError
– If both epochs and steps are specified, or neither is specified.- value (AnnealingDynamicValue or () -> any) – An annealing dynamic
value (which has
-
anneal_after_epochs
(value, freq)¶ Add an annealing hook to run after every few epochs.
Parameters: - value (AnnealingDynamicValue or () -> any) – An annealing dynamic
value (which has
.anneal()
), or any callable object. - freq (int) – The frequency for this annealing hook to run.
- value (AnnealingDynamicValue or () -> any) – An annealing dynamic
value (which has
-
anneal_after_steps
(value, freq)¶ Add an annealing hook to run after every few steps.
Parameters: - value (AnnealingDynamicValue or () -> any) – An annealing dynamic
value (which has
.anneal()
), or any callable object. - freq (int) – The frequency for this annealing hook to run.
- value (AnnealingDynamicValue or () -> any) – An annealing dynamic
value (which has
-
evaluate_after
(evaluator, epochs=None, steps=None)¶ Add an evaluation hook to run after every few epochs or steps.
Parameters: Raises: ValueError
– If both epochs and steps are specified, or neither is specified.
-
evaluate_after_epochs
(evaluator, freq)¶ Add an evaluation hook to run after every few epochs.
Parameters:
-
evaluate_after_steps
(evaluator, freq)¶ Add an evaluation hook to run after every few steps.
Parameters:
-
hook_lists
¶ Get all the hook lists.
Returns: - The tuple (self.before_epochs, self.before_steps,
- self.after_steps, self.after_epochs).
Return type: tuple[HookList]
-
log_after
(epochs=None, steps=None)¶ Add a logging hook to run after every few epochs or steps.
Parameters: Raises: ValueError
– If both epochs and steps are specified, or neither is specified.
-
log_after_epochs
(freq)¶ Add a logging hook to run after every few epochs.
Parameters: freq (int) – The frequency for this logging hook to run.
-
log_after_steps
(freq)¶ Add a logging hook to run after every few steps.
Parameters: freq (int) – The frequency for this logging hook to run.
-
remove_annealing_hooks
()¶ Remove annealing hooks from all lists.
Returns: The number of removed hooks. Return type: int
-
remove_by_priority
(priority)¶ Remove hooks having the specified priority from all lists.
Parameters: priority – The priority of the hooks to be removed. Returns: The number of removed hooks. Return type: int
-
remove_evaluation_hooks
()¶ Remove evaluation hooks from all lists.
Returns: The number of removed hooks. Return type: int
-
remove_log_hooks
()¶ Remove logging hooks from all lists.
Returns: The number of removed hooks. Return type: int
-
remove_validation_hooks
()¶ Remove evaluation hooks from all lists.
Returns: The number of removed hooks. Return type: int
-
run
()¶ Run training loop.
-
validate_after
(evaluator, epochs=None, steps=None)¶ Add an evaluation hook to run after every few epochs or steps.
Parameters: Raises: ValueError
– If both epochs and steps are specified, or neither is specified.
-
validate_after_epochs
(evaluator, freq)¶ Add an evaluation hook to run after every few epochs.
Parameters:
-
-
class
tfsnippet.trainer.
DynamicValue
¶ Bases:
object
Dynamic values fed into trainers.
It is sometimes necessary to feed a dynamic value into a trainer, e.g., an annealing learning rate. This class provides such a base class for all dynamic values.
-
get
()¶ Get the current value of this
DynamicValue
object.
-
-
class
tfsnippet.trainer.
SimpleDynamicValue
(value)¶ Bases:
tfsnippet.trainer.dynamic_values.DynamicValue
A simple
DynamicValue
, which stores the value in its internal attribute, and can be changed byset()
.-
__init__
(value)¶ Construct a new
SimpleDynamicValue
.Parameters: value – Any value to be set. It can even be another instance of DynamicValue
.
-
get
()¶ Get the current value of this
DynamicValue
object.
-
set
(value)¶ Set the value of this
SimpleDynamicValue
instance.Parameters: value – Any value to be set. It can even be another instance of DynamicValue
.
-
-
class
tfsnippet.trainer.
AnnealingDynamicValue
(initial_value, ratio, min_value=None)¶ Bases:
tfsnippet.trainer.dynamic_values.SimpleDynamicValue
A
DynamicValue
whose value is annealed (scaled) each timeanneal()
is called.-
__init__
(initial_value, ratio, min_value=None)¶ Construct a new
AnnealingDynamicValue
.Parameters: - initial_value – A number, the initial value.
- ratio – A number, the ratio of annealing at each time.
- min_value – Optional, a number, the minimum value.
-
anneal
()¶ Anneal the value.
-
-
tfsnippet.trainer.
auto_batch_weight
(*batch_arrays)¶ Automatically inspect the metric weight for an evaluation mini-batch.
Parameters: *batch_arrays – Mini-batch arrays. The .size
of the first array will be used as the weight.Returns: The inspected weight, or 1. if any error occurs during inspection.
-
class
tfsnippet.trainer.
Evaluator
(loop, metrics, inputs, data_flow, feed_dict=None, time_metric_name='eval_time', batch_weight_func=<function auto_batch_weight>)¶ Bases:
object
Class to compute evaluation metrics.
It is a common practice to compute one or more metrics for evaluation and validation during the training process. This class provides a convenient interface for computing metrics by mini-batches.
-
__init__
(loop, metrics, inputs, data_flow, feed_dict=None, time_metric_name='eval_time', batch_weight_func=<function auto_batch_weight>)¶ Construct a new
Evaluator
.Parameters: - loop (TrainLoop) – The training loop object.
- metrics (Tensor or dict[str, Tensor]) –
The validation loss metric, or a dict of metrics. All the metrics must be 0-d tensors.
If only a loss is specified, the default validation loss name
loop.valid_metric_name
will be used as its name. Otherwise if a dict is specified, the keys will be used as the names of each metric. - inputs (list[tf.Tensor]) – The input placeholders. The number of tensors, and the order of tensors, should both match the arrays of each mini-batch data, provided by data_flow.
- data_flow (DataFlow) – The validation data flow.
- feed_dict (dict[tf.Tensor, any]) – The fixed feed dict for
validation. It will be merged with inputs and the
argument of
run(feed_dict)
. (defaultNone
) - time_metric_name (None or str) – The metric name for collecting
evaluation time usage. Specify
None
to suppress the time usage metric. (default “eval_time”) - batch_weight_func ((*arrays) -> float or None) – Specify how
to compute the metric weight for each mini-batch. If
None
, will use 1. as the metric weight. (defaultauto_batch_weight()
)
-
batch_weight_func
¶ Get the function to compute the metric weight for each mini-batch.
-
last_metrics_dict
¶ Get the metric values from last evaluation.
Returns: The metric values dict. Return type: dict[str, any]
-
metrics
¶ Get the metrics to compute.
Returns: The metrics to compute. Return type: OrderedDict[str, tf.Tensor]
-
run
(feed_dict=None)¶ Run evaluation.
Parameters: feed_dict – The extra feed dict to be merged with the already configured dict. (default None
)
-
time_metric_name
¶ Get the metric name for collecting evaluation time usage.
-
-
tfsnippet.trainer.
resolve_feed_dict
(feed_dict, inplace=False)¶ Resolve all dynamic values in feed_dict into fixed values.
The supported dynamic value types and corresponding resolving method is listed as follows:
DynamicValue
:get()
will be called.- callable object: Will be called to get the value.
Parameters: Returns: The resolved feed dict.
-
tfsnippet.trainer.
merge_feed_dict
(*feed_dicts)¶ Merge all feed dicts into one.
Parameters: **feed_dicts – List of feed dicts. The later ones will override values specified in the previous ones. If a None
is specified, it will be simply ignored.Returns: The merged feed dict.
-
class
tfsnippet.trainer.
HookPriority
¶ Bases:
object
Pre-defined hook priorities for
BaseTrainer
andEvaluator
.Smaller values take higher priorities.
-
ANNEALING
= 1500¶
-
DEFAULT
= 1000¶
-
EVALUATION
= 500¶
-
LOGGING
= 10000¶
-
VALIDATION
= 500¶
-
-
class
tfsnippet.trainer.
HookEntry
(callback, freq, priority, birth)¶ Bases:
object
Configurations of a hook entry in
HookList
.-
__init__
(callback, freq, priority, birth)¶ Construct a new
HookEntry
.Parameters: - callback (() -> any) – The callable object, as the hook callback.
- freq (int) – The frequency for this callback to be called.
- priority (int) – The hook priority. Smaller number has higher priority when the hooks are called.
- birth (int) – The counter of birth, as an additional key for sorting the hook entries, such that old hooks will be placed in front of newly added hooks, if they have the same priority.
-
maybe_call
()¶ Decrease the counter, and call the callback if counter is less than 1. The counter will be reset to freq after then.
-
reset_counter
()¶ Reset the counter to freq, its initial value.
-
sort_key
()¶ Get the key for sorting this hook entry.
-
-
class
tfsnippet.trainer.
HookList
¶ Bases:
object
Class for managing hooks in
BaseTrainer
andEvaluator
.A hook is a registered callback that the trainers will call at certain time, during the training process. Apart from the callback method, each hook has a freq and a priority.
- The freq controls how often the particular hook should be called, e.g., every 2 epochs.
- The priority determines the priority (order) of calling the hooks. Smaller number corresponds to higher priority.
-
add_hook
(callback, freq=1, priority=1000)¶ Add a hook into the list.
Parameters:
-
call_hooks
()¶ Call all the registered hooks.
If any of the hook raises an error, it will stop the calling chain, and propagate the error to upper caller.
-
remove
(callback)¶ Remove all hooks having the specified callback.
Parameters: callback – The callback of the hooks to be removed. Returns: The number of removed hooks. Return type: int
-
remove_by_priority
(priority)¶ Remove all hooks having the specified priority.
Parameters: priority (int) – The priority of the hooks to be removed. Returns: The number of removed hooks. Return type: int
-
remove_if
(condition)¶ Remove all hooks matching the specified condition.
Parameters: condition ((callback, freq, priority) -> bool) – A callable object to tell whether or not a hook should be removed. Returns: The number of removed hooks. Return type: int
-
reset
()¶ Reset the frequency counter of all hooks.
-
class
tfsnippet.trainer.
LossTrainer
(loop, loss, train_op, inputs, data_flow, feed_dict=None, metric_name='loss')¶ Bases:
tfsnippet.trainer.trainer.Trainer
A subclass of
BaseTrainer
, which optimizes a single loss. This class is deprecated, useTrainer
instead.-
__init__
(loop, loss, train_op, inputs, data_flow, feed_dict=None, metric_name='loss')¶ Construct a new
LossTrainer
.Parameters: - loop (TrainLoop) – The training loop object.
- loss (tf.Tensor) – The training loss.
- train_op (tf.Operation) – The training operation.
- inputs (list[tf.Tensor]) – The input placeholders. The number of tensors, and the order of tensors, should both match the arrays of each mini-batch data, provided by data_flow.
- data_flow (DataFlow) – The training data flow. Each mini-batch must contain one array for each placeholder in inputs.
- feed_dict – The feed dict for training. It will be merged with
the arrays provided by data_flow in each step.
(default
None
) - metric_name (str) – The metric name for collecting training loss.
-
loss
¶ Get the training loss.
-
metric_name
¶ Get the metric name for collecting training loss.
-
-
class
tfsnippet.trainer.
Trainer
(loop, train_op, inputs, data_flow, feed_dict=None, metrics=None)¶ Bases:
tfsnippet.trainer.base_trainer.BaseTrainer
A subclass of
BaseTrainer
, executing a training operation per step. This might be the most commonly usedTrainer
. Code example:from tfsnippet.scaffold import TrainLoop from tfsnippet.trainer import (LossTrainer, Evaluator, AnnealingDynamicValue) # build the model input_x = tf.placeholder(...) input_y = tf.placeholder(...) learning_rate = tf.placeholder(...) # learning rate annealing # prepare for the data and train_data = DataFlow.arrays( [train_x, train_y], batch_size=128, shuffle=True, skip_incomplete=True ) valid_data = DataFlow.arrays( [valid_x, valid_y], batch_size=512) ... # derive the training operation optimizer = tf.train.AdamOptimizer(learning_rate) train_op = optimizer.minimize(loss) # run the trainer learning_rate_var = AnnealingDynamicValue(0.001, ratio=0.75) with TrainLoop(param_vars, max_epoch=10, early_stopping=True) as loop: trainer = Trainer( loop, train_op, [input_x, input_y], train_data, feed_dict={learning_rate: learning_rate_var}, metrics={'loss': loss'} ) evaluator = Evaluator( loop, {'loss': loss}, [input_x, input_y], valid_data) # validate after every epoch trainer.evaluate_after_epochs(evaluator, freq=1) # log after every epoch (and after validation, since # ``HookPriority.VALIDATION < HookPriority.LOGGING``) trainer.log_after_epochs(freq=1) # anneal the learning rate after every 10 epochs trainer.anneal_after_epochs(learning_rate_var, freq=10) # run the main training loop trainer.run()
-
__init__
(loop, train_op, inputs, data_flow, feed_dict=None, metrics=None)¶ Parameters: - loop (TrainLoop) – The training loop object.
- train_op (tf.Operation) – The training operation.
- inputs (list[tf.Tensor]) – The input placeholders. The number of tensors, and the order of tensors, should both match the arrays of each mini-batch data, provided by data_flow.
- data_flow (DataFlow) – The training data flow. Each mini-batch must contain one array for each placeholder in inputs.
- feed_dict – The feed dict for training. It will be merged with
the arrays provided by data_flow in each step.
(default
None
) - metrics (dict[str, tf.Tensor]) – Metrics to be computed along with train_op. The keys are the names of metrics.
-
_iter_steps
()¶ Subclasses should override this to iterate through steps.
A common implementation of
_iter_steps()
might be:def _iter_steps(self): return self.loop.iter_steps(training_data)
Yields: int or (int, tuple[np.ndarray]) –
- The step counter, or the step
counter as well as the step training data. Will be directly given to
_fit_step()
as the payload argument.
-
_run_step
(session, payload)¶ Subclasses should override this to run a training step.
Parameters: - session – The TensorFlow session.
- payload – The step payload generated by
_iter_steps()
.
-
feed_dict
¶ Get the feed dict for training.
Returns: The feed dict for training. Return type: dict[tf.Tensor, any]
-
metrics
¶ Get the metrics to be computed along with train_op.
-
train_op
¶ Get the training operation.
-
-
class
tfsnippet.trainer.
Validator
(loop, metrics, inputs, data_flow, feed_dict=None, time_metric_name='valid_time', batch_weight_func=<function auto_batch_weight>)¶ Bases:
tfsnippet.trainer.evaluator.Evaluator
Class to compute validation loss and other metrics.
This class is a legacy class, which inherits
Evaluator
. UseEvaluator
instead if you’re writing new code.