Skip to content

Commit

Permalink
llama-cpp multi server support
Browse files Browse the repository at this point in the history
llama-cpp does not support batching, concurrent completions requests, or really anything to speed our processes up.

The only clear solution here is to create our own form of paralellism by supporting running multiple servers at once.

via a `--num-servers` flag from the cli, a user can spin up 2,3, or even 4 of the `mistral 7b instruct` models since they only take about 5GB of RAM.

This allows us to split our dataset into batches like we do with vllm and execute threads running each batch in parallel. Each server handles its own batch

Signed-off-by: Charlie Doern <[email protected]>
  • Loading branch information
cdoern committed Oct 21, 2024
1 parent 067e4c1 commit 7e0fa83
Show file tree
Hide file tree
Showing 3 changed files with 237 additions and 68 deletions.
272 changes: 218 additions & 54 deletions src/instructlab/sdg/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,18 @@
import os
import time

from concurrent.futures import ThreadPoolExecutor
from instructlab.sdg.checkpointing import Checkpointer


from typing import List
from tqdm import tqdm
from threading import Lock
import math

# Third Party
# instructlab - All of these need to go away (other than sdg) - issue #6
from datasets import Dataset
from datasets import Dataset, concatenate_datasets
from xdg_base_dirs import xdg_data_dirs, xdg_data_home
import openai

Expand Down Expand Up @@ -262,7 +271,7 @@ def _mixer_init(ctx, output_dir, date_suffix, knowledge_auxiliary_inst):
# TODO - parameter removal needs to be done in sync with a CLI change.
# to be removed: logger, prompt_file_path, rouge_threshold, tls_*
def generate_data(
client: openai.OpenAI,
clients: list[openai.OpenAI],
logger: logging.Logger = logger, # pylint: disable=redefined-outer-name
model_family: Optional[str] = None,
model_name: Optional[str] = None,
Expand Down Expand Up @@ -332,7 +341,7 @@ def generate_data(
model_family = MODEL_FAMILY_MERLINITE

ctx = _context_init(
client,
clients[0],
model_family,
model_name,
num_instructions_to_generate,
Expand All @@ -359,60 +368,79 @@ def generate_data(

generated_data = None
empty_sdg_leaf_nodes = []
for leaf_node in leaf_nodes.values():
is_knowledge = False
leaf_node_path = leaf_node[0]["taxonomy_path"].replace("->", "_")
samples = leaf_node_to_samples(leaf_node, server_ctx_size, chunk_word_count)

if not samples:
raise GenerateException("Error: No samples found in leaf node.")

if samples[0].get("document"):
pipe = knowledge_pipe
is_knowledge = True

elif samples[0].get("seed_context"):
pipe = grounded_skills_pipe

else:
pipe = freeform_skills_pipe

logger.debug("Samples: %s", samples)
ds = Dataset.from_list(samples)
logger.debug("Dataset: %s", ds)
new_generated_data = pipe.generate(ds, leaf_node_path)
if len(new_generated_data) == 0:
empty_sdg_leaf_nodes.append(leaf_node_path)
logger.warning("Empty dataset for qna node: %s", leaf_node_path)
continue
generated_data = (
[new_generated_data]
if generated_data is None
else generated_data + [new_generated_data]
)
logger.info("Generated %d samples", len(generated_data))
logger.debug("Generated data: %s", generated_data)

if is_knowledge:
# generate mmlubench data for the current leaf node
generate_eval_task_data(
mmlu_bench_pipe,
leaf_node_path,
ds,
output_dir,
date_suffix,
)

mixer.collect(leaf_node_path, new_generated_data, is_knowledge)
for leaf_node in leaf_nodes.values():
leaf_node_path = leaf_node[0]["taxonomy_path"].replace("->", "_")

Check warning on line 373 in src/instructlab/sdg/generate_data.py

View workflow job for this annotation

GitHub Actions / pylint

W0311: Bad indentation. Found 12 spaces, expected 8 (bad-indentation)
samples = leaf_node_to_samples(leaf_node, server_ctx_size, chunk_word_count)

Check warning on line 374 in src/instructlab/sdg/generate_data.py

View workflow job for this annotation

GitHub Actions / pylint

W0311: Bad indentation. Found 12 spaces, expected 8 (bad-indentation)

Check warning on line 375 in src/instructlab/sdg/generate_data.py

View workflow job for this annotation

GitHub Actions / pylint

C0303: Trailing whitespace (trailing-whitespace)
if not samples:

Check warning on line 376 in src/instructlab/sdg/generate_data.py

View workflow job for this annotation

GitHub Actions / pylint

W0311: Bad indentation. Found 12 spaces, expected 8 (bad-indentation)
raise GenerateException("Error: No samples found in leaf node.")

Check warning on line 377 in src/instructlab/sdg/generate_data.py

View workflow job for this annotation

GitHub Actions / pylint

W0311: Bad indentation. Found 16 spaces, expected 12 (bad-indentation)

is_knowledge = False

Check warning on line 379 in src/instructlab/sdg/generate_data.py

View workflow job for this annotation

GitHub Actions / pylint

W0311: Bad indentation. Found 12 spaces, expected 8 (bad-indentation)
is_g_skill = False

Check warning on line 380 in src/instructlab/sdg/generate_data.py

View workflow job for this annotation

GitHub Actions / pylint

W0311: Bad indentation. Found 12 spaces, expected 8 (bad-indentation)
is_f_skill = False

Check warning on line 381 in src/instructlab/sdg/generate_data.py

View workflow job for this annotation

GitHub Actions / pylint

W0311: Bad indentation. Found 12 spaces, expected 8 (bad-indentation)
if samples[0].get("document"):

Check warning on line 382 in src/instructlab/sdg/generate_data.py

View workflow job for this annotation

GitHub Actions / pylint

W0311: Bad indentation. Found 12 spaces, expected 8 (bad-indentation)
is_knowledge = True

Check warning on line 383 in src/instructlab/sdg/generate_data.py

View workflow job for this annotation

GitHub Actions / pylint

W0311: Bad indentation. Found 16 spaces, expected 12 (bad-indentation)

elif samples[0].get("seed_context"):
is_g_skill = True
pipe = grounded_skills_pipe

else:
is_f_skill = True
pipe = freeform_skills_pipe

logger.debug("Samples: %s", samples)

ds = Dataset.from_list(samples)

# if we have multiple servers using llama we need to
# 1. split ds into batches per server
# 2. execute a thread running each batch in its own pipeline
# 3. add the data back together
# 4. return that data.
if len(clients) > 1:
new_generated_data = generate_on_multiple_servers(
ds,
clients,
checkpoint_dir,
model_family,
model_name,
num_instructions_to_generate,
output_dir,
date_suffix,
is_knowledge,
is_g_skill,
is_f_skill,
pipeline,
leaf_node_path,
num_cpus,
)
else:
new_generated_data = pipe.generate(ds, leaf_node_path)
if len(new_generated_data) == 0:
empty_sdg_leaf_nodes.append(leaf_node_path)
logger.warning("Empty dataset for qna node: %s", leaf_node_path)
continue
generated_data = (
[new_generated_data]
if generated_data is None
else generated_data + [new_generated_data]
)
logger.info("Generated %d samples", len(generated_data))
logger.debug("Generated data: %s", generated_data)

if generated_data is None:
generated_data = []
if is_knowledge:
# generate mmlubench data for the current leaf node
generate_eval_task_data(
mmlu_bench_pipe,
leaf_node_path,
ds,
output_dir,
date_suffix,
)

_gen_train_data(
generated_data,
os.path.join(output_dir, output_file_train),
os.path.join(output_dir, output_file_messages),
)
mixer.collect(leaf_node_path, new_generated_data, is_knowledge)

mixer.generate()

Expand All @@ -424,3 +452,139 @@ def generate_data(
" ".join(empty_sdg_leaf_nodes)
)
)


def process_llama_server_batch(
ds,
client,
model_family,
model_name,
num_instructions_to_generate,
checkpoint_dir,
batch_size,
num_cpus,
output_dir,
date_suffix,
is_knowledge,
is_g_skill,
is_f_skill,
pipeline,
leaf_node_path,
lock,
thread,
):

logger.info(f"Running on client {client.base_url} ")
ctx = _context_init(
client,
model_family,
model_name,
num_instructions_to_generate,
checkpoint_dir,
1, # save_freq
batch_size=batch_size,
batch_num_workers=num_cpus,
)

knowledge_pipe, freeform_skills_pipe, grounded_skills_pipe = _sdg_init(ctx, pipeline)
mmlu_ctx = dataclasses.replace(ctx, checkpoint_dir=None)
mmlu_bench_pipe = mmlubench_pipe_init(mmlu_ctx)
logger.debug("Dataset: %s", ds)
pipe = None
if is_knowledge:
pipe = knowledge_pipe
elif is_g_skill:
pipe = grounded_skills_pipe
elif is_f_skill:
pipe = freeform_skills_pipe

new_data = pipe.generate(ds, leaf_node_path, lock, thread)

if is_knowledge:
generate_eval_task_data(
mmlu_bench_pipe,
leaf_node_path,
ds,
output_dir,
date_suffix,
)
return new_data



def generate_on_multiple_servers(
ds,
clients,
checkpoint_dir,
model_family,
model_name,
num_instructions_to_generate,
output_dir,
date_suffix,
is_knowledge,
is_g_skill,
is_f_skill,
pipeline,
leaf_node_path,
num_cpus,

):
# num batches == num clients
total_size = len(ds)

batch_size = math.ceil(total_size / len(clients))

# Create a batch for each client using ds.select() and indices
batches = [
ds.select(range(i * batch_size, min((i + 1) * batch_size, total_size)))
for i in range(len(clients))
]
# batches will be the same as the number of clients.

checkpointer = Checkpointer(checkpoint_dir, 1)
output_splits = []
lock = Lock()

#tqdms = []
#for i in range(len(clients)):
#tdqm = tqdm(total=0, "Batch {i} Pipeline Progress", position=i, leave=True, lock_args=(True,))
logger.debug(f" batches {len(batches)}, clients {len(clients)}, {clients}")
# Using ThreadPoolExecutor to process each (ds, client) pair in parallel
with ThreadPoolExecutor() as executor:
futures = [
executor.submit(
process_llama_server_batch,
ds,
client,
model_family,
model_name,
num_instructions_to_generate,
checkpoint_dir,
0,
num_cpus,
output_dir,
date_suffix,
is_knowledge,
is_g_skill,
is_f_skill,
pipeline,
leaf_node_path,
lock,
thread,
)
for thread, (ds, client) in enumerate(zip(batches, clients))
]
for i, future in enumerate(futures):
if future.running():
print(f"Thread {i} is running")
elif future.done():
print(f"Thread {i} has completed")
elif future.cancelled():
print(f"Thread {i} was canceled")
for future in futures:
new_data = future.result()
output_splits.append(new_data)
checkpointer.checkpoint(new_data)
concatenate_datasets(output_splits)
logger.debug("Dataset: %s", output_splits)
return output_splits
17 changes: 10 additions & 7 deletions src/instructlab/sdg/llmblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from tqdm import tqdm
import openai

from threading import Lock
# Local
from .block import Block

Expand Down Expand Up @@ -156,7 +157,7 @@ def _gen_kwargs(self, gen_kwargs, **defaults):
gen_kwargs["temperature"] = float(gen_kwargs["temperature"])
return gen_kwargs

def _generate(self, samples) -> list:
def _generate(self, samples, lock, thread) -> list:
prompts = [self._format_prompt(sample) for sample in samples]
logger.debug(f"STARTING GENERATION FOR LLMBlock USING PROMPTS: {prompts}")
if self.server_supports_batched:
Expand All @@ -166,20 +167,22 @@ def _generate(self, samples) -> list:
return [choice.text.strip() for choice in response.choices]

results = []
progress_bar = tqdm(
range(len(prompts)), desc=f"{self.block_name} Prompt Generation"
)
#progress_bar = tqdm(
# range(len(prompts)), desc=f"{self.block_name} Prompt Generation Thread {thread}", position=thread,
#)
for prompt in prompts:
logger.debug(f"CREATING COMPLETION FOR PROMPT: {prompt}")
for _ in range(self.gen_kwargs.get("n", 1)):
response = self.ctx.client.completions.create(
prompt=prompt, **self.gen_kwargs
)
results.append(response.choices[0].text.strip())
progress_bar.update(1)
print(f"COMPLETION FOR THREAD {thread} DONE")
# with lock:
# progress_bar.update(1)
return results

def generate(self, samples: Dataset) -> Dataset:
def generate(self, samples: Dataset, lock: Lock, thread: int) -> Dataset:
"""
Generate the output from the block. This method should first validate the input data,
then generate the output, and finally parse the generated output before returning it.
Expand Down Expand Up @@ -211,7 +214,7 @@ def generate(self, samples: Dataset) -> Dataset:

# generate the output

outputs = self._generate(samples)
outputs = self._generate(samples, lock, thread)
logger.debug("Generated outputs: %s", outputs)

num_parallel_samples = self.gen_kwargs.get("n", 1)
Expand Down
Loading

0 comments on commit 7e0fa83

Please sign in to comment.