tfsnippet

tfsnippet Package

Functions

as_distribution(distribution) Convert a supported type of distribution into Distribution type.
reduce_group_ndims(operation, tensor, …[, …]) Reduce the last group_ndims dimensions in tensor, using operation.
summarize_variables(variables[, title, …]) Get a formatted summary about the variables.
auto_batch_weight(*batch_arrays) Automatically inspect the metric weight for an evaluation mini-batch.
merge_feed_dict(*feed_dicts) Merge all feed dicts into one.
resolve_feed_dict(feed_dict[, inplace]) Resolve all dynamic values in feed_dict into fixed values.
elbo_objective(log_joint, latent_log_prob[, …]) Derive the ELBO objective.
importance_sampling_log_likelihood(…[, …]) Compute \(\log p(\mathbf{x})\) by importance sampling.
iwae_estimator(log_values, axis[, keepdims, …]) Derive the gradient estimator for \(\mathbb{E}_{q(\mathbf{z}^{(1:K)}|\mathbf{x})}\Big[\log \frac{1}{K} \sum_{k=1}^K f\big(\mathbf{x},\mathbf{z}^{(k)}\big)\Big]\), by IWAE (Burda, Y., Grosse, R.
monte_carlo_objective(log_joint, latent_log_prob) Derive the Monte-Carlo objective.
nvil_estimator(values, latent_log_joint[, …]) Derive the gradient estimator for \(\mathbb{E}_{q(\mathbf{z}|\mathbf{x})}\big[f(\mathbf{x},\mathbf{z})\big]\), by NVIL (Mnih and Gregor, 2014) algorithm.
sgvb_estimator(values[, axis, keepdims, name]) Derive the gradient estimator for \(\mathbb{E}_{q(\mathbf{z}|\mathbf{x})}\big[f(\mathbf{x},\mathbf{z})\big]\), by SGVB (Kingma, D.P.
vimco_estimator(log_values, latent_log_joint) Derive the gradient estimator for
get_config_defaults(config) Get the default config values of config.
register_config_arguments(config, parser[, …]) Register config to the specified argument parser.
model_variable(name[, shape, dtype, …]) Get or create a model variable.
get_model_variables([scope]) Get all model variables (i.e., variables in MODEL_VARIABLES collection).
instance_reuse([method_or_scope, _sentinel, …]) Decorate an instance method to reuse a variable scope automatically.
global_reuse([method_or_scope, _sentinel, scope]) Decorate a function to reuse a variable scope automatically.
add_histogram(tensor[, summary_name, …]) Add the histogram of tensor to the default summary collector, and to collections.
add_summary(summary[, collections]) Add the summary to the default summary collector, and to collections.
default_summary_collector() Get the SummaryCollector object at the top of context stack.

Classes

BatchToValueDistribution(distribution, ndims) Distribution that converts the last few batch_ndims into values_ndims.
Bernoulli(logits[, dtype]) Univariate Bernoulli distribution.
Categorical(logits[, dtype]) Univariate Categorical distribution.
Concrete(temperature, logits[, …]) The class of Concrete (or Gumbel-Softmax) distribution from (Maddison, 2016; Jang, 2016), served as the continuous relaxation of the OnehotCategorical.
Discrete alias of tfsnippet.distributions.univariate.Categorical
DiscretizedLogistic(mean, log_scale, bin_size) Discretized logistic distribution (Kingma et.
Distribution(dtype, is_continuous, …) Base class for probability distributions.
ExpConcrete(temperature, logits[, …]) The class of ExpConcrete distribution from (Maddison, 2016), transformed from Concrete by taking logarithm.
FlowDistribution(distribution, flow) Transform a Distribution by a BaseFlow, as a new distribution.
FlowDistributionDerivedTensor(tensor, …) A combination of a FlowDistribution derived tensor, and its original stochastic tensor from the base distribution.
Mixture(categorical, components[, …]) Mixture distribution.
Normal(mean[, std, logstd, …]) Univariate Normal distribution.
OnehotCategorical(logits[, dtype]) One-hot multivariate Categorical distribution.
Uniform([minval, maxval, …]) Univariate Uniform distribution.
AnnealingVariable(name, initial_value, ratio) A non-trainable tf.Variable, whose value will be annealed as training goes by.
CheckpointSavableObject Base class for all objects that can be saved via CheckpointSaver.
CheckpointSaver(variables, save_dir[, …]) Save and restore tf.Variable, ScheduledVariable and CheckpointSavableObject with tf.train.Saver.
DefaultMetricFormatter Default training metric formatter.
EventKeys Defines event keys for TFSnippet.
MetricFormatter Base class for a training metrics formatter.
MetricLogger([summary_writer, …]) Logger for the training metrics.
ScheduledVariable(name, initial_value[, …]) A non-trainable tf.Variable, whose value might need to be changed as training goes by.
TrainLoop(param_vars[, var_groups, …]) Training loop object.
AnnealingScalar(loop, initial_value, ratio) A DynamicValue scalar, which anneals every few epochs or steps.
BaseTrainer(loop[, ensure_variables_initialized]) Base class for all trainers.
DynamicValue Dynamic values to be fed into trainers and evaluators.
Evaluator(loop, metrics, inputs, data_flow) Class to compute evaluation metrics.
LossTrainer(**kwargs) A subclass of BaseTrainer, which optimizes a single loss.
Trainer(loop, train_op, inputs, data_flow[, …]) A subclass of BaseTrainer, executing a training operation per step.
Validator(**kwargs) Class to compute validation loss and other metrics.
VariationalChain(variational, model[, …]) Chain of the variational and model nets for variational inference.
VariationalEvaluation(vi) Factory for variational evaluation outputs.
VariationalInference(log_joint, latent_log_probs) Class for variational inference.
VariationalLowerBounds(vi) Factory for variational lower-bounds.
VariationalTrainingObjectives(vi) Factory for variational training objectives.
BayesianNet([observed]) Bayesian networks.
DataFlow Data flows are objects for constructing mini-batch iterators.
DataMapper Base class for all data mappers.
SlidingWindow(data_array, window_size) DataMapper for producing sliding windows according to indices.
Config Base class for defining config values.
ConfigField(type[, default, description, …]) A config field.
GraphKeys Defines TensorFlow graph collection keys for TFSnippet.
InvertibleMatrix(size[, strict, dtype, …]) A matrix initialized to be an invertible, orthogonal matrix.
VarScopeObject([name, scope]) Base class for objects that own a variable scope.
SummaryCollector([collections, …]) Collecting summaries and histograms added by tfsnippet.add_summary() and tfsnippet.add_histogram().
StochasticTensor(distribution, tensor[, …]) Samples or observations of a stochastic variable.

Class Inheritance Diagram

Inheritance diagram of tfsnippet.distributions.batch_to_value.BatchToValueDistribution, tfsnippet.distributions.univariate.Bernoulli, tfsnippet.distributions.univariate.Categorical, tfsnippet.distributions.multivariate.Concrete, tfsnippet.distributions.univariate.Categorical, tfsnippet.distributions.discretized.DiscretizedLogistic, tfsnippet.distributions.base.Distribution, tfsnippet.distributions.multivariate.ExpConcrete, tfsnippet.distributions.flow.FlowDistribution, tfsnippet.distributions.flow.FlowDistributionDerivedTensor, tfsnippet.distributions.mixture.Mixture, tfsnippet.distributions.univariate.Normal, tfsnippet.distributions.multivariate.OnehotCategorical, tfsnippet.distributions.univariate.Uniform, tfsnippet.scaffold.scheduled_var.AnnealingVariable, tfsnippet.scaffold.checkpoint.CheckpointSavableObject, tfsnippet.scaffold.checkpoint.CheckpointSaver, tfsnippet.scaffold.logging_.DefaultMetricFormatter, tfsnippet.scaffold.event_keys.EventKeys, tfsnippet.scaffold.logging_.MetricFormatter, tfsnippet.scaffold.logging_.MetricLogger, tfsnippet.scaffold.scheduled_var.ScheduledVariable, tfsnippet.scaffold.train_loop_.TrainLoop, tfsnippet.trainer.dynamic_values.AnnealingScalar, tfsnippet.trainer.base_trainer.BaseTrainer, tfsnippet.trainer.dynamic_values.DynamicValue, tfsnippet.trainer.evaluator.Evaluator, tfsnippet.trainer.loss_trainer.LossTrainer, tfsnippet.trainer.trainer.Trainer, tfsnippet.trainer.validator.Validator, tfsnippet.variational.chain.VariationalChain, tfsnippet.variational.inference.VariationalEvaluation, tfsnippet.variational.inference.VariationalInference, tfsnippet.variational.inference.VariationalLowerBounds, tfsnippet.variational.inference.VariationalTrainingObjectives, tfsnippet.bayes.BayesianNet, tfsnippet.dataflows.base.DataFlow, tfsnippet.dataflows.data_mappers.DataMapper, tfsnippet.dataflows.data_mappers.SlidingWindow, tfsnippet.utils.config_utils.Config, tfsnippet.utils.config_utils.ConfigField, tfsnippet.utils.graph_keys.GraphKeys, tfsnippet.utils.invertible_matrix.InvertibleMatrix, tfsnippet.utils.reuse.VarScopeObject, tfsnippet.utils.summary_collector.SummaryCollector, tfsnippet.stochastic.StochasticTensor