Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[schematic-230] propagate user ids as attribute to spans #1568

Merged
merged 12 commits into from
Jan 17, 2025
59 changes: 46 additions & 13 deletions schematic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
SERVICE_VERSION,
Resource,
)
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace import TracerProvider, SpanProcessor
from opentelemetry.trace import Span, SpanContext, get_current_span
from opentelemetry.sdk.trace.export import BatchSpanProcessor, Span
from opentelemetry.sdk.trace.sampling import ALWAYS_OFF
from synapseclient import Synapse
Expand All @@ -26,7 +27,6 @@
from schematic.configuration.configuration import CONFIG
from schematic.loader import LOADER
from schematic.version import __version__
from schematic_api.api.security_controller import info_from_bearer_auth
from dotenv import load_dotenv

Synapse.allow_client_caching(False)
Expand All @@ -36,6 +36,47 @@
load_dotenv()


class AttributePropagatingSpanProcessor(SpanProcessor):
def __init__(self, attributes_to_propagate) -> None:
self.attributes_to_propagate = attributes_to_propagate

def on_start(self, span: Span, parent_context: SpanContext) -> None:
"""Propagates attributes from the parent span to the child span.
Arguments:
span: The child span to which the attributes should be propagated.
parent_context: The context of the parent span.
Returns:
None
"""
parent_span = get_current_span()
if parent_span is not None and parent_span.is_recording():
for attribute in self.attributes_to_propagate:
# Check if the attribute exists in the parent span's attributes
attribute_val = parent_span.attributes.get(attribute)
if attribute_val:
# Propagate the attribute to the current span
span.set_attribute(attribute, attribute_val)

def on_end(self, span: Span) -> None:
"""Propagates attributes from the child span back to the parent span"""
parent_span = get_current_span()
if parent_span is not None and parent_span.is_recording():
for attribute in self.attributes_to_propagate:
child_val = span.attributes.get(attribute)
parent_val = parent_span.attributes.get(attribute)
if child_val and not parent_val:
# Propagate the attribute back to parent span
parent_span.set_attribute(attribute, child_val)

def shutdown(self) -> None:
"""No-op method that does nothing when the span processor is shut down."""
pass

def force_flush(self, timeout_millis: int = 30000) -> None:
"""No-op method that does nothing when the span processor is forced to flush."""
pass


def create_telemetry_session() -> requests.Session:
"""
Create a requests session with authorization enabled if environment variables are set.
Expand Down Expand Up @@ -96,6 +137,9 @@ def set_up_tracing(session: requests.Session) -> None:
)

if tracing_export == "otlp":
# Add the custom AttributePropagatingSpanProcessor to propagate attributes
attribute_propagator = AttributePropagatingSpanProcessor(["user.id"])
trace.get_tracer_provider().add_span_processor(attribute_propagator)
exporter = OTLPSpanExporter(session=session)
trace.get_tracer_provider().add_span_processor(BatchSpanProcessor(exporter))
else:
Expand Down Expand Up @@ -139,17 +183,6 @@ def request_hook(span: Span, environ: Dict) -> None:
"""
if not span or not span.is_recording():
return
try:
if auth_header := environ.get("HTTP_AUTHORIZATION", None):
split_headers = auth_header.split(" ")
if len(split_headers) > 1:
token = auth_header.split(" ")[1]
user_info = info_from_bearer_auth(token)
if user_info:
span.set_attribute("user.id", user_info.get("sub"))
except Exception:
linglp marked this conversation as resolved.
Show resolved Hide resolved
logger.exception("Failed to set user info in span")

try:
if (request := environ.get("werkzeug.request", None)) and isinstance(
request, Request
Expand Down
17 changes: 7 additions & 10 deletions schematic/store/synapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
from schematic.utils.schema_utils import get_class_label_from_display_name
from schematic.utils.validate_utils import comma_separated_list_regex, rule_in_rule_list


logger = logging.getLogger("Synapse storage")

tracer = trace.get_tracer("Schematic")
Expand Down Expand Up @@ -321,9 +322,6 @@ def __init__(
Consider necessity of adding "columns" and "where_clauses" params to the constructor. Currently with how `query_fileview` is implemented, these params are not needed at this step but could be useful in the future if the need for more scoped querys expands.
"""
self.syn = self.login(synapse_cache_path, access_token)
current_span = trace.get_current_span()
if current_span.is_recording():
current_span.set_attribute("user.id", self.syn.credentials.owner_id)
self.project_scope = project_scope
self.storageFileview = CONFIG.synapse_master_fileview_id
self.manifest = CONFIG.synapse_manifest_basename
Expand Down Expand Up @@ -499,7 +497,6 @@ def login(
Returns:
synapseclient.Synapse: A Synapse object that is logged in
"""
# If no token is provided, try retrieving access token from environment
if not access_token:
access_token = os.getenv("SYNAPSE_ACCESS_TOKEN")

Expand All @@ -513,9 +510,6 @@ def login(
cache_client=False,
)
syn.login(authToken=access_token, silent=True)
current_span = trace.get_current_span()
if current_span.is_recording():
current_span.set_attribute("user.id", syn.credentials.owner_id)
except SynapseHTTPError as exc:
raise ValueError(
"No access to resources. Please make sure that your token is correct"
Expand All @@ -530,9 +524,12 @@ def login(
cache_client=False,
)
syn.login(silent=True)
current_span = trace.get_current_span()
if current_span.is_recording():
current_span.set_attribute("user.id", syn.credentials.owner_id)

# set user id attribute
current_span = trace.get_current_span()
if current_span.is_recording():
current_span.set_attribute("user.id", syn.credentials.owner_id)

return syn

def missing_entity_handler(method):
Expand Down
2 changes: 1 addition & 1 deletion schematic_api/api/openapi/api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ components:
type: http
scheme: bearer
bearerFormat: JWT
x-bearerInfoFunc: schematic_api.api.security_controller.info_from_bearer_auth
x-bearerInfoFunc: schematic_api.security_controller.info_from_bearer_auth

# TO DO: refactor query parameters and remove access_token
paths:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"security controller"
# pylint: disable=line-too-long
import logging
from typing import Dict, Union

from jwt import PyJWKClient, decode
from jwt.exceptions import PyJWTError
from synapseclient import Synapse
from synapseclient import Synapse # type: ignore

from schematic.configuration.configuration import CONFIG

Expand All @@ -15,11 +17,12 @@
skip_checks=True,
)
jwks_client = PyJWKClient(
uri=syn.authEndpoint + "/oauth2/jwks", headers=syn._generate_headers()
uri=syn.authEndpoint + "/oauth2/jwks",
headers=syn._generate_headers(), # pylint: disable=W0212
)


def info_from_bearer_auth(token: str) -> Dict[str, Union[str, int]]:
def info_from_bearer_auth(token: str) -> Union[Dict[str, Union[str, int]], None]:
"""
Authenticate user using bearer token. The token claims are decoded and returned.

Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_security_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from cryptography.hazmat.primitives.asymmetric import rsa
from pytest import LogCaptureFixture

from schematic_api.api.security_controller import info_from_bearer_auth
from schematic.schematic_api.security_controller import info_from_bearer_auth


class TestSecurityController:
Expand Down
Loading