From 85b0256e73cb4cc5b5936c49c86936da5023c7e6 Mon Sep 17 00:00:00 2001 From: jjmachan Date: Tue, 24 Sep 2024 23:27:50 +0530 Subject: [PATCH] rework description --- .../experimental/testset/transforms/engine.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/ragas/experimental/testset/transforms/engine.py b/src/ragas/experimental/testset/transforms/engine.py index 37a941a06..b0d6e4a64 100644 --- a/src/ragas/experimental/testset/transforms/engine.py +++ b/src/ragas/experimental/testset/transforms/engine.py @@ -43,6 +43,14 @@ async def run_coroutines(coroutines: t.List[t.Coroutine], desc: str, max_workers logger.error(f"unable to apply transformation: {e}") +def get_desc(transform: BaseGraphTransformations | Parallel): + if isinstance(transform, Parallel): + transform_names = [t.__class__.__name__ for t in transform.transformations] + return f"Applying [{', '.join(transform_names)}] transformations in parallel" + else: + return f"Applying {transform.__class__.__name__}" + + @dataclass class TransformerEngine: _nest_asyncio_applied: bool = False @@ -74,25 +82,19 @@ def apply( # if Sequences, apply each transformation sequentially if isinstance(transforms, t.List): for transform in transforms: - desc = f"Applying {transform.__class__.__name__}" asyncio.run( run_coroutines( transform.generate_execution_plan(kg), - desc, + get_desc(transform), run_config.max_workers, ) ) # if Parallel, collect inside it and run it all elif isinstance(transforms, Parallel): - transform_names = [t.__class__.__name__ for t in transforms.transformations] - print("names:", transform_names) - desc = ( - f"Applying [{', '.join(transform_names)}] transformations in parallel" - ) asyncio.run( run_coroutines( transforms.generate_execution_plan(kg), - desc, + get_desc(transforms), run_config.max_workers, ) )