Skip to content
This repository has been archived by the owner on Dec 10, 2018. It is now read-only.

Optionally use __slots__ for payload members #259

Open
wants to merge 12 commits into
base: develop
Choose a base branch
from
66 changes: 66 additions & 0 deletions tests/test_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,69 @@ def test_tpayload_pickle():
person_2 = pickle.loads(PICKLED_BYTES)

assert person == person_2


def test_load_slots():
thrift = thriftpy.load(
'addressbook.thrift',
use_slots=True,
module_name='addressbook_thrift'
)

# normal structs will have slots
assert thrift.PhoneNumber.__slots__ == ['type', 'number', 'mix_item']
assert thrift.Person.__slots__ == ['name', 'phones', 'created_at']
assert thrift.AddressBook.__slots__ == ['people']

# get/set undefined attributes
person = thrift.Person()
with pytest.raises(AttributeError):
person.attr_not_exist = "Does not work"

with pytest.raises(AttributeError):
person.attr_not_exist

pn = thrift.PhoneNumber()
with pytest.raises(AttributeError):
pn.attr_not_exist = "Does not work"

with pytest.raises(AttributeError):
pn.attr_not_exist

ab = thrift.AddressBook()
with pytest.raises(AttributeError):
ab.attr_not_exist = "Does not work"

with pytest.raises(AttributeError):
ab.attr_not_exist
# eo: get/set

# exceptions will not have slots
assert not hasattr(thrift.PersonNotExistsError, '__slots__')

# enums will not have slots
assert not hasattr(thrift.PhoneType, '__slots__')

# service itself will not be created with slots
assert not hasattr(thrift.AddressBookService, '__slots__')

# service args will have slots
args_slots = thrift.AddressBookService.get_phonenumbers_args.__slots__
assert args_slots == ['name', 'count']

result_slots = thrift.AddressBookService.get_phonenumbers_result.__slots__
assert result_slots == ['success']

# should be able to pickle slotted objects - if load with module_name
bob = thrift.Person(name="Bob")
p_str = pickle.dumps(bob)

assert pickle.loads(p_str) == bob

# works for recursive types too
rec = thriftpy.load('parser-cases/recursive_union.thrift', use_slots=True)
rec_slots = rec.Dynamic.__slots__
assert rec_slots == ['boolean', 'integer', 'doubl', 'str', 'arr', 'object']
dyn = rec.Dynamic()
with pytest.raises(AttributeError):
dyn.attr_not_exist = "shouldn't work"
8 changes: 4 additions & 4 deletions thriftpy/parser/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from .parser import parse, parse_fp


def load(path, module_name=None, include_dirs=None, include_dir=None):
def load(path, module_name=None, include_dirs=None, include_dir=None, use_slots=False):
"""Load thrift file as a module.

The module loaded and objects inside may only be pickled if module_name
Expand All @@ -27,17 +27,17 @@ def load(path, module_name=None, include_dirs=None, include_dir=None):
"""
real_module = bool(module_name)
thrift = parse(path, module_name, include_dirs=include_dirs,
include_dir=include_dir)
include_dir=include_dir, use_slots=use_slots)

if real_module:
sys.modules[module_name] = thrift
return thrift


def load_fp(source, module_name):
def load_fp(source, module_name, use_slots=False):
"""Load thrift file like object as a module.
"""
thrift = parse_fp(source, module_name)
thrift = parse_fp(source, module_name, use_slots=use_slots)
sys.modules[module_name] = thrift
return thrift

Expand Down
51 changes: 39 additions & 12 deletions thriftpy/parser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from .lexer import * # noqa
from .exc import ThriftParserError, ThriftGrammerError
from thriftpy._compat import urlopen, urlparse
from ..thrift import gen_init, TType, TPayload, TException
from ..thrift import gen_init, TType, TPayload, TSPayload, TException


def p_error(p):
Expand Down Expand Up @@ -215,7 +215,9 @@ def p_struct(p):

def p_seen_struct(p):
'''seen_struct : STRUCT IDENTIFIER '''
val = _make_empty_struct(p[2])
use_slots = p.parser.__use_slots__
base_cls = TSPayload if use_slots else TPayload
val = _make_empty_struct(p[2], base_cls=base_cls)
setattr(thrift_stack[-1], p[2], val)
p[0] = val

Expand All @@ -228,7 +230,9 @@ def p_union(p):

def p_seen_union(p):
'''seen_union : UNION IDENTIFIER '''
val = _make_empty_struct(p[2])
use_slots = p.parser.__use_slots__
base_cls = TSPayload if use_slots else TPayload
val = _make_empty_struct(p[2], base_cls=base_cls)
setattr(thrift_stack[-1], p[2], val)
p[0] = val

Expand Down Expand Up @@ -262,7 +266,8 @@ def p_service(p):
else:
extends = None

val = _make_service(p[2], p[len(p) - 2], extends)
use_slots = p.parser.__use_slots__
val = _make_service(p[2], p[len(p) - 2], extends, use_slots=use_slots)
setattr(thrift, p[2], val)
_add_thrift_meta('services', val)

Expand Down Expand Up @@ -430,8 +435,12 @@ def p_definition_type(p):
thrift_cache = {}


def _get_cache_key(prefix, use_slots=False):
return ('%s:slotted' % prefix) if use_slots else prefix


def parse(path, module_name=None, include_dirs=None, include_dir=None,
lexer=None, parser=None, enable_cache=True):
lexer=None, parser=None, enable_cache=True, use_slots=False):
"""Parse a single thrift file to module object, e.g.::

>>> from thriftpy.parser.parser import parse
Expand All @@ -452,6 +461,7 @@ def parse(path, module_name=None, include_dirs=None, include_dir=None,
:param enable_cache: if this is set to be `True`, parsed module will be
cached, this is enabled by default. If `module_name`
is provided, use it as cache key, else use the `path`.
:param use_slots: if set to `True` uses slots for struct members
"""
if os.name == 'nt' and sys.version_info < (3, 2):
os.path.samefile = lambda f1, f2: os.stat(f1) == os.stat(f2)
Expand All @@ -464,7 +474,8 @@ def parse(path, module_name=None, include_dirs=None, include_dir=None,

global thrift_cache

cache_key = module_name or os.path.normpath(path)
cache_prefix = module_name or os.path.normpath(path)
cache_key = _get_cache_key(cache_prefix, use_slots)

if enable_cache and cache_key in thrift_cache:
return thrift_cache[cache_key]
Expand All @@ -474,6 +485,8 @@ def parse(path, module_name=None, include_dirs=None, include_dir=None,
if parser is None:
parser = yacc.yacc(debug=False, write_tables=0)

parser.__use_slots__ = use_slots

global include_dirs_

if include_dirs is not None:
Expand Down Expand Up @@ -515,7 +528,7 @@ def parse(path, module_name=None, include_dirs=None, include_dir=None,
return thrift


def parse_fp(source, module_name, lexer=None, parser=None, enable_cache=True):
def parse_fp(source, module_name, lexer=None, parser=None, enable_cache=True, use_slots=False):
"""Parse a file-like object to thrift module object, e.g.::

>>> from thriftpy.parser.parser import parse_fp
Expand All @@ -530,13 +543,16 @@ def parse_fp(source, module_name, lexer=None, parser=None, enable_cache=True):
:param parser: ply parser to use, if not provided, `parse` will new one.
:param enable_cache: if this is set to be `True`, parsed module will be
cached by `module_name`, this is enabled by default.
:param use_slots: if set to `True` uses slots for struct members
"""
if not module_name.endswith('_thrift'):
raise ThriftParserError('ThriftPy can only generate module with '
'\'_thrift\' suffix')

cache_key = _get_cache_key(module_name, use_slots)

if enable_cache and module_name in thrift_cache:
return thrift_cache[module_name]
return thrift_cache[cache_key]

if not hasattr(source, 'read'):
raise ThriftParserError('Except `source` to be a file-like object with'
Expand All @@ -547,6 +563,8 @@ def parse_fp(source, module_name, lexer=None, parser=None, enable_cache=True):
if parser is None:
parser = yacc.yacc(debug=False, write_tables=0)

parser.__use_slots__ = use_slots

data = source.read()

thrift = types.ModuleType(module_name)
Expand All @@ -557,7 +575,7 @@ def parse_fp(source, module_name, lexer=None, parser=None, enable_cache=True):
thrift_stack.pop()

if enable_cache:
thrift_cache[module_name] = thrift
thrift_cache[cache_key] = thrift
return thrift


Expand Down Expand Up @@ -749,6 +767,8 @@ def _make_enum(name, kvs):

def _make_empty_struct(name, ttype=TType.STRUCT, base_cls=TPayload):
attrs = {'__module__': thrift_stack[-1].__name__, '_ttype': ttype}
if issubclass(base_cls, TSPayload):
attrs['__slots__'] = []
return type(name, (base_cls, ), attrs)


Expand All @@ -769,6 +789,9 @@ def _fill_in_struct(cls, fields, _gen_init=True):
setattr(cls, 'thrift_spec', thrift_spec)
setattr(cls, 'default_spec', default_spec)
setattr(cls, '_tspec', _tspec)
# add __slots__ for easy introspection
if issubclass(cls, TSPayload):
cls.__slots__ = [field for field, _ in default_spec]
if _gen_init:
gen_init(cls, thrift_spec, default_spec)
return cls
Expand All @@ -780,11 +803,13 @@ def _make_struct(name, fields, ttype=TType.STRUCT, base_cls=TPayload,
return _fill_in_struct(cls, fields, _gen_init=_gen_init)


def _make_service(name, funcs, extends):
def _make_service(name, funcs, extends, use_slots=False):
if extends is None:
extends = object

attrs = {'__module__': thrift_stack[-1].__name__}
base_cls = TSPayload if use_slots else TPayload
# service class itself will not be created with slots
cls = type(name, (extends, ), attrs)
thrift_services = []

Expand All @@ -793,21 +818,23 @@ def _make_service(name, funcs, extends):
# args payload cls
args_name = '%s_args' % func_name
args_fields = func[3]
args_cls = _make_struct(args_name, args_fields)
args_cls = _make_struct(args_name, args_fields, base_cls=base_cls)
setattr(cls, args_name, args_cls)
# result payload cls
result_name = '%s_result' % func_name
result_type = func[1]
result_throws = func[4]
result_oneway = func[0]
result_cls = _make_struct(result_name, result_throws,
_gen_init=False)
_gen_init=False, base_cls=base_cls)
setattr(result_cls, 'oneway', result_oneway)
if result_type != TType.VOID:
result_cls.thrift_spec[0] = _ttype_spec(result_type, 'success')
result_cls.default_spec.insert(0, ('success', None))
gen_init(result_cls, result_cls.thrift_spec, result_cls.default_spec)
setattr(cls, result_name, result_cls)
# default spec is modified after making struct so add slots here
result_cls.__slots__ = [f for f, _ in result_cls.default_spec]
thrift_services.append(func_name)
if extends is not None and hasattr(extends, 'thrift_services'):
thrift_services.extend(extends.thrift_services)
Expand Down
61 changes: 61 additions & 0 deletions thriftpy/thrift.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@

from __future__ import absolute_import

try:
import copy_reg as copyreg
except ImportError:
import copyreg

import functools
import linecache
import types
Expand Down Expand Up @@ -126,12 +131,35 @@ class TMessageType(object):

class TPayloadMeta(type):

_class_cache = {}

def __new__(cls, name, bases, attrs):
if "default_spec" in attrs:
spec = attrs.pop("default_spec")
attrs["__init__"] = init_func_generator(cls, spec)
return super(TPayloadMeta, cls).__new__(cls, name, bases, attrs)

def __call__(cls, *args, **kw):
if not issubclass(cls, TSPayload):
return type.__call__(cls, *args, **kw)
cls_name = cls.__name__.split('.')[-1]
cache_key = '%s:%s' % (cls.__module__, cls_name)
kls = TPayloadMeta._class_cache.get(cache_key)
if not kls:
fields = [field for field, _ in cls.default_spec]
kls = type(
cls_name,
(cls,),
{
'__slots__': fields,
'__module__': cls.__module__,
}
)
TPayloadMeta._class_cache[cache_key] = kls
fn = lambda obj: (cls, tuple(getattr(obj, f) for f in fields))
copyreg.pickle(kls, fn)
return type.__call__(kls, *args, **kw)


def gen_init(cls, thrift_spec=None, default_spec=None):
if thrift_spec is not None:
Expand Down Expand Up @@ -167,6 +195,39 @@ def __ne__(self, other):
return not self.__eq__(other)


class TSPayload(with_metaclass(TPayloadMeta, object)):

__slots__ = tuple()

__hash__ = None

def read(self, iprot):
iprot.read_struct(self)

def write(self, oprot):
oprot.write_struct(self)

def __repr__(self):
keys = self.__slots__
values = [getattr(self, k) for k in keys]
l = ['%s=%r' % (key, value) for key, value in zip(keys, values)]
return '%s(%s)' % (self.__class__.__name__, ', '.join(l))

def __str__(self):
return repr(self)

def __eq__(self, other):
if not isinstance(other, self.__class__):
return False
keys = self.__slots__
this = [getattr(self, k) for k in keys]
other_ = [getattr(other, k) for k in keys]
return this == other_

def __ne__(self, other):
return not self.__eq__(other)


class TClient(object):

def __init__(self, service, iprot, oprot=None):
Expand Down