加载中...

8.11 简化数据结构的初始化


问题

你写了很多仅仅用作数据结构的类,不想写太多烦人的 __init__() 函数

解决方案

可以在一个基类中写一个公用的 __init__() 函数:

  1. import math
  2. class Structure1:
  3. # Class variable that specifies expected fields
  4. _fields = []
  5. def __init__(self, *args):
  6. if len(args) != len(self._fields):
  7. raise TypeError('Expected {} arguments'.format(len(self._fields)))
  8. # Set the arguments
  9. for name, value in zip(self._fields, args):
  10. setattr(self, name, value)

然后使你的类继承自这个基类:

  1. # Example class definitions
  2. class Stock(Structure1):
  3. _fields = ['name', 'shares', 'price']
  4. class Point(Structure1):
  5. _fields = ['x', 'y']
  6. class Circle(Structure1):
  7. _fields = ['radius']
  8. def area(self):
  9. return math.pi * self.radius ** 2

使用这些类的示例:

  1. >>> s = Stock('ACME', 50, 91.1)
  2. >>> p = Point(2, 3)
  3. >>> c = Circle(4.5)
  4. >>> s2 = Stock('ACME', 50)
  5. Traceback (most recent call last):
  6. File "<stdin>", line 1, in <module>
  7. File "structure.py", line 6, in __init__
  8. raise TypeError('Expected {} arguments'.format(len(self._fields)))
  9. TypeError: Expected 3 arguments

如果还想支持关键字参数,可以将关键字参数设置为实例属性:

  1. class Structure2:
  2. _fields = []
  3. def __init__(self, *args, **kwargs):
  4. if len(args) > len(self._fields):
  5. raise TypeError('Expected {} arguments'.format(len(self._fields)))
  6. # Set all of the positional arguments
  7. for name, value in zip(self._fields, args):
  8. setattr(self, name, value)
  9. # Set the remaining keyword arguments
  10. for name in self._fields[len(args):]:
  11. setattr(self, name, kwargs.pop(name))
  12. # Check for any remaining unknown arguments
  13. if kwargs:
  14. raise TypeError('Invalid argument(s): {}'.format(','.join(kwargs)))
  15. # Example use
  16. if __name__ == '__main__':
  17. class Stock(Structure2):
  18. _fields = ['name', 'shares', 'price']
  19. s1 = Stock('ACME', 50, 91.1)
  20. s2 = Stock('ACME', 50, price=91.1)
  21. s3 = Stock('ACME', shares=50, price=91.1)
  22. # s3 = Stock('ACME', shares=50, price=91.1, aa=1)

你还能将不在 _fields 中的名称加入到属性中去:

  1. class Structure3:
  2. # Class variable that specifies expected fields
  3. _fields = []
  4. def __init__(self, *args, **kwargs):
  5. if len(args) != len(self._fields):
  6. raise TypeError('Expected {} arguments'.format(len(self._fields)))
  7. # Set the arguments
  8. for name, value in zip(self._fields, args):
  9. setattr(self, name, value)
  10. # Set the additional arguments (if any)
  11. extra_args = kwargs.keys() - self._fields
  12. for name in extra_args:
  13. setattr(self, name, kwargs.pop(name))
  14. if kwargs:
  15. raise TypeError('Duplicate values for {}'.format(','.join(kwargs)))
  16. # Example use
  17. if __name__ == '__main__':
  18. class Stock(Structure3):
  19. _fields = ['name', 'shares', 'price']
  20. s1 = Stock('ACME', 50, 91.1)
  21. s2 = Stock('ACME', 50, 91.1, date='8/2/2012')

讨论

当你需要使用大量很小的数据结构类的时候,相比手工一个个定义 __init__() 方法而已,使用这种方式可以大大简化代码。

在上面的实现中我们使用了 setattr() 函数类设置属性值,你可能不想用这种方式,而是想直接更新实例字典,就像下面这样:

  1. class Structure:
  2. # Class variable that specifies expected fields
  3. _fields= []
  4. def __init__(self, *args):
  5. if len(args) != len(self._fields):
  6. raise TypeError('Expected {} arguments'.format(len(self._fields)))
  7. # Set the arguments (alternate)
  8. self.__dict__.update(zip(self._fields,args))

尽管这也可以正常工作,但是当定义子类的时候问题就来了。当一个子类定义了 __slots__ 或者通过property(或描述器)来包装某个属性,那么直接访问实例字典就不起作用了。我们上面使用 setattr() 会显得更通用些,因为它也适用于子类情况。

这种方法唯一不好的地方就是对某些IDE而已,在显示帮助函数时可能不太友好。比如:

  1. >>> help(Stock)
  2. Help on class Stock in module __main__:
  3. class Stock(Structure)
  4. ...
  5. | Methods inherited from Structure:
  6. |
  7. | __init__(self, *args, **kwargs)
  8. |
  9. ...
  10. >>>

可以参考9.16小节来强制在 __init__() 方法中指定参数的类型签名。


还没有评论.