Skip to content

Commit

Permalink
lint fixes
Browse files Browse the repository at this point in the history
Signed-off-by: Oindrilla Chatterjee <[email protected]>
  • Loading branch information
oindrillac committed Jun 26, 2024
1 parent 18642d6 commit cea9e16
Showing 1 changed file with 55 additions and 15 deletions.
70 changes: 55 additions & 15 deletions src/instructlab/sdg/llmblock.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# Standard
from typing import Any, Dict, Union

Check warning on line 3 in src/instructlab/sdg/llmblock.py

View workflow job for this annotation

GitHub Actions / lint

W0611: Unused Union imported from typing (unused-import)
import re

# Third Party
Expand All @@ -8,7 +9,6 @@
# Local
from .block import Block
from .logger_config import setup_logger
from typing import Any, Dict, Union

logger = setup_logger(__name__)

Expand Down Expand Up @@ -110,22 +110,47 @@ def generate(self, samples, **gen_kwargs) -> Dataset:


class ConditionalLLMBlock(LLMBlock):
def __init__(self, block_name, config_paths, client, model_id, output_cols, selector_column_name, parser_name, model_prompt="{prompt}", **batch_kwargs) -> None:
super().__init__(block_name, config_paths[0][0], client, model_id, output_cols, model_prompt=model_prompt, **batch_kwargs)
def __init__(
self,
block_name,
config_paths,
client,
model_id,
output_cols,
selector_column_name,
parser_name,
model_prompt="{prompt}",
**batch_kwargs,
) -> None:
super().__init__(
block_name,
config_paths[0][0],
client,
model_id,
output_cols,
model_prompt=model_prompt,
**batch_kwargs,
)
self.selector_column_name = selector_column_name
self.prompt_template = {}
self.parser_name = parser_name
if len(config_paths) == 1 and config_paths[0][1] == 'All':
if len(config_paths) == 1 and config_paths[0][1] == "All":
self.prompt_template = self.prompt_struct.format(**self.block_config)
else:
for (config, config_key) in config_paths:
self.prompt_template[config_key] = self.prompt_struct.format(**self._load_config(config))
for config, config_key in config_paths:
self.prompt_template[config_key] = self.prompt_struct.format(
**self._load_config(config)
)

def _parse(self, generated_string):
if self.parser_name == 'default':
if self.parser_name == "default":

Check warning on line 146 in src/instructlab/sdg/llmblock.py

View workflow job for this annotation

GitHub Actions / lint

R1705: Unnecessary "elif" after "return", remove the leading "el" from "elif" (no-else-return)
return super()._parse(generated_string)
elif self.parser_name == 'multi-line-logical-section':
return {self.output_cols[0]: self.extract_multiline_logical_section(generated_string)}
elif self.parser_name == "multi-line-logical-section":
return {
self.output_cols[0]: self.extract_multiline_logical_section(
generated_string
)
}

def extract_multiline_logical_section(self, text):
"""
Expand All @@ -137,21 +162,36 @@ def extract_multiline_logical_section(self, text):
Returns:
list: A list of multi-line points without the point numbers.
"""
pattern = re.compile(r'## Logical Section \d+: (.*?)(?=## Logical Section \d+:|$)', re.DOTALL)
pattern = re.compile(
r"## Logical Section \d+: (.*?)(?=## Logical Section \d+:|$)", re.DOTALL
)
sections = pattern.findall(text)

return sections

def _generate(self, samples, **gen_kwargs) -> str:
if isinstance(self.prompt_template, dict):
prompts = [self.model_prompt.format(prompt=self.prompt_template[sample[self.selector_column_name]].format(**sample).strip()) for sample in samples]
prompts = [
self.model_prompt.format(
prompt=self.prompt_template[sample[self.selector_column_name]]
.format(**sample)
.strip()
)
for sample in samples
]
else:
prompts = [self.model_prompt.format(prompt=self.prompt_template.format(**sample).strip()) for sample in samples]
response = self.client.completions.create(prompt=prompts, **{**self.defaults, **gen_kwargs})
prompts = [
self.model_prompt.format(
prompt=self.prompt_template.format(**sample).strip()
)
for sample in samples
]
response = self.client.completions.create(
prompt=prompts, **{**self.defaults, **gen_kwargs}
)
return [choice.text.strip() for choice in response.choices]


def _validate(self, prompt_template: str, input_dict: Dict[str, Any]) -> bool:

Check warning on line 194 in src/instructlab/sdg/llmblock.py

View workflow job for this annotation

GitHub Actions / lint

W0221: Number of parameters was 2 in 'Block._validate' and is now 3 in overriding 'ConditionalLLMBlock._validate' method (arguments-differ)
if isinstance(prompt_template, dict):
prompt_template = prompt_template[input_dict[self.selector_column_name]]
return super()._validate(prompt_template, input_dict)
return super()._validate(prompt_template, input_dict)

0 comments on commit cea9e16

Please sign in to comment.