tfsnippet.utils¶
-
class
tfsnippet.utils.
Extractor
(archive_file)¶ Bases:
object
The base class for all archive extractors.
from tfsnippet.utils import Extractor, maybe_close with Extractor.open('a.zip') as archive_file: for name, f in archive_file: with maybe_close(f): # This file object may not be closeable, # thus we surround it by ``maybe_close()`` print('the content of {} is:'.format(name)) print(f.read())
-
__init__
(archive_file)¶ Initialize the base
Extractor
class.Parameters: archive_file – The archive file object.
-
close
()¶ Close the extractor.
-
iter_extract
()¶ Extract files from the archive with an iterator.
You may simply iterate over a
Extractor
object, which is same as calling to this method.Yields: (str, file-like) –
- Tuples of
(name, file-like object)
, the filename and corresponding file-like object for each file in the archive. The returned file-like object may or may not be closeable. You may surround it by
maybe_close()
.
- Tuples of
-
-
class
tfsnippet.utils.
TarExtractor
(fpath)¶ Bases:
tfsnippet.utils.archive_file.Extractor
Extractor for “.tar”, “.tar.gz”, “.tgz”, “.tar.bz2”, “.tbz”, “.tbz2”, “.tb2”, “.tar.xz”, “.txz” files.
-
iter_extract
()¶ Extract files from the archive with an iterator.
You may simply iterate over a
Extractor
object, which is same as calling to this method.Yields: (str, file-like) –
- Tuples of
(name, file-like object)
, the filename and corresponding file-like object for each file in the archive. The returned file-like object may or may not be closeable. You may surround it by
maybe_close()
.
- Tuples of
-
-
class
tfsnippet.utils.
ZipExtractor
(fpath)¶ Bases:
tfsnippet.utils.archive_file.Extractor
Extractor for “.zip” files.
-
iter_extract
()¶ Extract files from the archive with an iterator.
You may simply iterate over a
Extractor
object, which is same as calling to this method.Yields: (str, file-like) –
- Tuples of
(name, file-like object)
, the filename and corresponding file-like object for each file in the archive. The returned file-like object may or may not be closeable. You may surround it by
maybe_close()
.
- Tuples of
-
-
class
tfsnippet.utils.
RarExtractor
(fpath)¶ Bases:
tfsnippet.utils.archive_file.Extractor
Extractor for “.rar” files.
-
iter_extract
()¶ Extract files from the archive with an iterator.
You may simply iterate over a
Extractor
object, which is same as calling to this method.Yields: (str, file-like) –
- Tuples of
(name, file-like object)
, the filename and corresponding file-like object for each file in the archive. The returned file-like object may or may not be closeable. You may surround it by
maybe_close()
.
- Tuples of
-
-
tfsnippet.utils.
get_cache_root
()¶ Get the cache root directory.
Returns: Path of the cache root directory. Return type: str
-
tfsnippet.utils.
set_cache_root
(cache_root)¶ Set the root cache directory.
Parameters: cache_root (str) – The cache root directory. It will be normalized to absolute path.
-
class
tfsnippet.utils.
CacheDir
(name, cache_root=None)¶ Bases:
object
Class to manipulate a cache directory.
-
cache_root
¶ Get the cache root directory.
-
download
(uri, filename=None, show_progress=None, progress_file=<open file '<stderr>', mode 'w'>)¶ Download a file into this
CacheDir
.Parameters: - uri (str) – The URI to be retrieved.
- filename (str) – The filename to use as the downloaded file.
If filename already exists in this
CacheDir
, will not download uri. DefaultNone
, will automatically infer filename according to uri. - show_progress (bool) – Whether or not to show interactive
progress bar? If not specified, will show progress only
if progress_file is std.stdout or std.stderr, and
if progress_file.isatty() is
True
. - progress_file – The file object where to write the progress.
(default
sys.stderr
)
Returns: The absolute path of the downloaded file.
Return type: Raises: ValueError
– If filename cannot be inferred.
-
download_and_extract
(uri, filename=None, extract_dir=None, show_progress=None, progress_file=<open file '<stderr>', mode 'w'>)¶ Download a file into this
CacheDir
, and extract it.Parameters: - uri (str) – The URI to be retrieved.
- filename (str) – The filename to use as the downloaded file.
If filename already exists in this
CacheDir
, will not download uri. DefaultNone
, will automatically infer filename according to uri. - extract_dir (str) – The name to use as the extracted directory.
If extract_dir already exists in this
CacheDir
, will not extract archive_file. DefaultNone
, will automatically infer extract_dir according to filename. - show_progress (bool) – Whether or not to show interactive
progress bar? If not specified, will show progress only
if progress_file is std.stdout or std.stderr, and
if progress_file.isatty() is
True
. - progress_file – The file object where to write the progress.
(default
sys.stderr
)
Returns: The absolute path of the extracted directory.
Return type: Raises: ValueError
– If filename or extract_dir cannot be inferred.
-
extract_file
(archive_file, extract_dir=None, show_progress=None, progress_file=<open file '<stderr>', mode 'w'>)¶ Extract an archive file into this
CacheDir
.Parameters: - archive_file (str) – The path of the archive file.
- extract_dir (str) – The name to use as the extracted directory.
If extract_dir already exists in this
CacheDir
, will not extract archive_file. DefaultNone
, will automatically infer extract_dir according to archive_file. - show_progress (bool) – Whether or not to show interactive
progress bar? If not specified, will show progress only
if progress_file is std.stdout or std.stderr, and
if progress_file.isatty() is
True
. - progress_file – The file object where to write the progress.
(default
sys.stderr
)
Returns: The absolute path of the extracted directory.
Return type: Raises: ValueError
– If extract_dir cannot be inferred.
-
name
¶ Get the name of this cache directory under cache_root.
-
path
¶ Get the absolute path of this cache directory.
-
resolve
(sub_path)¶ Resolve a sub path relative to
self.path
.Parameters: sub_path – The sub path to resolve. Returns: The resolved absolute path of sub_path.
-
-
class
tfsnippet.utils.
AutoInitAndCloseable
¶ Bases:
object
Classes with
init()
to initialize its internal states, and alsoclose()
to destroy these states. Theinit()
method can be repeatedly called, which will cause initialization only at the first call. Thus other methods may always callinit()
at beginning, which can bring auto-initialization to the class.A context manager is implemented:
init()
is explicitly called when entering the context, whiledestroy()
is called when exiting the context.-
__enter__
()¶ Ensure the internal states are initialized.
-
__exit__
(exc_type, exc_val, exc_tb)¶ Cleanup the internal states.
-
_close
()¶ Override this method to destroy the internal states.
-
_init
()¶ Override this method to initialize the internal states.
-
close
()¶ Ensure the internal states are destroyed.
-
init
()¶ Ensure the internal states are initialized.
-
-
class
tfsnippet.utils.
Disposable
¶ Bases:
object
Classes which can only be used once.
-
_check_usage_and_set_used
()¶ Check whether the usage flag, ensure the object has not been used, and then set it to be used.
-
-
class
tfsnippet.utils.
NoReentrantContext
¶ Bases:
object
Base class for contexts which are not reentrant (i.e., if there is a context opened by
__enter__
, and it has not called__exit__
, the__enter__
cannot be called again).-
_enter
()¶ Enter the context. Subclasses should override this instead of the true
__enter__
method.
-
_exit
(exc_type, exc_val, exc_tb)¶ Exit the context. Subclasses should override this instead of the true
__exit__
method.
-
_require_entered
()¶ Require the context to be entered.
Raises: RuntimeError
– If the context is not entered.
-
-
class
tfsnippet.utils.
DisposableContext
¶ Bases:
tfsnippet.utils.concepts.NoReentrantContext
Base class for contexts which can only be entered once.
-
tfsnippet.utils.
minibatch_slices_iterator
(length, batch_size, skip_incomplete=False)¶ Iterate through all the mini-batch slices.
Parameters: - Yields
- slice: Slices of each mini-batch. The last mini-batch may contain
- less indices than batch_size.
-
tfsnippet.utils.
split_numpy_arrays
(arrays, portion=None, size=None, shuffle=True, random_state=None)¶ Split numpy arrays into two halves, by portion or by size.
Parameters: - arrays (Iterable[np.ndarray]) – Numpy arrays to be splitted.
- portion (float) – Portion of the second half. Ignored if size is specified.
- size (int) – Size of the second half.
- shuffle (bool) – Whether or not to shuffle before splitting?
- random_state (RandomState) – Optional numpy RandomState for shuffling
data. (default
None
, use the globalRandomState
).
Returns: Splitted two halves of arrays.
Return type:
-
tfsnippet.utils.
split_numpy_array
(array, portion=None, size=None, shuffle=True)¶ Split numpy array into two halves, by portion or by size.
Parameters: Returns: Splitted two halves of the array.
Return type: tuple[np.ndarray]
-
tfsnippet.utils.
DocInherit
(kclass)¶ Class decorator to enable kclass and all its sub-classes to automatically inherit docstrings from base classes.
Usage:
import six @DocInherit class Parent(object): """Docstring of the parent class.""" def some_method(self): """Docstring of the method.""" ... class Child(Parent): # inherits the docstring of :meth:`Parent` def some_method(self): # inherits the docstring of :meth:`Parent.some_method` ...
Parameters: kclass (Type) – The class to decorate. Returns: The decorated class.
-
class
tfsnippet.utils.
TemporaryDirectory
(suffix=None, prefix=None, dir=None)¶ Bases:
object
Create and return a temporary directory. This has the same behavior as mkdtemp but can be used as a context manager. For .. rubric:: example
- with TemporaryDirectory() as tmpdir:
- …
Upon exiting the context, the directory and everything contained in it are removed.
-
cleanup
()¶
-
tfsnippet.utils.
makedirs
(name, mode=511, exist_ok=False)¶
-
tfsnippet.utils.
humanize_duration
(seconds, short_units=True)¶ Format specified time duration as human readable text.
Parameters: Returns: The formatted time duration.
Return type:
-
tfsnippet.utils.
camel_to_underscore
(name)¶ Convert a camel-case name to underscore.
-
tfsnippet.utils.
get_valid_scope_name
(name, cls_or_instance=None)¶ Generate a valid scope name for the given method.
Parameters: - name (str) – The base name.
- cls_or_instance – The class or the instance object, optional.
Returns: The generated scope name.
Return type:
-
tfsnippet.utils.
maybe_close
(*args, **kwds)¶ Enter a context, and if obj has
.close()
method, close it when exiting the context.Parameters: obj – The object maybe to close. Yields: The specified obj.
-
tfsnippet.utils.
iter_files
(root_dir, sep='/')¶ Iterate through all files in root_dir, returning the relative paths of each file. The sub-directories will not be yielded. :param root_dir: The root directory, from which to iterate. :type root_dir: str :param sep: The separator for the relative paths. :type sep: str
Yields: str – The relative paths of each file.
-
class
tfsnippet.utils.
ETA
(take_initial_snapshot=True)¶ Bases:
object
Class to help compute the Estimated Time Ahead (ETA).
-
__init__
(take_initial_snapshot=True)¶ Construct a new
ETA
.Parameters: take_initial_snapshot (bool) – Whether or not to take the initial snapshot (0., time.time())
? (defaultTrue
)
-
get_eta
(progress, now=None, take_snapshot=True)¶ Get the Estimated Time Ahead (ETA).
Parameters: Returns: - The remaining seconds, or
None
if the ETA cannot be estimated.
Return type: - The remaining seconds, or
-
take_snapshot
(progress, now=None)¶ Take a snapshot of
(progress, now)
, for later computing ETA.Parameters: - progress – The current progress, range in
[0, 1]
. - now – The current timestamp in seconds. If not specified, use
time.time()
.
- progress – The current progress, range in
-
-
tfsnippet.utils.
auto_reuse_variables
(*args, **kwds)¶ Open a variable scope as a context, automatically choosing reuse flag.
The reuse flag will be set to
False
if the variable scope is opened for the first time, and it will be set toTrue
each time the variable scope is opened again.Parameters: - name_or_scope (str or tf.VariableScope) – The name of the variable scope, or the variable scope to open.
- reopen_name_scope (bool) – Whether or not to re-open the original name
scope of name_or_scope? This option takes effect only if
name_or_scope is actually an instance of
tf.VariableScope
. - **kwargs – Named arguments for opening the variable scope.
Yields: tf.VariableScope – The opened variable scope.
-
tfsnippet.utils.
instance_reuse
(method=None, scope=None)¶ Decorate an instance method within
auto_reuse_variables()
context.This decorator should be applied to unbound instance methods, and the instance that owns the methods should have
variable_scope
attribute. For example:class Foo(object): def __init__(self, name): with tf.variable_scope(name) as vs: self.variable_scope = vs @instance_reuse def foo(self): return tf.get_variable('bar', ...)
The above example is then equivalent to the following code:
class Foo(object): def __init__(self, name): with tf.variable_scope(name) as vs: self.variable_scope = vs def foo(self): with reopen_variable_scope(self.variable_scope): with auto_reuse_variables('foo'): return tf.get_variable('bar', ...)
By default the name of the variable scope should be equal to the name of the decorated method, and the name scope within the context should be equal to the variable scope name, plus some suffix to make it unique. The variable scope name can be set by scope argument, for example:
class Foo(object): @instance_reuse(scope='scope_name') def foo(self): return tf.get_variable('bar', ...)
Note that the variable reusing is based on the name of the variable scope, rather than the method. As a result, two methods with the same scope argument will reuse the same set of variables. For example:
class Foo(object): @instance_reuse(scope='foo') def foo_1(self): return tf.get_variable('bar', ...) @instance_reuse(scope='foo') def foo_2(self): return tf.get_variable('bar', ...)
These two methods will return the same bar variable.
Parameters: scope (str) – The name of the variable scope. If not set, will use the method name as scope name. This argument must be specified as named argument. See also
-
tfsnippet.utils.
global_reuse
(method=None, scope=None)¶ Decorate a function within
auto_reuse_variables()
scope globally.Any function or method applied with this decorator will be called within a variable scope opened first by
root_variable_scope()
, then byauto_reuse_variables()
. That is to say, the following code:@global_reuse def foo(): return tf.get_variable('bar', ...) bar = foo()
is equivalent to:
with root_variable_scope(): with auto_reuse_variables('foo'): bar = tf.get_variable('bar', ...)
By default the name of the variable scope should be equal to the name of the decorated method, and the name scope within the context should be equal to the variable scope name, plus some suffix to make it unique. The variable scope name can be set by scope argument, for example:
@global_reuse(scope='dense') def dense_layer(inputs): w = tf.get_variable('w', ...) b = tf.get_variable('b', ...) return tf.matmul(w, inputs) + b
Note that the variable reusing is based on the name of the variable scope, rather than the function object. As a result, two functions with the same name, or with the same scope argument, will reuse the same set of variables. For example:
@global_reuse(scope='foo') def foo_1(): return tf.get_variable('bar', ...) @global_reuse(scope='foo') def foo_2(): return tf.get_variable('bar', ...)
These two functions will return the same bar variable.
Parameters: scope (str) – The name of the variable scope. If not set, will use the function name as scope name. This argument must be specified as named argument. See also
-
tfsnippet.utils.
reopen_variable_scope
(*args, **kwds)¶ Reopen the specified var_scope and its original name scope.
Unlike
tf.variable_scope()
, which does not open the original name scope even if a storedtf.VariableScope
instance is specified, this method opens exactly the same name scope as the original one.Parameters: - var_scope (tf.VariableScope) – The variable scope instance.
- **kwargs – Named arguments for opening the variable scope.
-
tfsnippet.utils.
root_variable_scope
(*args, **kwds)¶ Open the root variable scope and its name scope.
Parameters: **kwargs – Named arguments for opening the root variable scope.
-
class
tfsnippet.utils.
VarScopeObject
(name=None, scope=None)¶ Bases:
object
Base class for object that owns a variable scope. It is typically used along with
instance_reuse()
.-
__init__
(name=None, scope=None)¶ Construct the
VarScopeObject
.Parameters: - name (str) – Name of this object. A unique variable scope name
would be picked up according to this argument, if scope is
not specified. If both this argument and scope is not
specified, the underscored class name would be considered as
name. This argument will be stored and can be accessed via
name
attribute of the instance. If not specified,name
would beNone
. - scope (str) – Scope of this object. If specified, it will be used as the variable scope name, even if another object has already taken the same scope. That is to say, these two objects will share the same variable scope.
- name (str) – Name of this object. A unique variable scope name
would be picked up according to this argument, if scope is
not specified. If both this argument and scope is not
specified, the underscored class name would be considered as
name. This argument will be stored and can be accessed via
-
__repr__
() <==> repr(x)¶
-
get_variables_as_dict
(sub_scope=None, collection='variables', strip_sub_scope=True)¶ Get the variables created inside this
VarScopeObject
.This method will collect variables from specified collection, which are created in the
variable_scope
of this object (or in the sub_scope ofvariable_scope
, if sub_scope is notNone
).Parameters: - sub_scope (str) – The sub-scope of
variable_scope
. - collection (str) – The collection from which to collect variables.
(default
tf.GraphKeys.GLOBAL_VARIABLES
). - strip_sub_scope (bool) – Whether or not to also strip the common
prefix of sub_scope? (default
True
)
Returns: - Dict which maps from the relative names of
variables to variable objects. By relative names we mean the full names of variables, without the common prefix of
variable_scope
(and sub_scope if strip_sub_scope isTrue
).
Return type: - sub_scope (str) – The sub-scope of
-
name
¶ Get the name of this object.
-
variable_scope
¶ Get the variable scope of this object.
-
-
tfsnippet.utils.
create_session
(lock_memory=True, log_device_placement=False, allow_soft_placement=True, **kwargs)¶ A convenient method to create a TensorFlow session.
Parameters: - lock_memory (True or False or float) –
- If
True
, lock all free memory. - If
False
, set allow_growth to True, i.e., not to lock - all free memory.
- If
- If float, lock this portion of memory.
(default
None
) - If
- log_device_placement (bool) – Whether to log the placement of graph
nodes. (default
False
) - allow_soft_placement (bool) – Whether or not to allow soft placement?
(default
True
) - **kwargs – Other named parameters to be passed to tf.ConfigProto.
Returns: The TensorFlow session.
Return type: tf.Session
- lock_memory (True or False or float) –
-
tfsnippet.utils.
get_default_session_or_error
()¶ Get the default session.
Returns: The default session. Return type: tf.Session Raises: RuntimeError
– If there’s no active session.
-
tfsnippet.utils.
get_variables_as_dict
(scope=None, collection='variables')¶ Get TensorFlow variables as dict.
Parameters: - scope (str or tf.VariableScope or None) – If
None
, will collect all the variables within current graph. If astr
or atf.VariableScope
, will collect the variables only from this scope. (defaultNone
) - collection (str) – Collect the variables only from this collection.
(default
tf.GraphKeys.GLOBAL_VARIABLES
)
Returns: - Dict which maps from names to TensorFlow
variables. The names will be the full names of variables if scope is not specified, or the relative names within the scope otherwise. By relative names we mean the variable names without the common scope name prefix.
Return type: - scope (str or tf.VariableScope or None) – If
-
class
tfsnippet.utils.
VariableSaver
(variables, save_dir, max_versions=2, filename='variables.dat', latest_file='latest', save_meta=True, name=None, scope=None)¶ Bases:
tfsnippet.utils.scope.VarScopeObject
Version controlled saving and restoring TensorFlow variables.
-
__init__
(variables, save_dir, max_versions=2, filename='variables.dat', latest_file='latest', save_meta=True, name=None, scope=None)¶ Construct the
VariableSaver
.Parameters: - variables (collections.Iterable[tf.Variable] or dict[str, any]) – List of variables, or dict of variables with explicit keys, which should be saved and restored.
- save_dir (str) – Directory where to place the saved variables.
- max_versions (int) – Maximum versions to keep in the directory (Default is 2). At least 2 versions should be kept, in order to prevent corrupted checkpoint files caused by IO failure.
- filename (str) – Name of the files of variable values (default is
variables.dat
). - latest_file (str) – Name of the file which organizes the checkpoint
versions (default is
latest
). - save_meta (bool) – Whether or not to save meta graph (default
is
True
). - name (str) – Optional name of this
VariableSaver
(argument ofVarScopeObject
). - scope (str) – Optional scope of this
VariableSaver
(argument ofVarScopeObject
).
-
get_latest_file
()¶ Get the latest available checkpoint file.
-
-
tfsnippet.utils.
get_uninitialized_variables
(variables=None, name=None)¶ Get uninitialized variables as a list.
Parameters: Returns: Uninitialized variables.
Return type: list[tf.Variable]
-
tfsnippet.utils.
ensure_variables_initialized
(variables=None, name=None)¶ Ensure variables are initialized.
Parameters: - variables (list[tf.Variable] or dict[str, tf.Variable]) – Ensure only the variables within this collection to be initialized. If not specified, will ensure all variables within the collection tf.GraphKeys.GLOBAL_VARIABLES to be initialized.
- name (str) – Name of this operation in TensorFlow graph. (default ensure_variables_initialized)
-
tfsnippet.utils.
int_shape
(tensor)¶ Get the int shape tuple of specified tensor.
Parameters: tensor – The tensor object. Returns: Return type: tuple[int or None] or None
-
tfsnippet.utils.
flatten
(x, k, name=None)¶ Flatten the front dimensions of x, such that the resulting tensor will have at most k dimensions.
Parameters: Returns: (The flatten tensor, the static front shape, and the front shape), or (the original tensor, None, None)
Return type: (tf.Tensor, tuple[int or None], tuple[int] or tf.Tensor) or (tf.Tensor, None, None)
-
tfsnippet.utils.
unflatten
(x, static_front_shape, front_shape, name=None)¶ The inverse transformation of
flatten()
.If both static_front_shape is None and front_shape is None, x will be returned without any change.
Parameters: Returns: The unflatten x.
Return type: tf.Tensor
-
tfsnippet.utils.
get_batch_size
(tensor, axis=0, name=None)¶ Infer the mini-batch size according to tensor.
Parameters: Returns: The batch size.
Return type: int or tf.Tensor
-
class
tfsnippet.utils.
StatisticsCollector
(shape=())¶ Bases:
object
Computing \(\mathrm{E}[X]\) and \(\operatorname{Var}[X]\) online.
-
__init__
(shape=())¶ Construct the
StatisticsCollector
.Parameters: shape – Shape of the values. The statistics will be collected for per element of the values. (default is ()
).
-
collect
(values, weight=1.0)¶ Update the statistics from values.
This method uses the following equation to update mean and square:
\[\frac{\sum_{i=1}^n w_i f(x_i)}{\sum_{j=1}^n w_j} = \frac{\sum_{i=1}^m w_i f(x_i)}{\sum_{j=1}^m w_j} + \frac{\sum_{i=m+1}^n w_i}{\sum_{j=1}^n w_j} \Bigg( \frac{\sum_{i=m+1}^n w_i f(x_i)}{\sum_{j=m+1}^n w_j} - \frac{\sum_{i=1}^m w_i f(x_i)}{\sum_{j=1}^m w_j} \Bigg)\]Parameters: - values – Values to be collected in batch, numpy array or scalar
whose shape ends with
self.shape
. The leading shape in front ofself.shape
is regarded as the batch shape. - weight – Weights of the values, should be broadcastable against the batch shape. (default is 1)
Raises: ValueError
– If the shape of values does not end with self.shape.- values – Values to be collected in batch, numpy array or scalar
whose shape ends with
-
counter
¶ Get the counter of collected values.
-
has_value
¶ Whether or not any value has been collected?
-
mean
¶ Get the mean of the values, i.e., \(\mathrm{E}[X]\).
-
reset
()¶ Reset the collector to initial state.
-
shape
¶ Get the shape of the values.
-
square
¶ Get \(\mathrm{E}[X^2]\) of the values.
-
stddev
¶ Get the std of the values, i.e., \(\sqrt{\operatorname{Var}[X]}\).
-
var
¶ Get the variance of the values, i.e., \(\operatorname{Var}[X]\).
-
weight_sum
¶ Get the weight summation.
-
-
class
tfsnippet.utils.
TensorWrapper
¶ Bases:
object
Tensor-like object that wraps a tf.Tensor instance.
This class is typically used to implement super-tensor classes, adding auxiliary methods to a
tf.Tensor
. The derived classes should call register_rensor_wrapper to register themselves into TensorFlow type system.Access to any undefined attributes, properties and methods will be transparently proxied to the wrapped tensor. Also,
TensorWrapper
can be directly used in mathematical expressions and most TensorFlow arithmetic functions. For example,TensorWrapper(...) + tf.exp(TensorWrapper(...))
.On the other hand,
TensorWrapper
are neithertf.Tensor
nor sub-classes oftf.Tensor
, i.e.,isinstance(TensorWrapper(...), tf.Tensor) == False
. This is essential for sub-classes ofTensorWrapper
being converted correctly totf.Tensor
bytf.convert_to_tensor()
, using the official type conversion system of TensorFlow.All the attributes defined in sub-classes of
TensorWrapper
must have names starting with_self_
. The properties and methods are not restricted by this rule.An example of inheriting
TensorWrapper
is shown as follows:class MyTensorWrapper(TensorWrapper): def __init__(self, wrapped, flag): super(MyTensorWrapper, self).__init__() self._self_wrapped = wrapped self._self_flag = flag @property def tensor(self): return self._self_wrapped @property def flag(self): return self._self_flag register_tensor_wrapper_class(MyTensorWrapper) # tests t = MyTensorWrapper(tf.constant(0., dtype=tf.float32), flag=123) assert(t.dtype == tf.float32) assert(t.flag == 123)
-
tensor
¶ Get the wrapped
tf.Tensor
. Derived classes must override this to return the actual wrapped tensor.Returns: The wrapped tensor. Return type: tf.Tensor
-
-
tfsnippet.utils.
register_tensor_wrapper_class
(cls)¶ Register a sub-class of
TensorWrapper
into TensorFlow type system.Parameters: cls – The subclass of TensorWrapper
to be registered.
-
tfsnippet.utils.
is_tensorflow_version_higher_or_equal
(version)¶ Check whether the version of TensorFlow is higher than or equal to version.
Parameters: version (str) – Expected version of TensorFlow. Returns: True if higher or equal to, False if not. Return type: bool
-
tfsnippet.utils.
is_integer
(x)¶ Test whether or not x is a Python or NumPy integer.
Parameters: x – The object to be tested. Returns: A boolean indicating whether x is a Python or NumPy integer. Return type: bool
-
tfsnippet.utils.
is_float
(x)¶ Test whether or not x is a Python or NumPy float.
Parameters: x – The object to be tested. Returns: A boolean indicating whether x is a Python or NumPy float. Return type: bool
-
tfsnippet.utils.
is_tensor_object
(x)¶ Test whether or not x is a tensor object.
tf.Tensor
,tf.Variable
,TensorWrapper
andzhusuan.StochasticTensor
are considered to be tensor objects.Parameters: x – The object to be tested. Returns: A boolean indicating whether x is a tensor object. Return type: bool
-
class
tfsnippet.utils.
TensorArgValidator
(name)¶ Bases:
object
Class to validate argument values of tensors.
-
__init__
(name)¶ Construct the
TensorArgValidator
.Parameters: name (str) – Name of the argument to be validated, used in error messages.
-
require_int32
(value)¶ Require value to be an 32-bit integer.
Parameters: value – Value to be validated. If is_tensor_object(value) == True
, it will be casted into atf.Tensor
with dtype astf.int32
. If otherwiseis_integer(value) == True
, the type will not be casted, but its value will be checked to ensure it falls between -2**31 ~ 2**31-1.Returns: The validated value. Raises: TypeError
– If specified value cannot be casted into int32, or the value is out of range.
-
require_non_negative
(value)¶ Require value to be non-negative, i.e.,
value >= 0
.Parameters: value – Value to be validated. If is_tensor_object(value) == True
, additional assertion will be added to value. Otherwise it will be validated againstvalue >= 0
immediately.Returns: The validated value. Raises: ValueError
– If specified value is not non-negative.
-
require_positive
(value)¶ Require value to be positive, i.e.,
value > 0
.Parameters: value – Value to be validated. If is_tensor_object(value) == True
, additional assertion will be added to value. Otherwise it will be validated againstvalue > 0
immediately.Returns: The validated value. Raises: ValueError
– If specified value is not non-negative.
-