小能豆

在 Python 类中支持等价性(“平等”)的方法

javascript

编写自定义类时,通过==和运算符实现等价性通常很重要。在 Python 中,这可以通过分别实现和特殊方法!=来实现。我发现最简单的方法是使用以下方法:__eq__``__ne__

class Foo:
    def __init__(self, item):
        self.item = item

    def __eq__(self, other):
        if isinstance(other, self.__class__):
            return self.__dict__ == other.__dict__
        else:
            return False

    def __ne__(self, other):
        return not self.__eq__(other)

您知道有更优雅的方法吗?您知道使用上述方法比较__dict__s 有什么特别的缺点吗?

注意:需要澄清一点——当__eq____ne__未定义时,您会发现这种行为:

>>> a = Foo(1)
>>> b = Foo(1)
>>> a is b
False
>>> a == b
False

也就是说,因为它确实运行了,所以a == b计算为,即身份测试(即“与是同一个对象吗?”)。False``a is b``a``b

__eq____ne__被定义时,你会发现这种行为(这是我们所追求的):

>>> a = Foo(1)
>>> b = Foo(1)
>>> a is b
False
>>> a == b
True

阅读 45

收藏
2024-07-30

共1个答案

小能豆

考虑这个简单的问题:

class Number:

    def __init__(self, number):
        self.number = number


n1 = Number(1)
n2 = Number(1)

n1 == n2 # False -- oops

因此,Python 默认使用对象标识符进行比较操作:

id(n1) # 140400634555856
id(n2) # 140400634555920

覆盖该__eq__函数似乎可以解决问题:

def __eq__(self, other):
    """Overrides the default implementation"""
    if isinstance(other, Number):
        return self.number == other.number
    return False


n1 == n2 # True
n1 != n2 # True in Python 2 -- oops, False in Python 3

Python 2中,请务必记住覆盖该__ne__函数,如文档所述:

比较运算符之间没有隐含关系。 的真值x==y并不意味着 的x!=y假值。因此,在定义 时__eq__(),还应定义__ne__()以使运算符的行为符合预期。

def __ne__(self, other):
    """Overrides the default implementation (unnecessary in Python 3)"""
    return not self.__eq__(other)


n1 == n2 # True
n1 != n2 # False

Python 3中,这不再是必要的,正如文档所述:

默认情况下,__ne__()委托给__eq__()并反转结果,除非它是NotImplemented。比较运算符之间没有其他隐含关系,例如 的真值(x<y or x==y)并不意味着x<=y

但这并不能解决我们所有的问题。让我们添加一个子类:

class SubNumber(Number):
    pass


n3 = SubNumber(1)

n1 == n3 # False for classic-style classes -- oops, True for new-style classes
n3 == n1 # True
n1 != n3 # True for classic-style classes -- oops, False for new-style classes
n3 != n1 # False

注意: Python 2 有两种类:

  • 经典风格(或旧式)的类,不继承object且声明为class A:class A():class A(B):B经典风格类;
  • 新式类,它们继承自object并且被声明为class A(object)class A(B):其中B是新式类。Python 3 只有被声明为class A:class A(object):或 的class A(B):

对于经典风格的类,比较操作总是调用第一个操作数的方法,而对于新风格的类,它总是调用子类操作数的方法,而不管操作数的顺序如何

因此,这里 ifNumber是一个经典风格的类:

  • n1 == n3呼叫n1.__eq__
  • n3 == n1呼叫n3.__eq__
  • n1 != n3呼叫n1.__ne__
  • n3 != n1呼叫n3.__ne__

并且如果Number是新式类:

  • 并呼叫;n1 == n3``n3 == n1``n3.__eq__
  • 并呼叫。n1 != n3``n3 != n1``n3.__ne__

为了修复Python 2 经典样式类的==and运算符的非交换性问题,当操作数类型不受支持时,and方法应该返回值。文档将该值定义为:!=``__eq__``__ne__``NotImplemented``NotImplemented

如果数值方法和丰富的比较方法未实现所提供操作数的运算,则它们可能会返回此值。(然后,解释器将根据运算符尝试反射运算或其他一些后备运算。)其真值为真。

在这种情况下,运算符将比较操作委托给另一个操作数的反射方法文档将反射方法定义为:

这些方法没有交换参数的版本(当左参数不支持操作但右参数支持时使用);相反,__lt__()__gt__()是彼此的反射,__le__()__ge__()是彼此的反射,和 __eq__()__ne__()它们自己的反射。

结果如下:

def __eq__(self, other):
    """Overrides the default implementation"""
    if isinstance(other, Number):
        return self.number == other.number
    return NotImplemented

def __ne__(self, other):
    """Overrides the default implementation (unnecessary in Python 3)"""
    x = self.__eq__(other)
    if x is NotImplemented:
        return NotImplemented
    return not x

即使对于新式类,当操作数属于无关类型(没有继承)时需要and运算符的交换性,则返回NotImplemented值而不是返回False也是正确的做法。==``!=

我们做到了吗?还没。我们有多少个唯一数字?

len(set([n1, n2, n3])) # 3 -- oops

集合使用对象的哈希值,默认情况下 Python 返回对象标识符的哈希值。让我们尝试覆盖它:

def __hash__(self):
    """Overrides the default implementation"""
    return hash(tuple(sorted(self.__dict__.items())))

len(set([n1, n2, n3])) # 1

最终结果如下所示(我在最后添加了一些断言以便验证):

class Number:

    def __init__(self, number):
        self.number = number

    def __eq__(self, other):
        """Overrides the default implementation"""
        if isinstance(other, Number):
            return self.number == other.number
        return NotImplemented

    def __ne__(self, other):
        """Overrides the default implementation (unnecessary in Python 3)"""
        x = self.__eq__(other)
        if x is not NotImplemented:
            return not x
        return NotImplemented

    def __hash__(self):
        """Overrides the default implementation"""
        return hash(tuple(sorted(self.__dict__.items())))


class SubNumber(Number):
    pass


n1 = Number(1)
n2 = Number(1)
n3 = SubNumber(1)
n4 = SubNumber(4)

assert n1 == n2
assert n2 == n1
assert not n1 != n2
assert not n2 != n1

assert n1 == n3
assert n3 == n1
assert not n1 != n3
assert not n3 != n1

assert not n1 == n4
assert not n4 == n1
assert n1 != n4
assert n4 != n1

assert len(set([n1, n2, n3, ])) == 1
assert len(set([n1, n2, n3, n4])) == 2
2024-07-30