-
Notifications
You must be signed in to change notification settings - Fork 44
/
atproto_firehose.py
368 lines (306 loc) · 13.7 KB
/
atproto_firehose.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
"""ATProto firehose client. Enqueues receive tasks for events for bridged users."""
from collections import namedtuple
from datetime import datetime, timedelta
from io import BytesIO
import itertools
import logging
import os
from queue import Queue
from threading import Event, Lock, Thread, Timer
import threading
import time
from arroba.datastore_storage import AtpRepo
from arroba.util import parse_at_uri
import dag_cbor
import dag_json
from google.cloud import ndb
from google.cloud.ndb.exceptions import ContextError
from granary.bluesky import AT_URI_PATTERN
from lexrpc.client import Client
import libipld
from oauth_dropins.webutil import util
from oauth_dropins.webutil.appengine_config import ndb_client
from oauth_dropins.webutil.appengine_info import DEBUG
from oauth_dropins.webutil.util import json_dumps, json_loads
from atproto import ATProto, Cursor
from common import (
cache_policy,
create_task,
global_cache,
global_cache_policy,
global_cache_timeout_policy,
NDB_CONTEXT_KWARGS,
PROTOCOL_DOMAINS,
report_error,
report_exception,
USER_AGENT,
)
from protocol import DELETE_TASK_DELAY
from web import Web
logger = logging.getLogger(__name__)
RECONNECT_DELAY = timedelta(seconds=30)
STORE_CURSOR_FREQ = timedelta(seconds=10)
# a commit operation. similar to arroba.repo.Write. record is None for deletes.
Op = namedtuple('Op', ['action', 'repo', 'path', 'seq', 'record', 'time'],
# last four fields are optional
defaults=[None, None, None, None])
# contains Ops
#
# maxsize is important here! if we hit this limit, subscribe will block when it
# tries to add more commits until handle consumes some. this keeps subscribe
# from getting too far ahead of handle and using too much memory in this queue.
commits = Queue(maxsize=1000)
# global so that subscribe can reuse it across calls
cursor = None
# global: _load_dids populates them, subscribe and handle use them
atproto_dids = set()
atproto_loaded_at = datetime(1900, 1, 1)
bridged_dids = set()
bridged_loaded_at = datetime(1900, 1, 1)
protocol_bot_dids = set()
dids_initialized = Event()
def load_dids():
# run in a separate thread since it needs to make its own NDB
# context when it runs in the timer thread
Thread(target=_load_dids).start()
dids_initialized.wait()
dids_initialized.clear()
def _load_dids():
global atproto_dids, atproto_loaded_at, bridged_dids, bridged_loaded_at
with ndb_client.context(**NDB_CONTEXT_KWARGS):
if not DEBUG:
Timer(STORE_CURSOR_FREQ.total_seconds(), _load_dids).start()
atproto_query = ATProto.query(ATProto.status == None,
ATProto.enabled_protocols != None,
ATProto.updated > atproto_loaded_at)
loaded_at = ATProto.query().order(-ATProto.updated).get().updated
new_atproto = [key.id() for key in atproto_query.iter(keys_only=True)]
atproto_dids.update(new_atproto)
# set *after* we populate atproto_dids so that if we crash earlier, we
# re-query from the earlier timestamp
atproto_loaded_at = loaded_at
bridged_query = AtpRepo.query(AtpRepo.status == None,
AtpRepo.created > bridged_loaded_at)
loaded_at = AtpRepo.query().order(-AtpRepo.created).get().created
new_bridged = [key.id() for key in bridged_query.iter(keys_only=True)]
bridged_dids.update(new_bridged)
# set *after* we populate bridged_dids so that if we crash earlier, we
# re-query from the earlier timestamp
bridged_loaded_at = loaded_at
if not protocol_bot_dids:
bot_keys = [Web(id=domain).key for domain in PROTOCOL_DOMAINS]
for bot in ndb.get_multi(bot_keys):
if bot:
if did := bot.get_copy(ATProto):
logger.info(f'Loaded protocol bot user {bot.key.id()} {did}')
protocol_bot_dids.add(did)
dids_initialized.set()
total = len(atproto_dids) + len(bridged_dids)
logger.info(f'DIDs: {total} ATProto {len(atproto_dids)} (+{len(new_atproto)}), AtpRepo {len(bridged_dids)} (+{len(new_bridged)}); commits {commits.qsize()}')
def subscriber():
"""Wrapper around :func:`_subscribe` that catches exceptions and reconnects."""
logger.info(f'started thread to subscribe to {os.environ["BGS_HOST"]} firehose')
load_dids()
with ndb_client.context(**NDB_CONTEXT_KWARGS):
while True:
try:
subscribe()
except BaseException:
report_exception()
logger.info(f'disconnected! waiting {RECONNECT_DELAY} and then reconnecting')
time.sleep(RECONNECT_DELAY.total_seconds())
def subscribe():
"""Subscribes to the relay's firehose.
Relay hostname comes from the ``BGS_HOST`` environment variable.
Args:
reconnect (bool): whether to always reconnect after we get disconnected
"""
global cursor
if not cursor:
cursor = Cursor.get_or_insert(
f'{os.environ["BGS_HOST"]} com.atproto.sync.subscribeRepos')
# TODO: remove? does this make us skip events? if we remove it, will we
# infinite loop when we fail on an event?
if cursor.cursor:
cursor.cursor += 1
last_stored_cursor = cur_timestamp = None
client = Client(f'https://{os.environ["BGS_HOST"]}',
headers={'User-Agent': USER_AGENT})
for frame in client.com.atproto.sync.subscribeRepos(decode=False,
cursor=cursor.cursor):
# parse header
header = libipld.decode_dag_cbor(frame)
if header.get('op') == -1:
_, payload = libipld.decode_dag_cbor_multi(frame)
logger.warning(f'Got error from relay! {payload}')
continue
t = header.get('t')
if t not in ('#commit', '#account', '#identity'):
if t not in ('#handle', '#tombstone'):
logger.info(f'Got {t} from relay')
continue
# parse payload
_, payload = libipld.decode_dag_cbor_multi(frame)
repo = payload.get('repo') or payload.get('did')
if not repo:
logger.warning(f'Payload missing repo! {payload}')
continue
seq = payload.get('seq')
if not seq:
logger.warning(f'Payload missing seq! {payload}')
continue
cur_timestamp = payload['time']
# if we fail processing this commit and raise an exception up to subscriber,
# skip it and start with the next commit when we're restarted
cursor.cursor = seq + 1
elapsed = util.now().replace(tzinfo=None) - cursor.updated
if elapsed > STORE_CURSOR_FREQ:
events_s = 0
if last_stored_cursor:
events_s = int((cursor.cursor - last_stored_cursor) /
elapsed.total_seconds())
last_stored_cursor = cursor.cursor
behind = util.now() - util.parse_iso8601(cur_timestamp)
# it's been long enough, update our stored cursor and metrics
logger.info(f'updating stored cursor to {cursor.cursor}, {events_s} events/s, {behind} ({int(behind.total_seconds())} s) behind')
cursor.put()
# when running locally, comment out put above and uncomment this
# cursor.updated = util.now().replace(tzinfo=None)
if t in ('#account', '#identity'):
if repo in atproto_dids or repo in bridged_dids:
logger.debug(f'Got {t[1:]} {repo}')
commits.put(Op(action='account', repo=repo, seq=seq,
time=cur_timestamp))
continue
blocks = {} # maps base32 str CID to dict block
if block_bytes := payload.get('blocks'):
_, blocks = libipld.decode_car(block_bytes)
# detect records from bridged ATProto users that we should handle
for p_op in payload.get('ops', []):
op = Op(repo=payload['repo'], action=p_op.get('action'),
path=p_op.get('path'), seq=payload['seq'], time=payload['time'])
if not op.action or not op.path:
logger.info(
f'bad payload! seq {op.seq} action {op.action} path {op.path}!')
continue
if op.repo in atproto_dids and op.action == 'delete':
logger.debug(f'Got delete from our ATProto user: {op}')
# TODO: also detect deletes of records that *reference* our bridged
# users, eg a delete of a follow or like or repost of them.
# not easy because we need to getRecord the record to check
commits.put(op)
continue
cid = p_op.get('cid')
block = blocks.get(cid)
# our own commits are sometimes missing the record
# https://github.com/snarfed/bridgy-fed/issues/1016
if not cid or not block:
continue
op = op._replace(record=block)
type = op.record.get('$type')
if not type:
logger.warning('commit record missing $type! {op.action} {op.repo} {op.path} {cid}')
logger.warning(dag_json.encode(op.record).decode())
continue
elif type not in ATProto.SUPPORTED_RECORD_TYPES:
continue
# generally we only want records from bridged Bluesky users. the one
# exception is follows of protocol bot users.
if (op.repo not in atproto_dids
and not (type == 'app.bsky.graph.follow'
and op.record['subject'] in protocol_bot_dids)):
continue
def is_ours(ref, also_atproto_users=False):
"""Returns True if the arg is a bridge user."""
if match := AT_URI_PATTERN.match(ref['uri']):
did = match.group('repo')
return did and (did in bridged_dids
or also_atproto_users and did in atproto_dids)
if type == 'app.bsky.feed.repost':
if not is_ours(op.record['subject'], also_atproto_users=True):
continue
elif type == 'app.bsky.feed.like':
if not is_ours(op.record['subject'], also_atproto_users=False):
continue
elif type in ('app.bsky.graph.block', 'app.bsky.graph.follow'):
if op.record['subject'] not in bridged_dids:
continue
elif type == 'app.bsky.feed.post':
if reply := op.record.get('reply'):
if not is_ours(reply['parent'], also_atproto_users=True):
continue
logger.debug(f'Got {op.action} {op.repo} {op.path}')
commits.put(op)
def handler():
"""Wrapper around :func:`handle` that catches exceptions and restarts."""
logger.info(f'started handle thread to store objects and enqueue receive tasks')
while True:
with ndb_client.context(**NDB_CONTEXT_KWARGS):
try:
handle()
# if we return cleanly, that means we hit the limit
break
except BaseException:
report_exception()
# fall through to loop to create new ndb context in case this is
# a ContextError
# https://console.cloud.google.com/errors/detail/CIvwj_7MmsfOWw;time=P1D;locations=global?project=bridgy-federated
def handle(limit=None):
def _handle_account(op):
# reload DID doc to fetch new changes
ATProto.load(op.repo, did_doc=True, remote=True)
def _handle(op):
at_uri = f'at://{op.repo}/{op.path}'
type, _ = op.path.strip('/').split('/', maxsplit=1)
if type not in ATProto.SUPPORTED_RECORD_TYPES:
logger.info(f'Skipping unsupported type {type}: {at_uri}')
return
# store object, enqueue receive task
verb = None
if op.action in ('create', 'update'):
record_kwarg = {
'bsky': op.record,
}
obj_id = at_uri
elif op.action == 'delete':
verb = (
'delete' if type in ('app.bsky.actor.profile', 'app.bsky.feed.post')
else 'stop-following' if type == 'app.bsky.graph.follow'
else 'undo')
obj_id = f'{at_uri}#{verb}'
record_kwarg = {
'our_as1': {
'objectType': 'activity',
'verb': verb,
'id': obj_id,
'actor': op.repo,
'object': at_uri,
},
}
else:
logger.error(f'Unknown action {action} for {op.repo} {op.path}')
return
if verb and verb not in ATProto.SUPPORTED_AS1_TYPES:
return
delay = DELETE_TASK_DELAY if op.action == 'delete' else None
try:
create_task(queue='receive', id=obj_id, source_protocol=ATProto.LABEL,
authed_as=op.repo, received_at=op.time, delay=delay,
**record_kwarg)
# when running locally, comment out above and uncomment this
# logger.info(f'enqueuing receive task for {at_uri}')
except ContextError:
raise # handled in handle()
except BaseException:
report_error(obj_id, exception=True)
seen = 0
while op := commits.get():
match op.action:
case 'account':
_handle_account(op)
case _:
_handle(op)
seen += 1
if limit is not None and seen >= limit:
return
assert False, "handle thread shouldn't reach here!"