小能豆

为什么当 a 和 b 引用相同的数据时,a.storage() 和 b.storage() 会返回 false?

py

>>> a = torch.arange(12).reshape(2, 6)
>>> a
tensor([[ 0,  1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10, 11]])
>>> b = a[1:, :]
>>> b.storage() is a.storage()
False

>>> b[0, 0] = 999
>>> b, a # both tensors are changed
(tensor([[999,   7,   8,   9,  10,  11]]),
 tensor([[  0,   1,   2,   3,   4,   5],
         [999,   7,   8,   9,  10,  11]]))

存储张量数据的对象到底是什么?如何检查两个张量是否共享内存?


阅读 10

收藏
2024-11-03

共1个答案

小能豆

torch.Tensor.storage()每次调用都会返回一个新的实例torch.Storage。您可以在以下示例中看到这一点

a.storage() is a.storage()
# False

要比较指向底层数据的指针,可以使用以下命令:

a.storage().data_ptr() == b.storage().data_ptr()
# True

这个pytorch 论坛帖子讨论了如何确定 pytorch 张量是否共享内存。


a.data_ptr()注意和之间的区别。第一个返回指向张量第一个元素a.storage().data_ptr()的指针,而第二个似乎指向底层数据(而不是切片视图)的内存地址,尽管没有记录

了解了上述内容,我们就可以理解为什么a.data_ptr()不同于b.data_ptr()。考虑以下代码:

import torch

a = torch.arange(4, dtype=torch.int64)
b = a[1:]
b.data_ptr() - a.data_ptr()
# 8

的第一个元素的地址b比的第一个元素多 8 个,a因为我们切片删除了第一个元素,并且每个元素是 8 个字节(dtype 是 64 位整数)。

如果我们使用与上面相同的代码,但使用 8 位整数数据类型,则内存地址将会相差一。

2024-11-03