nvil_estimator¶
-
tfsnippet.
nvil_estimator
(values, latent_log_joint, baseline=None, center_by_moving_average=True, decay=0.8, axis=None, keepdims=False, batch_axis=None, name=None)¶ Derive the gradient estimator for \(\mathbb{E}_{q(\mathbf{z}|\mathbf{x})}\big[f(\mathbf{x},\mathbf{z})\big]\), by NVIL (Mnih and Gregor, 2014) algorithm.
\[\begin{split}\begin{aligned} \nabla \, \mathbb{E}_{q(\mathbf{z}|\mathbf{x})} \big[f(\mathbf{x},\mathbf{z})\big] &= \mathbb{E}_{q(\mathbf{z}|\mathbf{x})}\Big[ \nabla f(\mathbf{x},\mathbf{z}) + f(\mathbf{x},\mathbf{z})\,\nabla\log q(\mathbf{z}|\mathbf{x})\Big] \\ &= \mathbb{E}_{q(\mathbf{z}|\mathbf{x})}\Big[ \nabla f(\mathbf{x},\mathbf{z}) + \big(f(\mathbf{x},\mathbf{z}) - C_{\psi}(\mathbf{x})-c\big)\,\nabla\log q(\mathbf{z}|\mathbf{x})\Big] \end{aligned}\end{split}\]where \(C_{\psi}(\mathbf{x})\) is a learnable network with parameter \(\psi\), and c is a learnable constant. They would be learnt by minimizing \(\mathbb{E}_{ q(\mathbf{z}|\mathbf{x}) }\Big[\big(f(\mathbf{x},\mathbf{z}) - C_{\psi}(\mathbf{x})-c\big)^2 \Big]\).
Parameters: - values – Values of the target function given z and x, i.e., \(f(\mathbf{z},\mathbf{x})\).
- latent_log_joint – Values of \(\log q(\mathbf{z}|\mathbf{x})\).
- baseline – Values of the baseline function \(C_{\psi}(\mathbf{x})\) given input x. If this is not specified, then this method will degenerate to the REINFORCE algorithm, with only a moving average estimated constant baseline c.
- center_by_moving_average (bool) – Whether or not to use the moving average to maintain an estimation of c in above equations?
- decay – The decaying factor for moving average.
- axis – The sampling axes to be reduced in outputs. If not specified, no axis will be reduced.
- keepdims (bool) – When axis is specified, whether or not to keep
the reduced axes? (default
False
) - batch_axis – The batch axes to be reduced when computing expectation over x. If not specified, all axes will be treated as batch axes, except the sampling axes.
- name (str) – Default name of the name scope. If not specified, generate one according to the method name.
Returns: The (surrogate, baseline cost).
Return type: (tf.Tensor, tf.Tensor)