diff --git a/ooniapi/services/ooniprobe/src/ooniprobe/main.py b/ooniapi/services/ooniprobe/src/ooniprobe/main.py index 6b347358..38607aff 100644 --- a/ooniapi/services/ooniprobe/src/ooniprobe/main.py +++ b/ooniapi/services/ooniprobe/src/ooniprobe/main.py @@ -95,7 +95,5 @@ async def health( @app.get("/") async def root(): # TODO(art): fix this redirect by pointing health monitoring to /health - #return RedirectResponse("/docs") - return { - "msg": "hello from ooniprobe" - } + # return RedirectResponse("/docs") + return {"msg": "hello from ooniprobe"} diff --git a/ooniapi/services/ooniprobe/src/ooniprobe/routers/v1/probe_services.py b/ooniapi/services/ooniprobe/src/ooniprobe/routers/v1/probe_services.py index 0dbafcfa..376fb8bf 100644 --- a/ooniapi/services/ooniprobe/src/ooniprobe/routers/v1/probe_services.py +++ b/ooniapi/services/ooniprobe/src/ooniprobe/routers/v1/probe_services.py @@ -15,22 +15,25 @@ log = logging.getLogger(__name__) + class ProbeLogin(BaseModel): - # Allow None username and password + # Allow None username and password # to deliver informational 401 error when they're missing - username : str | None = None + username: str | None = None # not actually used but necessary to be compliant with the old API schema - password : str | None = None + password: str | None = None + class ProbeLoginResponse(BaseModel): - token : str - expire : str + token: str + expire: str + @router.post("/login", tags=["ooniprobe"], response_model=ProbeLoginResponse) def probe_login_post( - probe_login : ProbeLogin, - response : Response, - settings : Settings = Depends(get_settings), + probe_login: ProbeLogin, + response: Response, + settings: Settings = Depends(get_settings), ) -> ProbeLoginResponse: if probe_login.username is None or probe_login.password is None: @@ -38,7 +41,7 @@ def probe_login_post( token = probe_login.username # TODO: We have to find a way to explicitly log metrics with prometheus. - # We're currently using the instrumentator default metrics, like http response counts + # We're currently using the instrumentator default metrics, like http response counts # Maybe using the same exporter as the instrumentator? try: dec = decode_jwt(token, audience="probe_login", key=settings.jwt_encryption_key) @@ -48,11 +51,11 @@ def probe_login_post( except jwt.exceptions.MissingRequiredClaimError: log.info("probe login: invalid or missing claim") # metrics.incr("probe_login_failed") - raise HTTPException(status_code=401, detail="Invalid credentials") + raise HTTPException(status_code=401, detail="Invalid credentials") except jwt.exceptions.InvalidSignatureError: log.info("probe login: invalid signature") # metrics.incr("probe_login_failed") - raise HTTPException(status_code=401, detail="Invalid credentials") + raise HTTPException(status_code=401, detail="Invalid credentials") except jwt.exceptions.DecodeError: # Not a JWT token: treat it as a "legacy" login # return jerror("Invalid or missing credentials", code=401) @@ -65,30 +68,33 @@ def probe_login_post( token = create_jwt(payload, key=settings.jwt_encryption_key) # expiration string used by the probe e.g. 2006-01-02T15:04:05Z expire = exp.strftime("%Y-%m-%dT%H:%M:%SZ") - login_response = ProbeLoginResponse(token=token, expire = expire) + login_response = ProbeLoginResponse(token=token, expire=expire) setnocacheresponse(response) return login_response + class ProbeRegister(BaseModel): # None of this values is actually used, but I add them # to keep it compliant with the old api - password : str - platform : str - probe_asn : str - probe_cc : str - software_name : str - software_version : str - supported_tests : List[str] + password: str + platform: str + probe_asn: str + probe_cc: str + software_name: str + software_version: str + supported_tests: List[str] + class ProbeRegisterResponse(BaseModel): - client_id: str + client_id: str + @router.post("/register", tags=["ooniprobe"], response_model=ProbeRegisterResponse) def probe_register_post( - probe_register : ProbeRegister, - response : Response, - settings : Settings = Depends(get_settings), + probe_register: ProbeRegister, + response: Response, + settings: Settings = Depends(get_settings), ) -> ProbeRegisterResponse: """Probe Services: Register @@ -96,7 +102,7 @@ def probe_register_post( The client_id/password tuple is saved by the probe and long-lived - Note that most of the request body arguments are not actually + Note that most of the request body arguments are not actually used but are kept here to use the same API as the old version """ @@ -112,13 +118,16 @@ def probe_register_post( return register_response + class ProbeUpdate(BaseModel): pass + class ProbeUpdateResponse(BaseModel): - status : str + status: str + @router.put("/update/{client_id}", tags=["ooniprobe"]) -def probe_update_post(probe_update : ProbeUpdate) -> ProbeUpdateResponse: +def probe_update_post(probe_update: ProbeUpdate) -> ProbeUpdateResponse: log.info("update successful") - return ProbeUpdateResponse(status="ok") \ No newline at end of file + return ProbeUpdateResponse(status="ok") diff --git a/ooniapi/services/ooniprobe/src/ooniprobe/routers/v2/vpn.py b/ooniapi/services/ooniprobe/src/ooniprobe/routers/v2/vpn.py index f9fd79bf..f1c94bd3 100644 --- a/ooniapi/services/ooniprobe/src/ooniprobe/routers/v2/vpn.py +++ b/ooniapi/services/ooniprobe/src/ooniprobe/routers/v2/vpn.py @@ -149,4 +149,4 @@ def get_vpn_config( # Pick 4 random endpoints to serve to the client endpoints=random.sample(endpoints, min(len(endpoints), 4)), date_updated=provider.date_updated.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), - ) \ No newline at end of file + ) diff --git a/ooniapi/services/ooniprobe/src/ooniprobe/utils.py b/ooniapi/services/ooniprobe/src/ooniprobe/utils.py index 532633a3..8d7b920c 100644 --- a/ooniapi/services/ooniprobe/src/ooniprobe/utils.py +++ b/ooniapi/services/ooniprobe/src/ooniprobe/utils.py @@ -3,6 +3,7 @@ Insert VPN credentials into database. """ + import base64 from datetime import datetime, timezone import itertools @@ -27,11 +28,13 @@ class OpenVPNConfig(TypedDict): cert: str key: str + class OpenVPNEndpoint(TypedDict): address: str protocol: str transport: str + def fetch_riseup_ca() -> str: r = httpx.get(RISEUP_CA_URL) r.raise_for_status() @@ -50,6 +53,7 @@ def fetch_openvpn_config() -> OpenVPNConfig: key, cert = pem.parse(pem_cert) return OpenVPNConfig(ca=ca, cert=cert.as_text(), key=key.as_text()) + def fetch_openvpn_endpoints() -> List[OpenVPNEndpoint]: endpoints = [] @@ -59,26 +63,33 @@ def fetch_openvpn_endpoints() -> List[OpenVPNEndpoint]: for ep in j["gateways"]: ip = ep["ip_address"] # TODO(art): do we want to store this metadata somewhere? - #location = ep["location"] - #hostname = ep["host"] + # location = ep["location"] + # hostname = ep["host"] for t in ep["capabilities"]["transport"]: if t["type"] != "openvpn": continue for transport, port in itertools.product(t["protocols"], t["ports"]): - endpoints.append(OpenVPNEndpoint( - address=f"{ip}:{port}", - protocol="openvpn", - transport=transport - )) + endpoints.append( + OpenVPNEndpoint( + address=f"{ip}:{port}", protocol="openvpn", transport=transport + ) + ) return endpoints + def format_endpoint(provider_name: str, ep: OONIProbeVPNProviderEndpoint) -> str: return f"{ep.protocol}://{provider_name}.corp/?address={ep.address}&transport={ep.transport}" -def upsert_endpoints(db: Session, new_endpoints: List[OpenVPNEndpoint], provider: OONIProbeVPNProvider): - new_endpoints_map = {f'{ep["address"]}-{ep["protocol"]}-{ep["transport"]}': ep for ep in new_endpoints} + +def upsert_endpoints( + db: Session, new_endpoints: List[OpenVPNEndpoint], provider: OONIProbeVPNProvider +): + new_endpoints_map = { + f'{ep["address"]}-{ep["protocol"]}-{ep["transport"]}': ep + for ep in new_endpoints + } for endpoint in provider.endpoints: - key = f'{endpoint.address}-{endpoint.protocol}-{endpoint.transport}' + key = f"{endpoint.address}-{endpoint.protocol}-{endpoint.transport}" if key in new_endpoints_map: endpoint.date_updated = datetime.now(timezone.utc) new_endpoints_map.pop(key) @@ -86,11 +97,13 @@ def upsert_endpoints(db: Session, new_endpoints: List[OpenVPNEndpoint], provider db.delete(endpoint) for ep in new_endpoints_map.values(): - db.add(OONIProbeVPNProviderEndpoint( - date_created=datetime.now(timezone.utc), - date_updated=datetime.now(timezone.utc), - protocol=ep["protocol"], - address=ep["address"], - transport=ep["transport"], - provider=provider - )) \ No newline at end of file + db.add( + OONIProbeVPNProviderEndpoint( + date_created=datetime.now(timezone.utc), + date_updated=datetime.now(timezone.utc), + protocol=ep["protocol"], + address=ep["address"], + transport=ep["transport"], + provider=provider, + ) + ) diff --git a/ooniapi/services/ooniprobe/tests/conftest.py b/ooniapi/services/ooniprobe/tests/conftest.py index bb4def25..f9a36fe3 100644 --- a/ooniapi/services/ooniprobe/tests/conftest.py +++ b/ooniapi/services/ooniprobe/tests/conftest.py @@ -67,7 +67,10 @@ def client_with_bad_settings(): client = TestClient(app) yield client + JWT_ENCRYPTION_KEY = "super_secure" + + @pytest.fixture def client(alembic_migration): app.dependency_overrides[get_settings] = make_override_get_settings( @@ -79,6 +82,7 @@ def client(alembic_migration): client = TestClient(app) yield client + @pytest.fixture def jwt_encryption_key(): - return JWT_ENCRYPTION_KEY \ No newline at end of file + return JWT_ENCRYPTION_KEY diff --git a/ooniapi/services/ooniprobe/tests/test_models.py b/ooniapi/services/ooniprobe/tests/test_models.py index 092238b6..89375e53 100644 --- a/ooniapi/services/ooniprobe/tests/test_models.py +++ b/ooniapi/services/ooniprobe/tests/test_models.py @@ -8,6 +8,7 @@ "openvpn://riseup.corp/?address=51.15.187.53:1194&transport=udp", ] + def test_create_providers(db, alembic_migration): provider = OONIProbeVPNProvider( provider_name="riseupvpn", @@ -15,33 +16,39 @@ def test_create_providers(db, alembic_migration): date_updated=datetime.now(timezone.utc), openvpn_cert="OPENVPN_CERT", openvpn_ca="OPENVPN_CA", - openvpn_key="OPENVPN_KEY" + openvpn_key="OPENVPN_KEY", ) db.add(provider) - db.add(OONIProbeVPNProviderEndpoint( - date_created=datetime.now(timezone.utc)-timedelta(hours=1), - date_updated=datetime.now(timezone.utc)-timedelta(hours=1), - protocol="openvpn", - address="51.15.187.53:1194", - transport="tcp", - provider=provider - )) - db.add(OONIProbeVPNProviderEndpoint( - date_created=datetime.now(timezone.utc)-timedelta(hours=1), - date_updated=datetime.now(timezone.utc)-timedelta(hours=1), - protocol="openvpn", - address="51.15.187.53:1194", - transport="udp", - provider=provider - )) - db.add(OONIProbeVPNProviderEndpoint( - date_created=datetime.now(timezone.utc)-timedelta(hours=1), - date_updated=datetime.now(timezone.utc)-timedelta(hours=1), - protocol="openvpn", - address="1.1.1.1:1194", - transport="udp", - provider=provider - )) + db.add( + OONIProbeVPNProviderEndpoint( + date_created=datetime.now(timezone.utc) - timedelta(hours=1), + date_updated=datetime.now(timezone.utc) - timedelta(hours=1), + protocol="openvpn", + address="51.15.187.53:1194", + transport="tcp", + provider=provider, + ) + ) + db.add( + OONIProbeVPNProviderEndpoint( + date_created=datetime.now(timezone.utc) - timedelta(hours=1), + date_updated=datetime.now(timezone.utc) - timedelta(hours=1), + protocol="openvpn", + address="51.15.187.53:1194", + transport="udp", + provider=provider, + ) + ) + db.add( + OONIProbeVPNProviderEndpoint( + date_created=datetime.now(timezone.utc) - timedelta(hours=1), + date_updated=datetime.now(timezone.utc) - timedelta(hours=1), + protocol="openvpn", + address="1.1.1.1:1194", + transport="udp", + provider=provider, + ) + ) db.commit() all_endpoints = db.query(OONIProbeVPNProviderEndpoint).all() @@ -55,30 +62,25 @@ def test_create_providers(db, alembic_migration): assert endpoint.provider.provider_name == "riseupvpn" assert addresses == set(["51.15.187.53:1194", "1.1.1.1:1194"]) - provider = db.query(OONIProbeVPNProvider).filter( - OONIProbeVPNProvider.provider_name == "riseupvpn", - OONIProbeVPNProvider.date_updated - > datetime.now(timezone.utc) - - timedelta(days=7), - ).one() + provider = ( + db.query(OONIProbeVPNProvider) + .filter( + OONIProbeVPNProvider.provider_name == "riseupvpn", + OONIProbeVPNProvider.date_updated + > datetime.now(timezone.utc) - timedelta(days=7), + ) + .one() + ) assert len(provider.endpoints) == 3 new_endpoints = [ OpenVPNEndpoint( - address="51.15.187.53:1194", - protocol="openvpn", - transport="udp" - ), - OpenVPNEndpoint( - address="51.15.187.53:1194", - protocol="openvpn", - transport="tcp" + address="51.15.187.53:1194", protocol="openvpn", transport="udp" ), OpenVPNEndpoint( - address="3.2.1.3:1194", - protocol="openvpn", - transport="udp" + address="51.15.187.53:1194", protocol="openvpn", transport="tcp" ), + OpenVPNEndpoint(address="3.2.1.3:1194", protocol="openvpn", transport="udp"), ] upsert_endpoints(db, new_endpoints, provider) diff --git a/ooniapi/services/ooniprobe/tests/test_probe_auth.py b/ooniapi/services/ooniprobe/tests/test_probe_auth.py index 65d3b419..8fe65ac3 100644 --- a/ooniapi/services/ooniprobe/tests/test_probe_auth.py +++ b/ooniapi/services/ooniprobe/tests/test_probe_auth.py @@ -15,11 +15,13 @@ def test_register_then_login(client, jwt_encryption_key): assert "client_id" in c assert len(c["client_id"]) == 132 - tok = auth.decode_jwt(c["client_id"], audience="probe_login", key = jwt_encryption_key) + tok = auth.decode_jwt( + c["client_id"], audience="probe_login", key=jwt_encryption_key + ) client_id = c["client_id"] c = postj(client, "/api/v1/login", username=client_id, password=pwd) - tok = auth.decode_jwt(c["token"], audience="probe_token", key = jwt_encryption_key) + tok = auth.decode_jwt(c["token"], audience="probe_token", key=jwt_encryption_key) assert tok["registration_time"] is not None # Login with a bogus client id emulating probes before 2022 @@ -28,54 +30,57 @@ def test_register_then_login(client, jwt_encryption_key): r = client.post("/api/v1/login", json=j) assert r.status_code == 200 token = r.json()["token"] - tok = auth.decode_jwt(token, audience="probe_token", key = jwt_encryption_key) + tok = auth.decode_jwt(token, audience="probe_token", key=jwt_encryption_key) assert tok["registration_time"] is None # we don't know the reg. time # Expect failed login resp = client.post("/api/v1/login", json=dict()) assert resp.status_code == 401 -def test_update(client : TestClient, jwt_encryption_key): + +def test_update(client: TestClient, jwt_encryption_key): # Update will just say ok to anything you send, no matter - # the data + # the data - # We can use whatever string for the client_id path parameter, + # We can use whatever string for the client_id path parameter, # but we use the login token to make sure that the returned token # works properly with the update endpoint c = _register(client) client_id = c["client_id"] c = postj(client, "/api/v1/login", username=client_id, password="some_pswd") - token = c['token'] + token = c["token"] data = _get_update_data() - resp = client.put(f"/api/v1/update/{token}",json=data) + resp = client.put(f"/api/v1/update/{token}", json=data) assert resp.status_code == 200 json = resp.json() assert "status" in json - assert json['status'] == "ok" + assert json["status"] == "ok" def _get_update_data() -> Dict[str, str]: return { - "probe_cc": "IT", - "probe_asn": "AS1234", - "platform": "android", - "software_name": "ooni-testing", - "software_version": "0.0.1", - "supported_tests": "web_connectivity", - "network_type": "wifi", - "available_bandwidth": "100", - "language": "en", - "token": "XXXX-TESTING", - "password": "testingPassword", - } + "probe_cc": "IT", + "probe_asn": "AS1234", + "platform": "android", + "software_name": "ooni-testing", + "software_version": "0.0.1", + "supported_tests": "web_connectivity", + "network_type": "wifi", + "available_bandwidth": "100", + "language": "en", + "token": "XXXX-TESTING", + "password": "testingPassword", + } + def postj(client, url, **kw): response = client.post(url, json=kw) assert response.status_code == 200, f"Error: {response.content}" - assert response.headers.get('content-type') == 'application/json' + assert response.headers.get("content-type") == "application/json" return response.json() + def _register(client): pwd = "HLdywVhzVCNqLvHCfmnMhIXqGmUFMTuYjmuGZhNlRTeIyvxeQTnjVJsiRkutHCSw" j = { @@ -87,4 +92,4 @@ def _register(client): "software_version": "0.1.0-dev", "supported_tests": ["web_connectivity"], } - return postj(client, "/api/v1/register", **j) \ No newline at end of file + return postj(client, "/api/v1/register", **j)