我们从Python开源项目中,提取了以下50个代码示例,用于说明如何使用keras.backend.ndim()。
def __init__(self, output_dim, num_senses, num_hyps, use_attention=False, return_attention=False, **kwargs): # Set output_dim in kwargs so that we can pass it along to LSTM's init kwargs['output_dim'] = output_dim self.num_senses = num_senses self.num_hyps = num_hyps self.use_attention = use_attention self.return_attention = return_attention super(OntoAttentionLSTM, self).__init__(**kwargs) # Recurrent would have set the input shape to cause the input dim to be 3. Change it. self.input_spec = [InputSpec(ndim=5)] if self.consume_less == "cpu": # In the LSTM implementation in Keras, consume_less = cpu causes all gates' inputs to be precomputed # and stored in memory. However, this doesn't work with OntoLSTM since the input to the gates is # dependent on the previous timestep's output. warnings.warn("OntoLSTM does not support consume_less = cpu. Changing it to mem.") self.consume_less = "mem" #TODO: Remove this dependency. if K.backend() == "tensorflow" and not self.unroll: warnings.warn("OntoLSTM does not work with unroll=False when backend is TF. Changing it to True.") self.unroll = True
def get_initial_states(self, onto_nse_input, input_mask=None): input_to_read = onto_nse_input # (batch_size, num_words, num_senses, num_hyps, output_dim + 1) memory_input = input_to_read[:, :, :, :, :-1] # (bs, words, senses, hyps, output_dim) if input_mask is None: mem_0 = K.mean(memory_input, axis=(2, 3)) # (batch_size, num_words, output_dim) else: memory_mask = input_mask if K.ndim(onto_nse_input) != K.ndim(input_mask): memory_mask = K.expand_dims(input_mask) memory_mask = K.cast(memory_mask / (K.sum(memory_mask) + K.epsilon()), 'float32') mem_0 = K.sum(memory_input * memory_mask, axis=(2,3)) # (batch_size, num_words, output_dim) flattened_mem_0 = K.batch_flatten(mem_0) initial_states = self.reader.get_initial_states(input_to_read) initial_states += [flattened_mem_0] return initial_states
def call(self, x, mask=None): mean = super(IntraAttention, self).call(x, mask) # x: (batch_size, input_length, input_dim) # mean: (batch_size, input_dim) ones = K.expand_dims(K.mean(K.ones_like(x), axis=(0, 2)), dim=0) # (1, input_length) # (batch_size, input_length, input_dim) tiled_mean = K.permute_dimensions(K.dot(K.expand_dims(mean), ones), (0, 2, 1)) if mask is not None: if K.ndim(mask) > K.ndim(x): # Assuming this is because of the bug in Bidirectional. Temporary fix follows. # TODO: Fix Bidirectional. mask = K.any(mask, axis=(-2, -1)) if K.ndim(mask) < K.ndim(x): mask = K.expand_dims(mask) x = switch(mask, x, K.zeros_like(x)) # (batch_size, input_length, proj_dim) projected_combination = K.tanh(K.dot(x, self.vector_projector) + K.dot(tiled_mean, self.mean_projector)) scores = K.dot(projected_combination, self.scorer) # (batch_size, input_length) weights = K.softmax(scores) # (batch_size, input_length) attended_x = K.sum(K.expand_dims(weights) * x, axis=1) # (batch_size, input_dim) return attended_x
def __init__(self, output_dim, init='glorot_uniform', activation='relu',weights=None, W_regularizer=None, b_regularizer=None, activity_regularizer=None, W_constraint=None, b_constraint=None, input_dim=None, **kwargs): self.W_initializer = initializers.get(init) self.b_initializer = initializers.get('zeros') self.activation = activations.get(activation) self.output_dim = output_dim self.input_dim = input_dim self.W_regularizer = regularizers.get(W_regularizer) self.b_regularizer = regularizers.get(b_regularizer) self.activity_regularizer = regularizers.get(activity_regularizer) self.W_constraint = constraints.get(W_constraint) self.b_constraint = constraints.get(b_constraint) self.initial_weights = weights self.input_spec = InputSpec(ndim=2) if self.input_dim: kwargs['input_shape'] = (self.input_dim,) super(SparseFullyConnectedLayer, self).__init__(**kwargs)
def build(self, input_shape): assert len(input_shape) == 2 input_dim = input_shape[1] #self.input_spec = InputSpec(dtype=K.floatx(), shape=(None, input_dim)) self.input_spec = InputSpec(ndim=2, axes={1: input_dim}) self.W = self.add_weight( shape=(input_dim, self.output_dim), initializer=self.W_initializer, name='SparseFullyConnected_W', regularizer=self.W_regularizer, constraint=self.W_constraint) self.b = self.add_weight( shape=(self.output_dim,), initializer=self.b_initializer, name='SparseFullyConnected_b', regularizer=self.b_regularizer, constraint=self.b_constraint) if self.initial_weights is not None: self.set_weights(self.initial_weights) del self.initial_weights #self.built = True #super(SparseFullyConnectedLayer, self).build(input_shape)
def _softmax(x, dim): """Computes softmax along a specified dim. Keras currently lacks this feature. """ if K.backend() == 'tensorflow': import tensorflow as tf return tf.nn.softmax(x, dim) elif K.backend() is 'cntk': import cntk return cntk.softmax(x, dim) elif K.backend() == 'theano': # Theano cannot softmax along an arbitrary dim. # So, we will shuffle `dim` to -1 and un-shuffle after softmax. perm = np.arange(K.ndim(x)) perm[dim], perm[-1] = perm[-1], perm[dim] x_perm = K.permute_dimensions(x, perm) output = K.softmax(x_perm) # Permute back perm[dim], perm[-1] = perm[-1], perm[dim] output = K.permute_dimensions(x, output) return output else: raise ValueError("Backend '{}' not supported".format(K.backend()))
def style_loss(style_image, target_image, style_masks, target_masks): '''Calculate style loss between style_image and target_image, in all regions. ''' assert 3 == K.ndim(style_image) == K.ndim(target_image) assert 3 == K.ndim(style_masks) == K.ndim(target_masks) loss = K.variable(0) for i in xrange(nb_labels): if K.image_dim_ordering() == 'th': style_mask = style_masks[i, :, :] target_mask = target_masks[i, :, :] else: style_mask = style_masks[:, :, i] target_mask = target_masks[:, :, i] loss += region_style_loss(style_image, target_image, style_mask, target_mask) return loss
def total_variation_loss(x): assert 4 == K.ndim(x) if K.image_dim_ordering() == 'th': a = K.square(x[:, :, :img_nrows - 1, :img_ncols - 1] - x[:, :, 1:, :img_ncols - 1]) b = K.square(x[:, :, :img_nrows - 1, :img_ncols - 1] - x[:, :, :img_nrows - 1, 1:]) else: a = K.square(x[:, :img_nrows - 1, :img_ncols - 1, :] - x[:, 1:, :img_ncols - 1, :]) b = K.square(x[:, :img_nrows - 1, :img_ncols - 1, :] - x[:, :img_nrows - 1, 1:, :]) return K.sum(K.pow(a + b, 1.25)) # Overall loss is the weighted sum of content_loss, style_loss and tv_loss # Each individual loss uses features from image/mask models.
def style_loss(style_image, target_image, style_masks, target_masks): '''Calculate style loss between style_image and target_image, in all regions. ''' assert 3 == K.ndim(style_image) == K.ndim(target_image) assert 3 == K.ndim(style_masks) == K.ndim(target_masks) loss = K.variable(0) for i in xrange(num_labels): if K.image_data_format() == 'channels_first': style_mask = style_masks[i, :, :] target_mask = target_masks[i, :, :] else: style_mask = style_masks[:, :, i] target_mask = target_masks[:, :, i] loss += region_style_loss(style_image, target_image, style_mask, target_mask) return loss
def total_variation_loss(x): assert 4 == K.ndim(x) if K.image_data_format() == 'channels_first': a = K.square(x[:, :, :img_nrows - 1, :img_ncols - 1] - x[:, :, 1:, :img_ncols - 1]) b = K.square(x[:, :, :img_nrows - 1, :img_ncols - 1] - x[:, :, :img_nrows - 1, 1:]) else: a = K.square(x[:, :img_nrows - 1, :img_ncols - 1, :] - x[:, 1:, :img_ncols - 1, :]) b = K.square(x[:, :img_nrows - 1, :img_ncols - 1, :] - x[:, :img_nrows - 1, 1:, :]) return K.sum(K.pow(a + b, 1.25)) # Overall loss is the weighted sum of content_loss, style_loss and tv_loss # Each individual loss uses features from image/mask models.
def compute_output_shape(self, input_shape): if self.filters > 1: ndim = len(input_shape) output_shape = [input_shape[0]] + [1] * (ndim-1) for i in set(range(1, ndim)) - set(self.sum_axes): output_shape[i] = input_shape[i] output_shape.append(self.filters) permute_dims = list(range(ndim + 1)) permute_dims[self.sum_axes[0]] = ndim permute_dims[ndim] = self.sum_axes[0] output_shape = [output_shape[i] for i in permute_dims] output_shape.pop(ndim) if len(self.sum_axes) > 1: output_shape.pop(self.sum_axes[1]) else: output_shape = input_shape output_shape = [output_shape[i] for i in set(range(len(input_shape))) - set(self.sum_axes)] if len(output_shape) == 1: output_shape.append(1) return tuple(output_shape)
def compute_output_shape(self, input_shape): ndim = len(input_shape) output_shape = [input_shape[0]] + [1] * (ndim-1) for i in set(range(1, ndim)) - set(self.sum_axes): output_shape[i] = input_shape[i] output_shape.append(self.filters_complex + self.filters_simple) permute_dims = list(range(ndim + 1)) permute_dims[self.sum_axes[0]] = ndim permute_dims[ndim] = self.sum_axes[0] output_shape = [output_shape[i] for i in permute_dims] output_shape.pop(ndim) if len(self.sum_axes) > 1: output_shape.pop(self.sum_axes[1]) return tuple(output_shape)
def build(self, input_shape): assert len(input_shape) >= 2 input_dim = input_shape[-1] self.input_dim = input_dim self.input_spec = [InputSpec(dtype=K.floatx(), ndim='2+')] self.W = self.add_weight((input_dim, self.output_dim), initializer=self.init, name='{}_W'.format(self.name), regularizer=self.W_regularizer, constraint=self.W_constraint) if self.bias: self.b = self.add_weight((self.output_dim,), initializer='zero', name='{}_b'.format(self.name), regularizer=self.b_regularizer, constraint=self.b_constraint) else: self.b = None if self.initial_weights is not None: self.set_weights(self.initial_weights) del self.initial_weights self.built = True
def __init__(self, init='glorot_uniform', activation=None, weights=None, W_regularizer=None, b_regularizer=None, activity_regularizer=None, W_constraint=None, b_constraint=None, bias=True, input_dim=None, **kwargs): self.init = initializations.get(init) self.activation = activations.get(activation) self.input_dim = input_dim self.W_regularizer = regularizers.get(W_regularizer) self.b_regularizer = regularizers.get(b_regularizer) self.activity_regularizer = regularizers.get(activity_regularizer) self.W_constraint = constraints.get(W_constraint) self.b_constraint = constraints.get(b_constraint) self.bias = bias self.initial_weights = weights self.input_spec = [InputSpec(ndim='2+')] if self.input_dim: kwargs['input_shape'] = (self.input_dim,) super(Feedback, self).__init__(**kwargs)
def build(self, input_shape): assert len(input_shape) >= 2 input_dim = input_shape[-1] self.input_dim = input_dim self.input_spec = [InputSpec(dtype=K.floatx(), ndim='2+')] self.W = self.add_weight((input_dim, input_dim), initializer=self.init, name='{}_W'.format(self.name), regularizer=self.W_regularizer, constraint=self.W_constraint) if self.bias: self.b = self.add_weight((input_dim,), initializer='zero', name='{}_b'.format(self.name), regularizer=self.b_regularizer, constraint=self.b_constraint) else: self.b = None if self.initial_weights is not None: self.set_weights(self.initial_weights) del self.initial_weights self.built = True
def __init__(self, init='glorot_uniform', activation=None, weights=None, W_regularizer=None, b_regularizer=None, activity_regularizer=None, W_constraint=None, b_constraint=None, bias=True, input_dim=None, **kwargs): self.init = initializations.get(init) self.activation = activations.get(activation) self.input_dim = input_dim self.W_regularizer = regularizers.get(W_regularizer) self.b_regularizer = regularizers.get(b_regularizer) self.activity_regularizer = regularizers.get(activity_regularizer) self.W_constraint = constraints.get(W_constraint) self.b_constraint = constraints.get(b_constraint) self.bias = bias self.initial_weights = weights self.input_spec = [InputSpec(ndim='2+')] if self.input_dim: kwargs['input_shape'] = (self.input_dim,) super(DivisiveNormalization, self).__init__(**kwargs)
def style_loss(style_image, target_image, style_masks, target_masks): '''Calculate style loss between style_image and target_image, in all regions. ''' assert 3 == K.ndim(style_image) == K.ndim(target_image) assert 3 == K.ndim(style_masks) == K.ndim(target_masks) loss = K.variable(0) for i in range(nb_labels): if K.image_dim_ordering() == 'th': style_mask = style_masks[i, :, :] target_mask = target_masks[i, :, :] else: style_mask = style_masks[:, :, i] target_mask = target_masks[:, :, i] loss += region_style_weight * region_style_loss(style_image, target_image, style_mask, target_mask) return loss
def region_style_loss(style_image, target_image, style_mask, target_mask): '''Calculate style loss between style_image and target_image, for one common region specified by their (boolean) masks ''' assert 3 == K.ndim(style_image) == K.ndim(target_image) assert 2 == K.ndim(style_mask) == K.ndim(target_mask) if K.image_dim_ordering() == 'th': masked_style = style_image * style_mask masked_target = target_image * target_mask nb_channels = K.shape(style_image)[0] else: masked_style = K.permute_dimensions( style_image, (2, 0, 1)) * style_mask masked_target = K.permute_dimensions( target_image, (2, 0, 1)) * target_mask nb_channels = K.shape(style_image)[-1] s = gram_matrix(masked_style) / K.mean(style_mask) / nb_channels c = gram_matrix(masked_target) / K.mean(target_mask) / nb_channels return K.mean(K.square(s - c))
def gram_matrix(x): """ Computes the outer-product of the input tensor x. Input ----- - x: input tensor of shape (C x H x W) Returns ------- - x . x^T Note that this can be computed efficiently if x is reshaped as a tensor of shape (C x H*W). """ # assert K.ndim(x) == 3 if K.image_dim_ordering() == 'th': features = K.batch_flatten(x) else: features = K.batch_flatten(K.permute_dimensions(x, (2, 0, 1))) return K.dot(features, K.transpose(features))
def style_loss(style, combination, mask_path=None, nb_channels=None): assert K.ndim(style) == 3 assert K.ndim(combination) == 3 if mask_path is not None: style_mask = load_mask(mask_path, nb_channels) style = style * style_mask combination = combination * style_mask del style_mask S = gram_matrix(style) C = gram_matrix(combination) channels = 3 size = img_width * img_height return K.sum(K.square(S - C)) / (4. * (channels ** 2) * (size ** 2)) # an auxiliary loss function # designed to maintain the "content" of the # base image in the generated image
def call(self, x, mask=False): input_shape = K.shape(x) cs = K.shape(self.img_in) if self.dim_ordering == 'th': input_shape = input_shape[-2:] cs = cs[-2:] else: input_shape = input_shape[1:3] cs = cs[1:3] dif = (input_shape - cs)/2 if self.dim_ordering == 'th': if K.ndim(x) == 5: return x[:, :, :, dif[0]:dif[0]+cs[0], dif[1]:dif[1]+cs[1]] return x[:, :, dif[0]:dif[0]+cs[0], dif[1]:dif[1]+cs[1]] else: if K.ndim(x) == 5: return x[:, :, dif[0]:dif[0]+cs[0], dif[1]:dif[1]+cs[1], :] return x[:, dif[0]:dif[0]+cs[0], dif[1]:dif[1]+cs[1], :]
def call(self, X, mask=None): if mask is not None: assert K.ndim(mask) == 2, 'Input mask to CRF must have dim 2 if not None' if self.test_mode == 'viterbi': test_output = self.viterbi_decoding(X, mask) else: test_output = self.get_marginal_prob(X, mask) self.uses_learning_phase = True if self.learn_mode == 'join': train_output = K.zeros_like(K.dot(X, self.kernel)) out = K.in_train_phase(train_output, test_output) else: if self.test_mode == 'viterbi': train_output = self.get_marginal_prob(X, mask) out = K.in_train_phase(train_output, test_output) else: out = test_output return out
def build(self, input_shape): self.input_spec = [InputSpec(ndim=3)] if K._BACKEND == 'tensorflow': if not input_shape[1]: raise Exception('When using TensorFlow, you should define ' 'explicitly the number of timesteps of ' 'your sequences.\n' 'If your first layer is an Embedding, ' 'make sure to pass it an "input_length" ' 'argument. Otherwise, make sure ' 'the first layer has ' 'an "input_shape" or "batch_input_shape" ' 'argument, including the time axis.') if not self.layer.built: self.layer.build(input_shape) self.layer.built = True super(ProbabilityTensor, self).build()
def build(self, input_shape): self.input_spec = [InputSpec(shape=input_shape)] input_dim = input_shape[4] - 1 # ignore sense prior parameter self.input_dim = input_dim # Saving onto-lstm weights to set them later. This way, LSTM's build method won't # delete them. initial_ontolstm_weights = self.initial_weights self.initial_weights = None lstm_input_shape = input_shape[:2] + (input_dim,) # removing senses and hyps # Now calling LSTM's build to initialize the LSTM weights super(OntoAttentionLSTM, self).build(lstm_input_shape) # This would have changed the input shape and ndim. Reset it again. self.input_spec = [InputSpec(shape=input_shape)] if self.use_attention: # Following are the attention parameters self.input_hyp_projector = self.inner_init((input_dim, self.output_dim), name='{}_input_hyp_projector'.format(self.name)) # Projection operator for synsets self.context_hyp_projector = self.inner_init((self.output_dim, self.output_dim), name='{}_context_hyp_projector'.format(self.name)) # Projection operator for hidden state (context) self.hyp_projector2 = self.inner_init((self.output_dim, self.output_dim), name='{}_hyp_projector2'.format(self.name)) # Projection operator for hidden state (context) self.hyp_scorer = self.init((self.output_dim,), name='{}_hyp_scorer'.format(self.name)) # LSTM's build method would have initialized trainable_weights. Add to it. self.trainable_weights.extend([self.input_hyp_projector, self.context_hyp_projector, self.hyp_projector2, self.hyp_scorer]) if initial_ontolstm_weights is not None: self.set_weights(initial_ontolstm_weights) del initial_ontolstm_weights
def get_initial_states(self, x): # Reimplementing because ndim of x is 5. (samples, timesteps, num_senses, num_hyps, embedding_dim) sense_hyp_stripped_x = x[:, :, 0, 0, :-1] # (samples, timesteps, input_dim), just like LSTM input. # We need the same initial states as regular LSTM return super(OntoAttentionLSTM, self).get_initial_states(sense_hyp_stripped_x)
def get_constants(self, x): # Reimplementing because ndim of x is 5. (samples, timesteps, num_senses, num_hyps, input_dim) if K.ndim(x) == 4: x = K.expand_dims(x) sense_hyp_stripped_x = x[:, :, 0, 0, :-1] # (samples, timesteps, input_dim), just like LSTM input. # We need the same constants as regular LSTM. lstm_constants = super(OntoAttentionLSTM, self).get_constants(sense_hyp_stripped_x) return lstm_constants
def compute_mask(self, input, mask): # redefining compute mask because the input ndim is different from the output ndim, and # this needs to be handled. if self.return_sequences and mask is not None: # Get rid of syn and hyp dimensions # input mask's shape: (batch_size, num_words, num_hyps, num_senses) # output mask's shape: (batch_size, num_words) return K.any(mask, axis=(-2, -1)) else: return None
def __init__(self, num_senses, num_hyps, use_attention=False, return_attention=False, **kwargs): assert "output_dim" in kwargs output_dim = kwargs.pop("output_dim") super(OntoAttentionNSE, self).__init__(output_dim, **kwargs) self.input_spec = [InputSpec(ndim=5)] # TODO: Define an attention output method that rebuilds the reader. self.return_attention = return_attention self.reader = OntoAttentionLSTM(self.output_dim, num_senses, num_hyps, use_attention=use_attention, consume_less='gpu', return_attention=False)
def compute_mask(self, input, mask): reader_mask = self.reader.compute_mask(input, mask) # The input mask is of ndim 5. Pass the output mask of the reader to NSE instead of the input mask. return super(OntoAttentionNSE, self).compute_mask(input, reader_mask)
def compute_mask(self, inputs, mask=None): dimension = K.ndim(inputs) mask_tensor = K.any(K.not_equal(inputs, self.mask_value), axis=-1) mask_base = K.any(mask_tensor, axis=1, keepdims=True) for axis in range(2, dimension - 1): mask_axis = K.any(mask_tensor, axis=axis, keepdims=True) mask_base = tf.logical_and(mask_base, mask_axis) return mask_base
def compute_mask(self, inputs, mask=None): channel_axis = K.ndim(inputs) - 1 mask_tensor = K.cast(mask, K.floatx()) mask_tensor = K.expand_dims(mask_tensor) mask_output = self.layer._pooling_function( mask_tensor, self.layer.pool_size, self.layer.strides, self.layer.padding, self.layer.data_format, ) mask_output = K.sum(mask_output, axis=channel_axis) next_mask_tensor = K.not_equal(mask_output, 0.0) return next_mask_tensor
def compute_mask(self, inputs, mask): channel_axis = K.ndim(inputs) - 1 mask_tensor = K.cast(mask, K.floatx()) mask_tensor = K.expand_dims(mask_tensor) mask_output = self._compute_mask_output(mask_tensor) mask_output = K.sum(mask_output, axis=channel_axis) next_mask_tensor = K.not_equal(mask_output, 0.0) return next_mask_tensor
def call(self, inputs, mask=None): outputs = self.layer.call(inputs) channel_axis = K.ndim(inputs) - 1 mask_tensor = K.cast(mask, K.floatx()) mask_tensor = K.expand_dims(mask_tensor) mask_output = self._compute_mask_output(mask_tensor) mask_output = K.repeat_elements( mask_output, self.layer.filters, channel_axis ) return outputs * mask_output
def softmax(x, axis, mask=None): if mask is None: mask = K.constant(True) mask = K.cast(mask, K.floatx()) if K.ndim(x) is K.ndim(mask) + 1: mask = K.expand_dims(mask) m = K.max(x, axis=axis, keepdims=True) e = K.exp(x - m) * mask s = K.sum(e, axis=axis, keepdims=True) s += K.cast(K.cast(s < K.epsilon(), K.floatx()) * K.epsilon(), K.floatx()) return e / s
def _collect_attention(x, a, mask): """ x is (B, T, D) a is (B, T, 1) or (B, T) mask is (B, T) """ if K.ndim(a) == 2: a = K.expand_dims(a) a = softmax(a, axis=1, mask=mask) # (B, T, 1) return K.sum(x * a, axis=1) # (B, D)
def _time_distributed_multiply(self, x, w): """Element-wise multiply vector and weights. # Arguments x: sequence of hidden states, (batch_size, ?, embedding_size) w: weights of one matching strategy of one direction, (mp_dim, embedding_size) # Output shape (?, mp_dim, embedding_size) """ # dimension of vector n_dim = K.ndim(x) embedding_size = K.int_shape(x)[-1] timesteps = K.int_shape(x)[1] if timesteps is None: timesteps = K.shape(x)[1] # collapse time dimension and batch dimension together x = K.reshape(x, (-1, embedding_size)) # reshape to (?, 1, embedding_size) x = K.expand_dims(x, axis=1) # reshape weights to (1, mp_dim, embedding_size) w = K.expand_dims(w, axis=0) # element-wise multiply x = x * w # reshape to original shape if n_dim == 3: x = K.reshape(x, K.stack([-1, timesteps, self.mp_dim, embedding_size])) x.set_shape([None, None, None, embedding_size]) elif n_dim == 2: x = K.reshape(x, K.stack([-1, self.mp_dim, embedding_size])) x.set_shape([None, None, embedding_size]) return x
def call(self, x, mask=None): #sys.stderr.write("sparse fuylly connected layer input data %s type:%s\n" % (x.name, K.type(x))) #sys.stderr.write("sparse fuylly connected layer weight type:%s\n" % (K.type(self.W))) print(str(K.ndim(x))) return self.activation(tf.sparse_tensor_dense_matmul(x, self.W) + self.b)
def call(self, x, mask=None): assert self.built, 'Layer must be built before being called' input_shape = K.int_shape(x) reduction_axes = list(range(len(input_shape))) del reduction_axes[self.axis] broadcast_shape = [1] * len(input_shape) broadcast_shape[self.axis] = input_shape[self.axis] if sorted(reduction_axes) == range(K.ndim(x))[:-1]: x_normed = K.batch_normalization( x, self.running_mean, self.running_std, self.beta, self.gamma, epsilon=self.epsilon) else: # need broadcasting broadcast_running_mean = K.reshape(self.running_mean, broadcast_shape) broadcast_running_std = K.reshape(self.running_std, broadcast_shape) broadcast_beta = K.reshape(self.beta, broadcast_shape) broadcast_gamma = K.reshape(self.gamma, broadcast_shape) x_normed = K.batch_normalization( x, broadcast_running_mean, broadcast_running_std, broadcast_beta, broadcast_gamma, epsilon=self.epsilon) return x_normed
def flatten(x): if K.ndim(x) >= 3: return Flatten()(x) else: return x