Python tensorflow.python.framework.ops 模块,IndexedSlices() 实例源码
我们从Python开源项目中,提取了以下26个代码示例,用于说明如何使用tensorflow.python.framework.ops.IndexedSlices()。
def grad_variance(self):
grad_var_ops = []
tensor_to_avg = []
for t, g in zip(self._tvars, self._grads):
if isinstance(g, ops.IndexedSlices):
tensor_to_avg.append(
tf.reshape(tf.unsorted_segment_sum(
g.values, g.indices, g.dense_shape[0]),
shape=t.get_shape()))
else:
tensor_to_avg.append(g)
avg_op = self._moving_averager.apply(tensor_to_avg)
grad_var_ops.append(avg_op)
with tf.control_dependencies([avg_op]):
self._grad_avg = [
self._moving_averager.average(val) for val in tensor_to_avg]
self._grad_avg_squared = [tf.square(val) for val in self._grad_avg]
self._grad_var = tf.maximum(
tf.constant(EPS, dtype=self._grad_norm_squared_avg.dtype),
self._grad_norm_squared_avg
- tf.add_n([tf.reduce_sum(val) for val in self._grad_avg_squared] ) )
if self._sparsity_debias:
self._grad_var *= self._sparsity_avg
return grad_var_ops
def clip_gradient_norms(gradients_to_variables, max_norm):
"""Clips the gradients by the given value.
Args:
gradients_to_variables: A list of gradient to variable pairs (tuples).
max_norm: the maximum norm value.
Returns:
A list of clipped gradient to variable pairs.
"""
clipped_grads_and_vars = []
for grad, var in gradients_to_variables:
if grad is not None:
if isinstance(grad, ops.IndexedSlices):
tmp = clip_ops.clip_by_norm(grad.values, max_norm)
grad = ops.IndexedSlices(tmp, grad.indices, grad.dense_shape)
else:
grad = clip_ops.clip_by_norm(grad, max_norm)
clipped_grads_and_vars.append((grad, var))
return clipped_grads_and_vars
def add_gradients_summaries(grads_and_vars):
"""Add summaries to gradients.
Args:
grads_and_vars: A list of gradient to variable pairs (tuples).
Returns:
The list of created summaries.
"""
summaries = []
for grad, var in grads_and_vars:
if grad is not None:
if isinstance(grad, ops.IndexedSlices):
grad_values = grad.values
else:
grad_values = grad
summaries.append(logging_ops.histogram_summary(
var.op.name + ':gradient', grad_values))
summaries.append(logging_ops.histogram_summary(
var.op.name + ':gradient_norm', clip_ops.global_norm([grad_values])))
else:
logging.info('Var %s has no gradient', var.op.name)
return summaries
def _multiply_gradients(grads_and_vars, gradient_multipliers):
"""Multiply specified gradients."""
multiplied_grads_and_vars = []
for grad, var in grads_and_vars:
if (grad is not None and
(var in gradient_multipliers or var.name in gradient_multipliers)):
key = var if var in gradient_multipliers else var.name
multiplier = constant_op.constant(
gradient_multipliers[key], dtype=dtypes.float32)
if isinstance(grad, ops.IndexedSlices):
grad_values = grad.values * multiplier
grad = ops.IndexedSlices(grad_values, grad.indices, grad.dense_shape)
else:
grad *= multiplier
multiplied_grads_and_vars.append((grad, var))
return multiplied_grads_and_vars
def add_gradients_summaries(grads_and_vars):
"""Add summaries to gradients.
Args:
grads_and_vars: A list of gradient to variable pairs (tuples).
Returns:
The list of created summaries.
"""
summaries = []
for grad, var in grads_and_vars:
if grad is not None:
if isinstance(grad, ops.IndexedSlices):
grad_values = grad.values
else:
grad_values = grad
summaries.append(summary.histogram_summary(
var.op.name + ':gradient', grad_values))
summaries.append(summary.histogram_summary(
var.op.name + ':gradient_norm', clip_ops.global_norm([grad_values])))
else:
logging.info('Var %s has no gradient', var.op.name)
return summaries
def clip_gradient_norms(gradients_to_variables, max_norm):
"""Clips the gradients by the given value.
Args:
gradients_to_variables: A list of gradient to variable pairs (tuples).
max_norm: the maximum norm value.
Returns:
A list of clipped gradient to variable pairs.
"""
clipped_grads_and_vars = []
for grad, var in gradients_to_variables:
if grad is not None:
if isinstance(grad, ops.IndexedSlices):
tmp = clip_ops.clip_by_norm(grad.values, max_norm)
grad = ops.IndexedSlices(tmp, grad.indices, grad.dense_shape)
else:
grad = clip_ops.clip_by_norm(grad, max_norm)
clipped_grads_and_vars.append((grad, var))
return clipped_grads_and_vars
def clip_gradient_norms(gradients_to_variables, max_norm):
"""Clips the gradients by the given value.
Args:
gradients_to_variables: A list of gradient to variable pairs (tuples).
max_norm: the maximum norm value.
Returns:
A list of clipped gradient to variable pairs.
"""
clipped_grads_and_vars = []
for grad, var in gradients_to_variables:
if grad is not None:
if isinstance(grad, ops.IndexedSlices):
tmp = clip_ops.clip_by_norm(grad.values, max_norm)
grad = ops.IndexedSlices(tmp, grad.indices, grad.dense_shape)
else:
grad = clip_ops.clip_by_norm(grad, max_norm)
clipped_grads_and_vars.append((grad, var))
return clipped_grads_and_vars
def _multiply_gradients(grads_and_vars, gradient_multipliers):
"""Multiply specified gradients."""
multiplied_grads_and_vars = []
for grad, var in grads_and_vars:
if (grad is not None and
(var in gradient_multipliers or var.name in gradient_multipliers)):
key = var if var in gradient_multipliers else var.name
multiplier = constant_op.constant(
gradient_multipliers[key], dtype=dtypes.float32)
if isinstance(grad, ops.IndexedSlices):
grad_values = grad.values * multiplier
grad = ops.IndexedSlices(grad_values, grad.indices, grad.dense_shape)
else:
grad *= multiplier
multiplied_grads_and_vars.append((grad, var))
return multiplied_grads_and_vars
def _clip_sparse(self, grad, var):
assert isinstance(grad, ops.IndexedSlices)
clip_dims = self._vars_to_clip_dims[var]
if 0 in clip_dims:
logging.warning("Clipping norm across dims %s for %s is inefficient "
"when including sparse dimension 0.", clip_dims,
var.op.name)
return self._clip_dense(var)
with ops.colocate_with(var):
var_subset = array_ops.gather(var, grad.indices)
with self._maybe_colocate_with(var):
normalized_var_subset = clip_ops.clip_by_norm(
var_subset, self._max_norm, clip_dims)
delta = ops.IndexedSlices(
var_subset - normalized_var_subset, grad.indices, grad.dense_shape)
with ops.colocate_with(var):
return var.scatter_sub(delta, use_locking=self._use_locking)
def _thin_stack_update_gradient(op, stack_grad, *rest):
stack = op.inputs[2]
batch_size = op.inputs[4].get_shape().as_list()[0]
t = op.get_attr("timestep")
# We usually slice off the head of the stack output in feedforward and
# send it off to downstream computation. The Slice feedforward op will
# generate a sparse gradient in the backward pass. Nix this sparsity
# at the very start.
if isinstance(stack_grad, ops.IndexedSlices):
# Trick: re-use our stack structure to store new gradients.
# Recover the original stack variable from the lookup/update chain.
stack = _fetch_stack(stack)
stack = tf.assign(stack, tf.zeros_like(stack))
stack = tf.scatter_update(stack, stack_grad.indices, stack_grad.values)
stack_grad = stack
with tf.control_dependencies([stack_grad]):
input_grad = tf.slice(stack_grad, [t * batch_size, 0], [batch_size, -1])
return input_grad, None, stack_grad, None, None, None
def _multiply_gradients(grads_and_vars, gradient_multipliers):
"""Multiply specified gradients."""
multiplied_grads_and_vars = []
for grad, var in grads_and_vars:
if (grad is not None and
(var in gradient_multipliers or var.name in gradient_multipliers)):
key = var if var in gradient_multipliers else var.name
multiplier = constant_op.constant(
gradient_multipliers[key], dtype=dtypes.float32)
if isinstance(grad, ops.IndexedSlices):
grad_values = grad.values * multiplier
grad = ops.IndexedSlices(grad_values, grad.indices, grad.dense_shape)
else:
grad *= multiplier
multiplied_grads_and_vars.append((grad, var))
return multiplied_grads_and_vars
def gradients(opt, loss, vars, step, max_gradient_norm=None, dont_clip=[]):
gradients = opt.compute_gradients(loss, vars)
if max_gradient_norm is not None:
to_clip = [(g, v) for g, v in gradients if v.name not in dont_clip]
not_clipped = [(g, v) for g, v in gradients if v.name in dont_clip]
gradients, variables = zip(*to_clip)
clipped_gradients, _ = clip_ops.clip_by_global_norm(
gradients,
max_gradient_norm
)
gradients = list(zip(clipped_gradients, variables)) + not_clipped
# Add histograms for variables, gradients and gradient norms
for gradient, variable in gradients:
if isinstance(gradient, ops.IndexedSlices):
grad_values = gradient.values
else:
grad_values = gradient
if grad_values is None:
print('warning: missing gradient: {}'.format(variable.name))
if grad_values is not None:
tf.summary.histogram(variable.name, variable)
tf.summary.histogram(variable.name + '/gradients', grad_values)
tf.summary.histogram(
variable.name + '/gradient_norm',
clip_ops.global_norm([grad_values])
)
return opt.apply_gradients(gradients, global_step=step)
def gradients(opt, loss, vars, step, max_gradient_norm=None, dont_clip=[]):
gradients = opt.compute_gradients(loss, vars)
if max_gradient_norm is not None:
to_clip = [(g, v) for g, v in gradients if v.name not in dont_clip]
not_clipped = [(g, v) for g, v in gradients if v.name in dont_clip]
gradients, variables = zip(*to_clip)
clipped_gradients, _ = clip_ops.clip_by_global_norm(
gradients,
max_gradient_norm
)
gradients = list(zip(clipped_gradients, variables)) + not_clipped
# Add histograms for variables, gradients and gradient norms
for gradient, variable in gradients:
if isinstance(gradient, ops.IndexedSlices):
grad_values = gradient.values
else:
grad_values = gradient
if grad_values is None:
print('warning: missing gradient: {}'.format(variable.name))
if grad_values is not None:
tf.summary.histogram(variable.name, variable)
tf.summary.histogram(variable.name + '/gradients', grad_values)
tf.summary.histogram(
variable.name + '/gradient_norm',
clip_ops.global_norm([grad_values])
)
return opt.apply_gradients(gradients, global_step=step)
def multiply_gradients(grads_and_vars, gradient_multipliers):
"""Multiply specified gradients.
Args:
grads_and_vars: A list of gradient to variable pairs (tuples).
gradient_multipliers: A map from either `Variables` or `Variable` op names
to the coefficient by which the associated gradient should be scaled.
Returns:
The updated list of gradient to variable pairs.
Raises:
ValueError: If `grads_and_vars` is not a list or if `gradient_multipliers`
is empty or None or if `gradient_multipliers` is not a dictionary.
"""
if not isinstance(grads_and_vars, list):
raise ValueError('`grads_and_vars` must be a list.')
if not gradient_multipliers:
raise ValueError('`gradient_multipliers` is empty.')
if not isinstance(gradient_multipliers, dict):
raise ValueError('`gradient_multipliers` must be a dict.')
multiplied_grads_and_vars = []
for grad, var in grads_and_vars:
if var in gradient_multipliers or var.op.name in gradient_multipliers:
key = var if var in gradient_multipliers else var.op.name
if grad is None:
raise ValueError('Requested multiple of `None` gradient.')
if isinstance(grad, ops.IndexedSlices):
tmp = grad.values * constant_op.constant(gradient_multipliers[key],
dtype=grad.dtype)
grad = ops.IndexedSlices(tmp, grad.indices, grad.dense_shape)
else:
grad *= constant_op.constant(gradient_multipliers[key],
dtype=grad.dtype)
multiplied_grads_and_vars.append((grad, var))
return multiplied_grads_and_vars
def _add_scaled_noise_to_gradients(grads_and_vars, gradient_noise_scale):
"""Adds scaled noise from a 0-mean normal distribution to gradients."""
gradients, variables = zip(*grads_and_vars)
noisy_gradients = []
for gradient in gradients:
if gradient is None:
noisy_gradients.append(None)
continue
if isinstance(gradient, ops.IndexedSlices):
gradient_shape = gradient.dense_shape
else:
gradient_shape = gradient.get_shape()
noise = random_ops.truncated_normal(gradient_shape) * gradient_noise_scale
noisy_gradients.append(gradient + noise)
return list(zip(noisy_gradients, variables))
def multiply_gradients(grads_and_vars, gradient_multipliers):
"""Multiply specified gradients.
Args:
grads_and_vars: A list of gradient to variable pairs (tuples).
gradient_multipliers: A map from either `Variables` or `Variable` op names
to the coefficient by which the associated gradient should be scaled.
Returns:
The updated list of gradient to variable pairs.
Raises:
ValueError: If `grads_and_vars` is not a list or if `gradient_multipliers`
is empty or None or if `gradient_multipliers` is not a dictionary.
"""
if not isinstance(grads_and_vars, list):
raise ValueError('`grads_and_vars` must be a list.')
if not gradient_multipliers:
raise ValueError('`gradient_multipliers` is empty.')
if not isinstance(gradient_multipliers, dict):
raise ValueError('`gradient_multipliers` must be a dict.')
multiplied_grads_and_vars = []
for grad, var in grads_and_vars:
if var in gradient_multipliers or var.op.name in gradient_multipliers:
key = var if var in gradient_multipliers else var.op.name
if grad is None:
raise ValueError('Requested multiple of `None` gradient.')
if isinstance(grad, ops.IndexedSlices):
tmp = grad.values * constant_op.constant(
gradient_multipliers[key], dtype=grad.dtype)
grad = ops.IndexedSlices(tmp, grad.indices, grad.dense_shape)
else:
grad *= constant_op.constant(
gradient_multipliers[key], dtype=grad.dtype)
multiplied_grads_and_vars.append((grad, var))
return multiplied_grads_and_vars
def multiply_gradients(grads_and_vars, gradient_multipliers):
"""Multiply specified gradients.
Args:
grads_and_vars: A list of gradient to variable pairs (tuples).
gradient_multipliers: A map from either `Variables` or `Variable` op names
to the coefficient by which the associated gradient should be scaled.
Returns:
The updated list of gradient to variable pairs.
Raises:
ValueError: If `grads_and_vars` is not a list or if `gradient_multipliers`
is empty or None or if `gradient_multipliers` is not a dictionary.
"""
if not isinstance(grads_and_vars, list):
raise ValueError('`grads_and_vars` must be a list.')
if not gradient_multipliers:
raise ValueError('`gradient_multipliers` is empty.')
if not isinstance(gradient_multipliers, dict):
raise ValueError('`gradient_multipliers` must be a dict.')
multiplied_grads_and_vars = []
for grad, var in grads_and_vars:
if var in gradient_multipliers or var.op.name in gradient_multipliers:
key = var if var in gradient_multipliers else var.op.name
if grad is None:
raise ValueError('Requested multiple of `None` gradient.')
if isinstance(grad, ops.IndexedSlices):
tmp = grad.values * constant_op.constant(gradient_multipliers[key],
dtype=grad.dtype)
grad = ops.IndexedSlices(tmp, grad.indices, grad.dense_shape)
else:
grad *= constant_op.constant(gradient_multipliers[key],
dtype=grad.dtype)
multiplied_grads_and_vars.append((grad, var))
return multiplied_grads_and_vars
def _add_scaled_noise_to_gradients(grads_and_vars, gradient_noise_scale):
"""Adds scaled noise from a 0-mean normal distribution to gradients."""
gradients, variables = zip(*grads_and_vars)
noisy_gradients = []
for gradient in gradients:
if gradient is None:
noisy_gradients.append(None)
continue
if isinstance(gradient, ops.IndexedSlices):
gradient_shape = gradient.dense_shape
else:
gradient_shape = gradient.get_shape()
noise = random_ops.truncated_normal(gradient_shape) * gradient_noise_scale
noisy_gradients.append(gradient + noise)
return list(zip(noisy_gradients, variables))
def patch_dynamic_stitch_grad():
"""Tensorflow's current gradient implementation for `tf.dynamic_stitch` is
incorrect. This monkey-patches Tensorflow to fix the bug."""
def DynamicStitchGrads(op, grad):
num_values = len(op.inputs) // 2
indices_grad = [None] * num_values
def AsInt32(x):
return (x if op.inputs[0].dtype == dtypes.int32 else
math_ops.cast(x, dtypes.int32))
idxs = [AsInt32(array_ops.reshape(op.inputs[i], (-1,)))
for i in range(num_values)]
if isinstance(grad, ops.IndexedSlices):
output_shape = array_ops.shape(op.outputs[0])
output_rows = output_shape[0]
grad = math_ops.unsorted_segment_sum(grad.values, grad.indices,
output_rows)
values_grad = []
zeros = array_ops.zeros_like(grad)
idx_zeros = [zeros[:array_ops.shape(x)[0]] for x in idxs]
grad_range = math_ops.range(array_ops.shape(grad)[0])
for i in range(num_values):
if i == num_values - 1:
v_grad = grad
else:
v_grad = data_flow_ops.dynamic_stitch(
[grad_range] + idxs[i + 1:], [grad] + idx_zeros[i + 1:])
v_grad = array_ops.gather(v_grad, AsInt32(op.inputs[i]))
values_grad += [v_grad]
return indices_grad + values_grad
# need to stick in the registry manually, to override the already
# registered implementation
ops._gradient_registry._registry["DynamicStitch"] = {
"type": DynamicStitchGrads, "location": traceback.extract_stack()}
def grad_variance(self):
grad_var_ops = []
tensor_to_avg = []
for t, g in zip(self._tvars, self._grads):
if isinstance(g, ops.IndexedSlices):
tensor_to_avg.append(tf.reshape(tf.unsorted_segment_sum(g.values, g.indices, g.dense_shape[0] ), shape=t.get_shape() ) )
else:
tensor_to_avg.append(g)
avg_op = self._moving_averager.apply(tensor_to_avg)
grad_var_ops.append(avg_op)
with tf.control_dependencies([avg_op] ):
self._grad_avg = [self._moving_averager.average(val) for val in tensor_to_avg]
self._grad_avg_squared = [tf.square(val) for val in self._grad_avg]
self._grad_var = self._grad_norm_squared_avg - tf.add_n( [tf.reduce_sum(val) for val in self._grad_avg_squared] )
return grad_var_ops
def _FloatyGatherGrad(op, grad):
if op.inputs[0].get_shape().is_fully_defined():
dense_shape = constant_op.constant(op.inputs[0].get_shape().as_list())
values_shape = [-1] + op.inputs[0].get_shape()[1:].as_list()
else:
# op.inputs[0] can be large, so colocate the shape calculation with it.
with ops.colocate_with(op.inputs[0]):
dense_shape = array_ops.shape(op.inputs[0])
values_shape = array_ops.concat(0, [[-1], dense_shape[1:]])
values = array_ops.reshape(grad, values_shape)
indices = math_ops.to_int32(array_ops.reshape(op.inputs[1], [-1]))
return [ops.IndexedSlices(values, indices, dense_shape), None]
def add_gradients_summary(grads):
"""Add histogram summary for given gradients and scalar summary for clipped gradients.
Args:
grads: A list of `Tensor`. The gradients to summarize.
Returns:
The list of created gradient summaries.
"""
# Add histograms for gradients.
summary = []
for gradient, var in grads:
if isinstance(gradient, ops.IndexedSlices):
grad_values = gradient.values
else:
grad_values = gradient
if grad_values is not None:
summary_name = var.op.name + '/Gradients'
summary.append(get_summary(SummaryTypes.HISTOGRAM, summary_name, grad_values))
summary_norm_name = var.op.name + '/GradientsNorm'
summary.append(get_summary(SummaryTypes.SCALAR, summary_norm_name,
clip_ops.global_norm([grad_values])))
summary.append(get_summary(SummaryTypes.SCALAR, 'ClippedGradientNorm',
clip_ops.global_norm(list(zip(*grads))[0])))
return summary
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
""""""
# Error checking
grads_and_vars = tuple(grads_and_vars)
for g_t, x_tm1 in grads_and_vars:
if not isinstance(g_t, (ops.Tensor, ops.IndexedSlices, type(None))):
raise TypeError(
"Gradient must be a Tensor, IndexedSlices, or None: %s" % g_t)
if not isinstance(x_tm1, variables.Variable):
raise TypeError(
"Variable must be a tf.Variable: %s" % x_tm1)
if g_t is not None:
self._assert_valid_dtypes([g_t, x_tm1])
var_list = [x_tm1 for g_t, x_tm1 in grads_and_vars if g_t is not None]
if not var_list:
raise ValueError("No gradients provided for any variable: %s" %
(grads_and_vars,))
# The actual stuff
with ops.control_dependencies(None):
self._create_slots(grads_and_vars)
update_ops = []
with ops.op_scope([], name, self._name) as name:
prepare = self._prepare(grads_and_vars)
for g_t, x_tm1 in grads_and_vars:
if g_t is None:
continue
with ops.name_scope("update_" + x_tm1.op.name), ops.device(x_tm1.device):
if isinstance(g_t, ops.Tensor):
update_ops.append(self._apply_dense(g_t, x_tm1, prepare))
else:
update_ops.append(self._apply_sparse(g_t, x_tm1, prepare))
if global_step is None:
return self._finish(update_ops, name)
else:
with ops.control_dependencies([self._finish(update_ops, "update")]):
with ops.device(global_step.device):
return state_ops.assign_add(global_step, 1, name=name).op
#=============================================================
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
""""""
# Error checking
grads_and_vars = tuple(grads_and_vars)
for g_t, x_tm1 in grads_and_vars:
if not isinstance(g_t, (ops.Tensor, ops.IndexedSlices, type(None))):
raise TypeError(
"Gradient must be a Tensor, IndexedSlices, or None: %s" % g_t)
if not isinstance(x_tm1, variables.Variable):
raise TypeError(
"Variable must be a tf.Variable: %s" % x_tm1)
if g_t is not None:
self._assert_valid_dtypes([g_t, x_tm1])
var_list = [x_tm1 for g_t, x_tm1 in grads_and_vars if g_t is not None]
if not var_list:
raise ValueError("No gradients provided for any variable: %s" %
(grads_and_vars,))
# The actual stuff
with ops.control_dependencies(None):
self._create_slots(grads_and_vars)
update_ops = []
with ops.op_scope([], name, self._name) as name:
prepare = self._prepare(grads_and_vars)
for g_t, x_tm1 in grads_and_vars:
if g_t is None:
continue
with ops.name_scope("update_" + x_tm1.op.name), ops.device(x_tm1.device):
if isinstance(g_t, ops.Tensor):
update_ops.append(self._apply_dense(g_t, x_tm1, prepare))
else:
update_ops.append(self._apply_sparse(g_t, x_tm1, prepare))
if global_step is None:
return self._finish(update_ops, name)
else:
with ops.control_dependencies([self._finish(update_ops, "update")]):
with ops.device(global_step.device):
return state_ops.assign_add(global_step, 1, name=name).op
#=============================================================
def _add_scaled_noise_to_gradients(grads_and_vars, gradient_noise_scale):
"""Adds scaled noise from a 0-mean normal distribution to gradients."""
gradients, variables = zip(*grads_and_vars)
noisy_gradients = []
for gradient in gradients:
if gradient is None:
noisy_gradients.append(None)
continue
if isinstance(gradient, ops.IndexedSlices):
gradient_shape = gradient.dense_shape
else:
gradient_shape = gradient.get_shape()
noise = random_ops.truncated_normal(gradient_shape) * gradient_noise_scale
noisy_gradients.append(gradient + noise)
return list(zip(noisy_gradients, variables))
def patch_state_grads():
"""Tensorflow doesn't have a gradient implementation for state ops (e.g.,
scatter_add/update). This adds them in."""
def ScatterUpdateGrads(op, grad):
var, indices, updates = op.inputs
updates_grad = array_ops.gather(grad, indices)
# TODO: the dynamic_stitch approach might be faster if there were
# a GPU dynamic_stitch implementation. should be available in tf 1.4
# grad_range = math_ops.range(grad.get_shape()[0].value)
# var_grad = data_flow_ops.dynamic_stitch(
# [grad_range, indices],
# [grad, array_ops.zeros(updates.get_shape())])
if isinstance(grad, ops.IndexedSlices):
# note: we could use this approach for everything, but the
# temporary variable approach seems to be slightly faster (but we
# can't use that on indexedslices)
var_grad = grad - array_ops.scatter_nd(
array_ops.expand_dims(indices, 1), updates_grad,
var.get_shape())
else:
var_grad = gen_state_ops._temporary_variable(
grad.get_shape(), grad.dtype)
var_name = var_grad.op.name
var_grad = state_ops.assign(var_grad, grad)
var_grad = state_ops.scatter_update(
var_grad, indices, array_ops.zeros_like(updates))
var_grad = gen_state_ops._destroy_temporary_variable(var_grad,
var_name)
return var_grad, None, updates_grad
def ScatterAddGrads(op, grad):
_, indices, _ = op.inputs
updates_grad = array_ops.gather(grad, indices)
return grad, None, updates_grad
def AssignGrads(op, grad):
return array_ops.zeros_like(grad), grad
def AssignAddGrads(op, grad):
return grad, grad
ops._gradient_registry._registry["ScatterUpdate"] = {
"type": ScatterUpdateGrads, "location": traceback.extract_stack()}
ops._gradient_registry._registry["ScatterAdd"] = {
"type": ScatterAddGrads, "location": traceback.extract_stack()}
ops._gradient_registry._registry["Assign"] = {
"type": AssignGrads, "location": traceback.extract_stack()}
ops._gradient_registry._registry["AssignAdd"] = {
"type": AssignAddGrads, "location": traceback.extract_stack()}