Python中的运算符重载

今天的任务很简单,就是熟悉一下Python中的运算符重载。一般,我们想让自定义的类支持一些计算操作,比如会添加如下方法以期达到计算的目的:

class Vector:

    def __init__(self, x=0, y=0):
        self.x = x
        self.y = y

    def __repr__(self):
        return 'Vector(%r, %r)' % (self.x, self.y)

    def __abs__(self):
        return hypot(self.x, self.y)

    def __bool__(self):
        return bool(abs(self))

    def __add__(self, other):
        x = self.x + other.x
        y = self.y + other.y
        return Vector(x, y)

    def __mul__(self, scalar):
        return Vector(self.x * scalar, self.y * scalar)



v1 = Vector(1, 2)
v2 = Vector(3, 4)
v3 = v1 + v2  # Vector(4, 6)

这种简单的方式没问题,但是python 对此也有一定的约束。

  1. 不能重载内置类型的运算符
  2. 不能新建运算符,只能重载现有的
  3. is , and,or,not不可以重载
    因为在其他语言中,程序员已经把重载运算符给滥用了。
一元运算符
  • 取负数 实现 __neg__(self)方法
  • 取正 __pos__(self)
  • 按位取反 __invert__(self)

    这些需要遵循 : 始终返回一个新对象,不能修改self

举例:

# 这里的Vector 兼容迭代器
    def __add__(self, other):
        paris = itertools.zip_longest(self, other, fillvalue=0.0)
        return Vector(a+b for a, b in Paris)

这里有个问题是 不支持左操作数是 非Vector的对象,但支持但操作数
针对中缀运算符(a+b),Python提供了特殊的分派机制

《Python中的运算符重载》 使用__add__ 和__radd__计算a+b 流程图

def  __radd__(self, other):
    return self + other #直接委托给 __add__

其实这里面还涉及到一个问题,就是操作数是不可迭代对象或者迭代元素不支持该操作符,比如: vector + 1 或者将一个strint 相加,这时候我们就得做出处理

    def __add__(self, other):
        try:
            pairs = itertools.zip_longest(self, other, fillvalue=0.0)
            return Vector(a + b for a, b in pairs)
        except TypeError:
            return NotImplemented

    def __radd__(self, other):
        return self + other

给出一个具体的示例:


from array import array
import reprlib
import math
import functools
import operator
import itertools
import numbers


class Vector:
    typecode = 'd'

    def __init__(self, components):
        self._components = array(self.typecode, components)

    def __iter__(self):
        return iter(self._components)

    def __repr__(self):
        components = reprlib.repr(self._components)
        components = components[components.find('['):-1]
        return 'Vector({})'.format(components)

    def __str__(self):
        return str(tuple(self))

    def __bytes__(self):
        return (bytes([ord(self.typecode)]) +
                bytes(self._components))

    def __eq__(self, other):
        if isinstance(other, Vector):
            return (len(self) == len(other) and
                    all(a == b for a, b in zip(self, other)))
        else:
            return NotImplemented

    def __hash__(self):
        hashes = (hash(x) for x in self)
        return functools.reduce(operator.xor, hashes, 0)

    def __abs__(self):
        return math.sqrt(sum(x * x for x in self))

    def __bool__(self):
        return bool(abs(self))

    def __len__(self):
        return len(self._components)

    def __getitem__(self, index):
        cls = type(self)
        if isinstance(index, slice):
            return cls(self._components[index])
        elif isinstance(index, int):
            return self._components[index]
        else:
            msg = '{.__name__} indices must be integers'
            raise TypeError(msg.format(cls))

    shortcut_names = 'xyzt'

    def __getattr__(self, name):
        cls = type(self)
        if len(name) == 1:
            pos = cls.shortcut_names.find(name)
            if 0 <= pos < len(self._components):
                return self._components[pos]
        msg = '{.__name__!r} object has no attribute {!r}'
        raise AttributeError(msg.format(cls, name))

    def angle(self, n):
        r = math.sqrt(sum(x * x for x in self[n:]))
        a = math.atan2(r, self[n-1])
        if (n == len(self) - 1) and (self[-1] < 0):
            return math.pi * 2 - a
        else:
            return a

    def angles(self):
        return (self.angle(n) for n in range(1, len(self)))

    def __format__(self, fmt_spec=''):
        if fmt_spec.endswith('h'):  # hyperspherical coordinates
            fmt_spec = fmt_spec[:-1]
            coords = itertools.chain([abs(self)],
                                     self.angles())
            outer_fmt = '<{}>'
        else:
            coords = self
            outer_fmt = '({})'
        components = (format(c, fmt_spec) for c in coords)
        return outer_fmt.format(', '.join(components))

    @classmethod
    def frombytes(cls, octets):
        typecode = chr(octets[0])
        memv = memoryview(octets[1:]).cast(typecode)
        return cls(memv)

    def __add__(self, other):
        try:
            pairs = itertools.zip_longest(self, other, fillvalue=0.0)
            return Vector(a + b for a, b in pairs)
        except TypeError:
            return NotImplemented

    def __radd__(self, other):
        return self + other

    def __mul__(self, scalar):
        if isinstance(scalar, numbers.Real):# 这里不使用具体的类,而是使用抽象基类,它涵盖了所需的全部类型
            return Vector(n * scalar for n in self)
        else:
            return NotImplemented

    def __rmul__(self, scalar):
        return self * scalar

    def __matmul__(self, other):
        try:
            return sum(a * b for a, b in zip(self, other))
        except TypeError:
            return NotImplemented

    def __rmatmul__(self, other):
        return self @ other  # this only works in Python 3.5

接下来再看一下中缀运算符方法名

《Python中的运算符重载》 中缀运算符方法名

何时会调用就地运算方法呢?在使用增量赋值运算符中(a+=b ; a*=b)
如果没有实现就地运算方法,a+=b 其实是 a = a+b 创建新的实例,而不是就地修改左操作数

    原文作者:楼上那位
    原文地址: https://www.jianshu.com/p/cd26b184d06e
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞