From 99bfa0d96455d28ec86329e4d5ca1167df32beb8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Arturo=20Filast=C3=B2?= Date: Fri, 6 Sep 2024 14:28:00 -0400 Subject: [PATCH] Refactoring of observation related activities (#86) * Add support for performing observation generation using multiple cores, instead of multiple threads since it's CPU bound * Separate observation activities into distinct smaller activities allowing for more narrowly scoped scheduling and retry policies * Change the type of PrevRange so that it's possible to serialize it in JSON allowing to pass it as a parameter to activities * Move update_assets into observation activity * Add support for passing config file via `CONFIG_FILE` environment variable * Improvements to the CLI commands * Drop several CLI arguments that should only be read from the config file * Other improvements related to typing --- oonidata/src/oonidata/dataclient.py | 8 +- oonidata/src/oonidata/s3client.py | 328 -------------- oonipipeline/pyproject.toml | 2 +- .../src/oonipipeline/analysis/control.py | 1 + oonipipeline/src/oonipipeline/cli/commands.py | 421 ++++-------------- oonipipeline/src/oonipipeline/settings.py | 48 ++ .../temporal/activities/analysis.py | 5 +- .../temporal/activities/common.py | 39 +- .../temporal/activities/ground_truths.py | 5 +- .../temporal/activities/observations.py | 352 +++++++++------ .../temporal/client_operations.py | 139 +++--- .../src/oonipipeline/temporal/common.py | 44 +- .../src/oonipipeline/temporal/schedules.py | 245 ++++++++++ .../src/oonipipeline/temporal/workers.py | 49 +- .../src/oonipipeline/temporal/workflows.py | 395 ---------------- .../temporal/workflows/__init__.py | 0 .../temporal/workflows/analysis.py | 118 +++++ .../oonipipeline/temporal/workflows/common.py | 19 + .../oonipipeline/temporal/workflows/ctrl.py | 49 ++ .../temporal/workflows/observations.py | 114 +++++ oonipipeline/tests/test_cli.py | 4 +- oonipipeline/tests/test_ctrl.py | 9 +- oonipipeline/tests/test_temporal_e2e.py | 117 +++++ oonipipeline/tests/test_workflows.py | 127 ++++-- oonipipeline/tests/utils.py | 15 + 25 files changed, 1301 insertions(+), 1352 deletions(-) delete mode 100644 oonidata/src/oonidata/s3client.py create mode 100644 oonipipeline/src/oonipipeline/settings.py create mode 100644 oonipipeline/src/oonipipeline/temporal/schedules.py delete mode 100644 oonipipeline/src/oonipipeline/temporal/workflows.py create mode 100644 oonipipeline/src/oonipipeline/temporal/workflows/__init__.py create mode 100644 oonipipeline/src/oonipipeline/temporal/workflows/analysis.py create mode 100644 oonipipeline/src/oonipipeline/temporal/workflows/common.py create mode 100644 oonipipeline/src/oonipipeline/temporal/workflows/ctrl.py create mode 100644 oonipipeline/src/oonipipeline/temporal/workflows/observations.py create mode 100644 oonipipeline/tests/test_temporal_e2e.py create mode 100644 oonipipeline/tests/utils.py diff --git a/oonidata/src/oonidata/dataclient.py b/oonidata/src/oonidata/dataclient.py index b3fc21fb..b039c9b1 100644 --- a/oonidata/src/oonidata/dataclient.py +++ b/oonidata/src/oonidata/dataclient.py @@ -42,9 +42,6 @@ def create_s3_client(): return boto3.client("s3", config=botoConfig(signature_version=botoSigUNSIGNED)) -s3 = create_s3_client() - - def date_interval(start_day: date, end_day: date): """ A generator for a date_interval. @@ -243,6 +240,7 @@ def stream_oldcan(body: io.BytesIO, s3path: str) -> Generator[dict, None, None]: def stream_measurements(bucket_name, s3path, ext): + s3 = create_s3_client() body = s3.get_object(Bucket=bucket_name, Key=s3path)["Body"] log.debug(f"streaming file s3://{bucket_name}/{s3path}") if ext == "jsonl.gz": @@ -334,6 +332,7 @@ def from_obj_dict(bucket_name: str, obj_dict: dict) -> "FileEntry": def list_all_testnames() -> Set[str]: + s3 = create_s3_client() testnames = set() paginator = s3.get_paginator("list_objects_v2") for r in paginator.paginate(Bucket=MC_BUCKET_NAME, Prefix="jsonl/", Delimiter="/"): @@ -354,6 +353,7 @@ def get_v2_search_prefixes(testnames: Set[str], ccs: Set[str]) -> List[Prefix]: If the ccs list is empty we will return prefixes for all countries for which that particular testname as measurements. """ + s3 = create_s3_client() prefixes = [] paginator = s3.get_paginator("list_objects_v2") for tn in testnames: @@ -577,7 +577,7 @@ def list_file_entries_batches( probe_cc: CSVList = None, test_name: CSVList = None, from_cans: bool = True, -) -> Tuple[List[Tuple], int]: +) -> Tuple[List[List[Tuple]], int]: if isinstance(start_day, str): start_day = datetime.strptime(start_day, "%Y-%m-%d").date() if isinstance(end_day, str): diff --git a/oonidata/src/oonidata/s3client.py b/oonidata/src/oonidata/s3client.py deleted file mode 100644 index 95bebfe7..00000000 --- a/oonidata/src/oonidata/s3client.py +++ /dev/null @@ -1,328 +0,0 @@ -import gzip -import io -import logging -import pathlib -import shutil -import tarfile -from datetime import date, datetime -from typing import Generator, List, Optional, Set -from urllib.parse import urlparse - -import boto3 -from botocore import UNSIGNED as botoSigUNSIGNED -from botocore.config import Config as botoConfig - -import lz4.frame -import orjson - -from .datautils import trivial_id -from .legacy.normalize_yamlooni import iter_yaml_msmt_normalized - -log = logging.getLogger(__name__) - -LEGACY_BUCKET_NAME = "ooni-data" -NEW_BUCKET_NAME = "ooni-data-eu-fra" - - -def read_to_bytesio(body: io.BytesIO) -> io.BytesIO: - read_body = io.BytesIO() - shutil.copyfileobj(body, read_body) - read_body.seek(0) - del body - return read_body - - -def stream_jsonl(body: io.BytesIO) -> Generator[dict, None, None]: - """ - JSONL is the most simple OONI measurement format. They are basically just a - bunch of report files concatenated together for the same test_name, country, - timestamp and compressed with gzip. - """ - with gzip.GzipFile(fileobj=body) as in_file: - for line in in_file: - if line == "": - continue - yield orjson.loads(line) - - -def stream_postcan(body: io.BytesIO) -> Generator[dict, None, None]: - """ - Postcans are the newer format, where each individual measurement entry has - it's own file inside of a tarball. - - Here is an example of the tar layout: - - $ tar tf 2024030100_AM_webconnectivity.n1.0.tar.gz | head -n 3 - var/lib/ooniapi/measurements/incoming/2024030100_AM_webconnectivity/20240301003627.966169_AM_webconnectivity_76f66893a38a3de6.post - var/lib/ooniapi/measurements/incoming/2024030100_AM_webconnectivity/20240301003629.092464_AM_webconnectivity_48be39a609d1dcb7.post - var/lib/ooniapi/measurements/incoming/2024030100_AM_webconnectivity/20240301003630.694204_AM_webconnectivity_a95c7da2775bf109.post - - You should not expect the prefix of - `var/lib/ooniapi/measurement/incoming/XXX` to remain constant over time, but - rather you should only use the last of the file name to determine the - measurement ID - `20240301003630.694204_AM_webconnectivity_a95c7da2775bf109.post`. - - Each `.post` file is a JSON document, which contains two top level keys: - - { - "format": "json", - "content": { - } - } - - Format is always set to `json` and content as the data for the actual - measurement. - - FIXME - Some older postcans have the .gz extension, but are actually not - compressed, tar needs to be able to re-seek back to the beginning of the - file in the event of it not finding the gzip magic header when operating in - "transparent compression mode". When we we fix that in the source data, we - might be able to avoid this. - """ - with tarfile.open(fileobj=body, mode="r|*") as tar: - for m in tar: - assert m.name.endswith(".post"), f"{m.name} doesn't end with .post" - in_file = tar.extractfile(m) - - assert in_file is not None, "found empty tarfile in {m.name}" - - j = orjson.loads(in_file.read()) - assert j["format"] == "json", "postcan with non json format" - - msmt = j["content"] - # extract msmt_uid from filename e.g: - # ... /20210614004521.999962_JO_signal_68eb19b439326d60.post - msmt_uid = pathlib.PurePath(m.name).with_suffix("").name - msmt["measurement_uid"] = msmt_uid - yield msmt - - -def stream_jsonlz4(body: io.BytesIO): - """ - lz4.frame requires the input stream to be seekable, so we need to load it - in memory - """ - read_body = read_to_bytesio(body) - - with lz4.frame.open(read_body, mode="rb") as in_file: - for line in in_file: - try: - msmt = orjson.loads(line) - except ValueError: - log.error("stream_jsonlz4: unable to parse json measurement") - continue - - msmt_uid = trivial_id(line, msmt) # type: ignore due to bad types in lz4 - msmt["measurement_uid"] = msmt_uid - yield msmt - - -def stream_yamllz4(body: io.BytesIO): - """ - lz4.frame requires the input stream to be seekable, so we need to load it - in memory - """ - read_body = read_to_bytesio(body) - - with lz4.frame.open(read_body) as in_file: - # The normalize function already add the measurement_uid - yield from iter_yaml_msmt_normalized(in_file) - - -def stream_oldcan(body: io.BytesIO) -> Generator[dict, None, None]: - """ - lz4.frame requires the input stream to be seekable, so we need to load it - in memory. - """ - read_body = read_to_bytesio(body) - - with lz4.frame.open(read_body) as lz4_file: - with tarfile.open(fileobj=lz4_file) as tar: # type: ignore due to bad types in lz4 - for m in tar: - in_file = tar.extractfile(m) - assert in_file is not None, "{m.name} is None" - - if m.name.endswith(".json"): - for line in in_file: - msmt = orjson.loads(line) - msmt_uid = trivial_id(line, msmt) - msmt["measurement_uid"] = msmt_uid - yield msmt - - elif m.name.endswith(".yaml"): - # The normalize function already add the measurement_uid - yield from iter_yaml_msmt_normalized(in_file) - - -def create_s3_anonymous_client(): - return boto3.client("s3", config=botoConfig(signature_version=botoSigUNSIGNED)) - - -class OONIMeasurementLister: - def __init__( - self, - *, - probe_cc_filter: Optional[set] = None, - test_name_filter: Optional[set] = None, - ): - self.s3 = create_s3_anonymous_client() - self.probe_cc_filter = probe_cc_filter - self.test_name_filter = test_name_filter - - def get_body(self, bucket_name: str, key: str) -> io.BytesIO: - return self.s3.get_object(Bucket=bucket_name, Key=key)["Body"] - - def apply_filter(self, stream_func, body): - """ - Some of the formats don't support doing filtering by path, so we - do it here as well. - """ - for msmt in stream_func(body): - - if ( - self.probe_cc_filter is not None - and msmt["probe_cc"] not in self.probe_cc_filter - ): - continue - - if ( - self.test_name_filter is not None - and msmt["test_name"] not in self.test_name_filter - ): - continue - - yield msmt - - def measurements(self, s3_url: str): - u = urlparse(s3_url) - bucket_name = u.netloc - assert u.scheme == "s3", "must be s3 URL" - assert bucket_name in ["ooni-data-eu-fra", "ooni-data"] - - body = self.get_body(bucket_name=bucket_name, key=u.path) - s3path = pathlib.PurePath(u.path) - - stream_func = None - if s3path.name.endswith("jsonl.gz"): - stream_func = stream_jsonl - elif s3path.name.endswith("tar.gz"): - stream_func = stream_postcan - elif s3path.name.endswith("tar.lz4"): - stream_func = stream_oldcan - elif s3path.name.endswith("json.lz4"): - stream_func = stream_jsonlz4 - elif s3path.name.endswith("yaml.lz4"): - stream_func = stream_yamllz4 - - assert stream_func is not None, f"invalid format for {s3path.name}" - yield from self.apply_filter(stream_func, body) - - -def get_v2_search_prefixes(s3, testnames: Set[str], ccs: Set[str]) -> List[str]: - """ - get_search_prefixes will return all the prefixes inside of the new jsonl - bucket that match the given testnames and ccs. - If the ccs list is empty we will return prefixes for all countries for - which that particular testname as measurements. - """ - prefixes = [] - paginator = s3.get_paginator("list_objects_v2") - for tn in testnames: - for r in paginator.paginate( - Bucket=NEW_BUCKET_NAME, Prefix=f"jsonl/{tn}/", Delimiter="/" - ): - for f in r.get("CommonPrefixes", []): - prefix = f["Prefix"] - cc = prefix.split("/")[-2] - if ccs and cc not in ccs: - continue - prefixes.append(prefix) - return prefixes - - -def get_v2_prefixes( - ccs: Set[str], testnames: Set[str], start_day: date, end_day: date -) -> List[str]: - legacy_prefixes = [ - Prefix(bucket_name=NEW_BUCKET_NAME, prefix=f"raw/{d:%Y%m%d}") - for d in date_interval(max(date(2020, 10, 20), start_day), end_day) - ] - if not testnames: - testnames = list_all_testnames() - prefixes = [] - if start_day < date(2020, 10, 21): - prefixes = get_v2_search_prefixes(testnames, ccs) - combos = list(itertools.product(prefixes, date_interval(start_day, end_day))) - # This results in a faster listing in cases where we need only a small time - # window or few testnames. For larger windows of time, we are better off - # just listing everything. - if len(combos) < 1_000_000: # XXX we might want to tweak this parameter a bit - prefixes = [ - Prefix(bucket_name=NEW_BUCKET_NAME, prefix=f"{p}{d:%Y%m%d}") - for p, d in combos - ] - - return prefixes + legacy_prefixes - - -def get_can_prefixes(start_day: date, end_day: date) -> List[Prefix]: - """ - Returns the list of search prefixes for cans. In most cases, since we don't - have the country code or test name in the path, all we do is return the - range of dates. - """ - new_cans = [ - Prefix(prefix=f"canned/{d:%Y-%m-%d}", bucket_name=NEW_BUCKET_NAME) - for d in date_interval( - # The new cans are between 2020-06-02 and 2020-10-21 inclusive - max(date(2020, 6, 2), start_day), - min(date(2020, 10, 22), end_day), - ) - ] - old_cans = [ - Prefix(prefix=f"canned/{d:%Y-%m-%d}", bucket_name=LEGACY_BUCKET_NAME) - for d in date_interval( - # The new cans are between 2020-06-02 and 2020-10-21 inclusive - # Note: the cans between 2020-06-02 and 2020-10-21 appears to be duplicated between the new and old cans. - # TODO: check if they are actually identical or not. - max(date(2012, 12, 5), start_day), - min(date(2020, 6, 2), end_day), - ) - ] - return old_cans + new_cans - - -def iter_file_entries(prefix: Prefix) -> Generator[FileEntry, None, None]: - s3_client = create_s3_client() - paginator = s3_client.get_paginator("list_objects_v2") - for r in paginator.paginate(Bucket=prefix.bucket_name, Prefix=prefix.prefix): - for obj_dict in r.get("Contents", []): - try: - if obj_dict["Key"].endswith(".json.gz"): - # We ignore the legacy can index files - continue - - yield FileEntry.from_obj_dict(prefix.bucket_name, obj_dict) - except ValueError as exc: - log.error(exc) - - -def make_measurement_listers( - bucket_start_day: date, - bucket_end_day: date, - probe_cc_filter: Optional[set] = None, - test_name_filter: Optional[set] = None, -): - start_timestamp = datetime.combine(bucket_start_day, datetime.min.time()) - end_timestamp = datetime.combine(bucket_end_day, datetime.min.time()) - - prefix_list = get_v2_prefixes(ccs, testnames, start_day, end_day) - if from_cans == True: - prefix_list = get_can_prefixes(start_day, end_day) + prefix_list - - log.debug(f"using prefix list {prefix_list}") - file_entries = [] - prefix_idx = 0 - total_prefixes = len(prefix_list) diff --git a/oonipipeline/pyproject.toml b/oonipipeline/pyproject.toml index bbd2abf4..96203f36 100644 --- a/oonipipeline/pyproject.toml +++ b/oonipipeline/pyproject.toml @@ -36,7 +36,7 @@ dependencies = [ "opentelemetry-exporter-otlp-proto-grpc ~= 1.18.0", "bokeh ~= 3.5.2", "uvicorn ~= 0.25.0", - "pydantic-settings ~= 2.1.0", + "pydantic-settings ~= 2.4.0", ] [tool.hatch.build.targets.sdist] diff --git a/oonipipeline/src/oonipipeline/analysis/control.py b/oonipipeline/src/oonipipeline/analysis/control.py index 6bed39d6..63f39a9b 100644 --- a/oonipipeline/src/oonipipeline/analysis/control.py +++ b/oonipipeline/src/oonipipeline/analysis/control.py @@ -208,6 +208,7 @@ def build_from_existing(self, db_str: str): with sqlite3.connect(db_str) as src_db: self.db = sqlite3.connect(":memory:") src_db.backup(self.db) + self.db.commit() def close(self): self.db.close() diff --git a/oonipipeline/src/oonipipeline/cli/commands.py b/oonipipeline/src/oonipipeline/cli/commands.py index 5f92b5bf..4fe02a60 100644 --- a/oonipipeline/src/oonipipeline/cli/commands.py +++ b/oonipipeline/src/oonipipeline/cli/commands.py @@ -1,10 +1,7 @@ -from configparser import ConfigParser import logging -import multiprocessing -import os from pathlib import Path from typing import List, Optional -from datetime import date, timedelta, datetime, timezone, time +from datetime import date, timedelta, datetime, timezone from typing import List, Optional from oonipipeline.temporal.client_operations import ( @@ -12,24 +9,19 @@ run_backfill, run_create_schedules, run_status, + run_reschedule, ) from oonipipeline.temporal.workers import start_workers import click from click_loglevel import LogLevel -from ..temporal.workflows import ( - GroundTruthsWorkflow, - GroundTruthsWorkflowParams, - ObservationsWorkflowParams, - AnalysisWorkflowParams, -) from ..__about__ import VERSION from ..db.connections import ClickhouseConnection from ..db.create_tables import make_create_queries, list_all_table_diffs from ..netinfo import NetinfoDB - +from ..settings import config def _parse_csv(ctx, param, s: Optional[str]) -> List[str]: if s: @@ -64,7 +56,6 @@ def _parse_csv(ctx, param, s: Optional[str]) -> List[str]: Note: this is the upload date, which doesn't necessarily match the measurement date. """, ) - start_at_option = click.option( "--start-at", type=click.DateTime(), @@ -83,59 +74,8 @@ def _parse_csv(ctx, param, s: Optional[str]) -> List[str]: Note: this is the upload date, which doesn't necessarily match the measurement date. """, ) - -clickhouse_option = click.option( - "--clickhouse", type=str, required=True, default="clickhouse://localhost" -) -clickhouse_buffer_min_time_option = click.option( - "--clickhouse-buffer-min-time", - type=int, - required=True, - default=10, - help="min_time for the Buffer tables in clickhouse. only applied during create. see: https://clickhouse.com/docs/en/engines/table-engines/special/buffer", -) -clickhouse_buffer_max_time_option = click.option( - "--clickhouse-buffer-max-time", - type=int, - required=True, - default=60, - help="max_time for the Buffer tables in clickhouse. only applied during create. see: https://clickhouse.com/docs/en/engines/table-engines/special/buffer", -) -telemetry_endpoint_option = click.option( - "--telemetry-endpoint", type=str, required=False, default=None -) -prometheus_bind_address_option = click.option( - "--prometheus-bind-address", type=str, required=False, default=None -) -temporal_address_option = click.option( - "--temporal-address", type=str, required=True, default="localhost:7233" -) -temporal_namespace_option = click.option( - "--temporal-namespace", type=str, required=False, default=None -) -temporal_tls_client_cert_path_option = click.option( - "--temporal-tls-client-cert-path", type=str, required=False, default=None -) -temporal_tls_client_key_path_option = click.option( - "--temporal-tls-client-key-path", type=str, required=False, default=None -) start_workers_option = click.option("--start-workers/--no-start-workers", default=True) -datadir_option = click.option( - "--data-dir", - type=str, - required=True, - default="tests/data/datadir", - help="data directory to store fingerprint and geoip databases", -) -parallelism_option = click.option( - "--parallelism", - type=int, - default=multiprocessing.cpu_count() + 2, - help="number of processes to use. Only works when writing to a database", -) - - def maybe_create_delete_tables( clickhouse_url: str, create_tables: bool, @@ -158,29 +98,6 @@ def maybe_create_delete_tables( db.execute(query) -def parse_config_file(ctx, path): - cfg = ConfigParser() - cfg.read(path) - ctx.default_map = {} - - try: - default_options = cfg["options"] - for name, _ in cli.commands.items(): - ctx.default_map.setdefault(name, {}) - ctx.default_map[name].update(default_options) - except KeyError: - # No default section - pass - - for sect in cfg.sections(): - command_path = sect.split(".") - defaults = ctx.default_map - for cmdname in command_path[1:]: - defaults = defaults.setdefault(cmdname, {}) - defaults.update(cfg[sect]) - return ctx.default_map - - @click.group() @click.option( "-l", @@ -190,35 +107,17 @@ def parse_config_file(ctx, path): help="Set logging level", show_default=True, ) -@click.option( - "-c", - "--config", - type=click.Path(dir_okay=False), - default="config.ini", - help="Read option defaults from the specified INI file", - show_default=True, -) @click.version_option(VERSION) -@click.pass_context -def cli(ctx, log_level: int, config: str): +def cli(log_level: int): logging.basicConfig(level=log_level) - if os.path.exists(config): - ctx.default_map = parse_config_file(ctx, config) @cli.command() @start_at_option @end_at_option -@clickhouse_option -@clickhouse_buffer_min_time_option -@clickhouse_buffer_max_time_option -@telemetry_endpoint_option -@prometheus_bind_address_option -@temporal_address_option -@temporal_namespace_option -@temporal_tls_client_cert_path_option -@temporal_tls_client_key_path_option -@click.option("--schedule-id", type=str, required=True) +@probe_cc_option +@test_name_option +@click.option("--workflow-name", type=str, required=True) @click.option( "--create-tables", is_flag=True, @@ -230,46 +129,41 @@ def cli(ctx, log_level: int, config: str): help="should we drop tables before creating them", ) def backfill( + probe_cc: List[str], + test_name: List[str], + workflow_name: str, start_at: datetime, end_at: datetime, - clickhouse: str, - clickhouse_buffer_min_time: int, - clickhouse_buffer_max_time: int, create_tables: bool, drop_tables: bool, - telemetry_endpoint: Optional[str], - prometheus_bind_address: Optional[str], - temporal_address: str, - temporal_namespace: Optional[str], - temporal_tls_client_cert_path: Optional[str], - temporal_tls_client_key_path: Optional[str], - schedule_id: str, ): """ Backfill for OONI measurements and write them into clickhouse """ - click.echo(f"Runnning backfill of schedule {schedule_id}") + click.echo(f"Runnning backfill of worfklow {workflow_name}") maybe_create_delete_tables( - clickhouse_url=clickhouse, + clickhouse_url=config.clickhouse_url, create_tables=create_tables, drop_tables=drop_tables, - clickhouse_buffer_min_time=clickhouse_buffer_min_time, - clickhouse_buffer_max_time=clickhouse_buffer_max_time, + clickhouse_buffer_min_time=config.clickhouse_buffer_min_time, + clickhouse_buffer_max_time=config.clickhouse_buffer_max_time, ) temporal_config = TemporalConfig( - prometheus_bind_address=prometheus_bind_address, - telemetry_endpoint=telemetry_endpoint, - temporal_address=temporal_address, - temporal_namespace=temporal_namespace, - temporal_tls_client_cert_path=temporal_tls_client_cert_path, - temporal_tls_client_key_path=temporal_tls_client_key_path, + prometheus_bind_address=config.prometheus_bind_address, + telemetry_endpoint=config.telemetry_endpoint, + temporal_address=config.temporal_address, + temporal_namespace=config.temporal_namespace, + temporal_tls_client_cert_path=config.temporal_tls_client_cert_path, + temporal_tls_client_key_path=config.temporal_tls_client_key_path, ) run_backfill( - schedule_id=schedule_id, + workflow_name=workflow_name, temporal_config=temporal_config, + probe_cc=probe_cc, + test_name=test_name, start_at=start_at, end_at=end_at, ) @@ -278,251 +172,99 @@ def backfill( @cli.command() @probe_cc_option @test_name_option -@clickhouse_option -@clickhouse_buffer_min_time_option -@clickhouse_buffer_max_time_option -@datadir_option -@telemetry_endpoint_option -@prometheus_bind_address_option -@temporal_address_option -@temporal_namespace_option -@temporal_tls_client_cert_path_option -@temporal_tls_client_key_path_option @click.option( "--fast-fail", is_flag=True, help="should we fail immediately when we encounter an error?", ) -@click.option( - "--analysis/--no-analysis", - is_flag=True, - help="should we schedule an analysis", - default=False, -) -@click.option( - "--observations/--no-observations", - is_flag=True, - help="should we schedule observations", - default=True, -) -@click.option( - "--delete", - is_flag=True, - default=False, - help="if we should delete the schedule instead of creating it", -) -@click.option( - "--create-tables", - is_flag=True, - help="should we attempt to create the required clickhouse tables", -) -@click.option( - "--drop-tables", - is_flag=True, - help="should we drop tables before creating them", -) def schedule( probe_cc: List[str], test_name: List[str], - clickhouse: str, - clickhouse_buffer_min_time: int, - clickhouse_buffer_max_time: int, - data_dir: str, fast_fail: bool, - create_tables: bool, - drop_tables: bool, - telemetry_endpoint: Optional[str], - prometheus_bind_address: Optional[str], - temporal_address: str, - temporal_namespace: Optional[str], - temporal_tls_client_cert_path: Optional[str], - temporal_tls_client_key_path: Optional[str], - analysis: bool, - observations: bool, - delete: bool, ): """ Create schedules for the specified parameters """ - if not observations and not analysis: - click.echo("either observations or analysis should be set") - return 1 + temporal_config = TemporalConfig( + telemetry_endpoint=config.telemetry_endpoint, + prometheus_bind_address=config.prometheus_bind_address, + temporal_address=config.temporal_address, + temporal_namespace=config.temporal_namespace, + temporal_tls_client_cert_path=config.temporal_tls_client_cert_path, + temporal_tls_client_key_path=config.temporal_tls_client_key_path, + ) - maybe_create_delete_tables( - clickhouse_url=clickhouse, - create_tables=create_tables, - drop_tables=drop_tables, - clickhouse_buffer_min_time=clickhouse_buffer_min_time, - clickhouse_buffer_max_time=clickhouse_buffer_max_time, + run_create_schedules( + probe_cc=probe_cc, + test_name=test_name, + clickhouse_url=config.clickhouse_url, + data_dir=config.data_dir, + temporal_config=temporal_config, ) - what_we_schedule = [] - if analysis: - what_we_schedule.append("analysis") - if observations: - what_we_schedule.append("observations") - click.echo(f"Scheduling {' and'.join(what_we_schedule)}") +@cli.command() +@probe_cc_option +@test_name_option +def reschedule( + probe_cc: List[str], + test_name: List[str], +): + """ + Create schedules for the specified parameters + """ temporal_config = TemporalConfig( - telemetry_endpoint=telemetry_endpoint, - prometheus_bind_address=prometheus_bind_address, - temporal_address=temporal_address, - temporal_namespace=temporal_namespace, - temporal_tls_client_cert_path=temporal_tls_client_cert_path, - temporal_tls_client_key_path=temporal_tls_client_key_path, + telemetry_endpoint=config.telemetry_endpoint, + prometheus_bind_address=config.prometheus_bind_address, + temporal_address=config.temporal_address, + temporal_namespace=config.temporal_namespace, + temporal_tls_client_cert_path=config.temporal_tls_client_cert_path, + temporal_tls_client_key_path=config.temporal_tls_client_key_path, ) - obs_params = None - if observations: - obs_params = ObservationsWorkflowParams( - probe_cc=probe_cc, - test_name=test_name, - clickhouse=clickhouse, - data_dir=str(data_dir), - fast_fail=fast_fail, - ) - analysis_params = None - if analysis: - analysis_params = AnalysisWorkflowParams( - probe_cc=probe_cc, - test_name=test_name, - clickhouse=clickhouse, - data_dir=str(data_dir), - ) - run_create_schedules( - obs_params=obs_params, - analysis_params=analysis_params, + run_reschedule( + probe_cc=probe_cc, + test_name=test_name, + clickhouse_url=config.clickhouse_url, + data_dir=config.data_dir, temporal_config=temporal_config, - delete=delete, ) @cli.command() -@prometheus_bind_address_option -@telemetry_endpoint_option -@temporal_address_option -@temporal_namespace_option -@temporal_tls_client_cert_path_option -@temporal_tls_client_key_path_option -def status( - telemetry_endpoint: Optional[str], - prometheus_bind_address: Optional[str], - temporal_address: str, - temporal_namespace: Optional[str], - temporal_tls_client_cert_path: Optional[str], - temporal_tls_client_key_path: Optional[str], -): - click.echo(f"getting status from {temporal_address}") +def status(): + click.echo(f"getting status from {config.temporal_address}") temporal_config = TemporalConfig( - prometheus_bind_address=prometheus_bind_address, - telemetry_endpoint=telemetry_endpoint, - temporal_address=temporal_address, - temporal_namespace=temporal_namespace, - temporal_tls_client_cert_path=temporal_tls_client_cert_path, - temporal_tls_client_key_path=temporal_tls_client_key_path, + prometheus_bind_address=config.prometheus_bind_address, + telemetry_endpoint=config.telemetry_endpoint, + temporal_address=config.temporal_address, + temporal_namespace=config.temporal_namespace, + temporal_tls_client_cert_path=config.temporal_tls_client_cert_path, + temporal_tls_client_key_path=config.temporal_tls_client_key_path, ) run_status(temporal_config=temporal_config) @cli.command() -@datadir_option -@parallelism_option -@prometheus_bind_address_option -@telemetry_endpoint_option -@temporal_address_option -@temporal_namespace_option -@temporal_tls_client_cert_path_option -@temporal_tls_client_key_path_option -def startworkers( - data_dir: Path, - parallelism: int, - prometheus_bind_address: Optional[str], - telemetry_endpoint: Optional[str], - temporal_address: str, - temporal_namespace: Optional[str], - temporal_tls_client_cert_path: Optional[str], - temporal_tls_client_key_path: Optional[str], -): - click.echo(f"starting {parallelism} workers") - click.echo(f"downloading NetinfoDB to {data_dir}") - NetinfoDB(datadir=Path(data_dir), download=True) +def startworkers(): + click.echo(f"starting workers") + click.echo(f"downloading NetinfoDB to {config.data_dir}") + NetinfoDB(datadir=Path(config.data_dir), download=True) click.echo("done downloading netinfodb") temporal_config = TemporalConfig( - prometheus_bind_address=prometheus_bind_address, - telemetry_endpoint=telemetry_endpoint, - temporal_address=temporal_address, - temporal_namespace=temporal_namespace, - temporal_tls_client_cert_path=temporal_tls_client_cert_path, - temporal_tls_client_key_path=temporal_tls_client_key_path, + prometheus_bind_address=config.prometheus_bind_address, + telemetry_endpoint=config.telemetry_endpoint, + temporal_address=config.temporal_address, + temporal_namespace=config.temporal_namespace, + temporal_tls_client_cert_path=config.temporal_tls_client_cert_path, + temporal_tls_client_key_path=config.temporal_tls_client_key_path, ) start_workers(temporal_config=temporal_config) @cli.command() -@probe_cc_option -@test_name_option -@start_day_option -@end_day_option -@clickhouse_option -@datadir_option -@click.option("--archives-dir", type=Path, required=True) -@click.option( - "--parallelism", - type=int, - default=multiprocessing.cpu_count() + 2, - help="number of processes to use. Only works when writing to a database", -) -def mkbodies( - probe_cc: List[str], - test_name: List[str], - start_day: date, - end_day: date, - clickhouse: str, - data_dir: Path, - archives_dir: Path, - parallelism: int, -): - """ - Make response body archives - """ - # start_response_archiver( - # probe_cc=probe_cc, - # test_name=test_name, - # start_day=start_day, - # end_day=end_day, - # data_dir=data_dir, - # archives_dir=archives_dir, - # clickhouse=clickhouse, - # parallelism=parallelism, - # ) - raise NotImplemented("TODO(art)") - - -@cli.command() -@datadir_option -@click.option("--archives-dir", type=Path, required=True) -@click.option( - "--parallelism", - type=int, - default=multiprocessing.cpu_count() + 2, - help="number of processes to use", -) -def fphunt(data_dir: Path, archives_dir: Path, parallelism: int): - click.echo("🏹 starting the hunt for blockpage fingerprints!") - # start_fingerprint_hunter( - # archives_dir=archives_dir, - # data_dir=data_dir, - # parallelism=parallelism, - # ) - raise NotImplemented("TODO(art)") - - -@cli.command() -@clickhouse_buffer_min_time_option -@clickhouse_buffer_max_time_option -@clickhouse_option @click.option( "--create-tables", is_flag=True, @@ -534,9 +276,6 @@ def fphunt(data_dir: Path, archives_dir: Path, parallelism: int): help="should we drop tables before creating them", ) def checkdb( - clickhouse: str, - clickhouse_buffer_min_time: int, - clickhouse_buffer_max_time: int, create_tables: bool, drop_tables: bool, ): @@ -545,12 +284,12 @@ def checkdb( is not specified, it will not perform any operations. """ maybe_create_delete_tables( - clickhouse_url=clickhouse, + clickhouse_url=config.clickhouse_url, create_tables=create_tables, drop_tables=drop_tables, - clickhouse_buffer_min_time=clickhouse_buffer_min_time, - clickhouse_buffer_max_time=clickhouse_buffer_max_time, + clickhouse_buffer_min_time=config.clickhouse_buffer_min_time, + clickhouse_buffer_max_time=config.clickhouse_buffer_max_time, ) - with ClickhouseConnection(clickhouse) as db: + with ClickhouseConnection(config.clickhouse_url) as db: list_all_table_diffs(db) diff --git a/oonipipeline/src/oonipipeline/settings.py b/oonipipeline/src/oonipipeline/settings.py new file mode 100644 index 00000000..4fe5f073 --- /dev/null +++ b/oonipipeline/src/oonipipeline/settings.py @@ -0,0 +1,48 @@ +import os +from typing import Optional, Tuple, Type +from pydantic import Field + +from pydantic_settings import ( + BaseSettings, + PydanticBaseSettingsSource, + SettingsConfigDict, + TomlConfigSettingsSource, +) + + +class Settings(BaseSettings): + model_config = SettingsConfigDict() + + data_dir: str = "tests/data/datadir" + + clickhouse_url: str = "clickhouse://localhost" + clickhouse_buffer_min_time: int = 10 + clickhouse_buffer_max_time: int = 60 + clickhouse_write_batch_size: int = 200_000 + + telemetry_endpoint: Optional[str] = None + prometheus_bind_address: Optional[str] = None + temporal_address: str = "localhost:7233" + temporal_namespace: Optional[str] = None + temporal_tls_client_cert_path: Optional[str] = None + temporal_tls_client_key_path: Optional[str] = None + + @classmethod + def settings_customise_sources( + cls, + settings_cls: Type[BaseSettings], + init_settings: PydanticBaseSettingsSource, + env_settings: PydanticBaseSettingsSource, + dotenv_settings: PydanticBaseSettingsSource, + file_secret_settings: PydanticBaseSettingsSource, + ) -> Tuple[PydanticBaseSettingsSource, ...]: + return ( + init_settings, + env_settings, + TomlConfigSettingsSource( + settings_cls, toml_file=os.environ.get("CONFIG_FILE", "") + ), + ) + + +config = Settings() diff --git a/oonipipeline/src/oonipipeline/temporal/activities/analysis.py b/oonipipeline/src/oonipipeline/temporal/activities/analysis.py index 7382c183..36756d7d 100644 --- a/oonipipeline/src/oonipipeline/temporal/activities/analysis.py +++ b/oonipipeline/src/oonipipeline/temporal/activities/analysis.py @@ -5,6 +5,7 @@ from datetime import datetime from typing import Dict, List +from oonipipeline.temporal.common import TS_FORMAT import opentelemetry.trace from temporalio import workflow, activity @@ -121,7 +122,7 @@ def make_analysis_in_a_day(params: MakeAnalysisParams) -> dict: get_prev_range( db=db_lookup, table_name=WebAnalysis.__table_name__, - timestamp=datetime.combine(day, datetime.min.time()), + timestamp=datetime.combine(day, datetime.min.time()).strftime(TS_FORMAT), test_name=[], probe_cc=probe_cc, timestamp_column="measurement_start_time", @@ -129,7 +130,7 @@ def make_analysis_in_a_day(params: MakeAnalysisParams) -> dict: get_prev_range( db=db_lookup, table_name=MeasurementExperimentResult.__table_name__, - timestamp=datetime.combine(day, datetime.min.time()), + timestamp=datetime.combine(day, datetime.min.time()).strftime(TS_FORMAT), test_name=[], probe_cc=probe_cc, timestamp_column="timeofday", diff --git a/oonipipeline/src/oonipipeline/temporal/activities/common.py b/oonipipeline/src/oonipipeline/temporal/activities/common.py index dfef2082..f52085cc 100644 --- a/oonipipeline/src/oonipipeline/temporal/activities/common.py +++ b/oonipipeline/src/oonipipeline/temporal/activities/common.py @@ -2,6 +2,7 @@ from datetime import datetime, timezone, timedelta from dataclasses import dataclass from typing import Dict, List, Tuple +from concurrent.futures import ProcessPoolExecutor from threading import Lock @@ -16,6 +17,8 @@ log = activity.logger +process_pool_executor = ProcessPoolExecutor() + @dataclass class ClickhouseParams: clickhouse_url: str @@ -24,23 +27,34 @@ class ClickhouseParams: @activity.defn def optimize_all_tables(params: ClickhouseParams): with ClickhouseConnection(params.clickhouse_url) as db: - for _, table_name in make_create_queries(): - if table_name.startswith("buffer_"): - continue + table_names = [table_name for _, table_name in make_create_queries()] + # We first flush the buffer_ tables and then the non-buffer tables + for table_name in filter(lambda x: x.startswith("buffer_"), table_names): + db.execute(f"OPTIMIZE TABLE {table_name}") + for table_name in filter(lambda x: not x.startswith("buffer_"), table_names): db.execute(f"OPTIMIZE TABLE {table_name}") @dataclass -class UpdateAssetsParams: - data_dir: str - refresh_hours: int = 10 - force_update: bool = False +class OptimizeTablesParams: + clickhouse: str + table_names: List[str] @activity.defn -def update_assets(params: UpdateAssetsParams): +def optimize_tables(params: OptimizeTablesParams): + with ClickhouseConnection(params.clickhouse) as db: + for table_name in params.table_names: + db.execute(f"OPTIMIZE TABLE {table_name}") + + +def update_assets( + data_dir: str, + refresh_hours: int = 10, + force_update: bool = False, +): last_updated_at = datetime(1984, 1, 1).replace(tzinfo=timezone.utc) - datadir = pathlib.Path(params.data_dir) + datadir = pathlib.Path(data_dir) last_updated_path = datadir / "last_updated.txt" @@ -53,10 +67,7 @@ def update_assets(params: UpdateAssetsParams): now = datetime.now(timezone.utc) last_updated_delta = now - last_updated_at - if ( - last_updated_delta > timedelta(hours=params.refresh_hours) - or params.force_update - ): + if last_updated_delta > timedelta(hours=refresh_hours) or force_update: lock = Lock() with lock: log.info("triggering update of netinfodb") @@ -64,7 +75,7 @@ def update_assets(params: UpdateAssetsParams): last_updated_path.write_text(now.strftime(DATETIME_UTC_FORMAT)) else: log.info( - f"skipping updating netinfodb because {last_updated_delta} < {params.refresh_hours}h" + f"skipping updating netinfodb because {last_updated_delta} < {refresh_hours}h" ) diff --git a/oonipipeline/src/oonipipeline/temporal/activities/ground_truths.py b/oonipipeline/src/oonipipeline/temporal/activities/ground_truths.py index e864ae1a..739b0573 100644 --- a/oonipipeline/src/oonipipeline/temporal/activities/ground_truths.py +++ b/oonipipeline/src/oonipipeline/temporal/activities/ground_truths.py @@ -42,9 +42,9 @@ def make_ground_truths_in_day(params: MakeGroundTruthsParams): dst_path = get_ground_truth_db_path(data_dir=params.data_dir, day=params.day) - if dst_path.exists() or params.force_rebuild: + if dst_path.exists() and params.force_rebuild: dst_path.unlink() - else: + elif dst_path.exists(): return t = PerfTimer() @@ -54,4 +54,5 @@ def make_ground_truths_in_day(params: MakeGroundTruthsParams): web_ground_truth_db.build_from_rows( rows=iter_web_ground_truths(db=db, measurement_day=day, netinfodb=netinfodb) ) + web_ground_truth_db.close() log.info(f"built ground truth DB {day} in {t.pretty}") diff --git a/oonipipeline/src/oonipipeline/temporal/activities/observations.py b/oonipipeline/src/oonipipeline/temporal/activities/observations.py index 840872e4..871dc658 100644 --- a/oonipipeline/src/oonipipeline/temporal/activities/observations.py +++ b/oonipipeline/src/oonipipeline/temporal/activities/observations.py @@ -1,6 +1,8 @@ +import asyncio +import concurrent.futures from dataclasses import dataclass -import dataclasses -from typing import List, Sequence, Tuple +import functools +from typing import Any, Dict, List, Optional, Sequence, Tuple, TypedDict from oonidata.dataclient import ( ccs_set, list_file_entries_batches, @@ -12,11 +14,12 @@ from oonipipeline.db.connections import ClickhouseConnection from oonipipeline.netinfo import NetinfoDB from oonipipeline.temporal.common import ( + PrevRange, get_prev_range, - make_db_rows, maybe_delete_prev_range, ) - +from oonipipeline.temporal.activities.common import process_pool_executor, update_assets +from oonipipeline.settings import config from opentelemetry import trace from temporalio import activity @@ -40,166 +43,239 @@ class MakeObservationsParams: bucket_date: str -def write_observations_to_db( - msmt: SupportedDataformats, - netinfodb: NetinfoDB, +FileEntryBatchType = Tuple[str, str, str, int] + + +@dataclass +class MakeObservationsFileEntryBatch: + batch_idx: int + clickhouse: str + write_batch_size: int + data_dir: str + bucket_date: str + probe_cc: List[str] + test_name: List[str] + bucket_date: str + fast_fail: bool + + +def make_observations_for_file_entry( db: ClickhouseConnection, + netinfodb: NetinfoDB, bucket_date: str, + bucket_name: str, + s3path: str, + ext: str, + ccs: set, + fast_fail: bool, ): - for observations in measurement_to_observations( - msmt=msmt, netinfodb=netinfodb, bucket_date=bucket_date + failure_count = 0 + measurement_count = 0 + for msmt_dict in stream_measurements( + bucket_name=bucket_name, s3path=s3path, ext=ext ): - if len(observations) == 0: + # Legacy cans don't allow us to pre-filter on the probe_cc, so + # we need to check for probe_cc consistency in here. + if ccs and msmt_dict["probe_cc"] not in ccs: continue - column_names = [f.name for f in dataclasses.fields(observations[0])] - table_name, rows = make_db_rows( - bucket_date=bucket_date, - dc_list=observations, - column_names=column_names, - ) - db.write_rows(table_name=table_name, rows=rows, column_names=column_names) + measurement_uid = msmt_dict.get("measurement_uid", None) + report_id = msmt_dict.get("report_id", None) + msmt_str = f"muid={measurement_uid} (rid={report_id})" + + if not msmt_dict.get("test_keys", None): + log.error( + f"measurement with empty test_keys: ({msmt_str})", + exc_info=True, + ) + continue + try: + msmt = load_measurement(msmt_dict) + obs_tuple = measurement_to_observations( + msmt=msmt, + netinfodb=netinfodb, + bucket_date=bucket_date, + ) + for obs_list in obs_tuple: + db.write_table_model_rows(obs_list, use_buffer_table=False) + measurement_count += 1 + except Exception as exc: + log.error(f"failed at idx: {measurement_count} ({msmt_str})", exc_info=True) + failure_count += 1 + if fast_fail: + db.close() + raise exc + log.debug(f"done processing file s3://{bucket_name}/{s3path}") + return measurement_count, failure_count def make_observations_for_file_entry_batch( - file_entry_batch: Sequence[Tuple[str, str, str, int]], - clickhouse: str, - write_batch_size: int, - data_dir: pathlib.Path, + file_entry_batch: List[FileEntryBatchType], bucket_date: str, probe_cc: List[str], - fast_fail: bool, -): - netinfodb = NetinfoDB(datadir=data_dir, download=False) + data_dir: pathlib.Path, + clickhouse: str, + write_batch_size: int, + fast_fail: bool = False, +) -> int: tbatch = PerfTimer() - - tracer = trace.get_tracer(__name__) - total_failure_count = 0 - current_span = trace.get_current_span() + ccs = ccs_set(probe_cc) + total_measurement_count = 0 + netinfodb = NetinfoDB(datadir=data_dir, download=False) with ClickhouseConnection(clickhouse, write_batch_size=write_batch_size) as db: - ccs = ccs_set(probe_cc) - idx = 0 for bucket_name, s3path, ext, fe_size in file_entry_batch: failure_count = 0 - # Nest the traced span within the current span - with tracer.start_span("MakeObservations:stream_file_entry") as span: - log.debug(f"processing file s3://{bucket_name}/{s3path}") - t = PerfTimer() - try: - for msmt_dict in stream_measurements( - bucket_name=bucket_name, s3path=s3path, ext=ext - ): - # Legacy cans don't allow us to pre-filter on the probe_cc, so - # we need to check for probe_cc consistency in here. - if ccs and msmt_dict["probe_cc"] not in ccs: - continue - msmt = None - try: - t = PerfTimer() - msmt = load_measurement(msmt_dict) - if not msmt.test_keys: - log.error( - f"measurement with empty test_keys: ({msmt.measurement_uid})", - exc_info=True, - ) - continue - obs_tuple = measurement_to_observations( - msmt=msmt, - netinfodb=netinfodb, - bucket_date=bucket_date, - ) - for obs_list in obs_tuple: - db.write_table_model_rows(obs_list) - idx += 1 - except Exception as exc: - msmt_str = msmt_dict.get("report_id", None) - if msmt: - msmt_str = msmt.measurement_uid - log.error( - f"failed at idx: {idx} ({msmt_str})", exc_info=True - ) - failure_count += 1 - - if fast_fail: - db.close() - raise exc - log.debug(f"done processing file s3://{bucket_name}/{s3path}") - except Exception as exc: - log.error( - f"failed to stream measurements from s3://{bucket_name}/{s3path}" - ) - log.error(exc) - # TODO(art): figure out if the rate of these metrics is too - # much. For each processed file a telemetry event is generated. - span.set_attribute("kb_per_sec", fe_size / 1024 / t.s) - span.set_attribute("fe_size", fe_size) - span.set_attribute("failure_count", failure_count) - span.add_event(f"s3_path: s3://{bucket_name}/{s3path}") - total_failure_count += failure_count - - current_span.set_attribute("total_runtime_ms", tbatch.ms) - current_span.set_attribute("total_failure_count", total_failure_count) - return idx + log.debug(f"processing file s3://{bucket_name}/{s3path}") + measurement_count, failure_count = make_observations_for_file_entry( + db=db, + netinfodb=netinfodb, + bucket_date=bucket_date, + bucket_name=bucket_name, + s3path=s3path, + ext=ext, + fast_fail=fast_fail, + ccs=ccs, + ) + total_measurement_count += measurement_count + total_failure_count += failure_count + log.info( + f"finished batch for bucket_date={bucket_date}\n" + f" {len(file_entry_batch)} entries \n" + f" in {tbatch.s:.3f} seconds \n" + f" msmt/s: {total_measurement_count / tbatch.s}" + ) + return total_measurement_count -@activity.defn -def make_observation_in_day(params: MakeObservationsParams) -> dict: - day = datetime.strptime(params.bucket_date, "%Y-%m-%d").date() - # TODO(art): this previous range search and deletion makes the idempotence - # of the activity not 100% accurate. - # We should look into fixing it. - with ClickhouseConnection(params.clickhouse) as db: - prev_ranges = [] - for table_name in ["obs_web"]: - prev_ranges.append( - ( - table_name, - get_prev_range( - db=db, - table_name=table_name, - bucket_date=params.bucket_date, - test_name=params.test_name, - probe_cc=params.probe_cc, - ), - ) - ) - log.info(f"prev_ranges: {prev_ranges}") +ObservationBatches = TypedDict( + "ObservationBatches", + {"batches": List[List[FileEntryBatchType]], "total_size": int}, +) + + +def make_observation_batches( + bucket_date: str, probe_cc: List[str], test_name: List[str] +) -> ObservationBatches: + day = datetime.strptime(bucket_date, "%Y-%m-%d").date() t = PerfTimer() - total_t = PerfTimer() file_entry_batches, total_size = list_file_entries_batches( - probe_cc=params.probe_cc, - test_name=params.test_name, + probe_cc=probe_cc, + test_name=test_name, start_day=day, end_day=day + timedelta(days=1), ) - log.info(f"running {len(file_entry_batches)} batches took {t.pretty}") - - total_msmt_count = 0 - for batch in file_entry_batches: - msmt_cnt = make_observations_for_file_entry_batch( - batch, - params.clickhouse, - 500_000, - pathlib.Path(params.data_dir), - params.bucket_date, - params.probe_cc, - params.fast_fail, - ) - total_msmt_count += msmt_cnt - - mb_per_sec = round(total_size / total_t.s / 10**6, 1) - msmt_per_sec = round(total_msmt_count / total_t.s) log.info( - f"finished processing all batches in {total_t.pretty} speed: {mb_per_sec}MB/s ({msmt_per_sec}msmt/s)" + f"listing bucket_date={bucket_date} {len(file_entry_batches)} batches took {t.pretty}" + ) + return {"batches": file_entry_batches, "total_size": total_size} + + +MakeObservationsResult = TypedDict( + "MakeObservationsResult", + { + "measurement_count": int, + "measurement_per_sec": float, + "mb_per_sec": float, + "total_size": int, + }, +) + + +@activity.defn +async def make_observations(params: MakeObservationsParams) -> MakeObservationsResult: + loop = asyncio.get_running_loop() + + tbatch = PerfTimer() + current_span = trace.get_current_span() + activity.logger.info(f"starting update_assets for {params.bucket_date}") + await loop.run_in_executor( + None, + functools.partial( + update_assets, + data_dir=params.data_dir, + refresh_hours=10, + force_update=False, + ), + ) + batches = await loop.run_in_executor( + None, + functools.partial( + make_observation_batches, + probe_cc=params.probe_cc, + test_name=params.test_name, + bucket_date=params.bucket_date, + ), ) + awaitables = [] + for file_entry_batch in batches["batches"]: + awaitables.append( + loop.run_in_executor( + process_pool_executor, + functools.partial( + make_observations_for_file_entry_batch, + file_entry_batch=file_entry_batch, + bucket_date=params.bucket_date, + probe_cc=params.probe_cc, + data_dir=pathlib.Path(params.data_dir), + clickhouse=params.clickhouse, + write_batch_size=config.clickhouse_write_batch_size, + fast_fail=False, + ), + ), + ) + measurement_count = sum(await asyncio.gather(*awaitables)) + + current_span.set_attribute("total_runtime_ms", tbatch.ms) + # current_span.set_attribute("total_failure_count", total_failure_count) + + return { + "measurement_count": measurement_count, + "mb_per_sec": float(batches["total_size"]) / 1024 / 1024 / tbatch.s, + "measurement_per_sec": measurement_count / tbatch.s, + "total_size": batches["total_size"], + } - if len(prev_ranges) > 0: - with ClickhouseConnection(params.clickhouse) as db: - for table_name, pr in prev_ranges: - log.info("deleting previous range of {pr}") - maybe_delete_prev_range(db=db, prev_range=pr) - return {"size": total_size, "measurement_count": total_msmt_count} +@dataclass +class GetPreviousRangeParams: + clickhouse: str + bucket_date: str + test_name: List[str] + probe_cc: List[str] + tables: List[str] + + +@activity.defn +def get_previous_range(params: GetPreviousRangeParams) -> List[PrevRange]: + with ClickhouseConnection(params.clickhouse) as db: + prev_ranges = [] + for table_name in params.tables: + prev_ranges.append( + get_prev_range( + db=db, + table_name=table_name, + bucket_date=params.bucket_date, + test_name=params.test_name, + probe_cc=params.probe_cc, + ), + ) + return prev_ranges + + +@dataclass +class DeletePreviousRangeParams: + clickhouse: str + previous_ranges: List[PrevRange] + + +@activity.defn +def delete_previous_range(params: DeletePreviousRangeParams) -> List[str]: + delete_queries = [] + with ClickhouseConnection(params.clickhouse) as db: + for pr in params.previous_ranges: + log.info("deleting previous range of {pr}") + delete_queries.append(maybe_delete_prev_range(db=db, prev_range=pr)) + return delete_queries diff --git a/oonipipeline/src/oonipipeline/temporal/client_operations.py b/oonipipeline/src/oonipipeline/temporal/client_operations.py index e2a595fb..fd5235e7 100644 --- a/oonipipeline/src/oonipipeline/temporal/client_operations.py +++ b/oonipipeline/src/oonipipeline/temporal/client_operations.py @@ -1,14 +1,14 @@ import asyncio import logging from dataclasses import dataclass -from datetime import datetime, timezone, timedelta +from datetime import datetime from typing import List, Optional, Tuple -from oonipipeline.temporal.workflows import ( - AnalysisWorkflowParams, - ObservationsWorkflowParams, - schedule_analysis, - schedule_observations, +from oonipipeline.temporal.schedules import ( + ScheduleIdMap, + schedule_all, + schedule_backfill, + reschedule_all, ) from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter @@ -19,8 +19,6 @@ from temporalio.client import ( Client as TemporalClient, - ScheduleBackfill, - ScheduleOverlapPolicy, WorkflowExecution, ) from temporalio.service import TLSConfig @@ -115,63 +113,65 @@ async def temporal_connect( async def execute_backfill( - schedule_id: str, - temporal_config: TemporalConfig, + probe_cc: List[str], + test_name: List[str], start_at: datetime, end_at: datetime, + workflow_name: str, + temporal_config: TemporalConfig, ): - log.info(f"running backfill for schedule_id={schedule_id}") + log.info(f"creating all schedules") client = await temporal_connect(temporal_config=temporal_config) - found_schedule_id = None - schedule_list = await client.list_schedules() - async for sched in schedule_list: - if sched.id.startswith(schedule_id): - found_schedule_id = sched.id - break - if not found_schedule_id: - log.error(f"schedule ID not found for prefix {schedule_id}") - return - - handle = client.get_schedule_handle(found_schedule_id) - await handle.backfill( - ScheduleBackfill( - start_at=start_at + timedelta(hours=1), - end_at=end_at + timedelta(hours=1), - overlap=ScheduleOverlapPolicy.ALLOW_ALL, - ), + return await schedule_backfill( + client=client, + probe_cc=probe_cc, + test_name=test_name, + start_at=start_at, + end_at=end_at, + workflow_name=workflow_name, ) async def create_schedules( - obs_params: Optional[ObservationsWorkflowParams], - analysis_params: Optional[AnalysisWorkflowParams], + probe_cc: List[str], + test_name: List[str], + clickhouse_url: str, + data_dir: str, temporal_config: TemporalConfig, - delete: bool = False, -) -> dict: +) -> ScheduleIdMap: log.info(f"creating all schedules") client = await temporal_connect(temporal_config=temporal_config) - obs_schedule_id = None - if obs_params is not None: - obs_schedule_id = await schedule_observations( - client=client, params=obs_params, delete=delete - ) - log.info(f"created schedule observations schedule with ID={obs_schedule_id}") + return await schedule_all( + client=client, + probe_cc=probe_cc, + test_name=test_name, + clickhouse_url=clickhouse_url, + data_dir=data_dir, + ) - analysis_schedule_id = None - if analysis_params is not None: - analysis_schedule_id = await schedule_analysis( - client=client, params=analysis_params, delete=delete - ) - log.info(f"created schedule analysis schedule with ID={analysis_schedule_id}") - return { - "analysis_schedule_id": analysis_schedule_id, - "observations_schedule_id": obs_schedule_id, - } +async def reschedule( + probe_cc: List[str], + test_name: List[str], + clickhouse_url: str, + data_dir: str, + temporal_config: TemporalConfig, +) -> ScheduleIdMap: + log.info(f"rescheduling everything") + + client = await temporal_connect(temporal_config=temporal_config) + + return await reschedule_all( + client=client, + probe_cc=probe_cc, + test_name=test_name, + clickhouse_url=clickhouse_url, + data_dir=data_dir, + ) async def get_status( @@ -213,7 +213,9 @@ async def get_status( def run_backfill( temporal_config: TemporalConfig, - schedule_id: str, + probe_cc: List[str], + test_name: List[str], + workflow_name: str, start_at: datetime, end_at: datetime, ): @@ -221,7 +223,9 @@ def run_backfill( asyncio.run( execute_backfill( temporal_config=temporal_config, - schedule_id=schedule_id, + workflow_name=workflow_name, + probe_cc=probe_cc, + test_name=test_name, start_at=start_at, end_at=end_at, ) @@ -231,18 +235,41 @@ def run_backfill( def run_create_schedules( - obs_params: Optional[ObservationsWorkflowParams], - analysis_params: Optional[AnalysisWorkflowParams], + probe_cc: List[str], + test_name: List[str], + clickhouse_url: str, + data_dir: str, temporal_config: TemporalConfig, - delete: bool, ): try: asyncio.run( create_schedules( - obs_params=obs_params, - analysis_params=analysis_params, + probe_cc=probe_cc, + test_name=test_name, + clickhouse_url=clickhouse_url, + data_dir=data_dir, + temporal_config=temporal_config, + ) + ) + except KeyboardInterrupt: + print("shutting down") + + +def run_reschedule( + probe_cc: List[str], + test_name: List[str], + clickhouse_url: str, + data_dir: str, + temporal_config: TemporalConfig, +): + try: + asyncio.run( + reschedule( + probe_cc=probe_cc, + test_name=test_name, + clickhouse_url=clickhouse_url, + data_dir=data_dir, temporal_config=temporal_config, - delete=delete, ) ) except KeyboardInterrupt: diff --git a/oonipipeline/src/oonipipeline/temporal/common.py b/oonipipeline/src/oonipipeline/temporal/common.py index 28b6e44a..6e497b43 100644 --- a/oonipipeline/src/oonipipeline/temporal/common.py +++ b/oonipipeline/src/oonipipeline/temporal/common.py @@ -3,6 +3,7 @@ from datetime import datetime, timedelta +import time from typing import ( Any, Callable, @@ -16,13 +17,14 @@ log = logging.getLogger("oonidata.processing") +TS_FORMAT = "%Y-%m-%d %H:%M:%S" @dataclass class BatchParameters: test_name: List[str] probe_cc: List[str] bucket_date: Optional[str] - timestamp: Optional[datetime] + timestamp: Optional[str] @dataclass @@ -31,8 +33,8 @@ class PrevRange: batch_parameters: BatchParameters timestamp_column: Optional[str] probe_cc_column: Optional[str] - max_created_at: Optional[datetime] = None - min_created_at: Optional[datetime] = None + max_created_at: Optional[str] = None + min_created_at: Optional[str] = None def format_query(self): start_timestamp = None @@ -46,7 +48,9 @@ def format_query(self): q_args["bucket_date"] = self.batch_parameters.bucket_date elif self.batch_parameters.timestamp: - start_timestamp = self.batch_parameters.timestamp + start_timestamp = datetime.strptime( + self.batch_parameters.timestamp, TS_FORMAT + ) end_timestamp = start_timestamp + timedelta(days=1) q_args["start_timestamp"] = start_timestamp q_args["end_timestamp"] = end_timestamp @@ -64,14 +68,25 @@ def format_query(self): return where, q_args -def maybe_delete_prev_range(db: ClickhouseConnection, prev_range: PrevRange): +def wait_for_mutations(db, table_name): + while True: + res = db.execute( + f"SELECT * FROM system.mutations WHERE is_done=0 AND table='{table_name}';" + ) + if len(res) == 0: # type: ignore + break + time.sleep(1) + + +def maybe_delete_prev_range(db: ClickhouseConnection, prev_range: PrevRange) -> str: """ We perform a lightweight delete of all the rows which have been regenerated, so we don't have any duplicates in the table """ if not prev_range.max_created_at or not prev_range.min_created_at: - return + return "" + wait_for_mutations(db, prev_range.table_name) # Disabled due to: https://github.com/ClickHouse/ClickHouse/issues/40651 # db.execute("SET allow_experimental_lightweight_delete = true;") @@ -84,7 +99,8 @@ def maybe_delete_prev_range(db: ClickhouseConnection, prev_range: PrevRange): q = f"ALTER TABLE {prev_range.table_name} DELETE " final_query = q + where - return db.execute(final_query, q_args) + db.execute(final_query, q_args) + return final_query def get_prev_range( @@ -93,7 +109,7 @@ def get_prev_range( test_name: List[str], probe_cc: List[str], bucket_date: Optional[str] = None, - timestamp: Optional[datetime] = None, + timestamp: Optional[str] = None, timestamp_column: str = "timestamp", probe_cc_column: str = "probe_cc", ) -> PrevRange: @@ -146,11 +162,15 @@ def get_prev_range( # We pad it by 1 second to take into account the time resolution downgrade # happening when going from clickhouse to python data types if max_created_at and min_created_at: - prev_range.max_created_at = (max_created_at + timedelta(seconds=1)).replace( - tzinfo=None + prev_range.max_created_at = ( + (max_created_at + timedelta(seconds=1)) + .replace(tzinfo=None) + .strftime(TS_FORMAT) ) - prev_range.min_created_at = (min_created_at - timedelta(seconds=1)).replace( - tzinfo=None + prev_range.min_created_at = ( + (min_created_at - timedelta(seconds=1)) + .replace(tzinfo=None) + .strftime(TS_FORMAT) ) return prev_range diff --git a/oonipipeline/src/oonipipeline/temporal/schedules.py b/oonipipeline/src/oonipipeline/temporal/schedules.py new file mode 100644 index 00000000..732d657f --- /dev/null +++ b/oonipipeline/src/oonipipeline/temporal/schedules.py @@ -0,0 +1,245 @@ +from dataclasses import dataclass +from typing import List, Optional, TypedDict + +import logging +from datetime import datetime, timedelta, timezone + + +from oonipipeline.temporal.workflows.analysis import AnalysisWorkflowParams +from oonipipeline.temporal.workflows.analysis import AnalysisWorkflow +from oonipipeline.temporal.workflows.common import ( + MAKE_OBSERVATIONS_START_TO_CLOSE_TIMEOUT, +) +from oonipipeline.temporal.workflows.common import TASK_QUEUE_NAME +from oonipipeline.temporal.workflows.common import MAKE_ANALYSIS_START_TO_CLOSE_TIMEOUT +from oonipipeline.temporal.workflows.observations import ObservationsWorkflow +from oonipipeline.temporal.workflows.observations import ObservationsWorkflowParams +from temporalio import workflow +from temporalio.client import ( + Client as TemporalClient, + Schedule, + ScheduleBackfill, + ScheduleActionStartWorkflow, + ScheduleIntervalSpec, + ScheduleSpec, + ScheduleState, + SchedulePolicy, + ScheduleOverlapPolicy, +) + +log = logging.getLogger("oonipipeline.workflows") + +OBSERVATIONS_SCHED_PREFIX = "oopln-sched-observations" +OBSERVATIONS_WF_PREFIX = "oopln-wf-observations" +ANALYSIS_WF_PREFIX = "oopln-wf-analysis" +ANALYSIS_SCHED_PREFIX = "oopln-sched-analysis" + + +def gen_schedule_filter_id(probe_cc: List[str], test_name: List[str]): + probe_cc_key = "ALLCCS" + if len(probe_cc) > 0: + probe_cc_key = ".".join(map(lambda x: x.lower(), sorted(probe_cc))) + test_name_key = "ALLTNS" + if len(test_name) > 0: + test_name_key = ".".join(map(lambda x: x.lower(), sorted(test_name))) + + return f"{probe_cc_key}-{test_name_key}" + + +@dataclass +class ScheduleIdMap: + observations: Optional[str] = None + analysis: Optional[str] = None + + +@dataclass +class ScheduleIdMapList: + observations: List[str] + analysis: List[str] + + +async def list_existing_schedules( + client: TemporalClient, + probe_cc: List[str], + test_name: List[str], +): + schedule_id_map_list = ScheduleIdMapList( + observations=[], + analysis=[], + ) + filter_id = gen_schedule_filter_id(probe_cc, test_name) + + schedule_list = await client.list_schedules() + async for sched in schedule_list: + if sched.id.startswith(f"{OBSERVATIONS_SCHED_PREFIX}-{filter_id}"): + schedule_id_map_list.observations.append(sched.id) + elif sched.id.startswith(f"{ANALYSIS_SCHED_PREFIX}-{filter_id}"): + schedule_id_map_list.analysis.append(sched.id) + + return schedule_id_map_list + + +async def schedule_all( + client: TemporalClient, + probe_cc: List[str], + test_name: List[str], + clickhouse_url: str, + data_dir: str, +) -> ScheduleIdMap: + schedule_id_map = ScheduleIdMap() + filter_id = gen_schedule_filter_id(probe_cc, test_name) + # We need to append a timestamp to the schedule so that we are able to rerun + # the backfill operations by deleting the existing schedule and + # re-scheduling it. Not doing so will mean that temporal will believe the + # workflow has already been execututed and will refuse to re-run it. + # TODO(art): check if there is a more idiomatic way of implementing this + ts = datetime.now(timezone.utc).strftime("%y.%m.%d_%H%M%S") + + existing_schedules = await list_existing_schedules( + client=client, probe_cc=probe_cc, test_name=test_name + ) + assert ( + len(existing_schedules.observations) < 2 + ), f"duplicate schedule for observations: {existing_schedules.observations}" + assert ( + len(existing_schedules.analysis) < 2 + ), f"duplicate schedule for analysis: {existing_schedules.analysis}" + + if len(existing_schedules.observations) == 1: + schedule_id_map.observations = existing_schedules.observations[0] + if len(existing_schedules.analysis) == 1: + schedule_id_map.analysis = existing_schedules.analysis[0] + + if schedule_id_map.observations is None: + obs_params = ObservationsWorkflowParams( + probe_cc=probe_cc, + test_name=test_name, + clickhouse=clickhouse_url, + data_dir=data_dir, + fast_fail=False, + ) + sched_handle = await client.create_schedule( + id=f"{OBSERVATIONS_SCHED_PREFIX}-{filter_id}-{ts}", + schedule=Schedule( + action=ScheduleActionStartWorkflow( + ObservationsWorkflow.run, + obs_params, + id=f"{OBSERVATIONS_WF_PREFIX}-{filter_id}-{ts}", + task_queue=TASK_QUEUE_NAME, + execution_timeout=MAKE_OBSERVATIONS_START_TO_CLOSE_TIMEOUT, + task_timeout=MAKE_OBSERVATIONS_START_TO_CLOSE_TIMEOUT, + run_timeout=MAKE_OBSERVATIONS_START_TO_CLOSE_TIMEOUT, + ), + spec=ScheduleSpec( + intervals=[ + ScheduleIntervalSpec( + every=timedelta(days=1), offset=timedelta(hours=2) + ) + ], + ), + policy=SchedulePolicy(overlap=ScheduleOverlapPolicy.ALLOW_ALL), + state=ScheduleState( + note="Run the observations workflow every day with an offset of 2 hours to ensure the files have been written to s3" + ), + ), + ) + schedule_id_map.observations = sched_handle.id + + if schedule_id_map.analysis is None: + analysis_params = AnalysisWorkflowParams( + probe_cc=probe_cc, + test_name=test_name, + clickhouse=clickhouse_url, + data_dir=data_dir, + fast_fail=False, + ) + sched_handle = await client.create_schedule( + id=f"{ANALYSIS_SCHED_PREFIX}-{filter_id}-{ts}", + schedule=Schedule( + action=ScheduleActionStartWorkflow( + AnalysisWorkflow.run, + analysis_params, + id=f"{ANALYSIS_WF_PREFIX}-{filter_id}-{ts}", + task_queue=TASK_QUEUE_NAME, + execution_timeout=MAKE_ANALYSIS_START_TO_CLOSE_TIMEOUT, + task_timeout=MAKE_ANALYSIS_START_TO_CLOSE_TIMEOUT, + run_timeout=MAKE_ANALYSIS_START_TO_CLOSE_TIMEOUT, + ), + spec=ScheduleSpec( + intervals=[ + ScheduleIntervalSpec( + # We offset the Analysis workflow by 4 hours assuming + # that the observation generation will take less than 4 + # hours to complete. + # TODO(art): it's probably better to refactor this into some + # kind of DAG + every=timedelta(days=1), + offset=timedelta(hours=6), + ) + ], + ), + policy=SchedulePolicy(overlap=ScheduleOverlapPolicy.ALLOW_ALL), + state=ScheduleState( + note="Run the analysis workflow every day with an offset of 6 hours to ensure the observation workflow has completed" + ), + ), + ) + schedule_id_map.analysis = sched_handle.id + + return schedule_id_map + + +async def reschedule_all( + client: TemporalClient, + probe_cc: List[str], + test_name: List[str], + clickhouse_url: str, + data_dir: str, +) -> ScheduleIdMap: + existing_schedules = await list_existing_schedules( + client=client, probe_cc=probe_cc, test_name=test_name + ) + for schedule_id in existing_schedules.observations + existing_schedules.analysis: + await client.get_schedule_handle(schedule_id).delete() + + return await schedule_all( + client=client, + probe_cc=probe_cc, + test_name=test_name, + clickhouse_url=clickhouse_url, + data_dir=data_dir, + ) + + +async def schedule_backfill( + client: TemporalClient, + workflow_name: str, + start_at: datetime, + end_at: datetime, + probe_cc: List[str], + test_name: List[str], +): + existing_schedules = await list_existing_schedules( + client=client, probe_cc=probe_cc, test_name=test_name + ) + if workflow_name == "observations": + assert ( + len(existing_schedules.observations) == 1 + ), "Expected one schedule for observations" + schedule_id = existing_schedules.observations[0] + elif workflow_name == "analysis": + assert ( + len(existing_schedules.analysis) == 1 + ), "Expected one schedule for analysis" + schedule_id = existing_schedules.analysis[0] + else: + raise ValueError(f"Unknown workflow name: {workflow_name}") + + handle = client.get_schedule_handle(schedule_id) + await handle.backfill( + ScheduleBackfill( + start_at=start_at + timedelta(hours=1), + end_at=end_at + timedelta(hours=1), + overlap=ScheduleOverlapPolicy.BUFFER_ALL, + ), + ) diff --git a/oonipipeline/src/oonipipeline/temporal/workers.py b/oonipipeline/src/oonipipeline/temporal/workers.py index 39c8afc1..2622c96b 100644 --- a/oonipipeline/src/oonipipeline/temporal/workers.py +++ b/oonipipeline/src/oonipipeline/temporal/workers.py @@ -8,25 +8,27 @@ from oonipipeline.temporal.activities.common import ( get_obs_count_by_cc, optimize_all_tables, - update_assets, + optimize_tables, ) from oonipipeline.temporal.activities.ground_truths import make_ground_truths_in_day -from oonipipeline.temporal.activities.observations import make_observation_in_day +from oonipipeline.temporal.activities.observations import ( + delete_previous_range, + get_previous_range, + make_observations, +) from oonipipeline.temporal.client_operations import ( TemporalConfig, log, temporal_connect, ) -from oonipipeline.temporal.workflows import ( - TASK_QUEUE_NAME, - AnalysisWorkflow, - GroundTruthsWorkflow, - ObservationsWorkflow, -) +from oonipipeline.temporal.workflows.common import TASK_QUEUE_NAME +from oonipipeline.temporal.workflows.analysis import AnalysisWorkflow +from oonipipeline.temporal.workflows.ctrl import GroundTruthsWorkflow +from oonipipeline.temporal.workflows.observations import ObservationsWorkflow log = logging.getLogger("oonipipeline.workers") -from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, Executor interrupt_event = asyncio.Event() @@ -37,25 +39,27 @@ ] ACTIVTIES = [ - make_observation_in_day, + delete_previous_range, + get_previous_range, + make_observations, make_ground_truths_in_day, make_analysis_in_a_day, optimize_all_tables, get_obs_count_by_cc, - update_assets, + optimize_tables, ] -async def worker_main(temporal_config: TemporalConfig): +async def worker_main( + temporal_config: TemporalConfig, max_workers: int, executor: Executor +): client = await temporal_connect(temporal_config=temporal_config) - max_workers = max(os.cpu_count() or 4, 4) - log.info(f"starting workers with max_workers={max_workers}") async with Worker( client, task_queue=TASK_QUEUE_NAME, workflows=WORKFLOWS, activities=ACTIVTIES, - activity_executor=ThreadPoolExecutor(max_workers=max_workers + 2), + activity_executor=executor, max_concurrent_activities=max_workers, max_concurrent_workflow_tasks=max_workers, ): @@ -65,12 +69,25 @@ async def worker_main(temporal_config: TemporalConfig): def start_workers(temporal_config: TemporalConfig): + max_workers = max(os.cpu_count() or 4, 4) + log.info(f"starting workers with max_workers={max_workers}") + executor = ThreadPoolExecutor(max_workers=max_workers + 2) + loop = asyncio.new_event_loop() + loop.set_default_executor(executor) # TODO(art): Investigate if we want to upgrade to python 3.12 and use this # instead # loop.set_task_factory(asyncio.eager_task_factory) try: - loop.run_until_complete(worker_main(temporal_config=temporal_config)) + loop.run_until_complete( + worker_main( + temporal_config=temporal_config, + max_workers=max_workers, + executor=executor, + ) + ) except KeyboardInterrupt: interrupt_event.set() loop.run_until_complete(loop.shutdown_asyncgens()) + executor.shutdown(wait=True, cancel_futures=True) + log.info("shut down thread pool") diff --git a/oonipipeline/src/oonipipeline/temporal/workflows.py b/oonipipeline/src/oonipipeline/temporal/workflows.py deleted file mode 100644 index 7f844460..00000000 --- a/oonipipeline/src/oonipipeline/temporal/workflows.py +++ /dev/null @@ -1,395 +0,0 @@ -from dataclasses import dataclass -from typing import List, Optional - -import logging -import asyncio -from datetime import datetime, timedelta, timezone - - -from temporalio import workflow -from temporalio.common import ( - SearchAttributeKey, -) -from temporalio.client import ( - Client as TemporalClient, - Schedule, - ScheduleActionStartWorkflow, - ScheduleIntervalSpec, - ScheduleSpec, - ScheduleState, - SchedulePolicy, - ScheduleOverlapPolicy, -) - - -with workflow.unsafe.imports_passed_through(): - import clickhouse_driver - - from oonidata.dataclient import date_interval - from oonidata.datautils import PerfTimer - from oonipipeline.db.connections import ClickhouseConnection - from oonipipeline.temporal.activities.analysis import ( - MakeAnalysisParams, - log, - make_analysis_in_a_day, - make_cc_batches, - ) - from oonipipeline.temporal.activities.common import ( - get_obs_count_by_cc, - ObsCountParams, - update_assets, - UpdateAssetsParams, - ) - from oonipipeline.temporal.activities.observations import ( - MakeObservationsParams, - make_observation_in_day, - ) - - from oonipipeline.temporal.activities.ground_truths import ( - MakeGroundTruthsParams, - make_ground_truths_in_day, - ) - from oonipipeline.temporal.activities.common import ( - optimize_all_tables, - ClickhouseParams, - ) - from oonipipeline.temporal.activities.ground_truths import get_ground_truth_db_path - -# Handle temporal sandbox violations related to calls to self.processName = -# mp.current_process().name in logger, see: -# https://github.com/python/cpython/blob/1316692e8c7c1e1f3b6639e51804f9db5ed892ea/Lib/logging/__init__.py#L362 -logging.logMultiprocessing = False - -log = logging.getLogger("oonipipeline.workflows") - -TASK_QUEUE_NAME = "oonipipeline-task-queue" - -# TODO(art): come up with a nicer way to nest workflows so we don't need such a high global timeout -MAKE_OBSERVATIONS_START_TO_CLOSE_TIMEOUT = timedelta(hours=48) -MAKE_GROUND_TRUTHS_START_TO_CLOSE_TIMEOUT = timedelta(hours=1) -MAKE_ANALYSIS_START_TO_CLOSE_TIMEOUT = timedelta(hours=10) - - -def get_workflow_start_time() -> datetime: - workflow_start_time = workflow.info().typed_search_attributes.get( - SearchAttributeKey.for_datetime("TemporalScheduledStartTime") - ) - assert workflow_start_time is not None, "TemporalScheduledStartTime not set" - return workflow_start_time - - -@dataclass -class ObservationsWorkflowParams: - probe_cc: List[str] - test_name: List[str] - clickhouse: str - data_dir: str - fast_fail: bool - log_level: int = logging.INFO - bucket_date: Optional[str] = None - - -@workflow.defn -class ObservationsWorkflow: - @workflow.run - async def run(self, params: ObservationsWorkflowParams) -> dict: - await workflow.execute_activity( - update_assets, - UpdateAssetsParams(data_dir=params.data_dir), - start_to_close_timeout=timedelta(hours=1), - ) - - if params.bucket_date is None: - params.bucket_date = ( - get_workflow_start_time() - timedelta(days=1) - ).strftime("%Y-%m-%d") - - await workflow.execute_activity( - optimize_all_tables, - ClickhouseParams(clickhouse_url=params.clickhouse), - start_to_close_timeout=timedelta(minutes=5), - ) - - workflow.logger.info( - f"Starting observation making with probe_cc={params.probe_cc},test_name={params.test_name} bucket_date={params.bucket_date}" - ) - res = await workflow.execute_activity( - activity=make_observation_in_day, - arg=MakeObservationsParams( - probe_cc=params.probe_cc, - test_name=params.test_name, - clickhouse=params.clickhouse, - data_dir=params.data_dir, - fast_fail=params.fast_fail, - bucket_date=params.bucket_date, - ), - task_queue=TASK_QUEUE_NAME, - start_to_close_timeout=MAKE_OBSERVATIONS_START_TO_CLOSE_TIMEOUT, - ) - res["bucket_date"] = params.bucket_date - return res - - -@dataclass -class GroundTruthsWorkflowParams: - start_day: str - end_day: str - clickhouse: str - data_dir: str - - -@workflow.defn -class GroundTruthsWorkflow: - @workflow.run - async def run( - self, - params: GroundTruthsWorkflowParams, - ): - await workflow.execute_activity( - update_assets, - UpdateAssetsParams(data_dir=params.data_dir), - start_to_close_timeout=timedelta(hours=1), - ) - - start_day = datetime.strptime(params.start_day, "%Y-%m-%d").date() - end_day = datetime.strptime(params.end_day, "%Y-%m-%d").date() - - async with asyncio.TaskGroup() as tg: - for day in date_interval(start_day, end_day): - tg.create_task( - workflow.execute_activity( - make_ground_truths_in_day, - MakeGroundTruthsParams( - clickhouse=params.clickhouse, - data_dir=params.data_dir, - day=day.strftime("%Y-%m-%d"), - ), - start_to_close_timeout=MAKE_GROUND_TRUTHS_START_TO_CLOSE_TIMEOUT, - ) - ) - - -@dataclass -class AnalysisWorkflowParams: - probe_cc: List[str] - test_name: List[str] - clickhouse: str - data_dir: str - parallelism: int = 10 - fast_fail: bool = False - day: Optional[str] = None - force_rebuild_ground_truths: bool = False - log_level: int = logging.INFO - - -@workflow.defn -class AnalysisWorkflow: - @workflow.run - async def run(self, params: AnalysisWorkflowParams) -> dict: - if params.day is None: - params.day = (get_workflow_start_time() - timedelta(days=1)).strftime( - "%Y-%m-%d" - ) - - await workflow.execute_activity( - update_assets, - UpdateAssetsParams(data_dir=params.data_dir), - start_to_close_timeout=timedelta(hours=1), - ) - - await workflow.execute_activity( - optimize_all_tables, - ClickhouseParams(clickhouse_url=params.clickhouse), - start_to_close_timeout=timedelta(minutes=5), - ) - - workflow.logger.info("building ground truth databases") - t = PerfTimer() - - await workflow.execute_activity( - make_ground_truths_in_day, - MakeGroundTruthsParams( - clickhouse=params.clickhouse, - data_dir=params.data_dir, - day=params.day, - force_rebuild=params.force_rebuild_ground_truths, - ), - start_to_close_timeout=timedelta(minutes=30), - ) - workflow.logger.info(f"built ground truth db in {t.pretty}") - - start_day = datetime.strptime(params.day, "%Y-%m-%d").date() - cnt_by_cc = await workflow.execute_activity( - get_obs_count_by_cc, - ObsCountParams( - clickhouse_url=params.clickhouse, - start_day=start_day.strftime("%Y-%m-%d"), - end_day=(start_day + timedelta(days=1)).strftime("%Y-%m-%d"), - ), - start_to_close_timeout=timedelta(minutes=30), - ) - - cc_batches = make_cc_batches( - cnt_by_cc=cnt_by_cc, - probe_cc=params.probe_cc, - parallelism=params.parallelism, - ) - - workflow.logger.info( - f"starting processing of {len(cc_batches)} batches for {params.day} days (parallelism = {params.parallelism})" - ) - workflow.logger.info(f"({cc_batches})") - - task_list = [] - async with asyncio.TaskGroup() as tg: - for probe_cc in cc_batches: - task = tg.create_task( - workflow.execute_activity( - make_analysis_in_a_day, - MakeAnalysisParams( - probe_cc=probe_cc, - test_name=params.test_name, - clickhouse=params.clickhouse, - data_dir=params.data_dir, - fast_fail=params.fast_fail, - day=params.day, - ), - start_to_close_timeout=MAKE_ANALYSIS_START_TO_CLOSE_TIMEOUT, - ) - ) - task_list.append(task) - - total_obs_count = sum(map(lambda x: x.result()["count"], task_list)) - return {"obs_count": total_obs_count, "day": params.day} - - -def gen_schedule_id(probe_cc: List[str], test_name: List[str], name: str): - probe_cc_key = "ALLCCS" - if len(probe_cc) > 0: - probe_cc_key = ".".join(map(lambda x: x.lower(), sorted(probe_cc))) - test_name_key = "ALLTNS" - if len(test_name) > 0: - test_name_key = ".".join(map(lambda x: x.lower(), sorted(test_name))) - - return f"oonipipeline-{name}-schedule-{probe_cc_key}-{test_name_key}" - - -async def schedule_observations( - client: TemporalClient, params: ObservationsWorkflowParams, delete: bool -) -> List[str]: - base_schedule_id = gen_schedule_id( - params.probe_cc, params.test_name, "observations" - ) - - existing_schedules = [] - schedule_list = await client.list_schedules() - async for sched in schedule_list: - if sched.id.startswith(base_schedule_id): - existing_schedules.append(sched.id) - - if delete is True: - for sched_id in existing_schedules: - schedule_handle = client.get_schedule_handle(sched_id) - await schedule_handle.delete() - return existing_schedules - - if len(existing_schedules) == 1: - return existing_schedules - elif len(existing_schedules) > 0: - print("WARNING: multiple schedules detected") - return existing_schedules - - ts = datetime.now(timezone.utc).strftime("%Y%m%d%H%M%S") - schedule_id = f"{base_schedule_id}-{ts}" - - await client.create_schedule( - id=schedule_id, - schedule=Schedule( - action=ScheduleActionStartWorkflow( - ObservationsWorkflow.run, - params, - id=schedule_id.replace("-schedule-", "-workflow-"), - task_queue=TASK_QUEUE_NAME, - execution_timeout=MAKE_OBSERVATIONS_START_TO_CLOSE_TIMEOUT, - task_timeout=MAKE_OBSERVATIONS_START_TO_CLOSE_TIMEOUT, - run_timeout=MAKE_OBSERVATIONS_START_TO_CLOSE_TIMEOUT, - ), - spec=ScheduleSpec( - intervals=[ - ScheduleIntervalSpec( - every=timedelta(days=1), offset=timedelta(hours=2) - ) - ], - ), - policy=SchedulePolicy(overlap=ScheduleOverlapPolicy.TERMINATE_OTHER), - state=ScheduleState( - note="Run the observations workflow every day with an offset of 2 hours to ensure the files have been written to s3" - ), - ), - ) - return [schedule_id] - - -async def schedule_analysis( - client: TemporalClient, params: AnalysisWorkflowParams, delete: bool -) -> List[str]: - base_schedule_id = gen_schedule_id(params.probe_cc, params.test_name, "analysis") - - existing_schedules = [] - schedule_list = await client.list_schedules() - async for sched in schedule_list: - if sched.id.startswith(base_schedule_id): - existing_schedules.append(sched.id) - - if delete is True: - for sched_id in existing_schedules: - schedule_handle = client.get_schedule_handle(sched_id) - await schedule_handle.delete() - return existing_schedules - - if len(existing_schedules) == 1: - return existing_schedules - elif len(existing_schedules) > 0: - print("WARNING: multiple schedules detected") - return existing_schedules - - # We need to append a timestamp to the schedule so that we are able to rerun - # the backfill operations by deleting the existing schedule and - # re-scheduling it. Not doing so will mean that temporal will believe the - # workflow has already been execututed and will refuse to re-run it. - # TODO(art): check if there is a more idiomatic way of implementing this - ts = datetime.now(timezone.utc).strftime("%Y%m%d%H%M%S") - schedule_id = f"{base_schedule_id}-{ts}" - - await client.create_schedule( - id=schedule_id, - schedule=Schedule( - action=ScheduleActionStartWorkflow( - AnalysisWorkflow.run, - params, - id=schedule_id.replace("-schedule-", "-workflow-"), - task_queue=TASK_QUEUE_NAME, - execution_timeout=MAKE_ANALYSIS_START_TO_CLOSE_TIMEOUT, - task_timeout=MAKE_ANALYSIS_START_TO_CLOSE_TIMEOUT, - run_timeout=MAKE_ANALYSIS_START_TO_CLOSE_TIMEOUT, - ), - spec=ScheduleSpec( - intervals=[ - ScheduleIntervalSpec( - # We offset the Analysis workflow by 4 hours assuming - # that the observation generation will take less than 4 - # hours to complete. - # TODO(art): it's probably better to refactor this into some - # kind of DAG - every=timedelta(days=1), - offset=timedelta(hours=6), - ) - ], - ), - policy=SchedulePolicy(overlap=ScheduleOverlapPolicy.TERMINATE_OTHER), - state=ScheduleState( - note="Run the analysis workflow every day with an offset of 6 hours to ensure the observation workflow has completed" - ), - ), - ) - return [schedule_id] diff --git a/oonipipeline/src/oonipipeline/temporal/workflows/__init__.py b/oonipipeline/src/oonipipeline/temporal/workflows/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/oonipipeline/src/oonipipeline/temporal/workflows/analysis.py b/oonipipeline/src/oonipipeline/temporal/workflows/analysis.py new file mode 100644 index 00000000..97c653b9 --- /dev/null +++ b/oonipipeline/src/oonipipeline/temporal/workflows/analysis.py @@ -0,0 +1,118 @@ +import asyncio +from datetime import datetime, timedelta +import logging +from dataclasses import dataclass +from typing import List, Optional + + +from temporalio import workflow + +with workflow.unsafe.imports_passed_through(): + from oonidata.datautils import PerfTimer + from oonipipeline.temporal.activities.analysis import ( + MakeAnalysisParams, + make_analysis_in_a_day, + make_cc_batches, + ) + from oonipipeline.temporal.activities.common import ( + ClickhouseParams, + ObsCountParams, + get_obs_count_by_cc, + optimize_all_tables, + ) + from oonipipeline.temporal.activities.ground_truths import ( + MakeGroundTruthsParams, + make_ground_truths_in_day, + ) + from oonipipeline.temporal.workflows.common import ( + MAKE_ANALYSIS_START_TO_CLOSE_TIMEOUT, + get_workflow_start_time, + ) + + +@dataclass +class AnalysisWorkflowParams: + probe_cc: List[str] + test_name: List[str] + clickhouse: str + data_dir: str + parallelism: int = 10 + fast_fail: bool = False + day: Optional[str] = None + force_rebuild_ground_truths: bool = False + log_level: int = logging.INFO + + +@workflow.defn +class AnalysisWorkflow: + @workflow.run + async def run(self, params: AnalysisWorkflowParams) -> dict: + if params.day is None: + params.day = (get_workflow_start_time() - timedelta(days=1)).strftime( + "%Y-%m-%d" + ) + + await workflow.execute_activity( + optimize_all_tables, + ClickhouseParams(clickhouse_url=params.clickhouse), + start_to_close_timeout=timedelta(minutes=5), + ) + + workflow.logger.info("building ground truth databases") + t = PerfTimer() + + await workflow.execute_activity( + make_ground_truths_in_day, + MakeGroundTruthsParams( + clickhouse=params.clickhouse, + data_dir=params.data_dir, + day=params.day, + force_rebuild=params.force_rebuild_ground_truths, + ), + start_to_close_timeout=timedelta(minutes=30), + ) + workflow.logger.info(f"built ground truth db in {t.pretty}") + + start_day = datetime.strptime(params.day, "%Y-%m-%d").date() + cnt_by_cc = await workflow.execute_activity( + get_obs_count_by_cc, + ObsCountParams( + clickhouse_url=params.clickhouse, + start_day=start_day.strftime("%Y-%m-%d"), + end_day=(start_day + timedelta(days=1)).strftime("%Y-%m-%d"), + ), + start_to_close_timeout=timedelta(minutes=30), + ) + + cc_batches = make_cc_batches( + cnt_by_cc=cnt_by_cc, + probe_cc=params.probe_cc, + parallelism=params.parallelism, + ) + + workflow.logger.info( + f"starting processing of {len(cc_batches)} batches for {params.day} days (parallelism = {params.parallelism})" + ) + workflow.logger.info(f"({cc_batches})") + + task_list = [] + async with asyncio.TaskGroup() as tg: + for probe_cc in cc_batches: + task = tg.create_task( + workflow.execute_activity( + make_analysis_in_a_day, + MakeAnalysisParams( + probe_cc=probe_cc, + test_name=params.test_name, + clickhouse=params.clickhouse, + data_dir=params.data_dir, + fast_fail=params.fast_fail, + day=params.day, + ), + start_to_close_timeout=MAKE_ANALYSIS_START_TO_CLOSE_TIMEOUT, + ) + ) + task_list.append(task) + + total_obs_count = sum(map(lambda x: x.result()["count"], task_list)) + return {"obs_count": total_obs_count, "day": params.day} diff --git a/oonipipeline/src/oonipipeline/temporal/workflows/common.py b/oonipipeline/src/oonipipeline/temporal/workflows/common.py new file mode 100644 index 00000000..9bf07f83 --- /dev/null +++ b/oonipipeline/src/oonipipeline/temporal/workflows/common.py @@ -0,0 +1,19 @@ +from datetime import datetime, timedelta + +from temporalio import workflow +from temporalio.common import SearchAttributeKey + + +def get_workflow_start_time() -> datetime: + workflow_start_time = workflow.info().typed_search_attributes.get( + SearchAttributeKey.for_datetime("TemporalScheduledStartTime") + ) + assert workflow_start_time is not None, "TemporalScheduledStartTime not set" + return workflow_start_time + + +# TODO(art): come up with a nicer way to nest workflows so we don't need such a high global timeout +MAKE_OBSERVATIONS_START_TO_CLOSE_TIMEOUT = timedelta(hours=48) +TASK_QUEUE_NAME = "oonipipeline-task-queue" +MAKE_GROUND_TRUTHS_START_TO_CLOSE_TIMEOUT = timedelta(hours=1) +MAKE_ANALYSIS_START_TO_CLOSE_TIMEOUT = timedelta(hours=10) diff --git a/oonipipeline/src/oonipipeline/temporal/workflows/ctrl.py b/oonipipeline/src/oonipipeline/temporal/workflows/ctrl.py new file mode 100644 index 00000000..1a8dbcaf --- /dev/null +++ b/oonipipeline/src/oonipipeline/temporal/workflows/ctrl.py @@ -0,0 +1,49 @@ +import asyncio +from dataclasses import dataclass + +from datetime import datetime, timedelta + +from temporalio import workflow + +with workflow.unsafe.imports_passed_through(): + from oonidata.dataclient import date_interval + from oonipipeline.temporal.activities.ground_truths import ( + MakeGroundTruthsParams, + make_ground_truths_in_day, + ) + from oonipipeline.temporal.workflows.common import ( + MAKE_GROUND_TRUTHS_START_TO_CLOSE_TIMEOUT, + ) + + +@dataclass +class GroundTruthsWorkflowParams: + start_day: str + end_day: str + clickhouse: str + data_dir: str + + +@workflow.defn +class GroundTruthsWorkflow: + @workflow.run + async def run( + self, + params: GroundTruthsWorkflowParams, + ): + start_day = datetime.strptime(params.start_day, "%Y-%m-%d").date() + end_day = datetime.strptime(params.end_day, "%Y-%m-%d").date() + + async with asyncio.TaskGroup() as tg: + for day in date_interval(start_day, end_day): + tg.create_task( + workflow.execute_activity( + make_ground_truths_in_day, + MakeGroundTruthsParams( + clickhouse=params.clickhouse, + data_dir=params.data_dir, + day=day.strftime("%Y-%m-%d"), + ), + start_to_close_timeout=MAKE_GROUND_TRUTHS_START_TO_CLOSE_TIMEOUT, + ) + ) diff --git a/oonipipeline/src/oonipipeline/temporal/workflows/observations.py b/oonipipeline/src/oonipipeline/temporal/workflows/observations.py new file mode 100644 index 00000000..d9482865 --- /dev/null +++ b/oonipipeline/src/oonipipeline/temporal/workflows/observations.py @@ -0,0 +1,114 @@ +import asyncio +from dataclasses import dataclass +from typing import List, Optional + +from datetime import timedelta + +from temporalio import workflow +from temporalio.common import RetryPolicy + +with workflow.unsafe.imports_passed_through(): + from oonidata.datautils import PerfTimer + from oonipipeline.temporal.activities.common import ( + OptimizeTablesParams, + optimize_tables, + ) + from oonipipeline.temporal.activities.observations import ( + DeletePreviousRangeParams, + GetPreviousRangeParams, + MakeObservationsParams, + delete_previous_range, + get_previous_range, + make_observations, + ) + from oonipipeline.temporal.workflows.common import ( + TASK_QUEUE_NAME, + get_workflow_start_time, + ) + + +@dataclass +class ObservationsWorkflowParams: + probe_cc: List[str] + test_name: List[str] + clickhouse: str + data_dir: str + fast_fail: bool + bucket_date: Optional[str] = None + + +@workflow.defn +class ObservationsWorkflow: + @workflow.run + async def run(self, params: ObservationsWorkflowParams) -> dict: + if params.bucket_date is None: + params.bucket_date = ( + get_workflow_start_time() - timedelta(days=1) + ).strftime("%Y-%m-%d") + + total_t = PerfTimer() + params_make_observations = MakeObservationsParams( + probe_cc=params.probe_cc, + test_name=params.test_name, + clickhouse=params.clickhouse, + data_dir=params.data_dir, + fast_fail=params.fast_fail, + bucket_date=params.bucket_date, + ) + + await workflow.execute_activity( + optimize_tables, + OptimizeTablesParams(clickhouse=params.clickhouse, table_names=["obs_web"]), + start_to_close_timeout=timedelta(minutes=20), + retry_policy=RetryPolicy(maximum_attempts=10), + ) + + previous_ranges = await workflow.execute_activity( + get_previous_range, + GetPreviousRangeParams( + clickhouse=params.clickhouse, + bucket_date=params.bucket_date, + test_name=params.test_name, + probe_cc=params.probe_cc, + tables=["obs_web"], + ), + start_to_close_timeout=timedelta(minutes=2), + retry_policy=RetryPolicy(maximum_attempts=4), + ) + workflow.logger.info( + f"finished get_previous_range for bucket_date={params.bucket_date}" + ) + + obs_res = await workflow.execute_activity( + make_observations, + params_make_observations, + start_to_close_timeout=timedelta(hours=48), + retry_policy=RetryPolicy(maximum_attempts=3), + ) + + workflow.logger.info( + f"finished make_observations for bucket_date={params.bucket_date} in " + f"{total_t.pretty} speed: {obs_res['mb_per_sec']}MB/s ({obs_res['measurement_per_sec']}msmt/s)" + ) + + workflow.logger.info( + f"finished optimize_tables for bucket_date={params.bucket_date}" + ) + + await workflow.execute_activity( + delete_previous_range, + DeletePreviousRangeParams( + clickhouse=params.clickhouse, + previous_ranges=previous_ranges, + ), + start_to_close_timeout=timedelta(minutes=10), + retry_policy=RetryPolicy(maximum_attempts=10), + ) + + return { + "measurement_count": obs_res["measurement_count"], + "size": obs_res["total_size"], + "mb_per_sec": obs_res["mb_per_sec"], + "bucket_date": params.bucket_date, + "measurement_per_sec": obs_res["measurement_per_sec"], + } diff --git a/oonipipeline/tests/test_cli.py b/oonipipeline/tests/test_cli.py index 5806fc7d..745753e5 100644 --- a/oonipipeline/tests/test_cli.py +++ b/oonipipeline/tests/test_cli.py @@ -5,7 +5,6 @@ import textwrap from oonipipeline.cli.commands import cli -from oonipipeline.cli.commands import parse_config_file from oonipipeline.temporal.client_operations import TemporalConfig, get_status import pytest @@ -38,6 +37,7 @@ def __init__(self): self.default_map = {} +@pytest.mark.skip("TODO(art): maybe test new settings parsing") def test_parse_config(tmp_path): ctx = MockContext() @@ -59,6 +59,7 @@ def test_parse_config(tmp_path): assert defaults["backfill"]["something"] == "other" +@pytest.mark.skip("TODO(art): moved into temporal_e2e") def test_full_workflow( db, cli_runner, @@ -259,6 +260,7 @@ def test_full_workflow( # We wait on the table buffers to be flushed wait_for_backfill() # assert len(list(tmp_path.glob("*.warc.gz"))) == 1 + db.execute("OPTIMIZE TABLE measurement_experiment_result") db.execute("OPTIMIZE TABLE buffer_measurement_experiment_result") wait_for_mutations(db, "measurement_experiment_result") diff --git a/oonipipeline/tests/test_ctrl.py b/oonipipeline/tests/test_ctrl.py index 44db3e71..a073397b 100644 --- a/oonipipeline/tests/test_ctrl.py +++ b/oonipipeline/tests/test_ctrl.py @@ -11,6 +11,7 @@ iter_web_ground_truths, ) from oonipipeline.temporal.activities.observations import ( + MakeObservationsFileEntryBatch, make_observations_for_file_entry_batch, ) @@ -59,7 +60,13 @@ def test_web_ground_truth_from_clickhouse(db, datadir, netinfodb, tmp_path): ) ] obs_msmt_count = make_observations_for_file_entry_batch( - file_entry_batch, db.clickhouse_url, 100, datadir, "2023-10-31", ["US"], False + file_entry_batch=file_entry_batch, + clickhouse=db.clickhouse_url, + write_batch_size=1, + data_dir=datadir, + bucket_date="2023-10-31", + probe_cc=["US"], + fast_fail=False, ) assert obs_msmt_count == 299 # Wait for buffers to flush diff --git a/oonipipeline/tests/test_temporal_e2e.py b/oonipipeline/tests/test_temporal_e2e.py new file mode 100644 index 00000000..f09cf1f2 --- /dev/null +++ b/oonipipeline/tests/test_temporal_e2e.py @@ -0,0 +1,117 @@ +import asyncio +from concurrent.futures import ThreadPoolExecutor +from oonipipeline.temporal.schedules import schedule_all, reschedule_all +import pytest + +from temporalio.testing import WorkflowEnvironment +from temporalio.worker import Worker + +from oonipipeline.temporal.workflows.common import TASK_QUEUE_NAME + +from oonipipeline.temporal.workflows.observations import ( + ObservationsWorkflow, + ObservationsWorkflowParams, +) +from oonipipeline.temporal.workers import ACTIVTIES + +from .utils import wait_for_mutations + + +@pytest.mark.asyncio +async def test_scheduling(datadir, db): + async with await WorkflowEnvironment.start_local() as env: + sched_res = await schedule_all( + client=env.client, + probe_cc=[], + test_name=[], + clickhouse_url=db.clickhouse_url, + data_dir=str(datadir), + ) + assert sched_res.analysis + assert sched_res.observations + + # Wait 1 second for the ID to change + await asyncio.sleep(1) + + reschedule_res = await reschedule_all( + client=env.client, + probe_cc=[], + test_name=[], + clickhouse_url=db.clickhouse_url, + data_dir=str(datadir), + ) + assert reschedule_res.observations != sched_res.observations + assert reschedule_res.analysis != sched_res.analysis + + +@pytest.mark.asyncio +async def test_observation_workflow(datadir, db): + obs_params = ObservationsWorkflowParams( + probe_cc=["BA"], + test_name=["web_connectivity"], + clickhouse=db.clickhouse_url, + data_dir=str(datadir.absolute()), + fast_fail=False, + bucket_date="2022-10-21", + ) + async with await WorkflowEnvironment.start_local() as env: + async with Worker( + env.client, + task_queue=TASK_QUEUE_NAME, + workflows=[ObservationsWorkflow], + activities=ACTIVTIES, + activity_executor=ThreadPoolExecutor(max_workers=4 + 2), + ): + wf_res = await env.client.execute_workflow( + ObservationsWorkflow.run, + obs_params, + id="obs-wf", + task_queue=TASK_QUEUE_NAME, + ) + db.execute("OPTIMIZE TABLE buffer_obs_web") + assert wf_res["measurement_count"] == 613 + assert wf_res["size"] == 11381440 + assert wf_res["bucket_date"] == "2022-10-21" + + res = db.execute( + """ + SELECT bucket_date, + COUNT(DISTINCT(measurement_uid)) + FROM obs_web WHERE probe_cc = 'BA' + GROUP BY bucket_date + """ + ) + bucket_dict = dict(res) + assert bucket_dict[wf_res["bucket_date"]] == wf_res["measurement_count"] + res = db.execute( + """ + SELECT bucket_date, + COUNT() + FROM obs_web WHERE probe_cc = 'BA' + GROUP BY bucket_date + """ + ) + bucket_dict = dict(res) + obs_count = bucket_dict[wf_res["bucket_date"]] + assert obs_count == 2548 + + wf_res = await env.client.execute_workflow( + ObservationsWorkflow.run, + obs_params, + id="obs-wf-2", + task_queue=TASK_QUEUE_NAME, + ) + db.execute("OPTIMIZE TABLE obs_web") + wait_for_mutations(db, "obs_web") + res = db.execute( + """ + SELECT bucket_date, + COUNT() + FROM obs_web WHERE probe_cc = 'BA' + GROUP BY bucket_date + """ + ) + bucket_dict = dict(res) + obs_count_2 = bucket_dict[wf_res["bucket_date"]] + + assert obs_count == obs_count_2 diff --git a/oonipipeline/tests/test_workflows.py b/oonipipeline/tests/test_workflows.py index 643a8a4c..df0534c9 100644 --- a/oonipipeline/tests/test_workflows.py +++ b/oonipipeline/tests/test_workflows.py @@ -4,7 +4,6 @@ import sqlite3 from typing import List, Tuple from unittest.mock import MagicMock -import time from temporalio.testing import WorkflowEnvironment from temporalio.worker import Worker @@ -20,12 +19,15 @@ from oonipipeline.temporal.activities.common import ( ClickhouseParams, - UpdateAssetsParams, + OptimizeTablesParams, get_obs_count_by_cc, ObsCountParams, ) from oonipipeline.temporal.activities.observations import ( + DeletePreviousRangeParams, + GetPreviousRangeParams, MakeObservationsParams, + MakeObservationsResult, make_observations_for_file_entry_batch, ) from oonipipeline.transforms.measurement_transformer import MeasurementTransformer @@ -36,6 +38,9 @@ make_cc_batches, ) from oonipipeline.temporal.common import ( + TS_FORMAT, + BatchParameters, + PrevRange, get_prev_range, maybe_delete_prev_range, ) @@ -43,27 +48,17 @@ MakeGroundTruthsParams, make_ground_truths_in_day, ) -from oonipipeline.temporal.workflows import ( +from oonipipeline.temporal.workflows.analysis import ( AnalysisWorkflowParams, - ObservationsWorkflow, AnalysisWorkflow, +) +from oonipipeline.temporal.workflows.observations import ( ObservationsWorkflowParams, - TASK_QUEUE_NAME, + ObservationsWorkflow, ) +from oonipipeline.temporal.workflows.common import TASK_QUEUE_NAME -# from oonipipeline.workflows.response_archiver import ResponseArchiver -# from oonipipeline.workflows.fingerprint_hunter import fingerprint_hunter - - -def wait_for_mutations(db, table_name): - while True: - res = db.execute( - f"SELECT * FROM system.mutations WHERE is_done=0 AND table='{table_name}';" - ) - if len(res) == 0: # type: ignore - break - time.sleep(1) - +from .utils import wait_for_mutations def test_get_prev_range(db): db.execute("DROP TABLE IF EXISTS test_range") @@ -97,8 +92,13 @@ def test_get_prev_range(db): probe_cc=[probe_cc], ) assert prev_range.min_created_at and prev_range.max_created_at - assert prev_range.min_created_at == (min_time - timedelta(seconds=1)) - assert prev_range.max_created_at == (rows[-1][0] + timedelta(seconds=1)) + assert prev_range.min_created_at == (min_time - timedelta(seconds=1)).strftime( + TS_FORMAT + ) + assert prev_range.max_created_at == (rows[-1][0] + timedelta(seconds=1)).strftime( + TS_FORMAT + ) + db.execute("TRUNCATE TABLE test_range") bucket_date = "2000-03-01" @@ -126,8 +126,12 @@ def test_get_prev_range(db): probe_cc=[probe_cc], ) assert prev_range.min_created_at and prev_range.max_created_at - assert prev_range.min_created_at == (min_time - timedelta(seconds=1)) - assert prev_range.max_created_at == (rows[-1][0] + timedelta(seconds=1)) + assert prev_range.min_created_at == (min_time - timedelta(seconds=1)).strftime( + TS_FORMAT + ) + assert prev_range.max_created_at == (rows[-1][0] + timedelta(seconds=1)).strftime( + TS_FORMAT + ) maybe_delete_prev_range( db=db, @@ -160,9 +164,18 @@ def test_make_file_entry_batch(datadir, db): ) ] obs_msmt_count = make_observations_for_file_entry_batch( - file_entry_batch, db.clickhouse_url, 100, datadir, "2023-10-31", ["IR"], False + file_entry_batch=file_entry_batch, + clickhouse=db.clickhouse_url, + write_batch_size=1, + data_dir=datadir, + bucket_date="2023-10-31", + probe_cc=["IR"], + fast_fail=False, ) + assert obs_msmt_count == 453 + # Flush buffer table + db.execute("OPTIMIZE TABLE buffer_obs_web") make_ground_truths_in_day( MakeGroundTruthsParams( day=date(2023, 10, 31).strftime("%Y-%m-%d"), @@ -170,7 +183,6 @@ def test_make_file_entry_batch(datadir, db): data_dir=datadir, ), ) - time.sleep(3) analysis_res = make_analysis_in_a_day( MakeAnalysisParams( probe_cc=["IR"], @@ -210,8 +222,8 @@ def test_write_observations(measurements, netinfodb, db): ): db.write_table_model_rows(obs_list) db.close() - # Wait for buffer tables to flush - time.sleep(3) + # Flush buffer table + db.execute("OPTIMIZE TABLE buffer_obs_web") cnt_by_cc = get_obs_count_by_cc( ObsCountParams( clickhouse_url=db.clickhouse_url, @@ -301,19 +313,14 @@ def test_full_processing(raw_measurements, netinfodb): ) -@activity.defn(name="update_assets") -async def update_assets_mocked(params: UpdateAssetsParams): - return - - @activity.defn(name="optimize_all_tables") async def optimize_all_tables_mocked(params: ClickhouseParams): return -@activity.defn(name="make_observation_in_day") -async def make_observation_in_day_mocked(params: MakeObservationsParams): - return {"size": 1000, "measurement_count": 42} +@activity.defn(name="optimize_tables") +async def optimize_tables_mocked(params: OptimizeTablesParams): + return @activity.defn(name="make_ground_truths_in_day") @@ -321,6 +328,30 @@ async def make_ground_truths_in_day_mocked(params: MakeGroundTruthsParams): return +@activity.defn(name="get_previous_range") +async def get_previous_range_mocked(params: GetPreviousRangeParams) -> List[PrevRange]: + return [ + PrevRange( + table_name="obs_web", + batch_parameters=BatchParameters( + test_name=[], + probe_cc=[], + bucket_date="2024-01-01", + timestamp=datetime(2024, 1, 1).strftime(TS_FORMAT), + ), + timestamp_column="timestamp", + probe_cc_column="probe_cc", + max_created_at=datetime(2024, 9, 1, 12, 34, 56).strftime(TS_FORMAT), + min_created_at=datetime(2024, 9, 1, 1, 23, 45).strftime(TS_FORMAT), + ) + ] + + +@activity.defn(name="delete_previous_range") +async def delete_previous_range_mocked(params: DeletePreviousRangeParams) -> None: + return + + @activity.defn(name="get_obs_count_by_cc") async def get_obs_count_by_cc_mocked(params: ObsCountParams): return { @@ -331,8 +362,20 @@ async def get_obs_count_by_cc_mocked(params: ObsCountParams): } +@activity.defn(name="make_observations") +async def make_observations_mocked( + params: MakeObservationsParams, +) -> MakeObservationsResult: + return { + "measurement_count": 100, + "measurement_per_sec": 3.0, + "mb_per_sec": 1.0, + "total_size": 2000, + } + + @activity.defn(name="make_analysis_in_a_day") -async def make_analysis_in_a_day_mocked(params: MakeAnalysisParams): +async def make_analysis_in_a_day_mocked(params: MakeAnalysisParams) -> dict: return {"count": 100} @@ -355,12 +398,14 @@ async def test_temporal_workflows(): task_queue=TASK_QUEUE_NAME, workflows=[ObservationsWorkflow, AnalysisWorkflow], activities=[ - update_assets_mocked, + optimize_tables_mocked, optimize_all_tables_mocked, - make_observation_in_day_mocked, make_ground_truths_in_day_mocked, get_obs_count_by_cc_mocked, make_analysis_in_a_day_mocked, + make_observations_mocked, + get_previous_range_mocked, + delete_previous_range_mocked, ], ): res = await env.client.execute_workflow( @@ -369,8 +414,8 @@ async def test_temporal_workflows(): id="obs-wf", task_queue=TASK_QUEUE_NAME, ) - assert res["size"] > 0 - assert res["measurement_count"] > 0 + assert res["size"] == 2000 + assert res["measurement_count"] == 100 assert res["bucket_date"] == "2024-01-02" res = await env.client.execute_workflow( @@ -383,8 +428,8 @@ async def test_temporal_workflows(): assert res["day"] == "2024-01-01" +@pytest.mark.skip(reason="TODO(art): fixme") def test_archive_http_transaction(measurements, tmpdir): - pytest.skip("TODO(art): fixme") db = MagicMock() db.write_row = MagicMock() @@ -423,8 +468,8 @@ def test_archive_http_transaction(measurements, tmpdir): assert res.fetchone()[0] == 1 +@pytest.mark.skip(reason="TODO(art): fixme") def test_fingerprint_hunter(fingerprintdb, measurements, tmpdir): - pytest.skip("TODO(art): fixme") db = MagicMock() db.write_rows = MagicMock() diff --git a/oonipipeline/tests/utils.py b/oonipipeline/tests/utils.py new file mode 100644 index 00000000..22ebc68a --- /dev/null +++ b/oonipipeline/tests/utils.py @@ -0,0 +1,15 @@ +# from oonipipeline.workflows.response_archiver import ResponseArchiver +# from oonipipeline.workflows.fingerprint_hunter import fingerprint_hunter + + +import time + + +def wait_for_mutations(db, table_name): + while True: + res = db.execute( + f"SELECT * FROM system.mutations WHERE is_done=0 AND table='{table_name}';" + ) + if len(res) == 0: # type: ignore + break + time.sleep(1)