diff --git a/corehq/motech/repeaters/models.py b/corehq/motech/repeaters/models.py index de79fdb559bf..56860797d8c2 100644 --- a/corehq/motech/repeaters/models.py +++ b/corehq/motech/repeaters/models.py @@ -618,6 +618,10 @@ def form_class_name(self): """ return self._repeater_type + @property + def can_merge_records(self): + return hasattr(self, 'merge_records') + class FormRepeater(Repeater): @@ -722,6 +726,46 @@ def get_headers(self, repeat_record): }) return headers + # TODO: Test + def merge_records(self): + """ + ``CaseRepeater`` and ``UpdateCaseRepeater`` forward a case as it + is at the time of sending (not as it is at the time that the + repeat record was registered). This method merges repeat records + to send each case only once and cancel any duplicate repeat + records for the same case. + """ + # Get only the payload IDs to be sent + payload_ids = { + record.payload_id + for record in self.repeat_records_ready[:self.num_workers] + } + payload_records = ( + self.repeat_records_ready + .filter(payload_id__in=payload_ids) + ) + if len(payload_records) == len(payload_ids): + # There are no duplicates + return + + records_by_payload_id = defaultdict(list) + for record in payload_records: + records_by_payload_id[record.payload_id].append(record) + + for payload_id, records in records_by_payload_id.items(): + if len(records) > 1: + new_record = RepeatRecord( + repeater_id=self.id, + domain=self.domain, + registered_at=records[0].registered_at, + next_check=records[0].next_check, + payload_id=payload_id, + ) + new_record.save() + for old_record in records: + old_record.cancel() + old_record.save() + class CreateCaseRepeater(CaseRepeater): class Meta: @@ -733,6 +777,11 @@ def allowed_to_forward(self, payload): # assume if there's exactly 1 xform_id that modified the case it's being created return super().allowed_to_forward(payload) and len(payload.xform_ids) == 1 + @property + def can_merge_records(self): + # CreateCaseRepeater will not have duplicate repeat records + return False + class UpdateCaseRepeater(CaseRepeater): """ diff --git a/corehq/motech/repeaters/tasks.py b/corehq/motech/repeaters/tasks.py index b7020c173b12..3ef5f6d251ac 100644 --- a/corehq/motech/repeaters/tasks.py +++ b/corehq/motech/repeaters/tasks.py @@ -362,7 +362,7 @@ def process_repeater(repeater, lock_token): """ Initiates a Celery task to process a repeater. """ - if hasattr(repeater, 'merge_records'): + if repeater.can_merge_records: _merge_records_and_process_repeater.delay(repeater.repeater_id, lock_token) else: _process_repeater(repeater, lock_token)