Python torch 模块,gather() 实例源码
我们从Python开源项目中,提取了以下44个代码示例,用于说明如何使用torch.gather()。
def forward(self, x, lengths):
"""Handles variable size captions
"""
# Embed word ids to vectors
x = self.embed(x)
packed = pack_padded_sequence(x, lengths, batch_first=True)
# Forward propagate RNN
out, _ = self.rnn(packed)
# Reshape *final* output to (batch_size, hidden_size)
padded = pad_packed_sequence(out, batch_first=True)
I = torch.LongTensor(lengths).view(-1, 1, 1)
I = Variable(I.expand(x.size(0), 1, self.embed_size)-1).cuda()
out = torch.gather(padded[0], 1, I).squeeze(1)
# normalization in the joint embedding space
out = l2norm(out)
# take absolute value, used by order embeddings
if self.use_abs:
out = torch.abs(out)
return out
def reverse_sequence(self, x, x_lens):
batch_size, seq_len, word_dim = x.size()
inv_idx = Variable(torch.arange(seq_len - 1, -1, -1).long())
shift_idx = Variable(torch.arange(0, seq_len).long())
if x.is_cuda:
inv_idx = inv_idx.cuda(x.get_device())
shift_idx = shift_idx.cuda(x.get_device())
inv_idx = inv_idx.unsqueeze(0).unsqueeze(-1).expand_as(x)
shift_idx = shift_idx.unsqueeze(0).unsqueeze(-1).expand_as(x)
shift = (seq_len + (-1 * x_lens)).unsqueeze(-1).unsqueeze(-1).expand_as(x)
shift_idx = shift_idx + shift
shift_idx = shift_idx.clamp(0, seq_len - 1)
x = x.gather(1, inv_idx)
x = x.gather(1, shift_idx)
return x
def forward(self, logits, target):
"""
:param logits: tensor with shape of [batch_size, seq_len, input_size]
:param target: tensor with shape of [batch_size, seq_len] of Long type filled with indexes to gather from logits
:return: tensor with shape of [batch_size] with perplexity evaluation
"""
[batch_size, seq_len, input_size] = logits.size()
logits = logits.view(-1, input_size)
log_probs = F.log_softmax(logits)
del logits
log_probs = log_probs.view(batch_size, seq_len, input_size)
target = target.unsqueeze(2)
out = t.gather(log_probs, dim=2, index=target).squeeze(2).neg()
ppl = out.mean(1).exp()
return ppl
def eliminate_rows(self, prob_sc, ind, phis):
""" eliminate rows of phis and prob_matrix scale """
length = prob_sc.size()[1]
mask = (prob_sc[:, :, 0] > 0.85).type(dtype)
rang = (Variable(torch.range(0, length - 1).unsqueeze(0)
.expand_as(mask)).
type(dtype))
ind_sc = torch.sort(rang * (1-mask) + length * mask, 1)[1]
# permute prob_sc
m = mask.unsqueeze(2).expand_as(prob_sc)
mm = m.clone()
mm[:, :, 1:] = 0
prob_sc = (torch.gather(prob_sc * (1 - m) + mm, 1,
ind_sc.unsqueeze(2).expand_as(prob_sc)))
# compose permutations
ind = torch.gather(ind, 1, ind_sc)
active = torch.gather(1-mask, 1, ind_sc)
# permute phis
active1 = active.unsqueeze(2).expand_as(phis)
ind1 = ind.unsqueeze(2).expand_as(phis)
active2 = active.unsqueeze(1).expand_as(phis)
ind2 = ind.unsqueeze(1).expand_as(phis)
phis_out = torch.gather(phis, 1, ind1) * active1
phis_out = torch.gather(phis_out, 2, ind2) * active2
return prob_sc, ind, phis_out, active
def get_ranking(predictions, labels, num_guesses=5):
"""
Given a matrix of predictions and labels for the correct ones, get the number of guesses
required to get the prediction right per example.
:param predictions: [batch_size, range_size] predictions
:param labels: [batch_size] array of labels
:param num_guesses: Number of guesses to return
:return:
"""
assert labels.size(0) == predictions.size(0)
assert labels.dim() == 1
assert predictions.dim() == 2
values, full_guesses = predictions.topk(predictions.size(1), dim=1)
_, ranking = full_guesses.topk(full_guesses.size(1), dim=1, largest=False)
gt_ranks = torch.gather(ranking.data, 1, labels[:, None]).squeeze()
guesses = full_guesses[:, :num_guesses]
return gt_ranks, guesses
def compute_loss(self, batch, output, target):
""" See base class for args description. """
scores = self.generator(self.bottle(output))
gtruth = target.view(-1)
if self.confidence < 1:
tdata = gtruth.data
mask = torch.nonzero(tdata.eq(self.padding_idx)).squeeze()
likelihood = torch.gather(scores.data, 1, tdata.unsqueeze(1))
tmp_ = self.one_hot.repeat(gtruth.size(0), 1)
tmp_.scatter_(1, tdata.unsqueeze(1), self.confidence)
if mask.dim() > 0:
likelihood.index_fill_(0, mask, 0)
tmp_.index_fill_(0, mask, 0)
gtruth = Variable(tmp_, requires_grad=False)
loss = self.criterion(scores, gtruth)
if self.confidence < 1:
loss_data = - likelihood.sum(0)
else:
loss_data = loss.data.clone()
stats = self.stats(loss_data, scores.data, target.view(-1).data)
return loss, stats
def prepare_batch(xs, lens, gpu=True):
lens, idx = torch.sort(lens, 0, True)
_, ridx = torch.sort(idx, 0)
idx_exp = idx.unsqueeze(0).unsqueeze(-1).expand_as(xs)
xs = torch.gather(xs, 1, idx_exp)
xs = Variable(xs, volatile=True)
lens = Variable(lens, volatile=True)
ridx = Variable(ridx, volatile=True)
if gpu:
xs = xs.cuda()
lens = lens.cuda()
ridx = ridx.cuda()
return xs, lens, ridx
def test_gather(self):
m, n, o = random.randint(10, 20), random.randint(10, 20), random.randint(10, 20)
elems_per_row = random.randint(1, 10)
dim = random.randrange(3)
src = torch.randn(m, n, o)
idx_size = [m, n, o]
idx_size[dim] = elems_per_row
idx = torch.LongTensor().resize_(*idx_size)
self._fill_indices(idx, dim, src.size(dim), elems_per_row, m, n, o)
actual = torch.gather(src, dim, idx)
expected = torch.Tensor().resize_(*idx_size)
for i in range(idx_size[0]):
for j in range(idx_size[1]):
for k in range(idx_size[2]):
ii = [i, j, k]
ii[dim] = idx[i,j,k]
expected[i,j,k] = src[tuple(ii)]
self.assertEqual(actual, expected, 0)
idx[0][0][0] = 23
self.assertRaises(RuntimeError, lambda: torch.gather(src, dim, idx))
src = torch.randn(3, 4, 5)
expected, idx = src.max(2)
actual = torch.gather(src, 2, idx)
self.assertEqual(actual, expected, 0)
def gather_index(input, index):
assert input.dim() == 2 and index.dim() == 1
index = index.unsqueeze(1).expand_as(input)
output = torch.gather(input, 1, index)
return output[:, 0]
def compute_loss(logits, y, lens):
batch_size, seq_len, vocab_size = logits.size()
logits = logits.view(batch_size * seq_len, vocab_size)
y = y.view(-1)
logprobs = F.log_softmax(logits)
losses = -torch.gather(logprobs, 1, y.unsqueeze(-1))
losses = losses.view(batch_size, seq_len)
mask = sequence_mask(lens, seq_len).float()
losses = losses * mask
loss_batch = losses.sum() / len(lens)
loss_step = losses.sum() / lens.sum().float()
return loss_batch, loss_step
def prepare_batch(self, batch_data, volatile=False):
x, x_lens, ys, ys_lens = batch_data
batch_dim = 0 if self.batch_first else 1
context_dim = 1 if self.batch_first else 0
x_lens, x_idx = torch.sort(x_lens, 0, True)
_, x_ridx = torch.sort(x_idx)
ys_lens, ys_idx = torch.sort(ys_lens, batch_dim, True)
x_ridx_exp = x_ridx.unsqueeze(context_dim).expand_as(ys_idx)
xys_idx = torch.gather(x_ridx_exp, batch_dim, ys_idx)
x = x[x_idx]
ys = torch.gather(ys, batch_dim, ys_idx.unsqueeze(-1).expand_as(ys))
x = Variable(x, volatile=volatile)
x_lens = Variable(x_lens, volatile=volatile)
ys_i = Variable(ys[..., :-1], volatile=volatile).contiguous()
ys_t = Variable(ys[..., 1:], volatile=volatile).contiguous()
ys_lens = Variable(ys_lens - 1, volatile=volatile)
xys_idx = Variable(xys_idx, volatile=volatile)
if self.is_cuda:
x = x.cuda(async=True)
x_lens = x_lens.cuda(async=True)
ys_i = ys_i.cuda(async=True)
ys_t = ys_t.cuda(async=True)
ys_lens = ys_lens.cuda(async=True)
xys_idx = xys_idx.cuda(async=True)
return x, x_lens, ys_i, ys_t, ys_lens, xys_idx
def enforce_angle(ang, xnorm, target, margin=0, linearized=False):
""" Enforce _real_ angular margin"""
m = margin + 1 # !! Just to keep parameters consistent w/ enforce_angle
tmp = torch.gather(ang, 1, target.view(-1, 1)).mul(m)
ang = ang.scatter(1, target.view(-1, 1), tmp)
ang = psi(ang, linearized)
ang = ang.mul(xnorm.view(-1, 1).expand_as(ang))
return ang
def enforce_angle(ang, xnorm, target, margin=0, linearized=False):
""" Enforce _real_ angular margin"""
m = margin + 1 # !! Just to keep parameters consistent w/ enforce_angle
tmp = torch.gather(ang, 1, target.view(-1, 1)).mul(m)
ang = ang.scatter(1, target.view(-1, 1), tmp)
ang = psi(ang, linearized)
ang = ang.mul(xnorm.view(-1, 1).expand_as(ang))
return ang
def _choose(self, lang_hs=None, words=None, sample=False):
# get all the possible choices
choices = self.domain.generate_choices(self.context)
# concatenate the list of the hidden states into one tensor
lang_hs = lang_hs if lang_hs is not None else torch.cat(self.lang_hs)
# concatenate all the words into one tensor
words = words if words is not None else torch.cat(self.words)
# logits for each of the item
logits = self.model.generate_choice_logits(words, lang_hs, self.ctx_h)
# construct probability distribution over only the valid choices
choices_logits = []
for i in range(self.domain.selection_length()):
idxs = [self.model.item_dict.get_idx(c[i]) for c in choices]
idxs = Variable(torch.from_numpy(np.array(idxs)))
idxs = self.model.to_device(idxs)
choices_logits.append(torch.gather(logits[i], 0, idxs).unsqueeze(1))
choice_logit = torch.sum(torch.cat(choices_logits, 1), 1, keepdim=False)
# subtract the max to softmax more stable
choice_logit = choice_logit.sub(choice_logit.max().data[0])
prob = F.softmax(choice_logit)
if sample:
# sample a choice
idx = prob.multinomial().detach()
logprob = F.log_softmax(choice_logit).gather(0, idx)
else:
# take the most probably choice
_, idx = prob.max(0, keepdim=True)
logprob = None
p_agree = prob[idx.data[0]]
# Pick only your choice
return choices[idx.data[0]][:self.domain.selection_length()], logprob, p_agree.data[0]
def _test_gather(self, cast, test_bounds=True):
m, n, o = random.randint(10, 20), random.randint(10, 20), random.randint(10, 20)
elems_per_row = random.randint(1, 10)
dim = random.randrange(3)
src = torch.randn(m, n, o)
idx_size = [m, n, o]
idx_size[dim] = elems_per_row
idx = torch.LongTensor().resize_(*idx_size)
TestTorch._fill_indices(self, idx, dim, src.size(dim), elems_per_row, m, n, o)
src = cast(src)
idx = cast(idx)
actual = torch.gather(src, dim, idx)
expected = cast(torch.Tensor().resize_(*idx_size))
for i in range(idx_size[0]):
for j in range(idx_size[1]):
for k in range(idx_size[2]):
ii = [i, j, k]
ii[dim] = idx[i, j, k]
expected[i, j, k] = src[tuple(ii)]
self.assertEqual(actual, expected, 0)
if test_bounds:
idx[0][0][0] = 23
self.assertRaises(RuntimeError, lambda: torch.gather(src, dim, idx))
src = cast(torch.randn(3, 4, 5))
expected, idx = src.max(2)
expected = cast(expected)
idx = cast(idx)
actual = torch.gather(src, 2, idx)
self.assertEqual(actual, expected, 0)
def reverse_padded_sequence(inputs, lengths, batch_first=False):
"""Reverses sequences according to their lengths.
Inputs should have size ``T x B x *`` if ``batch_first`` is False, or
``B x T x *`` if True. T is the length of the longest sequence (or larger),
B is the batch size, and * is any number of dimensions (including 0).
Arguments:
inputs (Variable): padded batch of variable length sequences.
lengths (list[int]): list of sequence lengths
batch_first (bool, optional): if True, inputs should be B x T x *.
Returns:
A Variable with the same size as inputs, but with each sequence
reversed according to its length.
"""
if not batch_first:
inputs = inputs.transpose(0, 1)
if inputs.size(0) != len(lengths):
raise ValueError('inputs incompatible with lengths.')
reversed_indices = [list(range(inputs.size(1)))
for _ in range(inputs.size(0))]
for i, length in enumerate(lengths):
if length > 0:
reversed_indices[i][:length] = reversed_indices[i][length-1::-1]
reversed_indices = (torch.LongTensor(reversed_indices).unsqueeze(2)
.expand_as(inputs))
reversed_indices = Variable(reversed_indices)
if inputs.is_cuda:
device = inputs.get_device()
reversed_indices = reversed_indices.cuda(device)
reversed_inputs = torch.gather(inputs, 1, reversed_indices)
if not batch_first:
reversed_inputs = reversed_inputs.transpose(0, 1)
return reversed_inputs
def _test_gather(self, cast, test_bounds=True):
m, n, o = random.randint(10, 20), random.randint(10, 20), random.randint(10, 20)
elems_per_row = random.randint(1, 10)
dim = random.randrange(3)
src = torch.randn(m, n, o)
idx_size = [m, n, o]
idx_size[dim] = elems_per_row
idx = torch.LongTensor().resize_(*idx_size)
TestTorch._fill_indices(self, idx, dim, src.size(dim), elems_per_row, m, n, o)
src = cast(src)
idx = cast(idx)
actual = torch.gather(src, dim, idx)
expected = cast(torch.Tensor().resize_(*idx_size))
for i in range(idx_size[0]):
for j in range(idx_size[1]):
for k in range(idx_size[2]):
ii = [i, j, k]
ii[dim] = idx[i, j, k]
expected[i, j, k] = src[tuple(ii)]
self.assertEqual(actual, expected, 0)
if test_bounds:
idx[0][0][0] = 23
self.assertRaises(RuntimeError, lambda: torch.gather(src, dim, idx))
src = cast(torch.randn(3, 4, 5))
expected, idx = src.max(2, True)
expected = cast(expected)
idx = cast(idx)
actual = torch.gather(src, 2, idx)
self.assertEqual(actual, expected, 0)
def mdn_loss(gmm_params, mu, stddev, batchsize):
gmm_mu, gmm_pi = get_gmm_coeffs(gmm_params)
eps = Variable(torch.randn(stddev.size()).normal_()).cuda()
z = torch.add(mu, torch.mul(eps, stddev))
z_flat = z.repeat(1, args.nmix)
z_flat = z_flat.view(batchsize*args.nmix, args.hiddensize)
gmm_mu_flat = gmm_mu.view(batchsize*args.nmix, args.hiddensize)
dist_all = torch.sqrt(torch.sum(torch.add(z_flat, gmm_mu_flat.mul(-1)).pow(2).mul(50), 1))
dist_all = dist_all.view(batchsize, args.nmix)
dist_min, selectids = torch.min(dist_all, 1)
gmm_pi_min = torch.gather(gmm_pi, 1, selectids.view(-1, 1))
gmm_loss = torch.mean(torch.add(-1*torch.log(gmm_pi_min+1e-30), dist_min))
gmm_loss_l2 = torch.mean(dist_min)
return gmm_loss, gmm_loss_l2
def maskedCE(logits, target, length):
"""
Args:
logits: A Variable containing a FloatTensor of size
(batch, max_len, num_classes) which contains the
unnormalized probability for each class.
target: A Variable containing a LongTensor of size
(batch, max_len) which contains the index of the true
class for each corresponding step.
length: A Variable containing a LongTensor of size (batch,)
which contains the length of each data in a batch.
Returns:
loss: An average loss value masked by the length.
"""
# logits_flat: (batch * max_len, num_classes)
logits_flat = logits.view(-1, logits.size(-1))
# log_probs_flat: (batch * max_len, num_classes)
log_probs_flat = F.log_softmax(logits_flat)
# target_flat: (batch * max_len, 1)
target_flat = target.view(-1, 1)
# losses_flat: (batch * max_len, 1)
losses_flat = -t.gather(log_probs_flat, dim=1, index=target_flat)
# losses: (batch, max_len)
losses = losses_flat.view(*target.size())
# mask: (batch, max_len)
mask = _sequence_mask(sequence_length=length, max_len=target.size(1))
losses = losses * mask.float()
loss = losses.sum() / length.float().sum()
return loss
def _test_gather(self, cast, test_bounds=True):
m, n, o = random.randint(10, 20), random.randint(10, 20), random.randint(10, 20)
elems_per_row = random.randint(1, 10)
dim = random.randrange(3)
src = torch.randn(m, n, o)
idx_size = [m, n, o]
idx_size[dim] = elems_per_row
idx = torch.LongTensor().resize_(*idx_size)
TestTorch._fill_indices(self, idx, dim, src.size(dim), elems_per_row, m, n, o)
src = cast(src)
idx = cast(idx)
actual = torch.gather(src, dim, idx)
expected = cast(torch.Tensor().resize_(*idx_size))
for i in range(idx_size[0]):
for j in range(idx_size[1]):
for k in range(idx_size[2]):
ii = [i, j, k]
ii[dim] = idx[i, j, k]
expected[i, j, k] = src[tuple(ii)]
self.assertEqual(actual, expected, 0)
if test_bounds:
idx[0][0][0] = 23
self.assertRaises(RuntimeError, lambda: torch.gather(src, dim, idx))
src = cast(torch.randn(3, 4, 5))
expected, idx = src.max(2, True)
expected = cast(expected)
idx = cast(idx)
actual = torch.gather(src, 2, idx)
self.assertEqual(actual, expected, 0)
def masked_cross_entropy(logits, target, length):
length = Variable(torch.LongTensor(length)).cuda()
"""
Args:
logits: A Variable containing a FloatTensor of size
(batch, max_len, num_classes) which contains the
unnormalized probability for each class.
target: A Variable containing a LongTensor of size
(batch, max_len) which contains the index of the true
class for each corresponding step.
length: A Variable containing a LongTensor of size (batch,)
which contains the length of each data in a batch.
Returns:
loss: An average loss value masked by the length.
"""
# logits_flat: (batch * max_len, num_classes)
logits_flat = logits.view(-1, logits.size(-1))
# log_probs_flat: (batch * max_len, num_classes)
log_probs_flat = functional.log_softmax(logits_flat)
# target_flat: (batch * max_len, 1)
target_flat = target.view(-1, 1)
# losses_flat: (batch * max_len, 1)
losses_flat = -torch.gather(log_probs_flat, dim=1, index=target_flat)
# losses: (batch, max_len)
losses = losses_flat.view(*target.size())
# mask: (batch, max_len)
mask = sequence_mask(sequence_length=length, max_len=target.size(1))
losses = losses * mask.float()
loss = losses.sum() / length.float().sum()
return loss
def _test_gather(self, cast, test_bounds=True):
m, n, o = random.randint(10, 20), random.randint(10, 20), random.randint(10, 20)
elems_per_row = random.randint(1, 10)
dim = random.randrange(3)
src = torch.randn(m, n, o)
idx_size = [m, n, o]
idx_size[dim] = elems_per_row
idx = torch.LongTensor().resize_(*idx_size)
TestTorch._fill_indices(self, idx, dim, src.size(dim), elems_per_row, m, n, o)
src = cast(src)
idx = cast(idx)
actual = torch.gather(src, dim, idx)
expected = cast(torch.Tensor().resize_(*idx_size))
for i in range(idx_size[0]):
for j in range(idx_size[1]):
for k in range(idx_size[2]):
ii = [i, j, k]
ii[dim] = idx[i, j, k]
expected[i, j, k] = src[tuple(ii)]
self.assertEqual(actual, expected, 0)
if test_bounds:
idx[0][0][0] = 23
self.assertRaises(RuntimeError, lambda: torch.gather(src, dim, idx))
src = cast(torch.randn(3, 4, 5))
expected, idx = src.max(2, True)
expected = cast(expected)
idx = cast(idx)
actual = torch.gather(src, 2, idx)
self.assertEqual(actual, expected, 0)
def forward(self, lstm_out, lengths):
"""
Args:
lstm_out: A Variable containing a 3D tensor of dimension
(seq_len, batch_size, hidden_x_dirs)
lengths: A Variable containing 1D LongTensor of dimension
(batch_size)
Return:
A Variable containing a 2D tensor of the same type as lstm_out of
dim (batch_size, hidden_x_dirs) corresponding to the concatenated
last hidden states of the forward and backward parts of the input.
"""
seq_len = lstm_out.size(0)
batch_size = lstm_out.size(1)
hidden_x_dirs = lstm_out.size(2)
single_dir_hidden = hidden_x_dirs / 2
lengths_fw = lengths
lengths_bw = seq_len - lengths_fw
rep_lengths_fw = lengths_fw.view(1, batch_size, 1)
rep_lengths_fw = rep_lengths_fw.repeat(1, 1, single_dir_hidden)
rep_lengths_bw = lengths_bw.view(1, batch_size, 1)
rep_lengths_bw = rep_lengths_bw.repeat(1, 1, single_dir_hidden)
# we want 2 chunks in the last dimension
out_fw, out_bw = torch.chunk(lstm_out, 2, 2)
h_t_fw = torch.gather(out_fw, 0, rep_lengths_fw-1)
h_t_bw = torch.gather(out_bw, 0, rep_lengths_bw)
# -> (batch_size, hidden_x_dirs)
last_hidden_out = torch.cat([h_t_fw, h_t_bw], 2).squeeze()
return last_hidden_out
def sort_by_embeddings(self, Phis, Inputs_N, e):
ind = torch.sort(e, 1)[1].squeeze()
for i, phis in enumerate(Phis):
# rearange phis
phis_out = (torch.gather(Phis[i], 1, ind.unsqueeze(2)
.expand_as(phis)))
Phis[i] = (torch.gather(phis_out, 2, ind.unsqueeze(1)
.expand_as(phis)))
# rearange inputs
Inputs_N[i] = torch.gather(Inputs_N[i], 1,
ind.unsqueeze(2).expand_as(Inputs_N[i]))
return Phis, Inputs_N
def combine_matrices(self, prob_matrix, prob_matrix_scale, perm):
# argmax
new_perm = self.discretize(prob_matrix_scale)
perm = torch.gather(perm, 1, new_perm)
prob_matrix = torch.bmm(prob_matrix_scale, prob_matrix)
return prob_matrix, perm
def outputs(self, input, prob_matrix, perm):
hard_output = (torch.gather(input, 1, perm.unsqueeze(2)
.expand_as(input)))
# soft argmax
soft_output = torch.bmm(prob_matrix, input)
return hard_output, soft_output
def combine_matrices(self, prob_matrix, prob_matrix_scale,
perm, last=False):
# prob_matrix shape is bs x length x length + 1. Add extra column.
length = prob_matrix_scale.size()[2]
first = Variable(torch.zeros([self.batch_size, 1, length])).type(dtype)
first[:, 0, 0] = 1.0
prob_matrix_scale = torch.cat((first, prob_matrix_scale), 1)
# argmax
new_perm = self.discretize(prob_matrix_scale)
perm = torch.gather(perm, 1, new_perm)
# combine
prob_matrix = torch.bmm(prob_matrix_scale, prob_matrix)
return prob_matrix, prob_matrix_scale, perm
def outputs(self, input, prob_matrix, perm):
hard_output = (torch.gather(input, 1, perm.unsqueeze(2)
.expand_as(input)))
# soft argmax
soft_output = torch.bmm(prob_matrix, input)
return hard_output, soft_output
def deploy(x, labels):
pred = m(x)
loss = crit(pred, labels)
values, bests = pred.topk(pred.size(1), dim=1)
_, ranking = bests.topk(bests.size(1), dim=1, largest=False) # [batch_size, dict_size]
rank = torch.gather(ranking.data, 1, labels.data[:, None]).cpu().numpy().squeeze()
top5_preds = bests[:, :5].cpu().data.numpy()
top1_acc = np.mean(rank==0)
top5_acc = np.mean(rank<5)
return loss.data[0], top1_acc, top5_acc
def devise_train(m, x, labels, data, att_crit=None, optimizers=None):
"""
Train the direct attribute prediction model
:param m: Model we're using
:param x: [batch_size, 3, 224, 224] Image input
:param labels: [batch_size] variable with indices of the right verbs
:param embeds: [vocab_size, 300] Variables with embeddings of all of the verbs
:param atts_matrix: [vocab_size, att_dim] matrix with GT attributes of the verbs
:param att_crit: AttributeLoss module that computes the loss
:param optimizers: the decorator will use these to update parameters
:return:
"""
# Make embed unit normed
embed_normed = _normalize(data.attributes.embeds)
mv_image = m(x).embed_pred
tmv_image = mv_image @ embed_normed.t()
# Use a random label from the same batch
correct_contrib = torch.gather(tmv_image, 1, labels[:,None])
# Should be fine to ignore where the correct contrib intersects because the gradient
# wrt input is 0
losses = (0.1 + tmv_image - correct_contrib.expand_as(tmv_image)).clamp(min=0.0)
# losses.scatter_(1, labels[:, None], 0.0)
loss = m.l2_penalty + losses.sum(1).squeeze().mean()
return loss
def transition_score(self, labels, lens):
"""
Arguments:
labels: [batch_size, seq_len] LongTensor
lens: [batch_size] LongTensor
"""
batch_size, seq_len = labels.size()
# pad labels with <start> and <stop> indices
labels_ext = Variable(labels.data.new(batch_size, seq_len + 2))
labels_ext[:, 0] = self.start_idx
labels_ext[:, 1:-1] = labels
mask = sequence_mask(lens + 1, max_len=seq_len + 2).long()
pad_stop = Variable(labels.data.new(1).fill_(self.stop_idx))
pad_stop = pad_stop.unsqueeze(-1).expand(batch_size, seq_len + 2)
labels_ext = (1 - mask) * pad_stop + mask * labels_ext
labels = labels_ext
trn = self.transitions
# obtain transition vector for each label in batch and timestep
# (except the last ones)
trn_exp = trn.unsqueeze(0).expand(batch_size, *trn.size())
lbl_r = labels[:, 1:]
lbl_rexp = lbl_r.unsqueeze(-1).expand(*lbl_r.size(), trn.size(0))
trn_row = torch.gather(trn_exp, 1, lbl_rexp)
# obtain transition score from the transition vector for each label
# in batch and timestep (except the first ones)
lbl_lexp = labels[:, :-1].unsqueeze(-1)
trn_scr = torch.gather(trn_row, 2, lbl_lexp)
trn_scr = trn_scr.squeeze(-1)
mask = sequence_mask(lens + 1).float()
trn_scr = trn_scr * mask
score = trn_scr.sum(1).squeeze(-1)
return score
def _bilstm_score(self, logits, y, lens):
y_exp = y.unsqueeze(-1)
scores = torch.gather(logits, 2, y_exp).squeeze(-1)
mask = sequence_mask(lens).float()
scores = scores * mask
score = scores.sum(1).squeeze(-1)
return score
def updateGradInput(self, input, gradOutput):
assert input.dim() == 2
assert gradOutput.dim() == 2
input_size = input.size()
n = input.size(0) # batch size
d = input.size(1) # dimensionality of vectors
self._gradInput = self._gradInput or input.new()
self.cross = self.cross or input.new()
# compute diagonal term with gradOutput
self._gradInput.resize_(n, d)
if self.p == float('inf'):
# specialization for the inf case
torch.mul(self._gradInput, self.norm.view(n, 1,1).expand(n, d,1), gradOutput)
self.buffer.resize_as_(input).zero_()
self.cross.resize_(n, 1)
torch.gather(self.cross, input, 1, self._indices)
self.cross.div_(self.norm)
self.buffer.scatter_(1, self._indices, self.cross)
else:
torch.mul(self._gradInput, self.normp.view(n, 1).expand(n, d), gradOutput)
# small optimizations for different p
# buffer = input*|input|^(p-2)
# for non-even p, need to add absolute value
if self.p % 2 != 0:
if self.p < 2:
# add eps to avoid possible division by 0
torch.abs(self.buffer, input).add_(self.eps).pow_(self.p-2).mul_(input)
else:
torch.abs(self.buffer, input).pow_(self.p-2).mul_(input)
# special case for p == 2, pow(x, 0) = 1
elif self.p == 2:
self.buffer.copy_(input)
else:
# p is even and > 2, pow(x, p) is always positive
torch.pow(self.buffer, input, self.p-2).mul_(input)
# compute cross term in two steps
self.cross.resize_(n, 1)
# instead of having a huge temporary matrix (b1*b2),
#: the computations as b1*(b2*gradOutput). This avoids redundant
# computation and also a huge buffer of size n*d^2
self.buffer2 = self.buffer2 or input.new() # nxd
torch.mul(self.buffer2, input, gradOutput)
torch.sum(self.cross, self.buffer2, 1)
self.buffer.mul_(self.cross.expand_as(self.buffer))
self._gradInput.add_(-1, self.buffer)
# reuse cross buffer for normalization
if self.p == float('inf'):
torch.mul(self.cross, self.norm, self.norm)
else:
torch.mul(self.cross, self.normp, self.norm)
self._gradInput.div_(self.cross.expand(n, d))
self.gradInput = self._gradInput.view(input_size)
return self.gradInput
def test_topk(self):
def topKViaSort(t, k, dim, dir):
sorted, indices = t.sort(dim, dir)
return sorted.narrow(dim, 0, k), indices.narrow(dim, 0, k)
def compareTensors(t, res1, ind1, res2, ind2, dim):
# Values should be exactly equivalent
self.assertEqual(res1, res2, 0)
# Indices might differ based on the implementation, since there is
# no guarantee of the relative order of selection
if not ind1.eq(ind2).all():
# To verify that the indices represent equivalent elements,
# gather from the input using the topk indices and compare against
# the sort indices
vals = t.gather(dim, ind2)
self.assertEqual(res1, vals, 0)
def compare(t, k, dim, dir):
topKVal, topKInd = t.topk(k, dim, dir, True)
sortKVal, sortKInd = topKViaSort(t, k, dim, dir)
compareTensors(t, sortKVal, sortKInd, topKVal, topKInd, dim)
t = torch.rand(random.randint(1, SIZE),
random.randint(1, SIZE),
random.randint(1, SIZE))
for kTries in range(3):
for dimTries in range(3):
for transpose in (True, False):
for dir in (True, False):
testTensor = t
if transpose:
dim1 = random.randrange(t.ndimension())
dim2 = dim1
while dim1 == dim2:
dim2 = random.randrange(t.ndimension())
testTensor = t.transpose(dim1, dim2)
dim = random.randrange(testTensor.ndimension())
k = random.randint(1, testTensor.size(dim))
compare(testTensor, k, dim, dir)
def sequence_cross_entropy_with_logits(logits: torch.FloatTensor,
targets: torch.LongTensor,
weights: torch.FloatTensor,
batch_average: bool = True) -> torch.FloatTensor:
"""
Computes the cross entropy loss of a sequence, weighted with respect to
some user provided weights. Note that the weighting here is not the same as
in the :func:`torch.nn.CrossEntropyLoss()` criterion, which is weighting
classes; here we are weighting the loss contribution from particular elements
in the sequence. This allows loss computations for models which use padding.
Parameters
----------
logits : ``torch.FloatTensor``, required.
A ``torch.FloatTensor`` of size (batch_size, sequence_length, num_classes)
which contains the unnormalized probability for each class.
targets : ``torch.LongTensor``, required.
A ``torch.LongTensor`` of size (batch, sequence_length) which contains the
index of the true class for each corresponding step.
weights : ``torch.FloatTensor``, required.
A ``torch.FloatTensor`` of size (batch, sequence_length)
batch_average : bool, optional, (default = True).
A bool indicating whether the loss should be averaged across the batch,
or returned as a vector of losses per batch element.
Returns
-------
A torch.FloatTensor representing the cross entropy loss.
If ``batch_average == True``, the returned loss is a scalar.
If ``batch_average == False``, the returned loss is a vector of shape (batch_size,).
"""
# shape : (batch * sequence_length, num_classes)
logits_flat = logits.view(-1, logits.size(-1))
# shape : (batch * sequence_length, num_classes)
log_probs_flat = torch.nn.functional.log_softmax(logits_flat, dim=-1)
# shape : (batch * max_len, 1)
targets_flat = targets.view(-1, 1).long()
# Contribution to the negative log likelihood only comes from the exact indices
# of the targets, as the target distributions are one-hot. Here we use torch.gather
# to extract the indices of the num_classes dimension which contribute to the loss.
# shape : (batch * sequence_length, 1)
negative_log_likelihood_flat = - torch.gather(log_probs_flat, dim=1, index=targets_flat)
# shape : (batch, sequence_length)
negative_log_likelihood = negative_log_likelihood_flat.view(*targets.size())
# shape : (batch, sequence_length)
negative_log_likelihood = negative_log_likelihood * weights.float()
# shape : (batch_size,)
per_batch_loss = negative_log_likelihood.sum(1) / (weights.sum(1).float() + 1e-13)
if batch_average:
num_non_empty_sequences = ((weights.sum(1) > 0).float().sum() + 1e-13)
return per_batch_loss.sum() / num_non_empty_sequences
return per_batch_loss
def val_sents(self, data, dec_logits):
vocab, previews = self.model.vocab, self.previews
x, x_lens, ys_i, ys_t, ys_lens, xys_idx = data
if self.batch_first:
cdata = [ys_i, ys_t, ys_lens, xys_idx, dec_logits]
cdata = [d.transpose(1, 0).contiguous() for d in cdata]
ys_i, ys_t, ys_lens, xys_idx, dec_logits = cdata
_, xys_ridx = torch.sort(xys_idx, 1)
xys_ridx_exp = xys_ridx.unsqueeze(-1).expand_as(ys_i)
ys_i = torch.gather(ys_i, 1, xys_ridx_exp)
ys_t = torch.gather(ys_t, 1, xys_ridx_exp)
dec_logits = [torch.index_select(logits, 0, xy_ridx)
for logits, xy_ridx in zip(dec_logits, xys_ridx)]
ys_lens = torch.gather(ys_lens, 1, xys_ridx)
x, x_lens = x[:previews], x_lens[:previews]
ys_i, ys_t = ys_i[:, :previews], ys_t[:, :previews]
dec_logits = torch.cat(
[logits[:previews].max(2)[1].squeeze(-1).unsqueeze(0)
for logits in dec_logits], 0)
ys_lens = ys_lens[:, :previews]
ys_i, ys_t = ys_i.transpose(1, 0), ys_t.transpose(1, 0)
dec_logits, ys_lens = dec_logits.transpose(1, 0), ys_lens.transpose(1,
0)
x, x_lens = x.data.tolist(), x_lens.data.tolist()
ys_i, ys_t = ys_i.data.tolist(), ys_t.data.tolist()
dec_logits, ys_lens = dec_logits.data.tolist(), ys_lens.data.tolist()
def to_sent(data, length, vocab):
return " ".join(vocab.i2f[data[i]] for i in range(length))
def to_sents(data, lens, vocab):
return [to_sent(d, l, vocab) for d, l in zip(data, lens)]
x_sents = to_sents(x, x_lens, vocab)
yi_sents = [to_sents(yi, y_lens, vocab) for yi, y_lens in
zip(ys_i, ys_lens)]
yt_sents = [to_sents(yt, y_lens, vocab) for yt, y_lens in
zip(ys_t, ys_lens)]
o_sents = [to_sents(dec_logit, y_lens, vocab)
for dec_logit, y_lens in zip(dec_logits, ys_lens)]
return x_sents, yi_sents, yt_sents, o_sents
def test_topk(self):
def topKViaSort(t, k, dim, dir):
sorted, indices = t.sort(dim, dir)
return sorted.narrow(dim, 0, k), indices.narrow(dim, 0, k)
def compareTensors(t, res1, ind1, res2, ind2, dim):
# Values should be exactly equivalent
self.assertEqual(res1, res2, 0)
# Indices might differ based on the implementation, since there is
# no guarantee of the relative order of selection
if not ind1.eq(ind2).all():
# To verify that the indices represent equivalent elements,
# gather from the input using the topk indices and compare against
# the sort indices
vals = t.gather(dim, ind2)
self.assertEqual(res1, vals, 0)
def compare(t, k, dim, dir):
topKVal, topKInd = t.topk(k, dim, dir, True)
sortKVal, sortKInd = topKViaSort(t, k, dim, dir)
compareTensors(t, sortKVal, sortKInd, topKVal, topKInd, dim)
t = torch.rand(random.randint(1, SIZE),
random.randint(1, SIZE),
random.randint(1, SIZE))
for _kTries in range(3):
for _dimTries in range(3):
for transpose in (True, False):
for dir in (True, False):
testTensor = t
if transpose:
dim1 = random.randrange(t.ndimension())
dim2 = dim1
while dim1 == dim2:
dim2 = random.randrange(t.ndimension())
testTensor = t.transpose(dim1, dim2)
dim = random.randrange(testTensor.ndimension())
k = random.randint(1, testTensor.size(dim))
compare(testTensor, k, dim, dir)
def test_topk(self):
def topKViaSort(t, k, dim, dir):
sorted, indices = t.sort(dim, dir)
return sorted.narrow(dim, 0, k), indices.narrow(dim, 0, k)
def compareTensors(t, res1, ind1, res2, ind2, dim):
# Values should be exactly equivalent
self.assertEqual(res1, res2, 0)
# Indices might differ based on the implementation, since there is
# no guarantee of the relative order of selection
if not ind1.eq(ind2).all():
# To verify that the indices represent equivalent elements,
# gather from the input using the topk indices and compare against
# the sort indices
vals = t.gather(dim, ind2)
self.assertEqual(res1, vals, 0)
def compare(t, k, dim, dir):
topKVal, topKInd = t.topk(k, dim, dir, True)
sortKVal, sortKInd = topKViaSort(t, k, dim, dir)
compareTensors(t, sortKVal, sortKInd, topKVal, topKInd, dim)
t = torch.rand(random.randint(1, SIZE),
random.randint(1, SIZE),
random.randint(1, SIZE))
for _kTries in range(3):
for _dimTries in range(3):
for transpose in (True, False):
for dir in (True, False):
testTensor = t
if transpose:
dim1 = random.randrange(t.ndimension())
dim2 = dim1
while dim1 == dim2:
dim2 = random.randrange(t.ndimension())
testTensor = t.transpose(dim1, dim2)
dim = random.randrange(testTensor.ndimension())
k = random.randint(1, testTensor.size(dim))
compare(testTensor, k, dim, dir)
def test_topk(self):
def topKViaSort(t, k, dim, dir):
sorted, indices = t.sort(dim, dir)
return sorted.narrow(dim, 0, k), indices.narrow(dim, 0, k)
def compareTensors(t, res1, ind1, res2, ind2, dim):
# Values should be exactly equivalent
self.assertEqual(res1, res2, 0)
# Indices might differ based on the implementation, since there is
# no guarantee of the relative order of selection
if not ind1.eq(ind2).all():
# To verify that the indices represent equivalent elements,
# gather from the input using the topk indices and compare against
# the sort indices
vals = t.gather(dim, ind2)
self.assertEqual(res1, vals, 0)
def compare(t, k, dim, dir):
topKVal, topKInd = t.topk(k, dim, dir, True)
sortKVal, sortKInd = topKViaSort(t, k, dim, dir)
compareTensors(t, sortKVal, sortKInd, topKVal, topKInd, dim)
t = torch.rand(random.randint(1, SIZE),
random.randint(1, SIZE),
random.randint(1, SIZE))
for _kTries in range(3):
for _dimTries in range(3):
for transpose in (True, False):
for dir in (True, False):
testTensor = t
if transpose:
dim1 = random.randrange(t.ndimension())
dim2 = dim1
while dim1 == dim2:
dim2 = random.randrange(t.ndimension())
testTensor = t.transpose(dim1, dim2)
dim = random.randrange(testTensor.ndimension())
k = random.randint(1, testTensor.size(dim))
compare(testTensor, k, dim, dir)
def evaluate(data_source, batch_size=10, window=args.window):
# Turn on evaluation mode which disables dropout.
if args.model == 'QRNN': model.reset()
model.eval()
total_loss = 0
ntokens = len(corpus.dictionary)
hidden = model.init_hidden(batch_size)
next_word_history = None
pointer_history = None
for i in range(0, data_source.size(0) - 1, args.bptt):
if i > 0: print(i, len(data_source), math.exp(total_loss / i))
data, targets = get_batch(data_source, i, evaluation=True, args=args)
output, hidden, rnn_outs, _ = model(data, hidden, return_h=True)
rnn_out = rnn_outs[-1].squeeze()
output_flat = output.view(-1, ntokens)
###
# Fill pointer history
start_idx = len(next_word_history) if next_word_history is not None else 0
next_word_history = torch.cat([one_hot(t.data[0], ntokens) for t in targets]) if next_word_history is None else torch.cat([next_word_history, torch.cat([one_hot(t.data[0], ntokens) for t in targets])])
#print(next_word_history)
pointer_history = Variable(rnn_out.data) if pointer_history is None else torch.cat([pointer_history, Variable(rnn_out.data)], dim=0)
#print(pointer_history)
###
# Built-in cross entropy
# total_loss += len(data) * criterion(output_flat, targets).data[0]
###
# Manual cross entropy
# softmax_output_flat = torch.nn.functional.softmax(output_flat)
# soft = torch.gather(softmax_output_flat, dim=1, index=targets.view(-1, 1))
# entropy = -torch.log(soft)
# total_loss += len(data) * entropy.mean().data[0]
###
# Pointer manual cross entropy
loss = 0
softmax_output_flat = torch.nn.functional.softmax(output_flat)
for idx, vocab_loss in enumerate(softmax_output_flat):
p = vocab_loss
if start_idx + idx > window:
valid_next_word = next_word_history[start_idx + idx - window:start_idx + idx]
valid_pointer_history = pointer_history[start_idx + idx - window:start_idx + idx]
logits = torch.mv(valid_pointer_history, rnn_out[idx])
theta = args.theta
ptr_attn = torch.nn.functional.softmax(theta * logits).view(-1, 1)
ptr_dist = (ptr_attn.expand_as(valid_next_word) * valid_next_word).sum(0).squeeze()
lambdah = args.lambdasm
p = lambdah * ptr_dist + (1 - lambdah) * vocab_loss
###
target_loss = p[targets[idx].data]
loss += (-torch.log(target_loss)).data[0]
total_loss += loss / batch_size
###
hidden = repackage_hidden(hidden)
next_word_history = next_word_history[-window:]
pointer_history = pointer_history[-window:]
return total_loss / len(data_source)
# Load the best saved model.
def test_topk(self):
def topKViaSort(t, k, dim, dir):
sorted, indices = t.sort(dim, dir)
return sorted.narrow(dim, 0, k), indices.narrow(dim, 0, k)
def compareTensors(t, res1, ind1, res2, ind2, dim):
# Values should be exactly equivalent
self.assertEqual(res1, res2, 0)
# Indices might differ based on the implementation, since there is
# no guarantee of the relative order of selection
if not ind1.eq(ind2).all():
# To verify that the indices represent equivalent elements,
# gather from the input using the topk indices and compare against
# the sort indices
vals = t.gather(dim, ind2)
self.assertEqual(res1, vals, 0)
def compare(t, k, dim, dir):
topKVal, topKInd = t.topk(k, dim, dir, True)
sortKVal, sortKInd = topKViaSort(t, k, dim, dir)
compareTensors(t, sortKVal, sortKInd, topKVal, topKInd, dim)
t = torch.rand(random.randint(1, SIZE),
random.randint(1, SIZE),
random.randint(1, SIZE))
for _kTries in range(3):
for _dimTries in range(3):
for transpose in (True, False):
for dir in (True, False):
testTensor = t
if transpose:
dim1 = random.randrange(t.ndimension())
dim2 = dim1
while dim1 == dim2:
dim2 = random.randrange(t.ndimension())
testTensor = t.transpose(dim1, dim2)
dim = random.randrange(testTensor.ndimension())
k = random.randint(1, testTensor.size(dim))
compare(testTensor, k, dim, dir)
def Decoder(self, input, hidden_encoder, phis,
input_target=None, target=None):
feed_target = False
if target is not None:
feed_target = True
# N_n is the number of elements of the scope of the n-th element
N = phis.sum(2).squeeze().unsqueeze(2).expand(self.batch_size, self.n,
self.hidden_size)
output = (Variable(torch.ones(self.batch_size, self.n, self.n))
.type(dtype))
index = ((N[:, 0] - 1) % (self.n)).type(dtype_l).unsqueeze(1)
hidden = (torch.gather(hidden_encoder, 1, index)).squeeze()
# W1xe size: (batch_size, n + 1, hidden_size)
W1xe = (torch.bmm(hidden_encoder, self.W1.unsqueeze(0).expand(
self.batch_size, self.hidden_size, self.hidden_size)))
# init token
start = (self.init_token.unsqueeze(0).expand(self.batch_size,
self.input_size))
input_step = start
for n in xrange(self.n):
# decouple interaction between different scopes by looking at
# subdiagonal elements of Phi
if n > 0:
t = (phis[:, n, n - 1].squeeze().unsqueeze(1).expand(
self.batch_size, self.hidden_size))
index = (((N[:, n] + n - 1) % (self.n)).type(dtype_l)
.unsqueeze(1))
init_hidden = (torch.gather(hidden_encoder, 1, index)
.squeeze())
hidden = t * hidden + (1 - t) * init_hidden
t = (phis[:, n, n - 1].squeeze().unsqueeze(1).expand(
self.batch_size, self.input_size))
input_step = t * input_step + (1 - t) * start
# Compute next state
hidden = self.decoder_cell(input_step, hidden)
# Compute pairwise interactions
u = self.attention(hidden, W1xe, hidden_encoder, tanh=True)
# Normalize interactions by taking the masked softmax by phi
attn = self.softmax_m(phis[:, n].squeeze(), u)
if feed_target:
# feed next step with target
next = (target[:, n].unsqueeze(1).unsqueeze(2)
.expand(self.batch_size, 1, self.input_size)
.type(dtype_l))
input_step = torch.gather(input_target, 1, next).squeeze()
else:
# blend inputs
input_step = (torch.sum(attn.unsqueeze(2).expand(
self.batch_size, self. n,
self.input_size) * input, 1)).squeeze()
# Update output
output[:, n] = attn
return output
def Decoder(self, input, hidden_encoder, phis,
input_target=None, target=None):
feed_target = False
if target is not None:
feed_target = True
# N[:, n] is the number of elements of the scope of the n-th element
N = phis.sum(2).squeeze().unsqueeze(2).expand(self.batch_size, self.n,
self.hidden_size)
output = (Variable(torch.ones(self.batch_size, self.n, self.n + 1))
.type(dtype))
index = ((N[:, 0] - 1) % (self.n)).type(dtype_l).unsqueeze(1).detach()
hidden = (torch.gather(hidden_encoder, 1, index + 1)).squeeze()
# W1xe size: (batch_size, n + 1, hidden_size)
W1xe = (torch.bmm(hidden_encoder, self.W1.unsqueeze(0).expand(
self.batch_size, self.hidden_size, self.hidden_size)))
# init token
start = (self.init_token.unsqueeze(0).expand(self.batch_size,
self.input_size))
input_step = start
for n in xrange(self.n):
# decouple interaction between different scopes by looking at
# subdiagonal elements of Phi
if n > 0:
t = (phis[:, n, n - 1].squeeze().unsqueeze(1).expand(
self.batch_size, self.hidden_size))
index = (((N[:, n] + n - 1) % (self.n)).type(dtype_l)
.unsqueeze(1)).detach()
init_hidden = (torch.gather(hidden_encoder, 1, index + 1)
.squeeze())
hidden = t * hidden + (1 - t) * init_hidden
t = (phis[:, n, n - 1].squeeze().unsqueeze(1).expand(
self.batch_size, self.input_size))
input_step = t * input_step + (1 - t) * start
# Compute next state
hidden = self.decoder_cell(input_step, hidden)
# Compute pairwise interactions
u = self.attention(hidden, W1xe, hidden_encoder)
# Normalize interactions by taking the masked softmax by phi
pad = Variable(torch.ones(self.batch_size, 1)).type(dtype)
mask = torch.cat((pad, phis[:, n].squeeze()), 1)
attn = self.softmax_m(mask, u)
if feed_target:
# feed next step with target
next = (target[:, n].unsqueeze(1).unsqueeze(2)
.expand(self.batch_size, 1, self.input_size)
.type(dtype_l))
input_step = torch.gather(input_target, 1, next).squeeze()
else:
# not blend
index = attn.max(1)[1].squeeze()
next = (index.unsqueeze(1).unsqueeze(2)
.expand(self.batch_size, 1, self.input_size)
.type(dtype_l))
input_step = torch.gather(input, 1, next).squeeze()
# blend inputs
# input_step = (torch.sum(attn.unsqueeze(2).expand(
# self.batch_size, self. n + 1,
# self.input_size) * input, 1)).squeeze()
# Update output
output[:, n] = attn
return output
def viterbi_decode(self, logits, lens):
"""Borrowed from pytorch tutorial
Arguments:
logits: [batch_size, seq_len, n_labels] FloatTensor
lens: [batch_size] LongTensor
"""
batch_size, seq_len, n_labels = logits.size()
vit = logits.data.new(batch_size, self.n_labels).fill_(-10000)
vit[:, self.start_idx] = 0
vit = Variable(vit)
c_lens = lens.clone()
logits_t = logits.transpose(1, 0)
pointers = []
for logit in logits_t:
vit_exp = vit.unsqueeze(1).expand(batch_size, n_labels, n_labels)
trn_exp = self.transitions.unsqueeze(0).expand_as(vit_exp)
vit_trn_sum = vit_exp + trn_exp
vt_max, vt_argmax = vit_trn_sum.max(2)
vt_max = vt_max.squeeze(-1)
vit_nxt = vt_max + logit
pointers.append(vt_argmax.squeeze(-1).unsqueeze(0))
mask = (c_lens > 0).float().unsqueeze(-1).expand_as(vit_nxt)
vit = mask * vit_nxt + (1 - mask) * vit
mask = (c_lens == 1).float().unsqueeze(-1).expand_as(vit_nxt)
vit += mask * self.transitions[ self.stop_idx ].unsqueeze(0).expand_as(vit_nxt)
c_lens = c_lens - 1
pointers = torch.cat(pointers)
scores, idx = vit.max(1)
idx = idx.squeeze(-1)
paths = [idx.unsqueeze(1)]
for argmax in reversed(pointers):
idx_exp = idx.unsqueeze(-1)
idx = torch.gather(argmax, 1, idx_exp)
idx = idx.squeeze(-1)
paths.insert(0, idx.unsqueeze(1))
paths = torch.cat(paths[1:], 1)
scores = scores.squeeze(-1)
return scores, paths