tfsnippet

tfsnippet Package

Functions

as_distribution(distribution) Convert a supported type of distribution into Distribution type.
reduce_group_ndims(operation, tensor, …[, …]) Reduce the last group_ndims dimensions in tensor, using operation.
early_stopping(*args, **kwargs)

Deprecated since version 0.1.

summarize_variables(variables[, title, …]) Get a formatted summary about the variables.
train_loop(*args, **kwargs)

Deprecated since version 0.1.

auto_batch_weight(*batch_arrays) Automatically inspect the metric weight for an evaluation mini-batch.
merge_feed_dict(*feed_dicts) Merge all feed dicts into one.
resolve_feed_dict(feed_dict[, inplace]) Resolve all dynamic values in feed_dict into fixed values.
elbo_objective(log_joint, latent_log_prob[, …]) Derive the ELBO objective.
importance_sampling_log_likelihood(…[, …]) Compute \(\log p(\mathbf{x})\) by importance sampling.
iwae_estimator(log_values, axis[, keepdims, …]) Derive the gradient estimator for \(\mathbb{E}_{q(\mathbf{z}^{(1:K)}|\mathbf{x})}\Big[\log \frac{1}{K} \sum_{k=1}^K f\big(\mathbf{x},\mathbf{z}^{(k)}\big)\Big]\), by IWAE (Burda, Y., Grosse, R.
monte_carlo_objective(log_joint, latent_log_prob) Derive the Monte-Carlo objective.
sgvb_estimator(values[, axis, keepdims, name]) Derive the gradient estimator for \(\mathbb{E}_{q(\mathbf{z}|\mathbf{x})}\big[f(\mathbf{x},\mathbf{z})\big]\), by SGVB (Kingma, D.P.
model_variable(name[, shape, dtype, …]) Get or create a model variable.
get_model_variables([scope]) Get all model variables (i.e., variables in MODEL_VARIABLES collection).
get_config_defaults(config) Get the default config values of config.
register_config_arguments(config, parser[, …]) Register config to the specified argument parser.
get_reuse_stack_top() Get the top of the reuse scope stack.
instance_reuse([method_or_scope, _sentinel, …]) Decorate an instance method to reuse a variable scope automatically.
global_reuse([method_or_scope, _sentinel, scope]) Decorate a function to reuse a variable scope automatically.

Classes

Bernoulli(logits[, dtype]) Univariate Bernoulli distribution.
Categorical(logits[, dtype]) Univariate Categorical distribution.
Concrete(temperature, logits[, …]) The class of Concrete (or Gumbel-Softmax) distribution from (Maddison, 2016; Jang, 2016), served as the continuous relaxation of the OnehotCategorical.
Discrete alias of tfsnippet.distributions.univariate.Categorical
Distribution Base class for probability distributions.
ExpConcrete(temperature, logits[, …]) The class of ExpConcrete distribution from (Maddison, 2016), transformed from Concrete by taking logarithm.
FlowDistribution(distribution, flow) Transform a Distribution by a BaseFlow, as a new distribution.
Normal(mean[, std, logstd, …]) Univariate Normal distribution.
OnehotCategorical(logits[, dtype]) One-hot multivariate Categorical distribution.
Uniform([minval, maxval, …]) Univariate Uniform distribution.
DefaultMetricFormatter Default training metric formatter.
EarlyStopping(param_vars[, initial_metric, …]) Early-stopping context object.
EarlyStoppingContext alias of tfsnippet.scaffold.early_stopping_.EarlyStopping
MetricFormatter Base class for a training metrics formatter.
MetricLogger([summary_writer, …]) Logger for the training metrics.
TrainLoop(param_vars[, var_groups, …]) Training loop object.
TrainLoopContext alias of tfsnippet.scaffold.train_loop_.TrainLoop
VariableSaver(variables, save_dir[, …]) Version controlled saving and restoring TensorFlow variables.
AnnealingVariable(name, initial_value, ratio) A non-trainable tf.Variable, whose value will be annealed as training goes by.
BaseTrainer(loop) Base class for all trainers.
Evaluator(loop, metrics, inputs, data_flow) Class to compute evaluation metrics.
HookEntry(callback, freq, priority, birth) Configurations of a hook entry in HookList.
HookList() Class for managing hooks in BaseTrainer and Evaluator.
HookPriority Pre-defined hook priorities for BaseTrainer and Evaluator.
LossTrainer(**kwargs) A subclass of BaseTrainer, which optimizes a single loss.
ScheduledVariable(name, initial_value[, …]) A non-trainable tf.Variable, whose value might need to be changed as training goes by.
Trainer(loop, train_op, inputs, data_flow[, …]) A subclass of BaseTrainer, executing a training operation per step.
Validator(**kwargs) Class to compute validation loss and other metrics.
VariationalChain(variational, model[, …]) Chain of the variational and model nets for variational inference.
VariationalEvaluation(vi) Factory for variational evaluation outputs.
VariationalInference(log_joint, latent_log_probs) Class for variational inference.
VariationalLowerBounds(vi) Factory for variational lower-bounds.
VariationalTrainingObjectives(vi) Factory for variational training objectives.
BayesianNet([observed]) Bayesian networks.
DataFlow Data flows are objects for constructing mini-batch iterators.
DataMapper Base class for all data mappers.
SlidingWindow(data_array, window_size) DataMapper for producing sliding windows according to indices.
Config() Base class for defining config values.
ConfigField(type[, default, description, …]) A config field.
VarScopeObject([name, scope]) Base class for objects that own a variable scope.
StochasticTensor(distribution, tensor[, …]) Samples or observations of a stochastic variable.

Class Inheritance Diagram

Inheritance diagram of tfsnippet.distributions.univariate.Bernoulli, tfsnippet.distributions.univariate.Categorical, tfsnippet.distributions.multivariate.Concrete, tfsnippet.distributions.univariate.Categorical, tfsnippet.distributions.base.Distribution, tfsnippet.distributions.multivariate.ExpConcrete, tfsnippet.distributions.flow.FlowDistribution, tfsnippet.distributions.univariate.Normal, tfsnippet.distributions.multivariate.OnehotCategorical, tfsnippet.distributions.univariate.Uniform, tfsnippet.scaffold.logs.DefaultMetricFormatter, tfsnippet.scaffold.early_stopping_.EarlyStopping, tfsnippet.scaffold.early_stopping_.EarlyStopping, tfsnippet.scaffold.logs.MetricFormatter, tfsnippet.scaffold.logs.MetricLogger, tfsnippet.scaffold.train_loop_.TrainLoop, tfsnippet.scaffold.train_loop_.TrainLoop, tfsnippet.scaffold.variable_saver.VariableSaver, tfsnippet.trainer.scheduled_var.AnnealingVariable, tfsnippet.trainer.base_trainer.BaseTrainer, tfsnippet.trainer.evaluator.Evaluator, tfsnippet.trainer.hooks.HookEntry, tfsnippet.trainer.hooks.HookList, tfsnippet.trainer.hooks.HookPriority, tfsnippet.trainer.loss_trainer.LossTrainer, tfsnippet.trainer.scheduled_var.ScheduledVariable, tfsnippet.trainer.trainer.Trainer, tfsnippet.trainer.validator.Validator, tfsnippet.variational.chain.VariationalChain, tfsnippet.variational.inference.VariationalEvaluation, tfsnippet.variational.inference.VariationalInference, tfsnippet.variational.inference.VariationalLowerBounds, tfsnippet.variational.inference.VariationalTrainingObjectives, tfsnippet.bayes.BayesianNet, tfsnippet.dataflows.base.DataFlow, tfsnippet.dataflows.data_mappers.DataMapper, tfsnippet.dataflows.data_mappers.SlidingWindow, tfsnippet.utils.config_utils.Config, tfsnippet.utils.config_utils.ConfigField, tfsnippet.utils.reuse.VarScopeObject, tfsnippet.stochastic.StochasticTensor