>>> a = torch.arange(12).reshape(2, 6) >>> a tensor([[ 0, 1, 2, 3, 4, 5], [ 6, 7, 8, 9, 10, 11]]) >>> b = a[1:, :] >>> b.storage() is a.storage() False
但
>>> b[0, 0] = 999 >>> b, a # both tensors are changed (tensor([[999, 7, 8, 9, 10, 11]]), tensor([[ 0, 1, 2, 3, 4, 5], [999, 7, 8, 9, 10, 11]]))
存储张量数据的对象到底是什么?如何检查两个张量是否共享内存?
torch.Tensor.storage()每次调用都会返回一个新的实例torch.Storage。您可以在以下示例中看到这一点
torch.Tensor.storage()
torch.Storage
a.storage() is a.storage() # False
要比较指向底层数据的指针,可以使用以下命令:
a.storage().data_ptr() == b.storage().data_ptr() # True
这个pytorch 论坛帖子讨论了如何确定 pytorch 张量是否共享内存。
a.data_ptr()注意和之间的区别。第一个返回指向张量第一个元素a.storage().data_ptr()的指针,而第二个似乎指向底层数据(而不是切片视图)的内存地址,尽管没有记录。
a.data_ptr()
a.storage().data_ptr()
了解了上述内容,我们就可以理解为什么a.data_ptr()不同于b.data_ptr()。考虑以下代码:
b.data_ptr()
import torch a = torch.arange(4, dtype=torch.int64) b = a[1:] b.data_ptr() - a.data_ptr() # 8
的第一个元素的地址b比的第一个元素多 8 个,a因为我们切片删除了第一个元素,并且每个元素是 8 个字节(dtype 是 64 位整数)。
b
a
如果我们使用与上面相同的代码,但使用 8 位整数数据类型,则内存地址将会相差一。