Python torch 模块,split() 实例源码
我们从Python开源项目中,提取了以下37个代码示例,用于说明如何使用torch.split()。
def forward(self, x):
x_shape = x.size() # (b, c, h, w)
offset = self.offset_filter(x) # (b, 2*c, h, w)
offset_w, offset_h = torch.split(offset, self.regular_filter.in_channels, 1) # (b, c, h, w)
offset_w = offset_w.contiguous().view(-1, int(x_shape[2]), int(x_shape[3])) # (b*c, h, w)
offset_h = offset_h.contiguous().view(-1, int(x_shape[2]), int(x_shape[3])) # (b*c, h, w)
if not self.input_shape or self.input_shape != x_shape:
self.input_shape = x_shape
grid_w, grid_h = np.meshgrid(np.linspace(-1, 1, x_shape[3]), np.linspace(-1, 1, x_shape[2])) # (h, w)
grid_w = torch.Tensor(grid_w)
grid_h = torch.Tensor(grid_h)
if self.cuda:
grid_w = grid_w.cuda()
grid_h = grid_h.cuda()
self.grid_w = nn.Parameter(grid_w)
self.grid_h = nn.Parameter(grid_h)
offset_w = offset_w + self.grid_w # (b*c, h, w)
offset_h = offset_h + self.grid_h # (b*c, h, w)
x = x.contiguous().view(-1, int(x_shape[2]), int(x_shape[3])).unsqueeze(1) # (b*c, 1, h, w)
x = F.grid_sample(x, torch.stack((offset_h, offset_w), 3)) # (b*c, h, w)
x = x.contiguous().view(-1, int(x_shape[1]), int(x_shape[2]), int(x_shape[3])) # (b, c, h, w)
x = self.regular_filter(x)
return x
def node_forward(self, inputs, child_c, child_h):
child_h_sum = torch.sum(child_h, dim=0, keepdim=True)
iou = self.ioux(inputs) + self.iouh(child_h_sum)
i, o, u = torch.split(iou, iou.size(1) // 3, dim=1)
i, o, u = F.sigmoid(i), F.sigmoid(o), F.tanh(u)
f = F.sigmoid(
self.fh(child_h) +
self.fx(inputs).repeat(len(child_h), 1)
)
fc = torch.mul(f, child_c)
c = torch.mul(i, u) + torch.sum(fc, dim=0, keepdim=True)
h = torch.mul(o, F.tanh(c))
return c, h
def get_distance_losses(self, A, AB, A_to_AB=True ):
As = torch.split(A, 1)
ABs = torch.split(AB, 1)
loss_distance_A = 0.0
num_pairs = 0
min_length = len(As)
for i in xrange(min_length - 1):
for j in xrange(i + 1, min_length):
num_pairs += 1
loss_distance_A_ij = \
self.get_individual_distance_loss(As[i], As[j],
ABs[i], ABs[j], A_to_AB)
loss_distance_A += loss_distance_A_ij
loss_distance_A = loss_distance_A / num_pairs
return loss_distance_A
def forward(self, inputs, batch_size, hidden_cell=None):
if hidden_cell is None:
# then must init with zeros
if use_cuda:
hidden = Variable(torch.zeros(2, batch_size, hp.enc_hidden_size).cuda())
cell = Variable(torch.zeros(2, batch_size, hp.enc_hidden_size).cuda())
else:
hidden = Variable(torch.zeros(2, batch_size, hp.enc_hidden_size))
cell = Variable(torch.zeros(2, batch_size, hp.enc_hidden_size))
hidden_cell = (hidden, cell)
_, (hidden,cell) = self.lstm(inputs.float(), hidden_cell)
# hidden is (2, batch_size, hidden_size), we want (batch_size, 2*hidden_size):
hidden_forward, hidden_backward = torch.split(hidden,1,0)
hidden_cat = torch.cat([hidden_forward.squeeze(0), hidden_backward.squeeze(0)],1)
# mu and sigma:
mu = self.fc_mu(hidden_cat)
sigma_hat = self.fc_sigma(hidden_cat)
sigma = torch.exp(sigma_hat/2.)
# N ~ N(0,1)
z_size = mu.size()
if use_cuda:
N = Variable(torch.normal(torch.zeros(z_size),torch.ones(z_size)).cuda())
else:
N = Variable(torch.normal(torch.zeros(z_size),torch.ones(z_size)))
z = mu + sigma*N
# mu and sigma_hat are needed for LKL loss
return z, mu, sigma_hat
def backward(self, outputs, targets, weights, normalizer, criterion, regression=False):
outputs_split = torch.split(outputs, self.batch_size, self.dim)
targets_split = torch.split(targets, self.batch_size, self.dim)
weights_split = torch.split(weights, self.batch_size, self.dim)
grad_output = []
loss = 0
for out_t, targ_t, w_t in zip(outputs_split, targets_split, weights_split):
grad_output_t, loss_t = super(MemEfficientGenerator, self).backward(
out_t, targ_t, w_t, normalizer, criterion, regression)
grad_output.append(grad_output_t)
loss += loss_t
grad_output = torch.cat(grad_output, self.dim)
return grad_output, loss
def forward(self, inp):
#if inp.dim() > 2:
# inp = inp.permute(0, 2, 1)
#inp = inp.contiguous().view(-1, self.L)
if not (type(inp) == Variable):
inp = Variable(inp[0])
if hasattr(self.arguments, 'pack_num'):
N = inp.size(0)
Ncut = int(N/self.arguments.pack_num)
split = torch.split(inp, Ncut, dim=0)
inp = torch.cat(split, dim=1)
h1 = F.tanh((self.l1(inp)))
#h2 = F.tanh(self.l2_bn(self.l2(h1)))
if self.arguments.tr_method == 'adversarial_wasserstein':
output = (self.l3(h1))
else:
output = F.sigmoid(self.l3(h1))
return output, h1
def __init__(self, root, single_spkr=False):
self.root = root
self.npzs = self.make_dataset(self.root)
if len(self.npzs) == 0:
raise(RuntimeError("Found 0 npz in subfolders of: " + root + "\n"
"Supported image extensions are: " +
self.NPZ_EXTENSION))
if single_spkr:
self.speakers = defaultdict(lambda: 0)
else:
self.speakers = []
for fname in self.npzs:
self.speakers += [os.path.basename(fname).split('_')[0]]
self.speakers = list(set(self.speakers))
self.speakers.sort()
self.speakers = {v: i for i, v in enumerate(self.speakers)}
code2phone = np.load(self.npzs[0])['code2phone']
self.dict = {v: k for k, v in enumerate(code2phone)}
def __init__(self, src, trgt, spkr, seq_len):
self.seq_len = seq_len
self.start = True
self.speakers = spkr
self.srcBatch = src[0]
self.srcLenths = src[1]
# split batch
self.tgtBatch = list(torch.split(trgt[0], self.seq_len, 0))
self.tgtBatch.reverse()
self.len = len(self.tgtBatch)
# split length list
batch_seq_len = len(self.tgtBatch)
self.tgtLenths = [self.split_length(l, batch_seq_len) for l in trgt[1]]
self.tgtLenths = torch.stack(self.tgtLenths)
self.tgtLenths = list(torch.split(self.tgtLenths, 1, 1))
self.tgtLenths = [x.squeeze() for x in self.tgtLenths]
self.tgtLenths.reverse()
assert len(self.tgtLenths) == len(self.tgtBatch)
def get_distance_losses(self):
As = torch.split(self.real_A, 1)
Bs = torch.split(self.real_B, 1)
ABs = torch.split(self.fake_B, 1)
BAs = torch.split(self.fake_A, 1)
loss_distance_A = 0.0
loss_distance_B = 0.0
num_pairs = 0
min_length = min(len(As), len(Bs))
for i in xrange(min_length - 1):
for j in xrange(i + 1, min_length):
num_pairs += 1
loss_distance_A_ij, loss_distance_B_ij = \
self.get_individual_distance_loss(As[i], As[j],
ABs[i], ABs[j],
Bs[i], Bs[j],
BAs[i], BAs[j])
loss_distance_A += loss_distance_A_ij
loss_distance_B += loss_distance_B_ij
loss_distance_A = loss_distance_A / num_pairs
loss_distance_B = loss_distance_B / num_pairs
return loss_distance_A, loss_distance_B
def forward(self, inputs, z, hidden_cell=None):
if hidden_cell is None:
# then we must init from z
hidden,cell = torch.split(F.tanh(self.fc_hc(z)),hp.dec_hidden_size,1)
hidden_cell = (hidden.unsqueeze(0).contiguous(), cell.unsqueeze(0).contiguous())
outputs,(hidden,cell) = self.lstm(inputs, hidden_cell)
# in training we feed the lstm with the whole input in one shot
# and use all outputs contained in 'outputs', while in generate
# mode we just feed with the last generated sample:
if self.training:
y = self.fc_params(outputs.view(-1, hp.dec_hidden_size))
else:
y = self.fc_params(hidden.view(-1, hp.dec_hidden_size))
# separate pen and mixture params:
params = torch.split(y,6,1)
params_mixture = torch.stack(params[:-1]) # trajectory
params_pen = params[-1] # pen up/down
# identify mixture params:
pi,mu_x,mu_y,sigma_x,sigma_y,rho_xy = torch.split(params_mixture,1,2)
# preprocess params::
if self.training:
len_out = Nmax+1
else:
len_out = 1
pi = F.softmax(pi.t().squeeze()).view(len_out,-1,hp.M)
sigma_x = torch.exp(sigma_x.t().squeeze()).view(len_out,-1,hp.M)
sigma_y = torch.exp(sigma_y.t().squeeze()).view(len_out,-1,hp.M)
rho_xy = torch.tanh(rho_xy.t().squeeze()).view(len_out,-1,hp.M)
mu_x = mu_x.t().squeeze().contiguous().view(len_out,-1,hp.M)
mu_y = mu_y.t().squeeze().contiguous().view(len_out,-1,hp.M)
q = F.softmax(params_pen).view(len_out,-1,3)
return pi,mu_x,mu_y,sigma_x,sigma_y,rho_xy,q,hidden,cell
def make_image(sequence, epoch, name='_output_'):
"""plot drawing with separated strokes"""
strokes = np.split(sequence, np.where(sequence[:,2]>0)[0]+1)
fig = plt.figure()
ax1 = fig.add_subplot(111)
for s in strokes:
plt.plot(s[:,0],-s[:,1])
canvas = plt.get_current_fig_manager().canvas
canvas.draw()
pil_image = PIL.Image.frombytes('RGB', canvas.get_width_height(),
canvas.tostring_rgb())
name = str(epoch)+name+'.jpg'
pil_image.save(name,"JPEG")
plt.close("all")
def unbundle(state):
if state is None:
return itertools.repeat(None)
return torch.split(torch.cat(state, 1), 1, 0)
def predict(self, outputs, targets, weights, criterion):
outputs_split = torch.split(outputs, self.batch_size, self.dim)
targets_split = torch.split(targets, self.batch_size, self.dim)
weights_split = torch.split(weights, self.batch_size, self.dim)
preds = []
loss = 0
for out_t, targ_t, w_t in zip(outputs_split, targets_split, weights_split):
preds_t, loss_t = super(MemEfficientGenerator, self).predict(
out_t, targ_t, w_t, criterion)
preds.append(preds_t)
loss += loss_t
preds = torch.cat(preds, self.dim)
return preds, loss
def forward(self, input_, hx):
"""
Args:
input_: A (batch, input_size) tensor containing input
features.
hx: A tuple (h_0, c_0), which contains the initial hidden
and cell state, where the size of both states is
(batch, hidden_size).
Returns:
h_1, c_1: Tensors containing the next hidden and cell state.
"""
h_0, c_0 = hx
batch_size = h_0.size(0)
bias_batch = (self.bias.unsqueeze(0)
.expand(batch_size, *self.bias.size()))
wh_b = torch.addmm(bias_batch, h_0, self.weight_hh)
wi = torch.mm(input_, self.weight_ih)
f, i, o, g = torch.split(wh_b + wi,
split_size=self.hidden_size, dim=1)
c_1 = torch.sigmoid(f)*c_0 + torch.sigmoid(i)*torch.tanh(g)
h_1 = torch.sigmoid(o) * torch.tanh(c_1)
return h_1, c_1
def forward(self, input_, hx, time):
"""
Args:
input_: A (batch, input_size) tensor containing input
features.
hx: A tuple (h_0, c_0), which contains the initial hidden
and cell state, where the size of both states is
(batch, hidden_size).
time: The current timestep value, which is used to
get appropriate running statistics.
Returns:
h_1, c_1: Tensors containing the next hidden and cell state.
"""
h_0, c_0 = hx
batch_size = h_0.size(0)
bias_batch = (self.bias.unsqueeze(0)
.expand(batch_size, *self.bias.size()))
wh = torch.mm(h_0, self.weight_hh)
wi = torch.mm(input_, self.weight_ih)
bn_wh = self.bn_hh(wh, time=time)
bn_wi = self.bn_ih(wi, time=time)
f, i, o, g = torch.split(bn_wh + bn_wi + bias_batch,
split_size=self.hidden_size, dim=1)
c_1 = torch.sigmoid(f)*c_0 + torch.sigmoid(i)*torch.tanh(g)
h_1 = torch.sigmoid(o) * torch.tanh(self.bn_c(c_1, time=time))
return h_1, c_1
def memoryEfficientLoss(outputs, targets, generator, crit, max_generator_batches, eval=False):
"""Memory efficient loss.
:param outputs: seq_len x batch_size x logits_size
:param targets: seq_len x batch_size
:param generator:
:param crit:
:param max_generator_batches:
:param eval:
:return:
"""
# compute generations one piece at a time
num_correct, loss = 0, 0
outputs = Variable(outputs.data, requires_grad=(not eval), volatile=eval) # seq_len x batch_size x logits_size
batch_size = outputs.size(1)
outputs_split = torch.split(outputs, max_generator_batches)
targets_split = torch.split(targets, max_generator_batches)
for i, (out_t, targ_t) in enumerate(zip(outputs_split, targets_split)):
# out_t = seq_len x batch_size x logits_size
# targ_t = seq_len x batch_size
out_t = out_t.view(-1, out_t.size(2)) # seq_len * batch_size x logits_size
scores_t = generator(out_t) # seq_len * batch_size x voc_size
loss_t = crit(scores_t, targ_t.view(-1)) # scholar (1-d)
pred_t = scores_t.max(1)[1] # seq_len * batch_size x 1
num_correct_t = pred_t.data.eq(targ_t.data).masked_select(targ_t.ne(Constants.PAD).data).sum()
num_correct += num_correct_t
loss += loss_t.data[0]
if not eval:
loss_t.div(batch_size).backward()
grad_output = None if outputs.grad is None else outputs.grad.data
return loss, grad_output, num_correct
def set_proposal_params(self, tensor_of_proposal_means_stds_coeffs):
n_components = int(tensor_of_proposal_means_stds_coeffs.size(0) / 3)
self.proposal_means, self.proposal_stds, self.proposal_coeffs = torch.split(tensor_of_proposal_means_stds_coeffs, n_components)
def split(self, split_size, dim=0):
"""Splits this tensor into a tuple of tensors.
See :func:`torch.split`.
"""
return torch.split(self, split_size, dim)
def forward(self, input_, hx):
"""
Args:
input_: A (batch, input_size) tensor containing input
features.
hx: A tuple (h_0, c_0), which contains the initial hidden
and cell state, where the size of both states is
(batch, hidden_size).
Returns:
h_1, c_1: Tensors containing the next hidden and cell state.
"""
h_0, c_0 = hx
batch_size = h_0.size(0)
bias_batch = (self.bias.unsqueeze(0)
.expand(batch_size, *self.bias.size()))
wh_b = torch.addmm(bias_batch, h_0, self.weight_hh)
wi = torch.mm(input_, self.weight_ih)
f, i, o, g = torch.split(wh_b + wi,
split_size=self.hidden_size, dim=1)
c_1 = torch.sigmoid(f)*c_0 + torch.sigmoid(i)*torch.tanh(g)
h_1 = torch.sigmoid(o) * torch.tanh(c_1)
return h_1, c_1
def forward(self, input_, hx, time):
"""
Args:
input_: A (batch, input_size) tensor containing input
features.
hx: A tuple (h_0, c_0), which contains the initial hidden
and cell state, where the size of both states is
(batch, hidden_size).
time: The current timestep value, which is used to
get appropriate running statistics.
Returns:
h_1, c_1: Tensors containing the next hidden and cell state.
"""
h_0, c_0 = hx
batch_size = h_0.size(0)
bias_batch = (self.bias.unsqueeze(0)
.expand(batch_size, *self.bias.size()))
wh = torch.mm(h_0, self.weight_hh)
wi = torch.mm(input_, self.weight_ih)
bn_wh = self.bn_hh(wh, time=time)
bn_wi = self.bn_ih(wi, time=time)
f, i, o, g = torch.split(bn_wh + bn_wi + bias_batch,
split_size=self.hidden_size, dim=1)
c_1 = torch.sigmoid(f)*c_0 + torch.sigmoid(i)*torch.tanh(g)
h_1 = torch.sigmoid(o) * torch.tanh(self.bn_c(c_1, time=time))
return h_1, c_1
def split(self, split_size, dim=0):
"""Splits this tensor into a tuple of tensors.
See :func:`torch.split`.
"""
return torch.split(self, split_size, dim)
def split(self, split_size, dim=0):
return torch.split(self, split_size, dim)
def execute(self):
maxLen = max([len(e) for e in self.progs])
for s in range(maxLen):
nodes = []
for i in range(len(self.progs)):
prog = self.progs[i]
if len(prog) <= s:
continue
nodes += [prog[s]]
groupedNodes = {}
for node in nodes:
groupedNodes.setdefault(node.cellInd, []).append(node)
for cellInd, nodes in groupedNodes.items():
arity = nodes[0].arity
cell = self.cells[cellInd]
outData = [node.inpData[0] for node in nodes]
if arity==1:
arg = t.cat(outData, 0)
outData = cell(arg)
outData = t.split(outData, 1, 0)
elif arity==2:
arg1 = t.cat(outData, 0)
arg2 = t.cat([node.inpData[1] for node in nodes], 0)
outData = cell(arg1, arg2)
outData = t.split(outData, 1, 0)
for node, outDat in zip(nodes, outData):
if node.prev is None:
node.outData = outDat
else:
node.prev.inpData += [outDat]
outData = [prog[-1].outData for prog in self.progs]
return t.cat(outData, 0)
def split(self, split_size, dim=0):
"""Splits this tensor into a tuple of tensors.
See :func:`torch.split`.
"""
return torch.split(self, split_size, dim)
def split(self, split_size, dim=0):
return torch.split(self, split_size, dim)
def sample_outputs(generator, Nsamples, arguments):
inp = torch.randn(Nsamples, arguments.L1)
if arguments.cuda:
inp = inp.cuda()
out = generator.forward(Variable(inp))
if arguments.task == 'images':
out = out.contiguous().view(-1, arguments.nfts, arguments.T)
return torch.split(out.data, split_size=1, dim=0)
def forward(self, x, ident, context, start=True):
out, attns = [], []
o_t = x[0]
self.init_buffer(ident, start)
for o_tm1 in torch.split(x, 1):
if not self.training:
o_tm1 = o_t.unsqueeze(0)
# predict weighted context based on S
c_t, mu_t, alpha_t = self.attn(self.S_t,
context.transpose(0, 1),
self.mu_t)
# advance mu and update buffer
self.S_t = self.update_buffer(self.S_t, c_t, o_tm1, ident)
self.mu_t = mu_t
# predict next time step based on buffer content
ot_out = self.N_o(self.S_t.view(self.S_t.size(0), -1))
sp_out = self.F_o(ident)
o_t = self.output(ot_out + sp_out)
out += [o_t]
attns += [alpha_t.squeeze()]
out_seq = torch.stack(out)
attns_seq = torch.stack(attns)
return out_seq, attns_seq
def loader(self, path):
feat = np.load(path)
txt = feat['phonemes'].astype('int64')
txt = torch.from_numpy(txt)
audio = feat['audio_features']
audio = torch.from_numpy(audio)
spkr = os.path.basename(path).split('_')[0]
return txt, audio, spkr
def forward(self, x, lstm_hidden_vb=None):
p = x.view(x.size(0), self.input_dims[0] * self.input_dims[1])
p = self.rl1(self.fc1(p))
p = self.rl2(self.fc2(p))
p = self.rl3(self.fc3(p))
p = self.rl4(self.fc4(p))
p = p.view(-1, self.hidden_dim)
if self.enable_lstm:
p_, v_ = torch.split(lstm_hidden_vb[0],1)
c_p, c_v = torch.split(lstm_hidden_vb[1],1)
p, c_p = self.lstm(p, (p_, c_p))
p_out = self.policy_5(p)
sig = self.policy_sig(p)
sig = self.softplus(sig)
v = x.view(x.size(0), self.input_dims[0] * self.input_dims[1])
v = self.rl1_v(self.fc1_v(v))
v = self.rl2_v(self.fc2_v(v))
v = self.rl3_v(self.fc3_v(v))
v = self.rl4_v(self.fc4_v(v))
v = v.view(-1, self.hidden_dim)
if self.enable_lstm:
v, c_v = self.lstm_v(v, (v_, c_v))
v_out = self.value_5(v)
if self.enable_lstm:
return p_out, sig, v_out, (torch.cat((p,v),0), torch.cat((c_p, c_v),0))
else:
return p_out, sig, v_out
def forward(self, input_, hx):
"""
Args:
input_: A (batch, input_size) tensor containing input
features.
hx: initial hidden, where the size of the state is
(batch, hidden_size).
Returns:
newh: Tensors containing the next hidden state.
"""
batch_size = hx.size(0)
bias_batch = (self.gate_bias.unsqueeze(0)
.expand(batch_size, *self.gate_bias.size()))
gate_Wh = torch.addmm(bias_batch, hx, self.gate_W)
gate_Ux = torch.mm(input_, self.gate_U)
r, z = torch.split(gate_Ux + gate_Wh,
split_size=self.hidden_size, dim=1)
Ux = torch.mm(input_, self.U)
unitary = self._EUNN(hx=hx, thetaA=self.thetaA, thetaB=self.thetaB)
unitary = unitary * r
newh = Ux + unitary
newh = self._modReLU(newh, self.bias)
newh = hx * z + (1-z) * newh
return newh
def reader(self):
with open(self.filepath, 'r') as f:
if self.has_header:
next(f)
for line in f:
w, *vec = line.split()
yield w, vec
def shards(data, size=25, test=False):
"""
Generator over variables that will be involved in a costly loss computation
such as the softmax. It yields dictionaries of the same form as the input,
where the variables have been splitted in smaller shards and detach from
the graph. It expects the consumer to back propagate through them in shards
of given a size. After all shards are consumed, the generator will take
care of backprop further from the input using the accumulated gradients.
"""
# Inspired by www.github.com/OpenNMT/OpenNMT-py/blob/master/onmt/Loss.py
if test:
yield data
return
detached = dict(detach_vars(data))
splits = ((key, torch.split(v, size)) for key, v in detached.items())
keys, splits = zip(*splits)
for split in zip(*splits):
yield dict(zip(keys, split)) # go and accumulate some loss
inputs, grads = [], []
for key, var in detached.items():
if var.grad is not None:
inputs.append(data[key]), grads.append(var.grad.data)
torch.autograd.backward(inputs, grads, retain_graph=True)
# Initializers
def split(self, split_size, dim=0):
r"""Splits this tensor into tensor chunks of :attr:`split_size` size.
See :func:`torch.split`.
"""
return torch.split(self, split_size, dim)
def forward(self, tensors: List[torch.Tensor], # pylint: disable=arguments-differ
mask: torch.Tensor = None) -> torch.Tensor:
"""
Compute a weighted average of the ``tensors``. The input tensors an be any shape
with at least two dimensions, but must all be the same shape.
When ``do_layer_norm=True``, the ``mask`` is required input. If the ``tensors`` are
dimensioned ``(dim_0, ..., dim_{n-1}, dim_n)``, then the ``mask`` is dimensioned
``(dim_0, ..., dim_{n-1})``, as in the typical case with ``tensors`` of shape
``(batch_size, timesteps, dim)`` and ``mask`` of shape ``(batch_size, timesteps)``.
When ``do_layer_norm=False`` the ``mask`` is ignored.
"""
if len(tensors) != self.mixture_size:
raise ConfigurationError("{} tensors were passed, but the module was initialized to "
"mix {} tensors.".format(len(tensors), self.mixture_size))
def _do_layer_norm(tensor, broadcast_mask, num_elements_not_masked):
tensor_masked = tensor * broadcast_mask
mean = torch.sum(tensor_masked) / num_elements_not_masked
variance = torch.sum(((tensor_masked - mean) * broadcast_mask)**2) / num_elements_not_masked
return (tensor - mean) / torch.sqrt(variance + 1E-12)
normed_weights = torch.nn.functional.softmax(torch.cat([parameter for parameter
in self.scalar_parameters]), dim=0)
normed_weights = torch.split(normed_weights, split_size=1)
if not self.do_layer_norm:
pieces = []
for weight, tensor in zip(normed_weights, tensors):
pieces.append(weight * tensor)
return self.gamma * sum(pieces)
else:
mask_float = mask.float()
broadcast_mask = mask_float.unsqueeze(-1)
input_dim = tensors[0].size(-1)
num_elements_not_masked = torch.sum(mask_float) * input_dim
pieces = []
for weight, tensor in zip(normed_weights, tensors):
pieces.append(weight * _do_layer_norm(tensor,
broadcast_mask, num_elements_not_masked))
return self.gamma * sum(pieces)
def forward(self, buffers, transitions):
buffers = [list(torch.split(b.squeeze(1), 1, 0))
for b in torch.split(buffers, 1, 1)]
stacks = [[buf[0], buf[0]] for buf in buffers]
if hasattr(self, 'tracker'):
self.tracker.reset_state()
else:
assert transitions is not None
if transitions is not None:
num_transitions = transitions.size(0)
# trans_loss, trans_acc = 0, 0
else:
num_transitions = len(buffers[0]) * 2 - 3
for i in range(num_transitions):
if transitions is not None:
trans = transitions[i]
if hasattr(self, 'tracker'):
tracker_states, trans_hyp = self.tracker(buffers, stacks)
if trans_hyp is not None:
trans = trans_hyp.max(1)[1]
# if transitions is not None:
# trans_loss += F.cross_entropy(trans_hyp, trans)
# trans_acc += (trans_preds.data == trans.data).mean()
# else:
# trans = trans_preds
else:
tracker_states = itertools.repeat(None)
lefts, rights, trackings = [], [], []
batch = zip(trans.data, buffers, stacks, tracker_states)
for transition, buf, stack, tracking in batch:
if transition == 3: # shift
stack.append(buf.pop())
elif transition == 2: # reduce
rights.append(stack.pop())
lefts.append(stack.pop())
trackings.append(tracking)
if rights:
reduced = iter(self.reduce(lefts, rights, trackings))
for transition, stack in zip(trans.data, stacks):
if transition == 2:
stack.append(next(reduced))
# if trans_loss is not 0:
return bundle([stack.pop() for stack in stacks])[0]
def alpha_loss(outputs, targets, generator, crit, max_generator_batches, rewards, proposed_weights, tau, alpha, eval=False):
"""Loss function of proposed method.
:param outputs: seq_len x batch_size x logits_size
:param targets: seq_len x batch_size
:param generator:
:param crit:
:param max_generator_batches:
:param eval:
:return:
"""
# compute generations one piece at a time
num_correct, loss = 0, 0
outputs = Variable(outputs.data, requires_grad=(not eval), volatile=eval) # seq_len x batch_size x logits_size
batch_size = outputs.size(1)
outputs_split = torch.split(outputs, max_generator_batches)
targets_split = torch.split(targets, max_generator_batches)
# TODO(sotetsuk): fix to calculate at once
importance_list = []
p_sample_efficiency_list = []
q_sample_efficiency_list = []
pq_sample_efficiency_list = []
for i, (out_t, targ_t) in enumerate(zip(outputs_split, targets_split)):
out_t = out_t.view(-1, out_t.size(2)) # seq_len * batch_size x logits_size
scores_t = generator(out_t) # seq_len * batch_size x voc_size
proposed_weights = torch.FloatTensor(proposed_weights)
log_q_weights = torch.FloatTensor(rewards) / tau
loss_t, importance_t, p_sample_efficiency_t, q_sample_efficiency_t, pq_sample_efficiency_t = crit(scores_t, targ_t.view(-1), proposed_weights, log_q_weights, alpha, rewards) # scholar (1-d)
pred_t = scores_t.max(1)[1] # seq_len * batch_size x 1
num_correct_t = pred_t.data.eq(targ_t.data).masked_select(targ_t.ne(Constants.PAD).data).sum()
num_correct += num_correct_t
loss += loss_t.data[0]
importance_list += importance_t
p_sample_efficiency_list += p_sample_efficiency_t
q_sample_efficiency_list += q_sample_efficiency_t
pq_sample_efficiency_list += pq_sample_efficiency_t
if not eval:
loss_t.div(batch_size).backward()
grad_output = None if outputs.grad is None else outputs.grad.data
return loss, grad_output, num_correct, importance_list, p_sample_efficiency_list, q_sample_efficiency_list, pq_sample_efficiency_list
def shards(state, shard_size, eval=False):
"""
Args:
state: A dictionary which corresponds to the output of
*LossCompute.make_shard_state(). The values for
those keys are Tensor-like or None.
shard_size: The maximum size of the shards yielded by the model.
eval: If True, only yield the state, nothing else.
Otherwise, yield shards.
Yields:
Each yielded shard is a dict.
Side effect:
After the last shard, this function does back-propagation.
"""
if eval:
yield state
else:
# non_none: the subdict of the state dictionary where the values
# are not None.
non_none = dict(filter_shard_state(state))
# Now, the iteration:
# state is a dictionary of sequences of tensor-like but we
# want a sequence of dictionaries of tensors.
# First, unzip the dictionary into a sequence of keys and a
# sequence of tensor-like sequences.
keys, values = zip(*((k, torch.split(v, shard_size))
for k, v in non_none.items()))
# Now, yield a dictionary for each shard. The keys are always
# the same. values is a sequence of length #keys where each
# element is a sequence of length #shards. We want to iterate
# over the shards, not over the keys: therefore, the values need
# to be re-zipped by shard and then each shard can be paired
# with the keys.
for shard_tensors in zip(*values):
yield dict(zip(keys, shard_tensors))
# Assumed backprop'd
variables = ((state[k], v.grad.data) for k, v in non_none.items()
if isinstance(v, Variable) and v.grad is not None)
inputs, grads = zip(*variables)
torch.autograd.backward(inputs, grads)