diff --git a/src/promptflow/promptflow/_core/tracer.py b/src/promptflow/promptflow/_core/tracer.py index 5322cb6617f..51ae4533a1e 100644 --- a/src/promptflow/promptflow/_core/tracer.py +++ b/src/promptflow/promptflow/_core/tracer.py @@ -20,6 +20,7 @@ from promptflow._core.operation_context import OperationContext from promptflow._utils.dataclass_serializer import serialize from promptflow._utils.multimedia_utils import default_json_encoder +from promptflow._utils.tool_utils import get_inputs_for_prompt_template, get_prompt_param_name_from_func from promptflow.contracts.tool import ConnectionType from promptflow.contracts.trace import Trace, TraceType @@ -263,6 +264,22 @@ def enrich_span_with_trace(span, trace): logging.warning(f"Failed to enrich span with trace: {e}") +def enrich_span_with_prompt_info(span, func, kwargs): + try: + # Assume there is only one prompt template parameter in the function, + # we use the first one by default if there are multiple. + prompt_tpl_param_name = get_prompt_param_name_from_func(func) + if prompt_tpl_param_name is not None: + prompt_tpl = kwargs.get(prompt_tpl_param_name) + prompt_vars = { + key: kwargs.get(key) for key in get_inputs_for_prompt_template(prompt_tpl) if key in kwargs + } + prompt_info = {"prompt.template": prompt_tpl, "prompt.variables": serialize_attribute(prompt_vars)} + span.set_attributes(prompt_info) + except Exception as e: + logging.warning(f"Failed to enrich span with prompt info: {e}") + + def enrich_span_with_input(span, input): try: serialized_input = serialize_attribute(input) @@ -350,6 +367,7 @@ async def wrapped(*args, **kwargs): span_name = get_node_name_from_context() if trace_type == TraceType.TOOL else trace.name with open_telemetry_tracer.start_as_current_span(span_name) as span: enrich_span_with_trace(span, trace) + enrich_span_with_prompt_info(span, func, kwargs) # Should not extract these codes to a separate function here. # We directly call func instead of calling Tracer.invoke, @@ -400,6 +418,7 @@ def wrapped(*args, **kwargs): span_name = get_node_name_from_context() if trace_type == TraceType.TOOL else trace.name with open_telemetry_tracer.start_as_current_span(span_name) as span: enrich_span_with_trace(span, trace) + enrich_span_with_prompt_info(span, func, kwargs) # Should not extract these codes to a separate function here. # We directly call func instead of calling Tracer.invoke, diff --git a/src/promptflow/tests/executor/e2etests/test_traces.py b/src/promptflow/tests/executor/e2etests/test_traces.py index d8dc891b06c..30b47dcc1ee 100644 --- a/src/promptflow/tests/executor/e2etests/test_traces.py +++ b/src/promptflow/tests/executor/e2etests/test_traces.py @@ -8,12 +8,13 @@ from promptflow._core.tracer import TraceType, trace from promptflow._utils.dataclass_serializer import serialize +from promptflow._utils.tool_utils import get_inputs_for_prompt_template from promptflow.contracts.run_info import Status from promptflow.executor import FlowExecutor from promptflow.executor._result import LineResult from ..process_utils import execute_function_in_subprocess -from ..utils import get_flow_sample_inputs, get_yaml_file, prepare_memory_exporter +from ..utils import get_flow_folder, get_flow_sample_inputs, get_yaml_file, prepare_memory_exporter, load_content OPEN_AI_FUNCTION_NAMES = [ "openai.resources.chat.completions.Completions.create", @@ -28,6 +29,11 @@ "__computed__.cumulative_token_count.total", ] +SHOULD_INCLUDE_PROMPT_FUNCTION_NAMES = [ + "render_template_jinja2", + "AzureOpenAI.chat", +] + def get_chat_input(stream): return { @@ -380,6 +386,48 @@ def validate_openai_tokens(self, span_list): for token_name in TOKEN_NAMES: assert span.attributes[token_name] == token_dict[span_id][token_name] + @pytest.mark.parametrize( + "flow_file, inputs, prompt_tpl_file", + [ + ("llm_tool", {"topic": "Hello", "stream": False}, "joke.jinja2"), + # Add back this test case after changing the interface of render_template_jinja2 + # ("prompt_tools", {"text": "test"}, "summarize_text_content_prompt.jinja2"), + ] + ) + def test_otel_trace_with_prompt( + self, + dev_connections, + flow_file, + inputs, + prompt_tpl_file, + ): + execute_function_in_subprocess( + self.assert_otel_traces_with_prompt, dev_connections, flow_file, inputs, prompt_tpl_file + ) + + def assert_otel_traces_with_prompt(self, dev_connections, flow_file, inputs, prompt_tpl_file): + memory_exporter = prepare_memory_exporter() + + executor = FlowExecutor.create(get_yaml_file(flow_file), dev_connections) + line_run_id = str(uuid.uuid4()) + resp = executor.exec_line(inputs, run_id=line_run_id) + assert isinstance(resp, LineResult) + assert isinstance(resp.output, dict) + + prompt_tpl = load_content(get_flow_folder(flow_file) / prompt_tpl_file) + prompt_vars = list(get_inputs_for_prompt_template(prompt_tpl).keys()) + span_list = memory_exporter.get_finished_spans() + for span in span_list: + assert span.status.status_code == StatusCode.OK + assert isinstance(span.name, str) + if span.attributes.get("function", "") in SHOULD_INCLUDE_PROMPT_FUNCTION_NAMES: + assert "prompt.template" in span.attributes + assert span.attributes["prompt.template"] == prompt_tpl + assert "prompt.variables" in span.attributes + for var in prompt_vars: + if var in inputs: + assert var in span.attributes["prompt.variables"] + def test_flow_with_traced_function(self): execute_function_in_subprocess(self.assert_otel_traces_run_flow_then_traced_function)