我们从Python开源项目中,提取了以下48个代码示例,用于说明如何使用tensorflow.python.ops.array_ops.rank()。
def _transpose_batch_time(x): """Transpose the batch and time dimensions of a Tensor. Retains as much of the static shape information as possible. Args: x: A tensor of rank 2 or higher. Returns: x transposed along the first two dimensions. Raises: ValueError: if `x` is rank 1 or lower. """ x_static_shape = x.get_shape() if x_static_shape.ndims is not None and x_static_shape.ndims < 2: raise ValueError( "Expected input tensor %s to have rank at least 2, but saw shape: %s" % (x, x_static_shape)) x_rank = array_ops.rank(x) x_t = array_ops.transpose( x, array_ops.concat( ([1, 0], math_ops.range(2, x_rank)), axis=0)) x_t.set_shape( tensor_shape.TensorShape([ x_static_shape[1].value, x_static_shape[0].value ]).concatenate(x_static_shape[2:])) return x_t
def get_ndims(self, x, name="get_ndims"): """Get `Tensor` number of dimensions (rank). Args: x: `Tensor`. name: `String`. The name to give this op. Returns: ndims: Scalar number of dimensions associated with a `Tensor`. """ with self._name_scope(name, values=[x]): x = ops.convert_to_tensor(x, name="x") ndims = x.get_shape().ndims if ndims is None: return array_ops.rank(x, name="ndims") return ops.convert_to_tensor(ndims, dtype=dtypes.int32, name="ndims")
def rank(self, name="rank"): """Tensor rank. Equivalent to `tf.rank(A)`. Will equal `n + 2`. If this operator represents the batch matrix `A` with `A.shape = [N1,...,Nn, k, k]`, the `rank` is `n + 2`. Args: name: A name scope to use for ops added by this method. Returns: `int32` `Tensor` """ # Derived classes get this "for free" once .shape() is implemented. with ops.name_scope(self.name): with ops.name_scope(name, values=self.inputs): return array_ops.size(self.shape())
def batch_shape(self, name="batch_shape"): """Shape of batches associated with this operator. If this operator represents the batch matrix `A` with `A.shape = [N1,...,Nn, k, k]`, the `batch_shape` is `[N1,...,Nn]`. Args: name: A name scope to use for ops added by this method. Returns: `int32` `Tensor` """ # Derived classes get this "for free" once .shape() is implemented. with ops.name_scope(self.name): with ops.name_scope(name, values=self.inputs): return array_ops.slice(self.shape(), [0], [self.rank() - 2])
def vector_space_dimension(self, name="vector_space_dimension"): """Dimension of vector space on which this acts. The `k` in `R^k`. If this operator represents the batch matrix `A` with `A.shape = [N1,...,Nn, k, k]`, the `vector_space_dimension` is `k`. Args: name: A name scope to use for ops added by this method. Returns: `int32` `Tensor` """ # Derived classes get this "for free" once .shape() is implemented. with ops.name_scope(self.name): with ops.name_scope(name, values=self.inputs): return array_ops.gather(self.shape(), self.rank() - 1)
def extract_batch_shape(x, num_event_dims, name="extract_batch_shape"): """Extract the batch shape from `x`. Assuming `x.shape = batch_shape + event_shape`, when `event_shape` has `num_event_dims` dimensions. This `Op` returns the batch shape `Tensor`. Args: x: `Tensor` with rank at least `num_event_dims`. If rank is not high enough this `Op` will fail. num_event_dims: `int32` scalar `Tensor`. The number of trailing dimensions in `x` to be considered as part of `event_shape`. name: A name to prepend to created `Ops`. Returns: batch_shape: `1-D` `int32` `Tensor` """ with ops.name_scope(name, values=[x]): x = ops.convert_to_tensor(x, name="x") return array_ops.slice( array_ops.shape(x), [0], [array_ops.rank(x) - num_event_dims])
def _get_identity_operator(self, v): """Get an `OperatorPDIdentity` to play the role of `D` in `VDV^T`.""" with ops.name_scope("get_identity_operator", values=[v]): if v.get_shape().is_fully_defined(): v_shape = v.get_shape().as_list() v_batch_shape = v_shape[:-2] r = v_shape[-1] id_shape = v_batch_shape + [r, r] else: v_shape = array_ops.shape(v) v_rank = array_ops.rank(v) v_batch_shape = array_ops.slice(v_shape, [0], [v_rank - 2]) r = array_ops.gather(v_shape, v_rank - 1) # Last dim of v id_shape = array_ops.concat(0, (v_batch_shape, [r, r])) return operator_pd_identity.OperatorPDIdentity( id_shape, v.dtype, verify_pd=self._verify_pd)
def _check_chol(self, chol): """Verify that `chol` is proper.""" chol = ops.convert_to_tensor(chol, name="chol") if not self.verify_pd: return chol shape = array_ops.shape(chol) rank = array_ops.rank(chol) is_matrix = check_ops.assert_rank_at_least(chol, 2) is_square = check_ops.assert_equal( array_ops.gather(shape, rank - 2), array_ops.gather(shape, rank - 1)) deps = [is_matrix, is_square] diag = array_ops.matrix_diag_part(chol) deps.append(check_ops.assert_positive(diag)) return control_flow_ops.with_dependencies(deps, chol)
def _sample_n(self, n, seed=None): # Recall _assert_valid_mu ensures mu and self._cov have same batch shape. shape = array_ops.concat(0, [self._cov.vector_shape(), [n]]) white_samples = random_ops.random_normal(shape=shape, mean=0, stddev=1, dtype=self.dtype, seed=seed) correlated_samples = self._cov.sqrt_matmul(white_samples) # Move the last dimension to the front perm = array_ops.concat(0, ( array_ops.pack([array_ops.rank(correlated_samples) - 1]), math_ops.range(0, array_ops.rank(correlated_samples) - 1))) # TODO(ebrevdo): Once we get a proper tensor contraction op, # perform the inner product using that instead of batch_matmul # and this slow transpose can go away! correlated_samples = array_ops.transpose(correlated_samples, perm) samples = correlated_samples + self.mu return samples
def __init__(self, label_name, weight_column_name, enable_centered_bias, head_name, thresholds): def loss_fn(logits, labels): check_shape_op = control_flow_ops.Assert( math_ops.less_equal(array_ops.rank(labels), 2), ["labels shape should be either [batch_size, 1] or [batch_size]"]) with ops.control_dependencies([check_shape_op]): labels = array_ops.reshape( labels, shape=[array_ops.shape(labels)[0], 1]) return losses.hinge_loss(logits, labels) super(_BinarySvmHead, self).__init__( train_loss_fn=loss_fn, eval_loss_fn=loss_fn, n_classes=2, label_name=label_name, weight_column_name=weight_column_name, enable_centered_bias=enable_centered_bias, head_name=head_name, thresholds=thresholds)
def _forward(self, x): # Pad the last dim with a zeros vector. We need this because it lets us # infer the scale in the inverse function. y = array_ops.expand_dims(x, dim=-1) if self._static_event_ndims == 0 else x ndims = (y.get_shape().ndims if y.get_shape().ndims is not None else array_ops.rank(y)) y = array_ops.pad(y, paddings=array_ops.concat(0, ( array_ops.zeros((ndims - 1, 2), dtype=dtypes.int32), [[0, 1]]))) # Set shape hints. if x.get_shape().ndims is not None: shape = x.get_shape().as_list() if self._static_event_ndims == 0: shape += [2] elif shape[-1] is not None: shape[-1] += 1 shape = tensor_shape.TensorShape(shape) y.get_shape().assert_is_compatible_with(shape) y.set_shape(shape) # Since we only support event_ndims in [0, 1] and we do padding, we always # reduce over the last dimension, i.e., dim=-1 (which is the default). return nn_ops.softmax(y)
def _sample_n(self, n, seed=None): # Recall _assert_valid_mu ensures mu and self._cov have same batch shape. shape = array_ops.concat(0, [self._cov.vector_shape(), [n]]) white_samples = random_ops.random_normal(shape=shape, mean=0., stddev=1., dtype=self.dtype, seed=seed) correlated_samples = self._cov.sqrt_matmul(white_samples) # Move the last dimension to the front perm = array_ops.concat(0, ( array_ops.pack([array_ops.rank(correlated_samples) - 1]), math_ops.range(0, array_ops.rank(correlated_samples) - 1))) # TODO(ebrevdo): Once we get a proper tensor contraction op, # perform the inner product using that instead of batch_matmul # and this slow transpose can go away! correlated_samples = array_ops.transpose(correlated_samples, perm) samples = correlated_samples + self.mu return samples
def transpose_batch_time(x): """Transpose the batch and time dimensions of a Tensor. Retains as much of the static shape information as possible. Args: x: A tensor of rank 2 or higher. Returns: x transposed along the first two dimensions. Raises: ValueError: if `x` is rank 1 or lower. """ x_static_shape = x.get_shape() if x_static_shape.ndims is not None and x_static_shape.ndims < 2: raise ValueError( "Expected input tensor %s to have rank at least 2, but saw shape: %s" % (x, x_static_shape)) x_rank = array_ops.rank(x) x_t = array_ops.transpose(x, array_ops.concat(([1, 0], math_ops.range(2, x_rank)), axis=0)) x_t.set_shape(tf.tensor_shape.TensorShape([ x_static_shape[1].value, x_static_shape[0].value]).concatenate(x_static_shape[2:])) return x_t
def batch_shape(self, name="batch_shape"): """Shape of batches associated with this operator. If this operator represents the batch matrix `A` with `A.shape = [N1,...,Nn, k, k]`, the `batch_shape` is `[N1,...,Nn]`. Args: name: A name scope to use for ops added by this method. Returns: `int32` `Tensor` """ # Derived classes get this "for free" once .shape() is implemented. with ops.name_scope(self.name): with ops.name_scope(name, values=self.inputs): return array_ops.strided_slice(self.shape(), [0], [self.rank() - 2])
def extract_batch_shape(x, num_event_dims, name="extract_batch_shape"): """Extract the batch shape from `x`. Assuming `x.shape = batch_shape + event_shape`, when `event_shape` has `num_event_dims` dimensions. This `Op` returns the batch shape `Tensor`. Args: x: `Tensor` with rank at least `num_event_dims`. If rank is not high enough this `Op` will fail. num_event_dims: `int32` scalar `Tensor`. The number of trailing dimensions in `x` to be considered as part of `event_shape`. name: A name to prepend to created `Ops`. Returns: batch_shape: `1-D` `int32` `Tensor` """ with ops.name_scope(name, values=[x]): x = ops.convert_to_tensor(x, name="x") return array_ops.strided_slice( array_ops.shape(x), [0], [array_ops.rank(x) - num_event_dims])
def _get_identity_operator(self, v): """Get an `OperatorPDIdentity` to play the role of `D` in `VDV^T`.""" with ops.name_scope("get_identity_operator", values=[v]): if v.get_shape().is_fully_defined(): v_shape = v.get_shape().as_list() v_batch_shape = v_shape[:-2] r = v_shape[-1] id_shape = v_batch_shape + [r, r] else: v_shape = array_ops.shape(v) v_rank = array_ops.rank(v) v_batch_shape = array_ops.strided_slice(v_shape, [0], [v_rank - 2]) r = array_ops.gather(v_shape, v_rank - 1) # Last dim of v id_shape = array_ops.concat((v_batch_shape, [r, r]), 0) return operator_pd_identity.OperatorPDIdentity( id_shape, v.dtype, verify_pd=self._verify_pd)
def _process_matrix(self, matrix, min_rank, event_ndims): """Helper to __init__ which gets matrix in batch-ready form.""" # Pad the matrix so that matmul works in the case of a matrix and vector # input. Keep track if the matrix was padded, to distinguish between a # rank 3 tensor and a padded rank 2 tensor. # TODO(srvasude): Remove side-effects from functions. Its currently unbroken # but error-prone since the function call order may change in the future. self._rank_two_event_ndims_one = math_ops.logical_and( math_ops.equal(array_ops.rank(matrix), min_rank), math_ops.equal(event_ndims, 1)) left = array_ops.where(self._rank_two_event_ndims_one, 1, 0) pad = array_ops.concat( [array_ops.ones( [left], dtype=dtypes.int32), array_ops.shape(matrix)], 0) return array_ops.reshape(matrix, pad)
def _sample_n(self, n, seed=None): # Recall _assert_valid_mu ensures mu and self._cov have same batch shape. shape = array_ops.concat([self._cov.vector_shape(), [n]], 0) white_samples = random_ops.random_normal(shape=shape, mean=0., stddev=1., dtype=self.dtype, seed=seed) correlated_samples = self._cov.sqrt_matmul(white_samples) # Move the last dimension to the front perm = array_ops.concat( (array_ops.stack([array_ops.rank(correlated_samples) - 1]), math_ops.range(0, array_ops.rank(correlated_samples) - 1)), 0) # TODO(ebrevdo): Once we get a proper tensor contraction op, # perform the inner product using that instead of batch_matmul # and this slow transpose can go away! correlated_samples = array_ops.transpose(correlated_samples, perm) samples = correlated_samples + self.mu return samples
def _is_rank(expected_rank, actual_tensor): """Returns whether actual_tensor's rank is expected_rank. Args: expected_rank: Integer defining the expected rank, or tensor of same. actual_tensor: Tensor to test. Returns: New tensor. """ with ops.name_scope('is_rank', values=[actual_tensor]) as scope: expected = ops.convert_to_tensor(expected_rank, name='expected') actual = array_ops.rank(actual_tensor, name='actual') return math_ops.equal(expected, actual, name=scope)
def _check_rank(value, expected_rank): """Check the rank of Tensor `value`, via shape inference and assertions. Args: value: A Tensor, possibly with shape associated shape information. expected_rank: int32 scalar (optionally a `Tensor`). Returns: new_value: A Tensor matching `value`. Accessing this tensor tests assertions on its rank. If expected_rank is not a `Tensor`, then new_value's shape's rank has been set. Raises: ValueError: if `expected_rank` is not a `Tensor` and the rank of `value` is known and is not equal to `expected_rank`. """ assert isinstance(value, ops.Tensor) with ops.control_dependencies([ control_flow_ops.Assert( math_ops.equal(expected_rank, array_ops.rank(value)), [string_ops.string_join( ["Rank of tensor %s should be: " % value.name, string_ops.as_string(expected_rank), ", shape received:"]), array_ops.shape(value)])]): new_value = array_ops.identity(value, name="rank_checked") if isinstance(expected_rank, ops.Tensor): expected_rank_value = tensor_util.constant_value(expected_rank) if expected_rank_value is not None: expected_rank = int(expected_rank_value) if not isinstance(expected_rank, ops.Tensor): try: new_value.set_shape(new_value.get_shape().with_rank(expected_rank)) except ValueError as e: raise ValueError("Rank check failed for %s: %s" % (value.name, str(e))) return new_value
def _log_loss_with_two_classes(logits, target): check_shape_op = control_flow_ops.Assert( math_ops.less_equal(array_ops.rank(target), 2), ["target's shape should be either [batch_size, 1] or [batch_size]"]) with ops.control_dependencies([check_shape_op]): target = array_ops.reshape(target, shape=[array_ops.shape(target)[0], 1]) return nn.sigmoid_cross_entropy_with_logits( logits, math_ops.to_float(target))
def _softmax_cross_entropy_loss(logits, target): check_shape_op = control_flow_ops.Assert( math_ops.less_equal(array_ops.rank(target), 2), ["target's shape should be either [batch_size, 1] or [batch_size]"]) with ops.control_dependencies([check_shape_op]): target = array_ops.reshape(target, shape=[array_ops.shape(target)[0]]) return nn.sparse_softmax_cross_entropy_with_logits(logits, target)
def __init__(self, label_name, weight_column_name): def loss_fn(logits, target): check_shape_op = control_flow_ops.Assert( math_ops.less_equal(array_ops.rank(target), 2), ["target's shape should be either [batch_size, 1] or [batch_size]"]) with ops.control_dependencies([check_shape_op]): target = array_ops.reshape( target, shape=[array_ops.shape(target)[0], 1]) return losses.hinge_loss(logits, target) super(_BinarySvmTargetColumn, self).__init__( loss_fn=loss_fn, n_classes=2, label_name=label_name, weight_column_name=weight_column_name)
def _check_labels_and_scores(boolean_labels, scores, check_shape): """Check the rank of labels/scores, return tensor versions.""" with ops.name_scope('_check_labels_and_scores', values=[boolean_labels, scores]): boolean_labels = ops.convert_to_tensor(boolean_labels, name='boolean_labels') scores = ops.convert_to_tensor(scores, name='scores') if boolean_labels.dtype != dtypes.bool: raise ValueError( 'Argument boolean_labels should have dtype bool. Found: %s', boolean_labels.dtype) if check_shape: labels_rank_1 = control_flow_ops.Assert( math_ops.equal(1, array_ops.rank(boolean_labels)), ['Argument boolean_labels should have rank 1. Found: ', boolean_labels.name, array_ops.shape(boolean_labels)]) scores_rank_1 = control_flow_ops.Assert( math_ops.equal(1, array_ops.rank(scores)), ['Argument scores should have rank 1. Found: ', scores.name, array_ops.shape(scores)]) with ops.control_dependencies([labels_rank_1, scores_rank_1]): return boolean_labels, scores else: return boolean_labels, scores
def _strict_conv1d(x, h): """Return x * h for rank 1 tensors x and h.""" with ops.name_scope('strict_conv1d', values=[x, h]): x = array_ops.reshape(x, (1, -1, 1, 1)) h = array_ops.reshape(h, (-1, 1, 1, 1)) result = nn_ops.conv2d(x, h, [1, 1, 1, 1], 'SAME') return array_ops.reshape(result, [-1])
def _dispatch_based_on_batch(self, batch_method, singleton_method, **args): """Helper to automatically call batch or singleton operation.""" if self.get_shape().ndims is not None: is_batch = self.get_shape().ndims > 2 if is_batch: return batch_method(**args) else: return singleton_method(**args) else: is_batch = self.rank() > 2 return control_flow_ops.cond( is_batch, lambda: batch_method(**args), lambda: singleton_method(**args) )
def _flip_matrix_to_vector_dynamic(mat, batch_shape): """Flip matrix to vector with dynamic shapes.""" mat_rank = array_ops.rank(mat) k = array_ops.gather(array_ops.shape(mat), mat_rank - 2) final_shape = array_ops.concat(0, (batch_shape, [k])) # mat.shape = matrix_batch_shape + [k, M] # Permutation corresponding to [M] + matrix_batch_shape + [k] perm = array_ops.concat( 0, ([mat_rank - 1], math_ops.range(0, mat_rank - 1))) mat_with_end_at_beginning = array_ops.transpose(mat, perm=perm) vector = array_ops.reshape(mat_with_end_at_beginning, final_shape) return vector
def _flip_vector_to_matrix_dynamic(vec, batch_shape): """flip_vector_to_matrix with dynamic shapes.""" # Shapes associated with batch_shape batch_rank = array_ops.size(batch_shape) # Shapes associated with vec. vec = ops.convert_to_tensor(vec, name="vec") vec_shape = array_ops.shape(vec) vec_rank = array_ops.rank(vec) vec_batch_rank = vec_rank - 1 m = vec_batch_rank - batch_rank # vec_shape_left = [M1,...,Mm] or []. vec_shape_left = array_ops.slice(vec_shape, [0], [m]) # If vec_shape_left = [], then condensed_shape = [1] since reduce_prod([]) = 1 # If vec_shape_left = [M1,...,Mm], condensed_shape = [M1*...*Mm] condensed_shape = [math_ops.reduce_prod(vec_shape_left)] k = array_ops.gather(vec_shape, vec_rank - 1) new_shape = array_ops.concat(0, (batch_shape, [k], condensed_shape)) def _flip_front_dims_to_back(): # Permutation corresponding to [N1,...,Nn] + [k, M1,...,Mm] perm = array_ops.concat( 0, (math_ops.range(m, vec_rank), math_ops.range(0, m))) return array_ops.transpose(vec, perm=perm) x_flipped = control_flow_ops.cond( math_ops.less(0, m), _flip_front_dims_to_back, lambda: array_ops.expand_dims(vec, -1)) return array_ops.reshape(x_flipped, new_shape)
def _sqrt_solve(self, rhs): # Recall the square root of this operator is M + VDV^T. # The Woodbury formula gives: # (M + VDV^T)^{-1} # = M^{-1} - M^{-1} V (D^{-1} + V^T M^{-1} V)^{-1} V^T M^{-1} # = M^{-1} - M^{-1} V C^{-1} V^T M^{-1} # where C is the capacitance matrix. # TODO(jvdillon) Determine if recursively applying rank-1 updates is more # efficient. May not be possible because a general n x n matrix can be # represeneted as n rank-1 updates, and solving with this matrix is always # done in O(n^3) time. m = self._operator v = self._v cchol = self._chol_capacitance(batch_mode=False) # The operators will use batch/singleton mode automatically. We don't # override. # M^{-1} rhs minv_rhs = m.solve(rhs) # V^T M^{-1} rhs vt_minv_rhs = math_ops.matmul(v, minv_rhs, transpose_a=True) # C^{-1} V^T M^{-1} rhs cinv_vt_minv_rhs = linalg_ops.cholesky_solve(cchol, vt_minv_rhs) # V C^{-1} V^T M^{-1} rhs v_cinv_vt_minv_rhs = math_ops.matmul(v, cinv_vt_minv_rhs) # M^{-1} V C^{-1} V^T M^{-1} rhs minv_v_cinv_vt_minv_rhs = m.solve(v_cinv_vt_minv_rhs) # M^{-1} - M^{-1} V C^{-1} V^T M^{-1} return minv_rhs - minv_v_cinv_vt_minv_rhs
def _expand_sample_shape(self, sample_shape): """Helper to `sample` which ensures sample_shape is 1D.""" sample_shape_static_val = tensor_util.constant_value(sample_shape) ndims = sample_shape.get_shape().ndims if sample_shape_static_val is None: if ndims is None or not sample_shape.get_shape().is_fully_defined(): ndims = array_ops.rank(sample_shape) expanded_shape = distribution_util.pick_vector( math_ops.equal(ndims, 0), np.array((1,), dtype=dtypes.int32.as_numpy_dtype()), array_ops.shape(sample_shape)) sample_shape = array_ops.reshape(sample_shape, expanded_shape) total = math_ops.reduce_prod(sample_shape) # reduce_prod([]) == 1 else: if ndims is None: raise ValueError( "Shouldn't be here; ndims cannot be none when we have a " "tf.constant shape.") if ndims == 0: sample_shape_static_val = np.reshape(sample_shape_static_val, [1]) sample_shape = ops.convert_to_tensor( sample_shape_static_val, dtype=dtypes.int32, name="sample_shape") total = np.prod(sample_shape_static_val, dtype=dtypes.int32.as_numpy_dtype()) return sample_shape, total
def _log_prob(self, k): k = ops.convert_to_tensor(k, name="k") logits = self.logits * array_ops.ones_like( array_ops.expand_dims(k, -1), dtype=self.logits.dtype) shape = array_ops.slice(array_ops.shape(logits), [0], [array_ops.rank(logits) - 1]) k *= array_ops.ones(shape, dtype=k.dtype) k.set_shape(tensor_shape.TensorShape(logits.get_shape()[:-1])) return -nn_ops.sparse_softmax_cross_entropy_with_logits(logits, k)
def _event_shape(self): return array_ops.gather(array_ops.shape(self._mean_val), [array_ops.rank(self._mean_val) - 1])
def _event_shape(self): return array_ops.gather(array_ops.shape(self.alpha), [array_ops.rank(self.alpha) - 1])
def _assert_valid_mu(self, mu): """Return `mu` after validity checks and possibly with assertations.""" cov = self._cov if mu.dtype != cov.dtype: raise TypeError( "mu and cov must have the same dtype. Found mu.dtype = %s, " "cov.dtype = %s" % (mu.dtype, cov.dtype)) # Try to validate with static checks. mu_shape = mu.get_shape() cov_shape = cov.get_shape() if mu_shape.is_fully_defined() and cov_shape.is_fully_defined(): if mu_shape != cov_shape[:-1]: raise ValueError( "mu.shape and cov.shape[:-1] should match. Found: mu.shape=%s, " "cov.shape=%s" % (mu_shape, cov_shape)) else: return mu # Static checks could not be run, so possibly do dynamic checks. if not self.validate_args: return mu else: assert_same_rank = check_ops.assert_equal( array_ops.rank(mu) + 1, cov.rank(), data=["mu should have rank 1 less than cov. Found: rank(mu) = ", array_ops.rank(mu), " rank(cov) = ", cov.rank()], ) with ops.control_dependencies([assert_same_rank]): assert_same_shape = check_ops.assert_equal( array_ops.shape(mu), cov.vector_shape(), data=["mu.shape and cov.shape[:-1] should match. " "Found: shape(mu) = " , array_ops.shape(mu), " shape(cov) = ", cov.shape()], ) return control_flow_ops.with_dependencies([assert_same_shape], mu)
def flip_matrix_to_vector(mat, batch_shape, static_batch_shape): """Flip dims to reshape batch matrix `mat` to a vector with given batch shape. ```python mat = tf.random_normal(2, 3, 4, 6) # Flip the trailing dimension around to the front. flip_matrix_to_vector(mat, [6, 2, 3], [6, 3, 2]) # Shape [6, 2, 3, 4] # Flip the trailing dimension around then reshape batch indices to batch_shape flip_matrix_to_vector(mat, [6, 3, 2], [6, 3, 2]) # Shape [6, 3, 2, 4] flip_matrix_to_vector(mat, [2, 3, 2, 3], [2,3,2,3]) # Shape [2, 3, 2, 3, 4]
Assume mat.shape = matrix_batch_shape + [k, M]. The returned vector is generated in two steps:
mat.shape = matrix_batch_shape + [k, M]
[M] + matrix_batch_shape + [k]
Reshape the leading dimensions, giving final shape = batch_shape + [k].
batch_shape + [k]
The reshape in step 2 will fail if the number of elements is not equal, i.e. M*prod(matrix_batch_shape) != prod(batch_shape).
M*prod(matrix_batch_shape) != prod(batch_shape)
See also: flip_vector_to_matrix.
Args: mat: Tensor with rank >= 2. batch_shape: int32 Tensor giving leading "batch" shape of result. static_batch_shape: TensorShape object giving batch shape of result.
Tensor
>= 2
int32
TensorShape
Returns: Tensor with same elements as mat but with shape batch_shape + [k]. """ mat = ops.convert_to_tensor(mat, name="mat") if (static_batch_shape.is_fully_defined() and mat.get_shape().is_fully_defined()): return _flip_matrix_to_vector_static(mat, static_batch_shape) else: return _flip_matrix_to_vector_dynamic(mat, batch_shape) ```
mat