Skip to content

Commit

Permalink
Merge pull request #33807 from dimagi/mjr/add-change-meta-context
Browse files Browse the repository at this point in the history
  • Loading branch information
mjriley authored Dec 5, 2023
2 parents 1df6804 + 2f6871b commit 416e7f1
Show file tree
Hide file tree
Showing 13 changed files with 200 additions and 29 deletions.
6 changes: 1 addition & 5 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,7 @@ on:
branches:
- master
- hotfix-deploy
- ap/sql-repeater/phase-2
- mjr/erm-update-rules
- mjr/erm-fixtures
- mjr/erm-custom-roles
- mjr/erm-roles
- mjr/add-change-meta-context
schedule:
# see corehq/apps/hqadmin/management/commands/static_analysis.py
- cron: '47 12 * * *'
Expand Down
24 changes: 22 additions & 2 deletions corehq/apps/data_interfaces/pillow.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@

from corehq.apps.data_interfaces.deduplication import is_dedupe_xmlns
from corehq.apps.data_interfaces.models import AutomaticUpdateRule
from corehq.form_processor.exceptions import XFormNotFound
from corehq.form_processor.models import CommCareCase
from corehq.form_processor.models.forms import XFormInstance
from corehq.toggles import CASE_DEDUPE
from corehq.util.soft_assert import soft_assert


class CaseDeduplicationProcessor(PillowProcessor):
Expand All @@ -29,14 +32,15 @@ def process_change(self, change):
if change.deleted:
return

if is_dedupe_xmlns(change.get_document().get('xmlns')):
associated_form = self._get_associated_form(change)
if not associated_form or is_dedupe_xmlns(associated_form.xmlns):
return

rules = self._get_rules(domain)
if not rules:
return

for case_update in get_case_updates(change.get_document()):
for case_update in get_case_updates(associated_form, for_case=change.id):
self._process_case_update(domain, case_update)

def _get_rules(self, domain):
Expand All @@ -54,3 +58,19 @@ def _process_action(self, domain, rule, action, changed_properties, case_id):
case = CommCareCase.objects.get_case(case_id, domain)
if case.type == rule.case_type:
rule.run_rule(case, datetime.utcnow())

def _get_associated_form(self, change):
associated_form_id = change.metadata.associated_document_id
associated_form = None
if associated_form_id:
try:
associated_form = XFormInstance.objects.get_form(associated_form_id)
except XFormNotFound:
_assert = soft_assert(['mriley_at_dimagi_dot_com'.replace('_at_', '@').replace('_dot_', '.')])
_assert(False, 'Associated form not found', {
'case_id': change.id,
'form_id': associated_form_id
})
associated_form = None

return associated_form
8 changes: 4 additions & 4 deletions corehq/apps/data_interfaces/tests/test_case_deduplication.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from corehq.apps.users.models import CommCareUser
from corehq.apps.users.tasks import tag_cases_as_deleted_and_remove_indices
from corehq.form_processor.models import CommCareCase, XFormInstance
from corehq.pillows.xform import get_xform_pillow
from corehq.pillows.case import get_case_pillow
from corehq.util.test_utils import flag_enabled, set_parent_case


Expand Down Expand Up @@ -714,10 +714,10 @@ def setUpClass(cls):
cls.domain = 'naboo'
cls.case_type = 'people'
cls.factory = CaseFactory(cls.domain)
cls.pillow = get_xform_pillow(skip_ucr=True)
cls.pillow = get_case_pillow(skip_ucr=True)

def setUp(self):
self.kafka_offset = get_topic_offset(topics.FORM_SQL)
self.kafka_offset = get_topic_offset(topics.CASE_SQL)

@patch("corehq.apps.data_interfaces.models.find_duplicate_case_ids")
def test_pillow_processes_changes(self, find_duplicate_cases_mock):
Expand All @@ -729,7 +729,7 @@ def test_pillow_processes_changes(self, find_duplicate_cases_mock):

find_duplicate_cases_mock.return_value = [case1.case_id, case2.case_id]

new_kafka_sec = get_topic_offset(topics.FORM_SQL)
new_kafka_sec = get_topic_offset(topics.CASE_SQL)
self.pillow.process_changes(since=self.kafka_offset, forever=False)

self._assert_case_duplicate_pair(case1.case_id, [case2.case_id])
Expand Down
2 changes: 1 addition & 1 deletion corehq/apps/hqcase/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def bulk_update_cases(domain, case_changes, device_id, xmlns=None):

def resave_case(domain, case, send_post_save_signal=True):
from corehq.form_processor.change_publishers import publish_case_saved
publish_case_saved(case, send_post_save_signal)
publish_case_saved(case, send_post_save_signal=send_post_save_signal)


def get_last_non_blank_value(case, case_property):
Expand Down
82 changes: 82 additions & 0 deletions corehq/ex-submodules/casexml/apps/case/tests/test_xform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from django.test import SimpleTestCase
from casexml.apps.case.xml import V2
from casexml.apps.case.xform import get_case_updates
from casexml.apps.case.xml.parser import CaseUpdate


class TestGetCaseUpdates(SimpleTestCase):
default_case_id = '1111'
default_user_id = '2222'
default_modified_time = '2023-11-28T15:26:55.859000Z'

def test_processes_single_case(self):
case_block = self._create_case_block(
case_id='case1',
user_id='abc',
modified_on='2023-11-28T15:26:55.859000Z',
create_block={'case_name': 'test', 'case_type': 'test_type'}
)
xform = {
'case': case_block
}

updates = get_case_updates(xform)
expected_case = self._create_case_update(
case_id='case1',
user_id='abc',
modified_on='2023-11-28T15:26:55.859000Z',
create_block={'case_name': 'test', 'case_type': 'test_type'}
)
self.assertEqual(expected_case, updates[0])

def test_processes_sub_case(self):
case1 = self._create_case_block(case_id='1')
case2 = self._create_case_block(case_id='2')
xform = {
'case': case1,
'sub_case': {
'case': case2
}
}

updates = get_case_updates(xform)
self.assertEqual(updates, [self._create_case_update(case_id='1'), self._create_case_update(case_id='2')])

def test_can_restrict_by_id(self):
case1 = self._create_case_block(case_id='1')
case2 = self._create_case_block(case_id='2')
xform = {
'case': case1,
'sub_case': {
'case': case2
}
}

updates = get_case_updates(xform, for_case='1')
self.assertEqual(updates, [self._create_case_update(case_id='1')])

def _create_case_block(
self, case_id=None, user_id=None, modified_on=None, create_block=None, update_block=None):
block = {
'@case_id': case_id or self.default_case_id,
'@date_modified': modified_on or self.default_modified_time,
'@user_id': user_id or self.default_user_id,
'@xmlns': 'http://commcarehq.org/case/transaction/v2',
}

if create_block:
block['create'] = create_block

if update_block:
block['update'] = update_block

return block

def _create_case_update(
self, case_id=None, user_id=None, modified_on=None, create_block=None, update_block=None):
block = self._create_case_block(case_id, user_id, modified_on, create_block, update_block)

return CaseUpdate(
case_id or self.default_case_id, V2, block,
user_id=(user_id or self.default_user_id),
modified_on_str=modified_on or self.default_modified_time)
42 changes: 41 additions & 1 deletion corehq/ex-submodules/casexml/apps/case/tests/xml/test_parser.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from casexml.apps.case.xml.parser import CaseUpdate
from casexml.apps.case.xml.parser import CaseUpdate, CaseCreateAction
from casexml.apps.case.xml import V2
from django.test import SimpleTestCase

Expand Down Expand Up @@ -40,6 +40,29 @@ def test_get_normalized_updates(self):
self.assertEqual(case_update.get_normalized_update_property_names(),
{'name', 'owner_id', 'type'})

def test_equality(self):
create_block = {
'case_name': 'test_case',
'owner_id': '12345',
'case_type': 'test_case_type'
}
case_block = self._create_case_block(create_block)

case_update_1 = CaseUpdate('case_id', V2, case_block)
case_update_2 = CaseUpdate('case_id', V2, case_block)
self.assertEqual(case_update_1, case_update_2)

def test_non_equality(self):
create_block = {
'case_name': 'test_case',
'owner_id': '12345',
'case_type': 'test_case_type'
}
case_block = self._create_case_block(create_block)
case_update_1 = CaseUpdate('case_id', V2, case_block)
case_update_2 = CaseUpdate('case_id2', V2, case_block)
self.assertNotEqual(case_update_1, case_update_2)

def _create_case_block(self, create_block=None, update_block=None):
block = {
'@case_id': '1111',
Expand All @@ -55,3 +78,20 @@ def _create_case_block(self, create_block=None, update_block=None):
block['update'] = update_block

return block


class CaseActionTests(SimpleTestCase):
def test_equality(self):
block = {
'case_name': 'test'
}
action1 = CaseCreateAction(block)
action2 = CaseCreateAction(block)

self.assertEqual(action1, action2)

def test_non_equality(self):
action1 = CaseCreateAction({'case_name': 'one'})
action2 = CaseCreateAction({'case_name': 'two'})

self.assertNotEqual(action1, action2)
20 changes: 14 additions & 6 deletions corehq/ex-submodules/casexml/apps/case/xform.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,14 +226,22 @@ def _extract_case_blocks(data, path=None, form_id=Ellipsis):
yield from _extract_case_blocks(value, new_path, form_id=form_id)


def get_case_updates(xform):
def get_case_updates(xform, for_case=None):
if not xform:
return []
updates = sorted(
[case_update_from_block(cb) for cb in extract_case_blocks(xform)],
key=lambda update: update.id
)
by_case_id = groupby(updates, lambda update: update.id)

updates = [case_update_from_block(cb) for cb in extract_case_blocks(xform)]

if for_case:
updates = [update for update in updates if update.id == for_case]
by_case_id = [(for_case, updates)]
else:
updates = sorted(
updates,
key=lambda update: update.id
)
by_case_id = groupby(updates, lambda update: update.id)

return list(itertools.chain(
*[order_updates(updates) for case_id, updates in by_case_id]
))
Expand Down
17 changes: 17 additions & 0 deletions corehq/ex-submodules/casexml/apps/case/xml/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,13 @@ def get_known_properties(self):
def __repr__(self):
return f"{type(self).__name__}(block={self.raw_block!r})"

def __eq__(self, other):
return (isinstance(other, self.__class__)
and self.__dict__ == other.__dict__)

def __ne__(self, other):
return not self.__eq__(other)

@classmethod
def _from_block_and_mapping(cls, block, mapping):
def _normalize(val):
Expand Down Expand Up @@ -345,6 +352,16 @@ def has_attachments(self):
def __str__(self):
return "%s: %s" % (self.version, self.id)

def __repr__(self):
return str(self.__dict__)

def __eq__(self, other):
return (isinstance(other, self.__class__)
and self.__dict__ == other.__dict__)

def __ne__(self, other):
return not self.__eq__(other)

def _filtered_action(self, func):
# filters the actions, assumes exactly 0 or 1 match.
filtered = list(filter(func, self.actions))
Expand Down
9 changes: 6 additions & 3 deletions corehq/ex-submodules/pillowtop/feed/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ class ChangeMeta(jsonobject.JsonObject):
# track when first published (will not get updated on retry, unlike publish_timestamp)
original_publication_datetime = jsonobject.DateTimeProperty(default=datetime.utcnow)

# available to hold any associated document. For cases, this is the form ID responsible for the change
associated_document_id = jsonobject.StringProperty()


class Change(object):
"""
Expand Down Expand Up @@ -121,10 +124,10 @@ def __getitem__(self, key):
return self._dict[key]

def __setitem__(self, key, value):
raise NotImplemented('This is a read-only dictionary!')
raise NotImplementedError('This is a read-only dictionary!')

def __delitem__(self, key, value):
raise NotImplemented('This is a read-only dictionary!')
raise NotImplementedError('This is a read-only dictionary!')

def __iter__(self):
return iter(self._dict)
Expand All @@ -136,7 +139,7 @@ def get(self, key, default=None):
return self._dict.get(key, default)

def pop(self, key, default):
raise NotImplemented('This is a read-only dictionary!')
raise NotImplementedError('This is a read-only dictionary!')

def to_dict(self):
return self._dict
Expand Down
6 changes: 5 additions & 1 deletion corehq/form_processor/backends/sql/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,11 @@ def publish_changes_to_kafka(processed_forms, cases, stock_result):
publish_form_saved(processed_forms.submitted)
cases = cases or []
for case in cases:
publish_case_saved(case, send_post_save_signal=False)
publish_case_saved(
case,
associated_form_id=processed_forms.submitted.form_id,
send_post_save_signal=False
)

if stock_result:
for ledger in stock_result.models_to_save:
Expand Down
7 changes: 4 additions & 3 deletions corehq/form_processor/change_publishers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,16 @@ def publish_form_deleted(domain, form_id):
))


def publish_case_saved(case, send_post_save_signal=True):
def publish_case_saved(case, associated_form_id=None, send_post_save_signal=True):
"""
Publish the change to kafka and run case post-save signals.
"""
producer.send_change(topics.CASE_SQL, change_meta_from_sql_case(case))
producer.send_change(topics.CASE_SQL, change_meta_from_sql_case(case, associated_form_id))
if send_post_save_signal:
sql_case_post_save.send(case.__class__, case=case)


def change_meta_from_sql_case(case):
def change_meta_from_sql_case(case, associated_form_id=None):
return ChangeMeta(
document_id=case.case_id,
data_source_type=data_sources.SOURCE_SQL,
Expand All @@ -49,6 +49,7 @@ def change_meta_from_sql_case(case):
document_subtype=case.type,
domain=case.domain,
is_deletion=case.is_deleted,
associated_document_id=associated_form_id
)


Expand Down
3 changes: 3 additions & 0 deletions corehq/pillows/case.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
get_ucr_processor,
)
from corehq.form_processor.backends.sql.dbaccessors import CaseReindexAccessor
from corehq.apps.data_interfaces.pillow import CaseDeduplicationProcessor
from corehq.messaging.pillow import CaseMessagingSyncProcessor
from corehq.pillows.base import is_couch_change_for_sql_domain
from corehq.pillows.case_search import get_case_search_processor
Expand Down Expand Up @@ -119,6 +120,8 @@ def get_case_pillow(
processors = [case_to_es_processor, CaseMessagingSyncProcessor()]
if settings.RUN_CASE_SEARCH_PILLOW:
processors.append(case_search_processor)
if settings.RUN_DEDUPLICATION_PILLOW:
processors.append(CaseDeduplicationProcessor())
if not skip_ucr:
# this option is useful in tests to avoid extra UCR setup where unneccessary
processors = [ucr_processor, ucr_dr_processor] + processors
Expand Down
Loading

0 comments on commit 416e7f1

Please sign in to comment.