tfsnippet.utils

class tfsnippet.utils.Extractor(archive_file)

Bases: object

The base class for all archive extractors.

from tfsnippet.utils import Extractor, maybe_close

with Extractor.open('a.zip') as archive_file:
    for name, f in archive_file:
        with maybe_close(f):  # This file object may not be closeable,
                              # thus we surround it by ``maybe_close()``
            print('the content of {} is:'.format(name))
            print(f.read())
__init__(archive_file)

Initialize the base Extractor class.

Parameters:archive_file – The archive file object.
close()

Close the extractor.

iter_extract()

Extract files from the archive with an iterator.

You may simply iterate over a Extractor object, which is same as calling to this method.

Yields:

(str, file-like)

Tuples of (name, file-like object), the

filename and corresponding file-like object for each file in the archive. The returned file-like object may or may not be closeable. You may surround it by maybe_close().

static open(file_path)

Create an Extractor instance for given archive file.

Parameters:file_path (str) – The path of the archive file.
Returns:The specified extractor instance.
Return type:Extractor
Raises:IOError – If the file_path is not a supported archive.
class tfsnippet.utils.TarExtractor(fpath)

Bases: tfsnippet.utils.archive_file.Extractor

Extractor for “.tar”, “.tar.gz”, “.tgz”, “.tar.bz2”, “.tbz”, “.tbz2”, “.tb2”, “.tar.xz”, “.txz” files.

iter_extract()

Extract files from the archive with an iterator.

You may simply iterate over a Extractor object, which is same as calling to this method.

Yields:

(str, file-like)

Tuples of (name, file-like object), the

filename and corresponding file-like object for each file in the archive. The returned file-like object may or may not be closeable. You may surround it by maybe_close().

class tfsnippet.utils.ZipExtractor(fpath)

Bases: tfsnippet.utils.archive_file.Extractor

Extractor for “.zip” files.

iter_extract()

Extract files from the archive with an iterator.

You may simply iterate over a Extractor object, which is same as calling to this method.

Yields:

(str, file-like)

Tuples of (name, file-like object), the

filename and corresponding file-like object for each file in the archive. The returned file-like object may or may not be closeable. You may surround it by maybe_close().

class tfsnippet.utils.RarExtractor(fpath)

Bases: tfsnippet.utils.archive_file.Extractor

Extractor for “.rar” files.

iter_extract()

Extract files from the archive with an iterator.

You may simply iterate over a Extractor object, which is same as calling to this method.

Yields:

(str, file-like)

Tuples of (name, file-like object), the

filename and corresponding file-like object for each file in the archive. The returned file-like object may or may not be closeable. You may surround it by maybe_close().

tfsnippet.utils.get_cache_root()

Get the cache root directory.

Returns:Path of the cache root directory.
Return type:str
tfsnippet.utils.set_cache_root(cache_root)

Set the root cache directory.

Parameters:cache_root (str) – The cache root directory. It will be normalized to absolute path.
class tfsnippet.utils.CacheDir(name, cache_root=None)

Bases: object

Class to manipulate a cache directory.

__init__(name, cache_root=None)

Construct a new CacheDir.

Parameters:
  • name (str) – The name of the sub-directory under cache_root.
  • cache_root (str or None) – The cache root directory. If not specified, use get_cache_root().
cache_root

Get the cache root directory.

download(uri, filename=None, show_progress=None, progress_file=<open file '<stderr>', mode 'w'>)

Download a file into this CacheDir.

Parameters:
  • uri (str) – The URI to be retrieved.
  • filename (str) – The filename to use as the downloaded file. If filename already exists in this CacheDir, will not download uri. Default None, will automatically infer filename according to uri.
  • show_progress (bool) – Whether or not to show interactive progress bar? If not specified, will show progress only if progress_file is std.stdout or std.stderr, and if progress_file.isatty() is True.
  • progress_file – The file object where to write the progress. (default sys.stderr)
Returns:

The absolute path of the downloaded file.

Return type:

str

Raises:

ValueError – If filename cannot be inferred.

download_and_extract(uri, filename=None, extract_dir=None, show_progress=None, progress_file=<open file '<stderr>', mode 'w'>)

Download a file into this CacheDir, and extract it.

Parameters:
  • uri (str) – The URI to be retrieved.
  • filename (str) – The filename to use as the downloaded file. If filename already exists in this CacheDir, will not download uri. Default None, will automatically infer filename according to uri.
  • extract_dir (str) – The name to use as the extracted directory. If extract_dir already exists in this CacheDir, will not extract archive_file. Default None, will automatically infer extract_dir according to filename.
  • show_progress (bool) – Whether or not to show interactive progress bar? If not specified, will show progress only if progress_file is std.stdout or std.stderr, and if progress_file.isatty() is True.
  • progress_file – The file object where to write the progress. (default sys.stderr)
Returns:

The absolute path of the extracted directory.

Return type:

str

Raises:

ValueError – If filename or extract_dir cannot be inferred.

extract_file(archive_file, extract_dir=None, show_progress=None, progress_file=<open file '<stderr>', mode 'w'>)

Extract an archive file into this CacheDir.

Parameters:
  • archive_file (str) – The path of the archive file.
  • extract_dir (str) – The name to use as the extracted directory. If extract_dir already exists in this CacheDir, will not extract archive_file. Default None, will automatically infer extract_dir according to archive_file.
  • show_progress (bool) – Whether or not to show interactive progress bar? If not specified, will show progress only if progress_file is std.stdout or std.stderr, and if progress_file.isatty() is True.
  • progress_file – The file object where to write the progress. (default sys.stderr)
Returns:

The absolute path of the extracted directory.

Return type:

str

Raises:

ValueError – If extract_dir cannot be inferred.

name

Get the name of this cache directory under cache_root.

path

Get the absolute path of this cache directory.

purge_all()

Delete everything in this CacheDir.

resolve(sub_path)

Resolve a sub path relative to self.path.

Parameters:sub_path – The sub path to resolve.
Returns:The resolved absolute path of sub_path.
class tfsnippet.utils.AutoInitAndCloseable

Bases: object

Classes with init() to initialize its internal states, and also close() to destroy these states. The init() method can be repeatedly called, which will cause initialization only at the first call. Thus other methods may always call init() at beginning, which can bring auto-initialization to the class.

A context manager is implemented: init() is explicitly called when entering the context, while destroy() is called when exiting the context.

__enter__()

Ensure the internal states are initialized.

__exit__(exc_type, exc_val, exc_tb)

Cleanup the internal states.

_close()

Override this method to destroy the internal states.

_init()

Override this method to initialize the internal states.

close()

Ensure the internal states are destroyed.

init()

Ensure the internal states are initialized.

class tfsnippet.utils.Disposable

Bases: object

Classes which can only be used once.

_check_usage_and_set_used()

Check whether the usage flag, ensure the object has not been used, and then set it to be used.

class tfsnippet.utils.NoReentrantContext

Bases: object

Base class for contexts which are not reentrant (i.e., if there is a context opened by __enter__, and it has not called __exit__, the __enter__ cannot be called again).

_enter()

Enter the context. Subclasses should override this instead of the true __enter__ method.

_exit(exc_type, exc_val, exc_tb)

Exit the context. Subclasses should override this instead of the true __exit__ method.

_require_entered()

Require the context to be entered.

Raises:RuntimeError – If the context is not entered.
class tfsnippet.utils.DisposableContext

Bases: tfsnippet.utils.concepts.NoReentrantContext

Base class for contexts which can only be entered once.

tfsnippet.utils.minibatch_slices_iterator(length, batch_size, skip_incomplete=False)

Iterate through all the mini-batch slices.

Parameters:
  • length (int) – Total length of data in an epoch.
  • batch_size (int) – Size of each mini-batch.
  • skip_incomplete (bool) – If True, discard the final batch if it contains less than batch_size number of items. (default False)
Yields
slice: Slices of each mini-batch. The last mini-batch may contain
less indices than batch_size.
tfsnippet.utils.split_numpy_arrays(arrays, portion=None, size=None, shuffle=True, random_state=None)

Split numpy arrays into two halves, by portion or by size.

Parameters:
  • arrays (Iterable[np.ndarray]) – Numpy arrays to be splitted.
  • portion (float) – Portion of the second half. Ignored if size is specified.
  • size (int) – Size of the second half.
  • shuffle (bool) – Whether or not to shuffle before splitting?
  • random_state (RandomState) – Optional numpy RandomState for shuffling data. (default None, use the global RandomState).
Returns:

Splitted two halves of arrays.

Return type:

(tuple[np.ndarray], tuple[np.ndarray])

tfsnippet.utils.split_numpy_array(array, portion=None, size=None, shuffle=True)

Split numpy array into two halves, by portion or by size.

Parameters:
  • array (np.ndarray) – A numpy array to be splitted.
  • portion (float) – Portion of the second half. Ignored if size is specified.
  • size (int) – Size of the second half.
  • shuffle (bool) – Whether or not to shuffle before splitting?
Returns:

Splitted two halves of the array.

Return type:

tuple[np.ndarray]

tfsnippet.utils.DocInherit(kclass)

Class decorator to enable kclass and all its sub-classes to automatically inherit docstrings from base classes.

Usage:

import six


@DocInherit
class Parent(object):
    """Docstring of the parent class."""

    def some_method(self):
        """Docstring of the method."""
        ...

class Child(Parent):
    # inherits the docstring of :meth:`Parent`

    def some_method(self):
        # inherits the docstring of :meth:`Parent.some_method`
        ...
Parameters:kclass (Type) – The class to decorate.
Returns:The decorated class.
class tfsnippet.utils.TemporaryDirectory(suffix=None, prefix=None, dir=None)

Bases: object

Create and return a temporary directory. This has the same behavior as mkdtemp but can be used as a context manager. For .. rubric:: example

with TemporaryDirectory() as tmpdir:

Upon exiting the context, the directory and everything contained in it are removed.

cleanup()
tfsnippet.utils.makedirs(name, mode=511, exist_ok=False)
tfsnippet.utils.humanize_duration(seconds, short_units=True)

Format specified time duration as human readable text.

Parameters:
  • seconds – Number of seconds of the time duration.
  • short_units (bool) – Whether or not to use short units (“d”, “h”, “m”, “s”) instead of long units (“day”, “hour”, “minute”, “second”)? (default False)
Returns:

The formatted time duration.

Return type:

str

tfsnippet.utils.camel_to_underscore(name)

Convert a camel-case name to underscore.

tfsnippet.utils.get_valid_scope_name(name, cls_or_instance=None)

Generate a valid scope name for the given method.

Parameters:
  • name (str) – The base name.
  • cls_or_instance – The class or the instance object, optional.
Returns:

The generated scope name.

Return type:

str

tfsnippet.utils.maybe_close(*args, **kwds)

Enter a context, and if obj has .close() method, close it when exiting the context.

Parameters:obj – The object maybe to close.
Yields:The specified obj.
tfsnippet.utils.iter_files(root_dir, sep='/')

Iterate through all files in root_dir, returning the relative paths of each file. The sub-directories will not be yielded. :param root_dir: The root directory, from which to iterate. :type root_dir: str :param sep: The separator for the relative paths. :type sep: str

Yields:str – The relative paths of each file.
class tfsnippet.utils.ETA(take_initial_snapshot=True)

Bases: object

Class to help compute the Estimated Time Ahead (ETA).

__init__(take_initial_snapshot=True)

Construct a new ETA.

Parameters:take_initial_snapshot (bool) – Whether or not to take the initial snapshot (0., time.time())? (default True)
get_eta(progress, now=None, take_snapshot=True)

Get the Estimated Time Ahead (ETA).

Parameters:
  • progress – The current progress, range in [0, 1].
  • now – The current timestamp in seconds. If not specified, use time.time().
  • take_snapshot (bool) – Whether or not to take a snapshot of the specified (progress, now)? (default True)
Returns:

The remaining seconds, or None if

the ETA cannot be estimated.

Return type:

float or None

take_snapshot(progress, now=None)

Take a snapshot of (progress, now), for later computing ETA.

Parameters:
  • progress – The current progress, range in [0, 1].
  • now – The current timestamp in seconds. If not specified, use time.time().
tfsnippet.utils.auto_reuse_variables(*args, **kwds)

Open a variable scope as a context, automatically choosing reuse flag.

The reuse flag will be set to False if the variable scope is opened for the first time, and it will be set to True each time the variable scope is opened again.

Parameters:
  • name_or_scope (str or tf.VariableScope) – The name of the variable scope, or the variable scope to open.
  • reopen_name_scope (bool) – Whether or not to re-open the original name scope of name_or_scope? This option takes effect only if name_or_scope is actually an instance of tf.VariableScope.
  • **kwargs – Named arguments for opening the variable scope.
Yields:

tf.VariableScope – The opened variable scope.

tfsnippet.utils.instance_reuse(method=None, scope=None)

Decorate an instance method within auto_reuse_variables() context.

This decorator should be applied to unbound instance methods, and the instance that owns the methods should have variable_scope attribute. For example:

class Foo(object):

    def __init__(self, name):
        with tf.variable_scope(name) as vs:
            self.variable_scope = vs

    @instance_reuse
    def foo(self):
        return tf.get_variable('bar', ...)

The above example is then equivalent to the following code:

class Foo(object):

    def __init__(self, name):
        with tf.variable_scope(name) as vs:
            self.variable_scope = vs

    def foo(self):
        with reopen_variable_scope(self.variable_scope):
            with auto_reuse_variables('foo'):
                return tf.get_variable('bar', ...)

By default the name of the variable scope should be equal to the name of the decorated method, and the name scope within the context should be equal to the variable scope name, plus some suffix to make it unique. The variable scope name can be set by scope argument, for example:

class Foo(object):

    @instance_reuse(scope='scope_name')
    def foo(self):
        return tf.get_variable('bar', ...)

Note that the variable reusing is based on the name of the variable scope, rather than the method. As a result, two methods with the same scope argument will reuse the same set of variables. For example:

class Foo(object):

    @instance_reuse(scope='foo')
    def foo_1(self):
        return tf.get_variable('bar', ...)

    @instance_reuse(scope='foo')
    def foo_2(self):
        return tf.get_variable('bar', ...)

These two methods will return the same bar variable.

Parameters:scope (str) – The name of the variable scope. If not set, will use the method name as scope name. This argument must be specified as named argument.
tfsnippet.utils.global_reuse(method=None, scope=None)

Decorate a function within auto_reuse_variables() scope globally.

Any function or method applied with this decorator will be called within a variable scope opened first by root_variable_scope(), then by auto_reuse_variables(). That is to say, the following code:

@global_reuse
def foo():
    return tf.get_variable('bar', ...)

bar = foo()

is equivalent to:

with root_variable_scope():
    with auto_reuse_variables('foo'):
        bar = tf.get_variable('bar', ...)

By default the name of the variable scope should be equal to the name of the decorated method, and the name scope within the context should be equal to the variable scope name, plus some suffix to make it unique. The variable scope name can be set by scope argument, for example:

@global_reuse(scope='dense')
def dense_layer(inputs):
    w = tf.get_variable('w', ...)
    b = tf.get_variable('b', ...)
    return tf.matmul(w, inputs) + b

Note that the variable reusing is based on the name of the variable scope, rather than the function object. As a result, two functions with the same name, or with the same scope argument, will reuse the same set of variables. For example:

@global_reuse(scope='foo')
def foo_1():
    return tf.get_variable('bar', ...)

@global_reuse(scope='foo')
def foo_2():
    return tf.get_variable('bar', ...)

These two functions will return the same bar variable.

Parameters:scope (str) – The name of the variable scope. If not set, will use the function name as scope name. This argument must be specified as named argument.
tfsnippet.utils.reopen_variable_scope(*args, **kwds)

Reopen the specified var_scope and its original name scope.

Unlike tf.variable_scope(), which does not open the original name scope even if a stored tf.VariableScope instance is specified, this method opens exactly the same name scope as the original one.

Parameters:
  • var_scope (tf.VariableScope) – The variable scope instance.
  • **kwargs – Named arguments for opening the variable scope.
tfsnippet.utils.root_variable_scope(*args, **kwds)

Open the root variable scope and its name scope.

Parameters:**kwargs – Named arguments for opening the root variable scope.
class tfsnippet.utils.VarScopeObject(name=None, scope=None)

Bases: object

Base class for object that owns a variable scope. It is typically used along with instance_reuse().

__init__(name=None, scope=None)

Construct the VarScopeObject.

Parameters:
  • name (str) – Name of this object. A unique variable scope name would be picked up according to this argument, if scope is not specified. If both this argument and scope is not specified, the underscored class name would be considered as name. This argument will be stored and can be accessed via name attribute of the instance. If not specified, name would be None.
  • scope (str) – Scope of this object. If specified, it will be used as the variable scope name, even if another object has already taken the same scope. That is to say, these two objects will share the same variable scope.
__repr__() <==> repr(x)
get_variables_as_dict(sub_scope=None, collection='variables', strip_sub_scope=True)

Get the variables created inside this VarScopeObject.

This method will collect variables from specified collection, which are created in the variable_scope of this object (or in the sub_scope of variable_scope, if sub_scope is not None).

Parameters:
  • sub_scope (str) – The sub-scope of variable_scope.
  • collection (str) – The collection from which to collect variables. (default tf.GraphKeys.GLOBAL_VARIABLES).
  • strip_sub_scope (bool) – Whether or not to also strip the common prefix of sub_scope? (default True)
Returns:

Dict which maps from the relative names of

variables to variable objects. By relative names we mean the full names of variables, without the common prefix of variable_scope (and sub_scope if strip_sub_scope is True).

Return type:

dict[str, tf.Variable]

name

Get the name of this object.

variable_scope

Get the variable scope of this object.

tfsnippet.utils.create_session(lock_memory=True, log_device_placement=False, allow_soft_placement=True, **kwargs)

A convenient method to create a TensorFlow session.

Parameters:
  • lock_memory (True or False or float) –
    • If True, lock all free memory.
    • If False, set allow_growth to True, i.e., not to lock
      all free memory.
    • If float, lock this portion of memory.

    (default None)

  • log_device_placement (bool) – Whether to log the placement of graph nodes. (default False)
  • allow_soft_placement (bool) – Whether or not to allow soft placement? (default True)
  • **kwargs – Other named parameters to be passed to tf.ConfigProto.
Returns:

The TensorFlow session.

Return type:

tf.Session

tfsnippet.utils.get_default_session_or_error()

Get the default session.

Returns:The default session.
Return type:tf.Session
Raises:RuntimeError – If there’s no active session.
tfsnippet.utils.get_variables_as_dict(scope=None, collection='variables')

Get TensorFlow variables as dict.

Parameters:
  • scope (str or tf.VariableScope or None) – If None, will collect all the variables within current graph. If a str or a tf.VariableScope, will collect the variables only from this scope. (default None)
  • collection (str) – Collect the variables only from this collection. (default tf.GraphKeys.GLOBAL_VARIABLES)
Returns:

Dict which maps from names to TensorFlow

variables. The names will be the full names of variables if scope is not specified, or the relative names within the scope otherwise. By relative names we mean the variable names without the common scope name prefix.

Return type:

dict[str, tf.Variable]

class tfsnippet.utils.VariableSaver(variables, save_dir, max_versions=2, filename='variables.dat', latest_file='latest', save_meta=True, name=None, scope=None)

Bases: tfsnippet.utils.scope.VarScopeObject

Version controlled saving and restoring TensorFlow variables.

__init__(variables, save_dir, max_versions=2, filename='variables.dat', latest_file='latest', save_meta=True, name=None, scope=None)

Construct the VariableSaver.

Parameters:
  • variables (collections.Iterable[tf.Variable] or dict[str, any]) – List of variables, or dict of variables with explicit keys, which should be saved and restored.
  • save_dir (str) – Directory where to place the saved variables.
  • max_versions (int) – Maximum versions to keep in the directory (Default is 2). At least 2 versions should be kept, in order to prevent corrupted checkpoint files caused by IO failure.
  • filename (str) – Name of the files of variable values (default is variables.dat).
  • latest_file (str) – Name of the file which organizes the checkpoint versions (default is latest).
  • save_meta (bool) – Whether or not to save meta graph (default is True).
  • name (str) – Optional name of this VariableSaver (argument of VarScopeObject).
  • scope (str) – Optional scope of this VariableSaver (argument of VarScopeObject).
get_latest_file()

Get the latest available checkpoint file.

restore(ignore_non_exist=False)

Restore the checkpoint from file if it exists.

Parameters:ignore_non_exist (bool) – Whether or not to ignore error if the checkpoint file does not exist? (default False)
Raises:IOError – If the checkpoint files do not exist, and ignore_non_exist is not True.
save(global_step=None)

Save the checkpoint to file.

Parameters:global_step (int or tf.Tensor) – The global step counter.
tfsnippet.utils.get_uninitialized_variables(variables=None, name=None)

Get uninitialized variables as a list.

Parameters:
  • variables (list[tf.Variable]) – Collect only uninitialized variables within this list. If not specified, will collect all uninitialized variables within tf.GraphKeys.GLOBAL_VARIABLES collection.
  • name (str) – Name of this operation in TensorFlow graph.
Returns:

Uninitialized variables.

Return type:

list[tf.Variable]

tfsnippet.utils.ensure_variables_initialized(variables=None, name=None)

Ensure variables are initialized.

Parameters:
  • variables (list[tf.Variable] or dict[str, tf.Variable]) – Ensure only the variables within this collection to be initialized. If not specified, will ensure all variables within the collection tf.GraphKeys.GLOBAL_VARIABLES to be initialized.
  • name (str) – Name of this operation in TensorFlow graph. (default ensure_variables_initialized)
tfsnippet.utils.int_shape(tensor)

Get the int shape tuple of specified tensor.

Parameters:tensor – The tensor object.
Returns:
The int shape tuple, or None
if the tensor shape is None.
Return type:tuple[int or None] or None
tfsnippet.utils.flatten(x, k, name=None)

Flatten the front dimensions of x, such that the resulting tensor will have at most k dimensions.

Parameters:
  • x (Tensor) – The tensor to be flatten.
  • k (int) – The maximum number of dimensions for the resulting tensor.
  • name (str or None) – Name of this operation.
Returns:

(The flatten tensor, the static front shape, and the front shape), or (the original tensor, None, None)

Return type:

(tf.Tensor, tuple[int or None], tuple[int] or tf.Tensor) or (tf.Tensor, None, None)

tfsnippet.utils.unflatten(x, static_front_shape, front_shape, name=None)

The inverse transformation of flatten().

If both static_front_shape is None and front_shape is None, x will be returned without any change.

Parameters:
  • x (Tensor) – The tensor to be unflatten.
  • static_front_shape (tuple[int or None] or None) – The static front shape.
  • front_shape (tuple[int] or tf.Tensor or None) – The front shape.
  • name (str or None) – Name of this operation.
Returns:

The unflatten x.

Return type:

tf.Tensor

tfsnippet.utils.get_batch_size(tensor, axis=0, name=None)

Infer the mini-batch size according to tensor.

Parameters:
  • tensor (tf.Tensor) – The input placeholder.
  • axis (int) – The axis of mini-batches. Default is 0.
  • name (str or None) – Name of this operation.
Returns:

The batch size.

Return type:

int or tf.Tensor

class tfsnippet.utils.StatisticsCollector(shape=())

Bases: object

Computing \(\mathrm{E}[X]\) and \(\operatorname{Var}[X]\) online.

__init__(shape=())

Construct the StatisticsCollector.

Parameters:shape – Shape of the values. The statistics will be collected for per element of the values. (default is ()).
collect(values, weight=1.0)

Update the statistics from values.

This method uses the following equation to update mean and square:

\[\frac{\sum_{i=1}^n w_i f(x_i)}{\sum_{j=1}^n w_j} = \frac{\sum_{i=1}^m w_i f(x_i)}{\sum_{j=1}^m w_j} + \frac{\sum_{i=m+1}^n w_i}{\sum_{j=1}^n w_j} \Bigg( \frac{\sum_{i=m+1}^n w_i f(x_i)}{\sum_{j=m+1}^n w_j} - \frac{\sum_{i=1}^m w_i f(x_i)}{\sum_{j=1}^m w_j} \Bigg)\]
Parameters:
  • values – Values to be collected in batch, numpy array or scalar whose shape ends with self.shape. The leading shape in front of self.shape is regarded as the batch shape.
  • weight – Weights of the values, should be broadcastable against the batch shape. (default is 1)
Raises:

ValueError – If the shape of values does not end with self.shape.

counter

Get the counter of collected values.

has_value

Whether or not any value has been collected?

mean

Get the mean of the values, i.e., \(\mathrm{E}[X]\).

reset()

Reset the collector to initial state.

shape

Get the shape of the values.

square

Get \(\mathrm{E}[X^2]\) of the values.

stddev

Get the std of the values, i.e., \(\sqrt{\operatorname{Var}[X]}\).

var

Get the variance of the values, i.e., \(\operatorname{Var}[X]\).

weight_sum

Get the weight summation.

class tfsnippet.utils.TensorWrapper

Bases: object

Tensor-like object that wraps a tf.Tensor instance.

This class is typically used to implement super-tensor classes, adding auxiliary methods to a tf.Tensor. The derived classes should call register_rensor_wrapper to register themselves into TensorFlow type system.

Access to any undefined attributes, properties and methods will be transparently proxied to the wrapped tensor. Also, TensorWrapper can be directly used in mathematical expressions and most TensorFlow arithmetic functions. For example, TensorWrapper(...) + tf.exp(TensorWrapper(...)).

On the other hand, TensorWrapper are neither tf.Tensor nor sub-classes of tf.Tensor, i.e., isinstance(TensorWrapper(...), tf.Tensor) == False. This is essential for sub-classes of TensorWrapper being converted correctly to tf.Tensor by tf.convert_to_tensor(), using the official type conversion system of TensorFlow.

All the attributes defined in sub-classes of TensorWrapper must have names starting with _self_. The properties and methods are not restricted by this rule.

An example of inheriting TensorWrapper is shown as follows:

class MyTensorWrapper(TensorWrapper):

    def __init__(self, wrapped, flag):
        super(MyTensorWrapper, self).__init__()
        self._self_wrapped = wrapped
        self._self_flag = flag

    @property
    def tensor(self):
        return self._self_wrapped

    @property
    def flag(self):
        return self._self_flag

register_tensor_wrapper_class(MyTensorWrapper)

# tests
t = MyTensorWrapper(tf.constant(0., dtype=tf.float32), flag=123)
assert(t.dtype == tf.float32)
assert(t.flag == 123)
tensor

Get the wrapped tf.Tensor. Derived classes must override this to return the actual wrapped tensor.

Returns:The wrapped tensor.
Return type:tf.Tensor
tfsnippet.utils.register_tensor_wrapper_class(cls)

Register a sub-class of TensorWrapper into TensorFlow type system.

Parameters:cls – The subclass of TensorWrapper to be registered.
tfsnippet.utils.is_tensorflow_version_higher_or_equal(version)

Check whether the version of TensorFlow is higher than or equal to version.

Parameters:version (str) – Expected version of TensorFlow.
Returns:True if higher or equal to, False if not.
Return type:bool
tfsnippet.utils.is_integer(x)

Test whether or not x is a Python or NumPy integer.

Parameters:x – The object to be tested.
Returns:A boolean indicating whether x is a Python or NumPy integer.
Return type:bool
tfsnippet.utils.is_float(x)

Test whether or not x is a Python or NumPy float.

Parameters:x – The object to be tested.
Returns:A boolean indicating whether x is a Python or NumPy float.
Return type:bool
tfsnippet.utils.is_tensor_object(x)

Test whether or not x is a tensor object.

tf.Tensor, tf.Variable, TensorWrapper and zhusuan.StochasticTensor are considered to be tensor objects.

Parameters:x – The object to be tested.
Returns:A boolean indicating whether x is a tensor object.
Return type:bool
class tfsnippet.utils.TensorArgValidator(name)

Bases: object

Class to validate argument values of tensors.

__init__(name)

Construct the TensorArgValidator.

Parameters:name (str) – Name of the argument to be validated, used in error messages.
require_int32(value)

Require value to be an 32-bit integer.

Parameters:value – Value to be validated. If is_tensor_object(value) == True, it will be casted into a tf.Tensor with dtype as tf.int32. If otherwise is_integer(value) == True, the type will not be casted, but its value will be checked to ensure it falls between -2**31 ~ 2**31-1.
Returns:The validated value.
Raises:TypeError – If specified value cannot be casted into int32, or the value is out of range.
require_non_negative(value)

Require value to be non-negative, i.e., value >= 0.

Parameters:value – Value to be validated. If is_tensor_object(value) == True, additional assertion will be added to value. Otherwise it will be validated against value >= 0 immediately.
Returns:The validated value.
Raises:ValueError – If specified value is not non-negative.
require_positive(value)

Require value to be positive, i.e., value > 0.

Parameters:value – Value to be validated. If is_tensor_object(value) == True, additional assertion will be added to value. Otherwise it will be validated against value > 0 immediately.
Returns:The validated value.
Raises:ValueError – If specified value is not non-negative.