From cea9e167966cbeda11edd446690521ce40f762c7 Mon Sep 17 00:00:00 2001 From: Oindrilla Chatterjee Date: Wed, 26 Jun 2024 15:25:58 -0400 Subject: [PATCH] lint fixes Signed-off-by: Oindrilla Chatterjee --- src/instructlab/sdg/llmblock.py | 70 ++++++++++++++++++++++++++------- 1 file changed, 55 insertions(+), 15 deletions(-) diff --git a/src/instructlab/sdg/llmblock.py b/src/instructlab/sdg/llmblock.py index d7832ad2..07753505 100644 --- a/src/instructlab/sdg/llmblock.py +++ b/src/instructlab/sdg/llmblock.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # Standard +from typing import Any, Dict, Union import re # Third Party @@ -8,7 +9,6 @@ # Local from .block import Block from .logger_config import setup_logger -from typing import Any, Dict, Union logger = setup_logger(__name__) @@ -110,22 +110,47 @@ def generate(self, samples, **gen_kwargs) -> Dataset: class ConditionalLLMBlock(LLMBlock): - def __init__(self, block_name, config_paths, client, model_id, output_cols, selector_column_name, parser_name, model_prompt="{prompt}", **batch_kwargs) -> None: - super().__init__(block_name, config_paths[0][0], client, model_id, output_cols, model_prompt=model_prompt, **batch_kwargs) + def __init__( + self, + block_name, + config_paths, + client, + model_id, + output_cols, + selector_column_name, + parser_name, + model_prompt="{prompt}", + **batch_kwargs, + ) -> None: + super().__init__( + block_name, + config_paths[0][0], + client, + model_id, + output_cols, + model_prompt=model_prompt, + **batch_kwargs, + ) self.selector_column_name = selector_column_name self.prompt_template = {} self.parser_name = parser_name - if len(config_paths) == 1 and config_paths[0][1] == 'All': + if len(config_paths) == 1 and config_paths[0][1] == "All": self.prompt_template = self.prompt_struct.format(**self.block_config) else: - for (config, config_key) in config_paths: - self.prompt_template[config_key] = self.prompt_struct.format(**self._load_config(config)) + for config, config_key in config_paths: + self.prompt_template[config_key] = self.prompt_struct.format( + **self._load_config(config) + ) def _parse(self, generated_string): - if self.parser_name == 'default': + if self.parser_name == "default": return super()._parse(generated_string) - elif self.parser_name == 'multi-line-logical-section': - return {self.output_cols[0]: self.extract_multiline_logical_section(generated_string)} + elif self.parser_name == "multi-line-logical-section": + return { + self.output_cols[0]: self.extract_multiline_logical_section( + generated_string + ) + } def extract_multiline_logical_section(self, text): """ @@ -137,21 +162,36 @@ def extract_multiline_logical_section(self, text): Returns: list: A list of multi-line points without the point numbers. """ - pattern = re.compile(r'## Logical Section \d+: (.*?)(?=## Logical Section \d+:|$)', re.DOTALL) + pattern = re.compile( + r"## Logical Section \d+: (.*?)(?=## Logical Section \d+:|$)", re.DOTALL + ) sections = pattern.findall(text) return sections def _generate(self, samples, **gen_kwargs) -> str: if isinstance(self.prompt_template, dict): - prompts = [self.model_prompt.format(prompt=self.prompt_template[sample[self.selector_column_name]].format(**sample).strip()) for sample in samples] + prompts = [ + self.model_prompt.format( + prompt=self.prompt_template[sample[self.selector_column_name]] + .format(**sample) + .strip() + ) + for sample in samples + ] else: - prompts = [self.model_prompt.format(prompt=self.prompt_template.format(**sample).strip()) for sample in samples] - response = self.client.completions.create(prompt=prompts, **{**self.defaults, **gen_kwargs}) + prompts = [ + self.model_prompt.format( + prompt=self.prompt_template.format(**sample).strip() + ) + for sample in samples + ] + response = self.client.completions.create( + prompt=prompts, **{**self.defaults, **gen_kwargs} + ) return [choice.text.strip() for choice in response.choices] - def _validate(self, prompt_template: str, input_dict: Dict[str, Any]) -> bool: if isinstance(prompt_template, dict): prompt_template = prompt_template[input_dict[self.selector_column_name]] - return super()._validate(prompt_template, input_dict) \ No newline at end of file + return super()._validate(prompt_template, input_dict)