是否可以通过 Tensor 操作获得以下循环?
a = torch.Tensor([1, 0, 0, 0]) b = torch.Tensor([1, 2, 3, 4]) for i in range(1, a.shape[0]): a[i] = b[i] + a[i-1] print(a) # [1, 3, 6, 10]
该操作取决于先前的值a和沿途计算的值(以动态规划方式)。
a
是否可以通过张量运算获得这种类型的顺序计算?
在 PyTorch 中,可以通过 torch.cumsum 获得类似的累积求和操作,而无需显式使用循环。对于你的示例:
torch.cumsum
a = torch.Tensor([1, 0, 0, 0]) b = torch.Tensor([1, 2, 3, 4])
可以将其转化为以下无显式循环的实现:
# Step 1: Create a cumulative sum tensor from b cumsum_b = torch.cumsum(b, dim=0) # Step 2: Adjust the first element to match a[0] = 1 a = cumsum_b a[0] = 1 print(a) # Output: [1.0, 3.0, 6.0, 10.0]
torch.cumsum(b, dim=0)
dim=0
b = [1, 2, 3, 4]
[1, 3, 6, 10]
a[0] = 1
对于更加复杂的动态依赖关系(如非简单加法),torch.scan 是未来可能适用的解决方案,但目前在 PyTorch 中未直接提供(在 JAX 中可以使用 jax.lax.scan)。
torch.scan
jax.lax.scan
对于 PyTorch,现在可以通过编写递归计算的方式,结合 torch.jit 提升性能。这种实现适合更复杂的动态规划问题。
torch.jit
虽然可以通过类似 torch.cumsum 的张量操作消除循环,但这种方法只适用于可简化为扫描或累积的操作。更复杂的动态规划问题可能仍需显式使用循环或编写自定义 CUDA 内核来提升效率。