VariationalTrainingObjectives

class tfsnippet.VariationalTrainingObjectives(vi)

Bases: object

Factory for variational training objectives.

Methods Summary

iwae([name]) Get the SGVB training objective for importance weighted objective.
reinforce([variance_reduction, baseline, …]) Get the REINFORCE training objective.
rws_wake([name]) Get the wake-phase Reweighted Wake-Sleep (RWS) training objective.
sgvb([name]) Get the SGVB training objective.
vimco([name]) Get the VIMCO training objective.

Methods Documentation

iwae(name=None)

Get the SGVB training objective for importance weighted objective.

Parameters:name (str) – TensorFlow name scope of the graph nodes. (default “iwae”)
Returns:
The per-data SGVB training objective for importance
weighted objective.
Return type:tf.Tensor

See also

tfsnippet.variational.iwae_estimator()

reinforce(variance_reduction=True, baseline=None, decay=0.8, name=None)

Get the REINFORCE training objective.

Parameters:
  • variance_reduction (bool) – Whether to use variance reduction.
  • baseline (tf.Tensor) – A trainable estimation for the scale of the elbo value.
  • decay (float) – The moving average decay for variance normalization.
  • name (str) – TensorFlow name scope of the graph nodes. (default “reinforce”)
Returns:

The per-data REINFORCE training objective.

Return type:

tf.Tensor

See also

zhusuan.variational.EvidenceLowerBoundObjective.reinforce()

rws_wake(name=None)

Get the wake-phase Reweighted Wake-Sleep (RWS) training objective.

Parameters:name (str) – TensorFlow name scope of the graph nodes. (default “rws_wake”)
Returns:The per-data wake-phase RWS training objective.
Return type:tf.Tensor

See also

zhusuan.variational.InclusiveKLObjective.rws()

sgvb(name=None)

Get the SGVB training objective.

Parameters:name (str) – TensorFlow name scope of the graph nodes. (default “sgvb”)
Returns:
The per-data SGVB training objective.
It is the negative of ELBO, which should directly be minimized.
Return type:tf.Tensor

See also

tfsnippet.variational.sgvb_estimator()

vimco(name=None)

Get the VIMCO training objective.

Parameters:name (str) – TensorFlow name scope of the graph nodes. (default “vimco”)
Returns:The per-data VIMCO training objective.
Return type:tf.Tensor

See also

zhusuan.variational.ImportanceWeightedObjective.vimco()