TensorWrapper

class tfsnippet.utils.TensorWrapper

Bases: object

Tensor-like object that wraps a tf.Tensor instance.

This class is typically used to implement super-tensor classes, adding auxiliary methods to a tf.Tensor. The derived classes should call register_rensor_wrapper to register themselves into TensorFlow type system.

Access to any undefined attributes, properties and methods will be transparently proxied to the wrapped tensor. Also, TensorWrapper can be directly used in mathematical expressions and most TensorFlow arithmetic functions. For example, TensorWrapper(...) + tf.exp(TensorWrapper(...)).

On the other hand, TensorWrapper are neither tf.Tensor nor sub-classes of tf.Tensor, i.e., isinstance(TensorWrapper(...), tf.Tensor) == False. This is essential for sub-classes of TensorWrapper being converted correctly to tf.Tensor by tf.convert_to_tensor(), using the official type conversion system of TensorFlow.

All the attributes defined in sub-classes of TensorWrapper must have names starting with _self_. The properties and methods are not restricted by this rule.

An example of inheriting TensorWrapper is shown as follows:

class MyTensorWrapper(TensorWrapper):

    def __init__(self, wrapped, flag):
        super(MyTensorWrapper, self).__init__()
        self._self_wrapped = wrapped
        self._self_flag = flag

    @property
    def tensor(self):
        return self._self_wrapped

    @property
    def flag(self):
        return self._self_flag

register_tensor_wrapper_class(MyTensorWrapper)

# tests
t = MyTensorWrapper(tf.constant(0., dtype=tf.float32), flag=123)
assert(t.dtype == tf.float32)
assert(t.flag == 123)

Attributes Summary

tensor Get the wrapped tf.Tensor.

Attributes Documentation

tensor

Get the wrapped tf.Tensor. Derived classes must override this to return the actual wrapped tensor.

Returns:The wrapped tensor.
Return type:tf.Tensor