Skip to content

Commit

Permalink
Add com.beeper.hs.order to initial and incremental sync responses.
Browse files Browse the repository at this point in the history
  • Loading branch information
adamvy committed May 22, 2024
1 parent 63b773d commit 2f7e55c
Showing 1 changed file with 25 additions and 18 deletions.
43 changes: 25 additions & 18 deletions synapse/storage/databases/main/receipts.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,13 +416,13 @@ async def _get_linearized_receipts_for_rooms(

def f(
txn: LoggingTransaction,
) -> List[Tuple[str, str, str, str, Optional[str], str]]:
) -> List[Tuple[str, str, str, str, Optional[str], str, str]]:
if from_key:
sql = """
SELECT stream_id, instance_name, room_id, receipt_type,
user_id, event_id, thread_id, data
FROM receipts_linearized WHERE
stream_id > ? AND stream_id <= ? AND
user_id, event_id, thread_id, event_stream_ordering, data
FROM receipts_linearized
WHERE stream_id > ? AND stream_id <= ? AND
"""
clause, args = make_in_list_sql_clause(
self.database_engine, "room_id", room_ids
Expand All @@ -435,9 +435,9 @@ def f(
else:
sql = """
SELECT stream_id, instance_name, room_id, receipt_type,
user_id, event_id, thread_id, data
FROM receipts_linearized WHERE
stream_id <= ? AND
user_id, event_id, thread_id, event_stream_ordering, data
FROM receipts_linearized
WHERE stream_id > ? AND stream_id <= ? AND
"""

clause, args = make_in_list_sql_clause(
Expand All @@ -447,8 +447,8 @@ def f(
txn.execute(sql + clause, [to_key.get_max_stream_pos()] + list(args))

return [
(room_id, receipt_type, user_id, event_id, thread_id, data)
for stream_id, instance_name, room_id, receipt_type, user_id, event_id, thread_id, data in txn
(room_id, receipt_type, user_id, event_id, thread_id, event_stream_ordering, data)
for stream_id, instance_name, room_id, receipt_type, user_id, event_id, thread_id, event_stream_ordering, data in txn
if MultiWriterStreamToken.is_stream_position_in_range(
from_key, to_key, instance_name, stream_id
)
Expand All @@ -459,7 +459,7 @@ def f(
)

results: JsonDict = {}
for room_id, receipt_type, user_id, event_id, thread_id, data in txn_results:
for room_id, receipt_type, user_id, event_id, thread_id, event_stream_ordering, data in txn_results:
# We want a single event per room, since we want to batch the
# receipts by room, event and type.
room_event = results.setdefault(
Expand All @@ -480,6 +480,10 @@ def f(
# This means we will drop some receipts, but MSC4102 is designed to drop semantically
# meaningless receipts, so this is okay. Previously, we would drop meaningful data!
receipt_data = db_to_json(data)

# MSC4033: inject event order into receipt
receipt_data["com.beeper.hs.order"] = event_stream_ordering

if user_id in receipt_type_dict: # existing receipt
# is the existing receipt threaded and we are currently processing an unthreaded one?
if "thread_id" in receipt_type_dict[user_id] and not thread_id:
Expand Down Expand Up @@ -517,19 +521,19 @@ async def get_linearized_receipts_for_all_rooms(
A dictionary of roomids to a list of receipts.
"""

def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str, str]]:
def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str, str, str]]:
if from_key:
sql = """
SELECT stream_id, instance_name, room_id, receipt_type, user_id, event_id, data
FROM receipts_linearized WHERE
stream_id > ? AND stream_id <= ?
SELECT stream_id, instance_name, room_id, receipt_type, user_id, event_id, event_stream_ordering, data
FROM receipts_linearized
WHERE stream_id > ? AND stream_id <= ?
ORDER BY stream_id DESC
LIMIT 100
"""
txn.execute(sql, [from_key.stream, to_key.get_max_stream_pos()])
else:
sql = """
SELECT stream_id, instance_name, room_id, receipt_type, user_id, event_id, data
SELECT stream_id, instance_name, room_id, receipt_type, user_id, event_id, event_stream_ordering, data
FROM receipts_linearized WHERE
stream_id <= ?
ORDER BY stream_id DESC
Expand All @@ -539,8 +543,8 @@ def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str, str]]:
txn.execute(sql, [to_key.get_max_stream_pos()])

return [
(room_id, receipt_type, user_id, event_id, data)
for stream_id, instance_name, room_id, receipt_type, user_id, event_id, data in txn
(room_id, receipt_type, user_id, event_id, event_stream_ordering, data)
for stream_id, instance_name, room_id, receipt_type, user_id, event_id, event_stream_ordering, data in txn
if MultiWriterStreamToken.is_stream_position_in_range(
from_key, to_key, instance_name, stream_id
)
Expand All @@ -551,7 +555,7 @@ def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str, str]]:
)

results: JsonDict = {}
for room_id, receipt_type, user_id, event_id, data in txn_results:
for room_id, receipt_type, user_id, event_id, event_stream_ordering, data in txn_results:
# We want a single event per room, since we want to batch the
# receipts by room, event and type.
room_event = results.setdefault(
Expand All @@ -566,6 +570,9 @@ def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str, str]]:

receipt_type_dict[user_id] = db_to_json(data)

# MSC4033: inject event order into receipt
receipt_type_dict[user_id]["com.beeper.hs.order"] = event_stream_ordering

return results

async def get_users_sent_receipts_between(
Expand Down

0 comments on commit 2f7e55c

Please sign in to comment.