get_batch_size

tfsnippet.utils.get_batch_size(tensor, axis=0, name=None)

Infer the mini-batch size according to tensor.

Parameters:
  • tensor (tf.Tensor) – The input placeholder.
  • axis (int) – The axis of mini-batches. Default is 0.
  • name (str) – Default name of the name scope. If not specified, generate one according to the method name.
Returns:

The batch size.

Return type:

int or tf.Tensor