Skip to content

Commit

Permalink
Merge branch 'main' into wantsui/fix-should-skip
Browse files Browse the repository at this point in the history
  • Loading branch information
wantsui authored Nov 14, 2024
2 parents 75ca673 + bd0097f commit ebfb726
Show file tree
Hide file tree
Showing 12 changed files with 166 additions and 17 deletions.
3 changes: 2 additions & 1 deletion .circleci/config.templ.yml
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ commands:
RIOT_RUN_RECOMPILE_REQS: "<< pipeline.parameters.riot_run_latest >>"
DD_CIVISIBILITY_AGENTLESS_ENABLED: true
no_output_timeout: 5m
attempts: 2
command: |
ulimit -c unlimited
./scripts/run-test-suite '<<parameters.pattern>>' <<pipeline.parameters.coverage>> 1
Expand Down Expand Up @@ -521,7 +522,7 @@ jobs:

appsec_integrations:
<<: *machine_executor
parallelism: 7
parallelism: 13
steps:
- run_test:
pattern: 'appsec_integrations'
Expand Down
36 changes: 36 additions & 0 deletions .github/workflows/check_old_target_branch.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
name: Check for Old Target Branch

on:
pull_request:

jobs:
check_target_branch:
name: "Check for old target branch"
runs-on: ubuntu-latest
permissions:
pull-requests: write
steps:
- name: Check if target branch is too old to backport
id: check-branch
run: |
# Define regex for branches with major version 0 or 1, or versions from 2.0 to 2.12
old_branch_regex="^(0|1)(\\.|$)|^2\\.([0-9]|1[0-2])(\\.|$)"
target_branch="${{ github.event.pull_request.base.ref }}"
if [[ "$target_branch" =~ $old_branch_regex ]]; then
echo "Old target branch detected: $target_branch"
echo "old_branch=true" >> $GITHUB_ENV
else
echo "old_branch=false" >> $GITHUB_ENV
fi
- name: Old branch warning on PR
if: env.old_branch == 'true'
uses: thollander/actions-comment-pull-request@v2
with:
message: |
🚫 **This target branch is too old or unsupported. Please update the target branch to continue.**
- name: Fail the job if branch is old
if: env.old_branch == 'true'
run: exit 1
9 changes: 9 additions & 0 deletions ddtrace/contrib/celery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,15 @@ def run(self):
.. py:data:: ddtrace.config.celery['distributed_tracing']
Whether or not to pass distributed tracing headers to Celery workers.
Note: this flag applies to both Celery workers and callers separately.
On the caller: enabling propagation causes the caller and worker to
share a single trace while disabling causes them to be separate.
On the worker: enabling propagation causes context to propagate across
tasks, such as when Task A queues work for Task B, or if Task A retries.
Disabling propagation causes each celery.run task to be in its own
separate trace.
Can also be enabled with the ``DD_CELERY_DISTRIBUTED_TRACING`` environment variable.
Expand Down
3 changes: 2 additions & 1 deletion ddtrace/contrib/internal/celery/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@


# Celery Context key
CTX_KEY = "__dd_task_span"
SPAN_KEY = "__dd_task_span"
CTX_KEY = "__dd_task_context"

# Span names
PRODUCER_ROOT_SPAN = "celery.apply"
Expand Down
24 changes: 19 additions & 5 deletions ddtrace/contrib/internal/celery/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
from ddtrace.contrib import trace_utils
from ddtrace.contrib.internal.celery import constants as c
from ddtrace.contrib.internal.celery.utils import attach_span
from ddtrace.contrib.internal.celery.utils import attach_span_context
from ddtrace.contrib.internal.celery.utils import detach_span
from ddtrace.contrib.internal.celery.utils import retrieve_span
from ddtrace.contrib.internal.celery.utils import retrieve_span_context
from ddtrace.contrib.internal.celery.utils import retrieve_task_id
from ddtrace.contrib.internal.celery.utils import set_tags_from_context
from ddtrace.ext import SpanKind
Expand Down Expand Up @@ -43,6 +45,7 @@ def trace_prerun(*args, **kwargs):
return

request_headers = task.request.get("headers", {})
request_headers = request_headers or retrieve_span_context(task, task_id)
trace_utils.activate_distributed_headers(pin.tracer, int_config=config.celery, request_headers=request_headers)

# propagate the `Span` in the current task Context
Expand All @@ -65,6 +68,8 @@ def trace_prerun(*args, **kwargs):

span.set_tag(SPAN_MEASURED_KEY)
attach_span(task, task_id, span)
if config.celery["distributed_tracing"]:
attach_span_context(task, task_id, span)


def trace_postrun(*args, **kwargs):
Expand Down Expand Up @@ -110,6 +115,13 @@ def trace_before_publish(*args, **kwargs):
if pin is None:
return

# If Task A calls Task B, and Task A excepts, then Task B may have no parent when apply is called.
# In these cases, we don't use the "current context" of attached span/tracer, for context, we use
# the attached distributed context.
if config.celery["distributed_tracing"]:
request_headers = retrieve_span_context(task, task_id, is_publish=False)
trace_utils.activate_distributed_headers(pin.tracer, int_config=config.celery, request_headers=request_headers)

# apply some tags here because most of the data is not available
# in the task_after_publish signal
service = config.celery["producer_service_name"]
Expand Down Expand Up @@ -145,11 +157,13 @@ def trace_before_publish(*args, **kwargs):
trace_headers = {}
propagator.inject(span.context, trace_headers)

# put distributed trace headers where celery will propagate them
task_headers = kwargs.get("headers") or {}
task_headers.setdefault("headers", {})
task_headers["headers"].update(trace_headers)
kwargs["headers"] = task_headers
kwargs.setdefault("headers", {})

# This is a hack for other versions, such as https://github.com/celery/celery/issues/4875
# We always uses the double ["headers"]["headers"] because it works both before and
# after the changes made in celery
kwargs["headers"].setdefault("headers", {})
kwargs["headers"]["headers"].update(trace_headers)


def trace_after_publish(*args, **kwargs):
Expand Down
38 changes: 34 additions & 4 deletions ddtrace/contrib/internal/celery/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,13 @@

from ddtrace._trace.span import Span
from ddtrace.contrib.trace_utils import set_flattened_tags
from ddtrace.propagation.http import HTTPPropagator

from .constants import CTX_KEY
from .constants import SPAN_KEY


propagator = HTTPPropagator


TAG_KEYS = frozenset(
Expand Down Expand Up @@ -66,6 +71,31 @@ def set_tags_from_context(span: Span, context: Dict[str, Any]) -> None:
set_flattened_tags(span, context_tags)


def attach_span_context(task, task_id, span, is_publish=False):
trace_headers = {}
propagator.inject(span.context, trace_headers)

# put distributed trace headers where celery will propagate them
context_dict = getattr(task, CTX_KEY, None)
if context_dict is None:
context_dict = dict()
setattr(task, CTX_KEY, context_dict)

context_dict[(task_id, is_publish, "distributed_context")] = trace_headers


def retrieve_span_context(task, task_id, is_publish=False):
"""Helper to retrieve an active `Span` stored in a `Task`
instance
"""
context_dict = getattr(task, CTX_KEY, None)
if context_dict is None:
return

# DEV: See note in `attach_span` for key info
return context_dict.get((task_id, is_publish, "distributed_context"))


def attach_span(task, task_id, span, is_publish=False):
"""Helper to propagate a `Span` for the given `Task` instance. This
function uses a `WeakValueDictionary` that stores a Datadog Span using
Expand All @@ -85,10 +115,10 @@ def attach_span(task, task_id, span, is_publish=False):
NOTE: We cannot test for this well yet, because we do not run a celery worker,
and cannot run `task.apply_async()`
"""
weak_dict = getattr(task, CTX_KEY, None)
weak_dict = getattr(task, SPAN_KEY, None)
if weak_dict is None:
weak_dict = WeakValueDictionary()
setattr(task, CTX_KEY, weak_dict)
setattr(task, SPAN_KEY, weak_dict)

weak_dict[(task_id, is_publish)] = span

Expand All @@ -97,7 +127,7 @@ def detach_span(task, task_id, is_publish=False):
"""Helper to remove a `Span` in a Celery task when it's propagated.
This function handles tasks where the `Span` is not attached.
"""
weak_dict = getattr(task, CTX_KEY, None)
weak_dict = getattr(task, SPAN_KEY, None)
if weak_dict is None:
return

Expand All @@ -112,7 +142,7 @@ def retrieve_span(task, task_id, is_publish=False):
"""Helper to retrieve an active `Span` stored in a `Task`
instance
"""
weak_dict = getattr(task, CTX_KEY, None)
weak_dict = getattr(task, SPAN_KEY, None)
if weak_dict is None:
return
else:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
fixes:
- |
celery: This fix resolves two issues with context propagation in celery
1. Invalid span parentage when task A calls task B async and task A errors out, causing A's queuing of B, and B itself to not be parented under A.
2. Invalid context propagation from client to workers, and across retries, causing multiple traces instead of a single trace
2 changes: 2 additions & 0 deletions riotfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -2842,6 +2842,7 @@ def select_pys(min_version=MIN_PYTHON_VERSION, max_version=MAX_PYTHON_VERSION):
env={
"DD_PROFILING_ENABLE_ASSERTS": "1",
"DD_PROFILING__FORCE_LEGACY_EXPORTER": "1",
"CPUCOUNT": "12",
},
pkgs={
"gunicorn": latest,
Expand Down Expand Up @@ -2979,6 +2980,7 @@ def select_pys(min_version=MIN_PYTHON_VERSION, max_version=MAX_PYTHON_VERSION):
"DD_PROFILING_EXPORT_LIBDD_ENABLED": "1",
# Enable pytest v2 plugin to handle pytest-cpp items in the test suite
"_DD_CIVISIBILITY_USE_PYTEST_V2": "1",
"CPUCOUNT": "12",
},
pkgs={
"gunicorn": latest,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import pytest

from tests.utils import flaky


@pytest.mark.subprocess()
def test_ddtrace_iast_flask_patch():
Expand Down Expand Up @@ -146,6 +148,7 @@ def _uninstall_watchdog_and_reload():
del sys.modules["tests.appsec.iast.fixtures.entrypoint.views"]


@flaky(1736035200)
@pytest.mark.subprocess(check_logs=False)
def test_ddtrace_iast_flask_app_create_app_patch_all_enable_iast_propagation():
import dis
Expand Down
3 changes: 2 additions & 1 deletion tests/contrib/celery/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@

@pytest.fixture(scope="session")
def celery_config():
return {"broker_url": BROKER_URL, "result_backend": BACKEND_URL}
return {"broker_url": BROKER_URL, "result_backend": BACKEND_URL, "task_concurrency": 5}


@pytest.fixture
def celery_worker_parameters():
return {
# See https://github.com/celery/celery/issues/3642#issuecomment-457773294
"perform_ping_check": False,
"concurrency": 2,
}


Expand Down
35 changes: 35 additions & 0 deletions tests/contrib/celery/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,41 @@ def run(self):
assert span.get_tag("span.kind") == "consumer"
assert span.error == 0

def test_task_chain_same_trace(self):
@self.app.task(max_retries=1, default_retry_delay=1)
def fn_b(user, force_logout=False):
raise ValueError("Foo")

self.celery_worker.reload() # Reload after each task or we get an unregistered error

@self.app.task(bind=True, max_retries=1, autoretry_for=(Exception,), default_retry_delay=1)
def fn_a(self, user, force_logout=False):
fn_b.apply_async(args=[user], kwargs={"force_logout": force_logout})
raise ValueError("foo")

self.celery_worker.reload() # Reload after each task or we get an unregistered error

traces = None
try:
with self.override_config("celery", dict(distributed_tracing=True)):
t = fn_a.apply_async(args=["user"], kwargs={"force_logout": True})
# We wait 10 seconds so all tasks finish. While it'd be nice to block
# until all tasks complete, celery doesn't offer an option. Using get()
# causes a deadlock, since in test-mode we only have one worker.
import time

time.sleep(10)
t.get()
except Exception:
pass

traces = self.pop_traces()
# The below tests we have 1 trace with 8 spans, which is the shape generated
assert len(traces) > 0
assert sum([1 for trace in traces for span in trace]) == 8
trace_id = traces[0][0].trace_id
assert all(trace_id == span.trace_id for trace in traces for span in trace)

@mock.patch("kombu.messaging.Producer.publish", mock.Mock(side_effect=ValueError))
def test_fn_task_apply_async_soft_exception(self):
# If the underlying library runs into an exception that doesn't crash the app
Expand Down
21 changes: 16 additions & 5 deletions tests/profiling_v2/test_gunicorn.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def debug_print(*args):

def _run_gunicorn(*args):
cmd = (
["ddtrace-run", "gunicorn", "--bind", "127.0.0.1:7643", "--chdir", os.path.dirname(__file__)]
["ddtrace-run", "gunicorn", "--bind", "127.0.0.1:7644", "--chdir", os.path.dirname(__file__)]
+ list(args)
+ ["tests.profiling.gunicorn-app:app"]
)
Expand Down Expand Up @@ -58,32 +58,43 @@ def _test_gunicorn(gunicorn, tmp_path, monkeypatch, *args):
filename = str(tmp_path / "gunicorn.pprof")
monkeypatch.setenv("DD_PROFILING_OUTPUT_PPROF", filename)

debug_print("Creating gunicorn workers")
# DEV: We only start 1 worker to simplify the test
proc = gunicorn("-w", "1", *args)
# Wait for the workers to start
time.sleep(3)
time.sleep(5)

if proc.poll() is not None:
pytest.fail("Gunicorn failed to start")

debug_print("Making request to gunicorn server")
try:
with urllib.request.urlopen("http://127.0.0.1:7643") as f:
with urllib.request.urlopen("http://127.0.0.1:7644", timeout=5) as f:
status_code = f.getcode()
assert status_code == 200, status_code
response = f.read().decode()
debug_print(response)

except Exception as e:
pytest.fail("Failed to make request to gunicorn server %s" % e)
finally:
# Need to terminate the process to get the output and release the port
proc.terminate()

debug_print("Reading gunicorn worker output to get PIDs")
output = proc.stdout.read().decode()
worker_pids = _get_worker_pids(output)
debug_print("Gunicorn worker PIDs: %s" % worker_pids)

for line in output.splitlines():
debug_print(line)

assert len(worker_pids) == 1, output
assert proc.wait() == 0, output

debug_print("Waiting for gunicorn process to terminate")
try:
assert proc.wait(timeout=5) == 0, output
except subprocess.TimeoutExpired:
pytest.fail("Failed to terminate gunicorn process ", output)
assert "module 'threading' has no attribute '_active'" not in output, output

for pid in worker_pids:
Expand Down

0 comments on commit ebfb726

Please sign in to comment.