diff --git a/docs/reference.rst b/docs/reference.rst index 76f913fd..faf3a196 100644 --- a/docs/reference.rst +++ b/docs/reference.rst @@ -32,11 +32,15 @@ Meta options .. attribute:: model This optional attribute describes the class of objects to generate. + It could be a class or the fully qualified import path to it. If unset, it will be inherited from parent :class:`Factory` subclasses. .. versionadded:: 2.4.0 + .. versionadded:: 3.3 + Support fully qualified import path to the class + .. method:: get_model_class() Returns the actual model class (:attr:`FactoryOptions.model` might be the diff --git a/factory/base.py b/factory/base.py index 36b2359a..cef66804 100644 --- a/factory/base.py +++ b/factory/base.py @@ -373,7 +373,7 @@ def get_model_class(self): This can be overridden in framework-specific subclasses to hook into existing model repositories, for instance. """ - return self.model + return utils.resolve_type(self.model) if isinstance(self.model, str) else self.model def __str__(self): return "<%s for %s>" % (self.__class__.__name__, self.factory.__name__) diff --git a/factory/declarations.py b/factory/declarations.py index fe2e34d9..52678286 100644 --- a/factory/declarations.py +++ b/factory/declarations.py @@ -343,31 +343,21 @@ class _FactoryWrapper: path for that subclass (e.g 'myapp.factories.MyFactory'). """ def __init__(self, factory_or_path): - self.factory = None - self.module = self.name = '' - if isinstance(factory_or_path, type): - self.factory = factory_or_path - else: - if not (isinstance(factory_or_path, str) and '.' in factory_or_path): - raise ValueError( - "A factory= argument must receive either a class " - "or the fully qualified path to a Factory subclass; got " - "%r instead." % factory_or_path) - self.module, self.name = factory_or_path.rsplit('.', 1) + + if not (isinstance(factory_or_path, type) or (isinstance(factory_or_path, str) and '.' in factory_or_path)): + raise ValueError( + "A factory= argument must receive either a class " + "or the fully qualified path to a Factory subclass; got " + "%r instead." % factory_or_path) + self.factory = factory_or_path def get(self): - if self.factory is None: - self.factory = utils.import_object( - self.module, - self.name, - ) + if isinstance(self.factory, str): + self.factory = utils.resolve_type(self.factory) return self.factory def __repr__(self): - if self.factory is None: - return f'<_FactoryImport: {self.module}.{self.name}>' - else: - return f'<_FactoryImport: {self.factory.__class__}>' + return f'<_FactoryImport: {self.factory}>' class SubFactory(BaseDeclaration): diff --git a/factory/utils.py b/factory/utils.py index a74e0b35..f813b1eb 100644 --- a/factory/utils.py +++ b/factory/utils.py @@ -16,6 +16,14 @@ def import_object(module_name, attribute_name): return getattr(module, attribute_name) +def resolve_type(type_or_path): + if isinstance(type_or_path, type): + return type_or_path + if not (isinstance(type_or_path, str) and '.' in type_or_path): + raise ValueError("Must receive either an object or the fully qualified path") + return import_object(*type_or_path.rsplit('.', 1)) + + class log_pprint: """Helper for properly printing args / kwargs passed to an object. diff --git a/tests/test_base.py b/tests/test_base.py index 0b9ffa15..c409b40f 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -1,5 +1,6 @@ # Copyright: See the LICENSE file. +import mailbox import unittest from factory import base, declarations, enums, errors @@ -215,6 +216,16 @@ class Meta: with self.assertRaises(TypeError): type("SecondFactory", (base.Factory,), {"Meta": Meta}) + def test_meta_model_as_path(self): + class MailboxFactory(base.Factory): + class Meta: + model = "mailbox.Mailbox" + path = "/tmp/mail" + + box = MailboxFactory() + assert isinstance(box, mailbox.Mailbox) + assert box._path == "/tmp/mail" + class DeclarationParsingTests(unittest.TestCase): def test_classmethod(self): diff --git a/tests/test_declarations.py b/tests/test_declarations.py index c9458ffe..b379bab0 100644 --- a/tests/test_declarations.py +++ b/tests/test_declarations.py @@ -196,7 +196,7 @@ def test_path(self): def test_lazyness(self): f = declarations._FactoryWrapper('factory.declarations.Sequence') - self.assertEqual(None, f.factory) + self.assertEqual('factory.declarations.Sequence', f.factory) factory_class = f.get() self.assertEqual(declarations.Sequence, factory_class) @@ -205,7 +205,7 @@ def test_cache(self): """Ensure that _FactoryWrapper tries to import only once.""" orig_date = datetime.date w = declarations._FactoryWrapper('datetime.date') - self.assertEqual(None, w.factory) + self.assertEqual('datetime.date', w.factory) factory_class = w.get() self.assertEqual(orig_date, factory_class) diff --git a/tests/test_utils.py b/tests/test_utils.py index 1d54eefe..f0625d45 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,6 +1,6 @@ # Copyright: See the LICENSE file. - +import datetime import itertools import unittest @@ -10,9 +10,7 @@ class ImportObjectTestCase(unittest.TestCase): def test_datetime(self): imported = utils.import_object('datetime', 'date') - import datetime - d = datetime.date - self.assertEqual(d, imported) + self.assertEqual(datetime.date, imported) def test_unknown_attribute(self): with self.assertRaises(AttributeError): @@ -23,6 +21,23 @@ def test_invalid_module(self): utils.import_object('this-is-an-invalid-module', '__name__') +class ResolveTypeTestCase(unittest.TestCase): + def test_datetime(self): + imported = utils.resolve_type('datetime.date') + self.assertEqual(datetime.date, imported) + + def test_unknown_attribute(self): + with self.assertRaises(AttributeError): + utils.resolve_type('datetime.foo') + + def test_invalid_module(self): + with self.assertRaises(ImportError): + utils.resolve_type('this-is-an-invalid-module.__name__') + + def test_is_a_class(self): + return utils.resolve_type(datetime.date) is datetime.date + + class LogPPrintTestCase(unittest.TestCase): def test_nothing(self): txt = str(utils.log_pprint())