diff --git a/waffle/models.py b/waffle/models.py index fa76b6eb..40fff1b7 100644 --- a/waffle/models.py +++ b/waffle/models.py @@ -56,6 +56,13 @@ def get_from_db(cls, name): objects = cls.objects if get_setting('READ_FROM_WRITE_DB'): objects = objects.using(router.db_for_write(cls)) + + if get_setting(cls.CREATE_MISSING_SETTING): + obj, _created = objects.get_or_create( + name=name, defaults=cls._defaults() + ) + return obj + return objects.get(name=name) @classmethod @@ -197,12 +204,17 @@ class AbstractBaseFlag(BaseModel): SINGLE_CACHE_KEY = 'FLAG_CACHE_KEY' ALL_CACHE_KEY = 'ALL_FLAGS_CACHE_KEY' + CREATE_MISSING_SETTING = 'CREATE_MISSING_FLAGS' class Meta: abstract = True verbose_name = _('Flag') verbose_name_plural = _('Flags') + @classmethod + def _defaults(cls): + return {'everyone': get_setting('FLAG_DEFAULT')} + def flush(self): cache = get_cache() keys = self.get_flush_keys() @@ -247,16 +259,6 @@ def is_active(self, request): log_level = get_setting('LOG_MISSING_FLAGS') if log_level: logger.log(log_level, 'Flag %s not found', self.name) - if get_setting('CREATE_MISSING_FLAGS'): - flag, _created = get_waffle_flag_model().objects.get_or_create( - name=self.name, - defaults={ - 'everyone': get_setting('FLAG_DEFAULT') - } - ) - cache = get_cache() - cache.set(self._cache_key(self.name), flag) - return get_setting('FLAG_DEFAULT') if get_setting('OVERRIDE'): @@ -443,25 +445,21 @@ class Switch(BaseModel): SINGLE_CACHE_KEY = 'SWITCH_CACHE_KEY' ALL_CACHE_KEY = 'ALL_SWITCHES_CACHE_KEY' + CREATE_MISSING_SETTING = 'CREATE_MISSING_SWITCHES' class Meta: verbose_name = _('Switch') verbose_name_plural = _('Switches') + @classmethod + def _defaults(cls): + return {'active': get_setting('SWITCH_DEFAULT')} + def is_active(self): if not self.pk: log_level = get_setting('LOG_MISSING_SWITCHES') if log_level: logger.log(log_level, 'Switch %s not found', self.name) - if get_setting('CREATE_MISSING_SWITCHES'): - switch, _created = Switch.objects.get_or_create( - name=self.name, - defaults={ - 'active': get_setting('SWITCH_DEFAULT') - } - ) - cache = get_cache() - cache.set(self._cache_key(self.name), switch) return get_setting('SWITCH_DEFAULT') @@ -510,28 +508,22 @@ class Sample(BaseModel): SINGLE_CACHE_KEY = 'SAMPLE_CACHE_KEY' ALL_CACHE_KEY = 'ALL_SAMPLES_CACHE_KEY' + CREATE_MISSING_SETTING = 'CREATE_MISSING_SAMPLES' class Meta: verbose_name = _('Sample') verbose_name_plural = _('Samples') + @classmethod + def _defaults(cls): + default_percent = 100 if get_setting('SAMPLE_DEFAULT') else 0 + return {'percent': default_percent} + def is_active(self): if not self.pk: log_level = get_setting('LOG_MISSING_SAMPLES') if log_level: logger.log(log_level, 'Sample %s not found', self.name) - if get_setting('CREATE_MISSING_SAMPLES'): - - default_percent = 100 if get_setting('SAMPLE_DEFAULT') else 0 - - sample, _created = Sample.objects.get_or_create( - name=self.name, - defaults={ - 'percent': default_percent - } - ) - cache = get_cache() - cache.set(self._cache_key(self.name), sample) return get_setting('SAMPLE_DEFAULT') return Decimal(str(random.uniform(0, 100))) <= self.percent diff --git a/waffle/tests/test_waffle.py b/waffle/tests/test_waffle.py index 1b906eed..684f0809 100644 --- a/waffle/tests/test_waffle.py +++ b/waffle/tests/test_waffle.py @@ -441,6 +441,16 @@ def test_flag_created_dynamically_default_false(self): def test_flag_created_dynamically_default_true(self): self.assert_flag_dynamically_created_with_value(True) + @override_settings(WAFFLE_CREATE_MISSING_FLAGS=True) + @override_settings(WAFFLE_FLAG_DEFAULT=True) + def test_flag_created_dynamically_upon_retrieval(self): + FLAG_NAME = 'myflag' + flag_model = waffle.get_waffle_flag_model() + flag = flag_model.get(FLAG_NAME) + + assert flag.is_active(get()) + assert flag_model.objects.filter(name=FLAG_NAME).exists() + @mock.patch('waffle.models.logger') def test_no_logging_missing_flag_by_default(self, mock_logger): request = get()