我们从Python开源项目中,提取了以下30个代码示例,用于说明如何使用keras.backend.switch()。
def _drop_path(self, inputs): count = len(inputs) drops = K.switch( self.is_global, self._gen_global_path(count), self._gen_local_drops(count, self.p) ) ave = K.zeros(shape=self.average_shape) for i in range(0, count): ave += inputs[i] * drops[i] sum = K.sum(drops) # Check that the sum is not 0 (global droppath can make it # 0) to avoid divByZero ave = K.switch( K.not_equal(sum, 0.), ave/sum, ave) return ave
def clip_norm(g, c, n): if c > 0: if K.backend() == 'tensorflow': import tensorflow as tf import copy condition = n >= c then_expression = tf.scalar_mul(c / n, g) else_expression = g if hasattr(then_expression, 'get_shape'): g_shape = copy.copy(then_expression.get_shape()) elif hasattr(then_expression, 'dense_shape'): g_shape = copy.copy(then_expression.dense_shape) if condition.dtype != tf.bool: condition = tf.cast(condition, 'bool') g = K.tensorflow_backend.control_flow_ops.cond( condition, lambda: then_expression, lambda: else_expression) if hasattr(then_expression, 'get_shape'): g.set_shape(g_shape) elif hasattr(then_expression, 'dense_shape'): g._dense_shape = g_shape else: g = K.switch(n >= c, g * c / n, g) return g
def fscore(y_true, y_pred, average='samples', beta=2): sum_axis = 1 if average == 'samples' else 0 # calculate weighted counts true_and_pred = K.round(K.clip(y_true * y_pred, 0, 1)) tp_sum = K.sum(true_and_pred, axis=sum_axis) pred_sum = K.sum(y_pred, axis=sum_axis) true_sum = K.sum(y_true, axis=sum_axis) beta2 = beta ** 2 precision = tp_sum / (pred_sum + K.epsilon()) recall = tp_sum / (true_sum + K.epsilon()) f_score = ((1 + beta2) * precision * recall / (beta2 * precision + recall + K.epsilon())) # f_score[tp_sum == 0] = 0.0 # f_score = K.switch(K.equal(f_score, 0.0), 0.0, f_score) return K.mean(f_score)
def f1_score_keras(y_true, y_pred): #convert probas to 0,1 y_ppred = K.zeros_like(y_true) y_pred_ones = K.T.set_subtensor(y_ppred[K.T.arange(y_true.shape[0]), K.argmax(y_pred, axis=-1)], 1) #where y_ture=1 and y_pred=1 -> true positive y_true_pred = K.sum(y_true*y_pred_ones, axis=0) #for each class: how many where classified as said class pred_cnt = K.sum(y_pred_ones, axis=0) #for each class: how many are true members of said class gold_cnt = K.sum(y_true, axis=0) #precision for each class precision = K.T.switch(K.T.eq(pred_cnt, 0), 0, y_true_pred/pred_cnt) #recall for each class recall = K.T.switch(K.T.eq(gold_cnt, 0), 0, y_true_pred/gold_cnt) #f1 for each class f1_class = K.T.switch(K.T.eq(precision + recall, 0), 0, 2*(precision*recall)/(precision+recall)) #return average f1 score over all classes return K.mean(f1_class)
def f1_score_taskB(y_true, y_pred): #convert probas to 0,1 y_pred_ones = K.zeros_like(y_true) y_pred_ones[:, K.argmax(y_pred, axis=-1)] = 1 #where y_ture=1 and y_pred=1 -> true positive y_true_pred = K.sum(y_true*y_pred_ones, axis=0) #for each class: how many where classified as said class pred_cnt = K.sum(y_pred_ones, axis=0) #for each class: how many are true members of said class gold_cnt = K.sum(y_true, axis=0) #precision for each class precision = K.switch(K.equal(pred_cnt, 0), 0, y_true_pred/pred_cnt) #recall for each class recall = K.switch(K.equal(gold_cnt, 0), 0, y_true_pred/gold_cnt) #f1 for each class f1_class = K.switch(K.equal(precision + recall, 0), 0, 2*(precision*recall)/(precision+recall)) #return average f1 score over all classes return f1_class
def f1_score_semeval(y_true, y_pred): # convert probas to 0,1 y_ppred = K.zeros_like(y_true) y_pred_ones = K.T.set_subtensor(y_ppred[K.T.arange(y_true.shape[0]), K.argmax(y_pred, axis=-1)], 1) # where y_ture=1 and y_pred=1 -> true positive y_true_pred = K.sum(y_true * y_pred_ones, axis=0) # for each class: how many where classified as said class pred_cnt = K.sum(y_pred_ones, axis=0) # for each class: how many are true members of said class gold_cnt = K.sum(y_true, axis=0) # precision for each class precision = K.T.switch(K.T.eq(pred_cnt, 0), 0, y_true_pred / pred_cnt) # recall for each class recall = K.T.switch(K.T.eq(gold_cnt, 0), 0, y_true_pred / gold_cnt) # f1 for each class f1_class = K.T.switch(K.T.eq(precision + recall, 0), 0, 2 * (precision * recall) / (precision + recall)) #return average f1 score over all classes return (f1_class[0] + f1_class[2])/2.0
def precision_keras(y_true, y_pred): #convert probas to 0,1 y_pred_ones = K.zeros_like(y_true) y_pred_ones[:, K.argmax(y_pred, axis=-1)] = 1 #where y_ture=1 and y_pred=1 -> true positive y_true_pred = K.sum(y_true*y_pred_ones, axis=0) #for each class: how many where classified as said class pred_cnt = K.sum(y_pred_ones, axis=0) #precision for each class precision = K.switch(K.equal(pred_cnt, 0), 0, y_true_pred/pred_cnt) #return average f1 score over all classes return K.mean(precision)
def f1_score_taskB(y_true, y_pred): # convert probas to 0,1 y_pred_ones = K.zeros_like(y_true) y_pred_ones[:, K.argmax(y_pred, axis=-1)] = 1 # where y_ture=1 and y_pred=1 -> true positive y_true_pred = K.sum(y_true * y_pred_ones, axis=0) # for each class: how many where classified as said class pred_cnt = K.sum(y_pred_ones, axis=0) # for each class: how many are true members of said class gold_cnt = K.sum(y_true, axis=0) # precision for each class precision = K.switch(K.equal(pred_cnt, 0), 0, y_true_pred / pred_cnt) # recall for each class recall = K.switch(K.equal(gold_cnt, 0), 0, y_true_pred / gold_cnt) # f1 for each class f1_class = K.switch(K.equal(precision + recall, 0), 0, 2 * (precision * recall) / (precision + recall)) # return average f1 score over all classes return f1_class
def precision_keras(y_true, y_pred): # convert probas to 0,1 y_pred_ones = K.zeros_like(y_true) y_pred_ones[:, K.argmax(y_pred, axis=-1)] = 1 # where y_ture=1 and y_pred=1 -> true positive y_true_pred = K.sum(y_true * y_pred_ones, axis=0) # for each class: how many where classified as said class pred_cnt = K.sum(y_pred_ones, axis=0) # precision for each class precision = K.switch(K.equal(pred_cnt, 0), 0, y_true_pred / pred_cnt) # return average f1 score over all classes return K.mean(precision)
def call(self, x, mask=None): def get_node_w(node): return self.W[self.node_indices[node], :, :] def get_node_b(node): return self.b[self.node_indices[node], :] def compute_output(input, node=self.root_node): if not hasattr(node, 'left'): return zeros((K.shape(input)[0],)) + self.node_indices[node] else: node_output = K.dot(x, get_node_w(node)) if self.bias: node_output += get_node_b(node) left_prob = node_output[:, 0] right_prob = 1 - node_output[:, 0] left_node_output = compute_output(input, node.left) right_node_output = compute_output(input, node.right) return K.switch(left_prob > right_prob, left_node_output, right_node_output) return K.cast(compute_output(x), 'int32')
def call(self, inputs, mask=None): t = inputs timegate = K.abs(self.timegate) period = timegate[0] shift = timegate[1] r_on = timegate[2] phi = ((t - shift) % period) / period # K.switch not consistent between Theano and Tensorflow backend, # so write explicitly. # TODO check if still the case up = K.cast(K.less(phi, r_on * 0.5), K.floatx()) * 2 * phi / r_on mid = ( K.cast(K.less(phi, r_on), K.floatx()) * K.cast(K.greater(phi, r_on * 0.5), K.floatx()) * (2 - (2 * phi / r_on)) ) end = K.cast(K.greater(phi, r_on * 0.5), K.floatx()) * self.alpha * phi k = up + mid + end return k
def time_distributed_nonzero_max_pooling(x): """ Computes maximum along the first (time) dimension. It ignores the mask m. In: x - input; a 3D tensor mask_value - value to mask out, if None then no masking; by default 0.0, """ import theano.tensor as T mask_value=0.0 x = T.switch(T.eq(x, mask_value), -numpy.inf, x) masked_max_x = x.max(axis=1) # replace infinities with mask_value masked_max_x = T.switch(T.eq(masked_max_x, -numpy.inf), 0, masked_max_x) return masked_max_x
def time_distributed_masked_max(x, m): """ Computes max along the first (time) dimension. In: x - input; a 3D tensor m - mask m_value - value for masking """ # place infinities where mask is off m_value = 0.0 tmp = K.switch(K.equal(m, 0.0), -numpy.inf, 0.0) x_with_inf = x + K.expand_dims(tmp) x_max = K.max(x_with_inf, axis=1) r = K.switch(K.equal(x_max, -numpy.inf), m_value, x_max) return r ## classes ## # Transforms existing layers to masked layers
def get_split_averages(input_tensor, input_mask, indices): # Splits input tensor into three parts based on the indices and # returns average of values prior to index, values at the index and # average of values after the index. # input_tensor: (batch_size, input_length, input_dim) # input_mask: (batch_size, input_length) # indices: (batch_size, 1) # (1, input_length) length_range = K.expand_dims(K.arange(K.shape(input_tensor)[1]), dim=0) # (batch_size, input_length) batched_range = K.repeat_elements(length_range, K.shape(input_tensor)[0], 0) tiled_indices = K.repeat_elements(indices, K.shape(input_tensor)[1], 1) # (batch_size, input_length) greater_mask = K.greater(batched_range, tiled_indices) # (batch_size, input_length) lesser_mask = K.lesser(batched_range, tiled_indices) # (batch_size, input_length) equal_mask = K.equal(batched_range, tiled_indices) # (batch_size, input_length) # We also need to mask these masks using the input mask. # (batch_size, input_length) if input_mask is not None: greater_mask = switch(input_mask, greater_mask, K.zeros_like(greater_mask)) lesser_mask = switch(input_mask, lesser_mask, K.zeros_like(lesser_mask)) post_sum = K.sum(switch(K.expand_dims(greater_mask), input_tensor, K.zeros_like(input_tensor)), axis=1) # (batch_size, input_dim) pre_sum = K.sum(switch(K.expand_dims(lesser_mask), input_tensor, K.zeros_like(input_tensor)), axis=1) # (batch_size, input_dim) values_at_indices = K.sum(switch(K.expand_dims(equal_mask), input_tensor, K.zeros_like(input_tensor)), axis=1) # (batch_size, input_dim) post_normalizer = K.expand_dims(K.sum(greater_mask, axis=1) + K.epsilon(), dim=1) # (batch_size, 1) pre_normalizer = K.expand_dims(K.sum(lesser_mask, axis=1) + K.epsilon(), dim=1) # (batch_size, 1) return K.cast(pre_sum / pre_normalizer, 'float32'), values_at_indices, K.cast(post_sum / post_normalizer, 'float32')
def _gen_local_drops(self, count, p): # Create a local droppath with at least one path arr = self._random_arr(count, p) drops = K.switch( K.any(arr), arr, self._arr_with_one(count) ) return drops
def masked_mse(y_true, y_pred): mask = T.isnan(y_true) diff = y_pred - y_true squared = K.square(diff) sum_squared_error = K.sum( K.switch(mask, 0.0, squared), axis=-1) n_valid_per_sample = K.sum(~mask, axis=-1) return sum_squared_error / n_valid_per_sample
def masked_binary_crossentropy(y_true, y_pred): mask = T.isnan(y_true) cross_entropy_values = K.binary_crossentropy( output=y_pred, target=y_true) sum_cross_entropy_values = K.sum( K.switch(mask, 0.0, cross_entropy_values), axis=-1) n_valid_per_sample = K.sum(~mask, axis=-1) return sum_cross_entropy_values / n_valid_per_sample
def mse_no_NaN(y_true, y_pred): '''For each sample, sum squared error ignoring NaN values''' return K.sum(K.square(K.switch(K.logical_not(K.is_nan(y_true)), y_true, y_pred) - y_pred), axis = -1)
def binary_crossnetropy_no_NaN(y_true, y_pred): return K.sum(K.binary_crossentropy(K.switch(K.is_nan(y_true), y_pred, y_true), y_pred), axis = -1)
def step(self, x, states): prev_output = states[0] time_step = states[1] B_U = states[2] B_W = states[3] period = states[4] if self.consume_less == 'cpu': h = x else: h = K.dot(x * B_W, self.W) + self.b output = self.activation(h + K.dot(prev_output * B_U, self.U)) output = K.switch(K.equal(time_step % period, 0.), output, prev_output) return output, [output, time_step+1]
def get_gradients(self, loss, params): grads = K.gradients(loss, params) if self.scale != 1.: grads = [g*K.variable(self.scale) for g in grads] if hasattr(self, 'clipnorm') and self.clipnorm > 0: norm = K.sqrt(sum([K.sum(K.square(g)) for g in grads])) grads = [K.switch(norm >= self.clipnorm, g * self.clipnorm / norm, g) for g in grads] if hasattr(self, 'clipvalue') and self.clipvalue > 0: grads = [K.clip(g, -self.clipvalue, self.clipvalue) for g in grads] return grads
def average_precision(y_true, y_pred): y_pred_sort_idx = K.T.argsort(y_pred, axis=0)[::-1] y_true_sorted = y_true[y_pred_sort_idx] true_cumsum = K.T.cumsum(y_true_sorted) true_range = K.T.arange(1, y_true.shape[0] + 1) true_sum = K.sum(y_true) #if no prediction is relevant just return 0 return K.T.switch(K.T.eq(true_sum, 0), 0, K.sum((true_cumsum / true_range) * y_true_sorted) * (1 / true_sum))
def f1_score_task3(y_true, y_pred): # convert probas to 0,1 y_pred_ones = K.zeros_like(y_true) # y_pred_ones = K.T.set_subtensor(y_ppred[K.T.arange(y_true.shape[0]), K.argmax(y_pred, axis=-1)], 1) indices_x = K.arange(y_true.shape[0]) indices_y = K.argmax(y_pred, axis=-1) indices = K.concatenate(indices_x, indices_y) values = K.ones_like(indices_x) shape = y_pred_ones.shape delta = tf.SparseTensor(indices, values, shape) y_pred_ones[:, K.argmax(y_pred, axis=-1)] = 1 # where y_ture=1 and y_pred=1 -> true positive y_true_pred = K.sum(y_true * y_pred_ones, axis=0) # for each class: how many where classified as said class pred_cnt = K.sum(y_pred_ones, axis=0) # for each class: how many are true members of said class gold_cnt = K.sum(y_true, axis=0) # precision for each class precision = K.switch(K.equal(pred_cnt, 0), 0, y_true_pred / pred_cnt) # recall for each class recall = K.switch(K.equal(gold_cnt, 0), 0, y_true_pred / gold_cnt) # f1 for each class f1_class = K.switch(K.equal(precision + recall, 0), 0, 2 * (precision * recall) / (precision + recall)) # return average f1 score over all classes return f1_class[1]
def call(self, x, mask=None): # x: (batch_size, input_length, input_dim) where input_length = head_size + 2 head_encoding = x[:, :-2, :] # (batch_size, head_size, input_dim) prep_encoding = x[:, -2, :] # (batch_size, input_dim) child_encoding = x[:, -1, :] # (batch_size, input_dim) if self.composition_type == 'HPCD': # TODO: The following line may not work with TF. # (batch_size, head_size, input_dim, 1) * (1, head_size, input_dim, proj_dim) head_proj_prod = K.expand_dims(head_encoding) * K.expand_dims(self.dist_proj_head, dim=0) head_projection = K.sum(head_proj_prod, axis=2) # (batch_size, head_size, proj_dim) else: head_projection = K.dot(head_encoding, self.proj_head) # (batch_size, head_size, proj_dim) prep_projection = K.expand_dims(K.dot(prep_encoding, self.proj_prep), dim=1) # (batch_size, 1, proj_dim) child_projection = K.expand_dims(K.dot(child_encoding, self.proj_child), dim=1) # (batch_size, 1, proj_dim) #(batch_size, head_size, proj_dim) if self.composition_type == 'HPCT': composed_projection = K.tanh(head_projection + prep_projection + child_projection) elif self.composition_type == 'HPC' or self.composition_type == "HPCD": prep_child_projection = K.tanh(prep_projection + child_projection) # (batch_size, 1, proj_dim) composed_projection = K.tanh(head_projection + prep_child_projection) else: # Composition type in HC composed_projection = K.tanh(head_projection + child_projection) for hidden_layer in self.hidden_layers: composed_projection = K.tanh(K.dot(composed_projection, hidden_layer)) # (batch_size, head_size, proj_dim) # (batch_size, head_size) head_word_scores = K.squeeze(K.dot(composed_projection, self.scorer), axis=-1) if mask is None: attachment_probabilities = K.softmax(head_word_scores) # (batch_size, head_size) else: if K.ndim(mask) > 2: # This means this layer came after a Bidirectional layer. Keras has this bug which # concatenates input masks instead of output masks. # TODO: Fix Bidirectional instead. mask = K.any(mask, axis=(-2, -1)) # We need to do a masked softmax. exp_scores = K.exp(head_word_scores) # (batch_size, head_size) head_mask = mask[:, :-2] # (batch_size, head_size) # (batch_size, head_size) masked_exp_scores = switch(head_mask, exp_scores, K.zeros_like(head_encoding[:, :, 0])) # (batch_size, 1). Adding epsilon to avoid divison by 0. But epsilon is float64. exp_sum = K.cast(K.expand_dims(K.sum(masked_exp_scores, axis=1) + K.epsilon()), 'float32') attachment_probabilities = masked_exp_scores / exp_sum # (batch_size, head_size) return attachment_probabilities
def call(self, x, mask=None): # x[0]: (batch_size, input_length, input_dim) # x[1]: (batch_size, 1) indices of prepositions # Optional: x[2]: (batch_size, input_length - 2) assert isinstance(x, list) or isinstance(x, tuple) encoded_sentence = x[0] prep_indices = K.squeeze(x[1], axis=-1) #(batch_size,) batch_indices = K.arange(K.shape(encoded_sentence)[0]) # (batch_size,) if self.with_attachment_probs: # We're essentially doing K.argmax(x[2]) here, but argmax is not differentiable! head_probs = x[2] head_probs_padding = K.zeros_like(x[2])[:, :2] # (batch_size, 2) # (batch_size, input_length) padded_head_probs = K.concatenate([head_probs, head_probs_padding]) # (batch_size, 1) max_head_probs = K.expand_dims(K.max(padded_head_probs, axis=1)) # (batch_size, input_length, 1) max_head_prob_indices = K.expand_dims(K.equal(padded_head_probs, max_head_probs)) # (batch_size, input_length, input_dim) masked_head_encoding = K.switch(max_head_prob_indices, encoded_sentence, K.zeros_like(encoded_sentence)) # (batch_size, input_dim) head_encoding = K.sum(masked_head_encoding, axis=1) else: head_indices = prep_indices - 1 # (batch_size,) head_encoding = encoded_sentence[batch_indices, head_indices, :] # (batch_size, input_dim) prep_encoding = encoded_sentence[batch_indices, prep_indices, :] # (batch_size, input_dim) child_encoding = encoded_sentence[batch_indices, prep_indices+1, :] # (batch_size, input_dim) ''' prep_indices = x[1] sentence_mask = mask[0] if sentence_mask is not None: if K.ndim(sentence_mask) > 2: # This means this layer came after a Bidirectional layer. Keras has this bug which # concatenates input masks instead of output masks. # TODO: Fix Bidirectional instead. sentence_mask = K.any(sentence_mask, axis=(-2, -1)) head_encoding, prep_encoding, child_encoding = self.get_split_averages(encoded_sentence, sentence_mask, prep_indices) ''' head_projection = K.dot(head_encoding, self.proj_head) # (batch_size, proj_dim) prep_projection = K.dot(prep_encoding, self.proj_prep) # (batch_size, proj_dim) child_projection = K.dot(child_encoding, self.proj_child) # (batch_size, proj_dim) #(batch_size, proj_dim) if self.composition_type == 'HPCT': composed_projection = K.tanh(head_projection + prep_projection + child_projection) elif self.composition_type == 'HPC': prep_child_projection = K.tanh(prep_projection + child_projection) # (batch_size, proj_dim) composed_projection = K.tanh(head_projection + prep_child_projection) else: # Composition type in HC composed_projection = K.tanh(head_projection + child_projection) for hidden_layer in self.hidden_layers: composed_projection = K.tanh(K.dot(composed_projection, hidden_layer)) # (batch_size, proj_dim) # (batch_size, num_classes) class_scores = K.dot(composed_projection, self.scorer) label_probabilities = K.softmax(class_scores) return label_probabilities
def get_updates(self, params, constraints, loss): grads = self.get_gradients(loss, params) self.updates = [K.update_add(self.iterations, 1)] lr = self.lr if self.inital_decay > 0: lr *= (1. / (1. + self.decay * self.iterations)) t = self.iterations + 1 lr_t = lr * K.sqrt(1. - K.pow(self.beta_2, t)) / (1. - K.pow(self.beta_1, t)) shapes = [K.get_variable_shape(p) for p in params] ms = [K.zeros(shape) for shape in shapes] vs = [K.zeros(shape) for shape in shapes] f = K.variable(0) d = K.variable(1) self.weights = [self.iterations] + ms + vs + [f, d] cond = K.greater(t, K.variable(1)) small_delta_t = K.switch(K.greater(loss, f), self.small_k + 1, 1. / (self.big_K + 1)) big_delta_t = K.switch(K.greater(loss, f), self.big_K + 1, 1. / (self.small_k + 1)) c_t = K.minimum(K.maximum(small_delta_t, loss / (f + self.epsilon)), big_delta_t) f_t = c_t * f r_t = K.abs(f_t - f) / (K.minimum(f_t, f)) d_t = self.beta_3 * d + (1 - self.beta_3) * r_t f_t = K.switch(cond, f_t, loss) d_t = K.switch(cond, d_t, K.variable(1.)) self.updates.append(K.update(f, f_t)) self.updates.append(K.update(d, d_t)) for p, g, m, v in zip(params, grads, ms, vs): m_t = (self.beta_1 * m) + (1. - self.beta_1) * g v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square(g) p_t = p - lr_t * m_t / (d_t * K.sqrt(v_t) + self.epsilon) self.updates.append(K.update(m, m_t)) self.updates.append(K.update(v, v_t)) new_p = p_t # apply constraints if p in constraints: c = constraints[p] new_p = c(new_p) self.updates.append(K.update(p, new_p)) return self.updates
def accumulate(attend_function, inputs, input_length, mask=None, return_probabilities=False): '''get the running attention over a sequence. given a 3dim tensor where the 1st dim is time (or not. whatever.), calculating the running attended sum. in other words, at the first time step, you only have that item. at the second time step, attend over the first two items. at the third.. the third. so on. this basically a mod on keras' rnn implementation author: bcm ''' ndim = inputs.ndim assert ndim >= 3, 'inputs should be at least 3d' axes = [1,0] + list(range(2, ndim)) inputs = inputs.dimshuffle(axes) indices = list(range(input_length)) successive_outputs = [] if mask is not None: if mask.ndim == ndim-1: mask = K.expand_dims(mask) assert mask.ndim == ndim mask = mask.dimshuffle(axes) prev_output = None successive_outputs = [] successive_pvecs = [] uncover_mask = K.zeros_like(inputs) uncover_indices = K.arange(input_length) for _ in range(ndim-1): uncover_indices = K.expand_dims(uncover_indices) make_subset = lambda i,X: K.switch(uncover_indices <= i, X, uncover_mask) for i in indices: inputs_i = make_subset(i,inputs) mask_i = make_subset(i,mask) if mask is not None: output = attend_function(inputs_i, mask_i) # this should not output the time dimension; it should be marginalized over. else: output = attend_function(inputs_i) # this should not output the time dimension; it should be marginalized over. if return_probabilities: output, p_vectors = output successive_pvecs.append(p_vectors) assert output.ndim == 2, "Your attention function is malfunctioning; the attention accumulator should return 2 dimensional tensors" successive_outputs.append(output) outputs = K.pack(successive_outputs) K.squeeze(outputs, -1) axes = [1, 0] + list(range(2, outputs.ndim)) outputs = outputs.dimshuffle(axes) if return_probabilities: out_pvecs = K.pack(successive_pvecs) K.squeeze(out_pvecs, -1) out_pvecs = out_pvecs.dimshuffle(axes) outputs = [outputs, out_pvecs] return outputs