diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 69dfe4dac8..30068d66ec 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -416,13 +416,13 @@ async def _get_linearized_receipts_for_rooms( def f( txn: LoggingTransaction, - ) -> List[Tuple[str, str, str, str, Optional[str], str]]: + ) -> List[Tuple[str, str, str, str, Optional[str], str, str]]: if from_key: sql = """ SELECT stream_id, instance_name, room_id, receipt_type, - user_id, event_id, thread_id, data - FROM receipts_linearized WHERE - stream_id > ? AND stream_id <= ? AND + user_id, event_id, thread_id, event_stream_ordering, data + FROM receipts_linearized + WHERE stream_id > ? AND stream_id <= ? AND """ clause, args = make_in_list_sql_clause( self.database_engine, "room_id", room_ids @@ -435,9 +435,9 @@ def f( else: sql = """ SELECT stream_id, instance_name, room_id, receipt_type, - user_id, event_id, thread_id, data - FROM receipts_linearized WHERE - stream_id <= ? AND + user_id, event_id, thread_id, event_stream_ordering, data + FROM receipts_linearized + WHERE stream_id > ? AND stream_id <= ? AND """ clause, args = make_in_list_sql_clause( @@ -447,8 +447,8 @@ def f( txn.execute(sql + clause, [to_key.get_max_stream_pos()] + list(args)) return [ - (room_id, receipt_type, user_id, event_id, thread_id, data) - for stream_id, instance_name, room_id, receipt_type, user_id, event_id, thread_id, data in txn + (room_id, receipt_type, user_id, event_id, thread_id, event_stream_ordering, data) + for stream_id, instance_name, room_id, receipt_type, user_id, event_id, thread_id, event_stream_ordering, data in txn if MultiWriterStreamToken.is_stream_position_in_range( from_key, to_key, instance_name, stream_id ) @@ -459,7 +459,7 @@ def f( ) results: JsonDict = {} - for room_id, receipt_type, user_id, event_id, thread_id, data in txn_results: + for room_id, receipt_type, user_id, event_id, thread_id, event_stream_ordering, data in txn_results: # We want a single event per room, since we want to batch the # receipts by room, event and type. room_event = results.setdefault( @@ -480,6 +480,10 @@ def f( # This means we will drop some receipts, but MSC4102 is designed to drop semantically # meaningless receipts, so this is okay. Previously, we would drop meaningful data! receipt_data = db_to_json(data) + + # MSC4033: inject event order into receipt + receipt_data["com.beeper.hs.order"] = event_stream_ordering + if user_id in receipt_type_dict: # existing receipt # is the existing receipt threaded and we are currently processing an unthreaded one? if "thread_id" in receipt_type_dict[user_id] and not thread_id: @@ -517,19 +521,19 @@ async def get_linearized_receipts_for_all_rooms( A dictionary of roomids to a list of receipts. """ - def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str, str]]: + def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str, str, str]]: if from_key: sql = """ - SELECT stream_id, instance_name, room_id, receipt_type, user_id, event_id, data - FROM receipts_linearized WHERE - stream_id > ? AND stream_id <= ? + SELECT stream_id, instance_name, room_id, receipt_type, user_id, event_id, event_stream_ordering, data + FROM receipts_linearized + WHERE stream_id > ? AND stream_id <= ? ORDER BY stream_id DESC LIMIT 100 """ txn.execute(sql, [from_key.stream, to_key.get_max_stream_pos()]) else: sql = """ - SELECT stream_id, instance_name, room_id, receipt_type, user_id, event_id, data + SELECT stream_id, instance_name, room_id, receipt_type, user_id, event_id, event_stream_ordering, data FROM receipts_linearized WHERE stream_id <= ? ORDER BY stream_id DESC @@ -539,8 +543,8 @@ def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str, str]]: txn.execute(sql, [to_key.get_max_stream_pos()]) return [ - (room_id, receipt_type, user_id, event_id, data) - for stream_id, instance_name, room_id, receipt_type, user_id, event_id, data in txn + (room_id, receipt_type, user_id, event_id, event_stream_ordering, data) + for stream_id, instance_name, room_id, receipt_type, user_id, event_id, event_stream_ordering, data in txn if MultiWriterStreamToken.is_stream_position_in_range( from_key, to_key, instance_name, stream_id ) @@ -551,7 +555,7 @@ def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str, str]]: ) results: JsonDict = {} - for room_id, receipt_type, user_id, event_id, data in txn_results: + for room_id, receipt_type, user_id, event_id, event_stream_ordering, data in txn_results: # We want a single event per room, since we want to batch the # receipts by room, event and type. room_event = results.setdefault( @@ -566,6 +570,9 @@ def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str, str]]: receipt_type_dict[user_id] = db_to_json(data) + # MSC4033: inject event order into receipt + receipt_type_dict[user_id]["com.beeper.hs.order"] = event_stream_ordering + return results async def get_users_sent_receipts_between(