Skip to content

Commit

Permalink
Merge pull request #254 from stochasticai/dev
Browse files Browse the repository at this point in the history
release 0.1.8
MarcosRiveraMartinez authored Sep 6, 2023
2 parents 4c71825 + 8021076 commit fbeea1a
Showing 9 changed files with 192 additions and 45 deletions.
14 changes: 7 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "xturing"
version = "0.1.7"
version = "0.1.8"
description = "Fine-tuning, evaluation and data generation for LLMs"

authors = [
@@ -43,12 +43,12 @@ keywords = [
dependencies = [
"torch >= 1.9.0",
"pytorch-lightning",
"transformers==4.28.1",
"datasets",
"evaluate",
"bitsandbytes==0.37.2",
"transformers==4.31.0",
"datasets==2.14.5",
"evaluate==0.4.0",
"bitsandbytes==0.41.1",
"sentencepiece",
"deepspeed",
"deepspeed==0.9.5",
"gradio",
"click",
"wget",
@@ -58,7 +58,7 @@ dependencies = [
"openai >= 0.27.0",
"pydantic >= 1.10.0",
"rouge-score >= 0.1.2",
"accelerate",
"accelerate==0.22.0",
"wandb",
]

2 changes: 1 addition & 1 deletion src/xturing/__about__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.1.7"
__version__ = "0.1.8"
49 changes: 48 additions & 1 deletion src/xturing/config/finetuning_config.yaml
Original file line number Diff line number Diff line change
@@ -32,6 +32,13 @@ bloom_lora_int8:
batch_size: 8
max_length: 256

bloom_int8:
learning_rate: 1e-4
weight_decay: 0.01
num_train_epochs: 3
batch_size: 8
max_length: 256

cerebras:
learning_rate: 5e-5
weight_decay: 0.01
@@ -50,6 +57,13 @@ cerebras_lora_int8:
batch_size: 8
max_length: 256

cerebras_int8:
learning_rate: 1e-4
weight_decay: 0.01
num_train_epochs: 3
batch_size: 8
max_length: 256

distilgpt2:
learning_rate: 1e-3
weight_decay: 0.01
@@ -115,6 +129,13 @@ galactica_lora_int8:
batch_size: 8
max_length: 256

galactica_int8:
learning_rate: 1e-4
weight_decay: 0.01
num_train_epochs: 3
batch_size: 8
max_length: 256

generic:
learning_rate: 1e-4
weight_decay: 0.01
@@ -169,6 +190,13 @@ gptj_lora_int8:
batch_size: 8
max_length: 256

gptj_int8:
learning_rate: 1e-4
weight_decay: 0.01
num_train_epochs: 3
batch_size: 8
max_length: 256

gpt2:
learning_rate: 1e-3
weight_decay: 0.01
@@ -187,13 +215,18 @@ gpt2_lora_int8:
num_train_epochs: 3
batch_size: 16

gpt2_int8:
learning_rate: 3e-3
weight_decay: 0.01
num_train_epochs: 3
batch_size: 16

llama:
learning_rate: 5e-5
weight_decay: 0.01
num_train_epochs: 3
optimizer_name: cpu_adam


llama_lora:
learning_rate: 1e-4
weight_decay: 0.01
@@ -207,6 +240,13 @@ llama_lora_int8:
batch_size: 8
max_length: 256

llama_int8:
learning_rate: 1e-4
weight_decay: 0.01
num_train_epochs: 3
batch_size: 8
max_length: 256

llama_lora_kbit:
learning_rate: 3e-4
num_train_epochs: 3
@@ -275,3 +315,10 @@ opt_lora_int8:
num_train_epochs: 3
batch_size: 8
max_length: 256

opt_int8:
learning_rate: 1e-4
weight_decay: 0.01
num_train_epochs: 3
batch_size: 8
max_length: 256
37 changes: 37 additions & 0 deletions src/xturing/config/generation_config.yaml
Original file line number Diff line number Diff line change
@@ -25,6 +25,11 @@ bloom_lora_int8:
max_new_tokens: 256
do_sample: false

# Greedy search
bloom_int8:
max_new_tokens: 256
do_sample: false

# Contrastive search
cerebras:
penalty_alpha: 0.6
@@ -44,6 +49,11 @@ cerebras_lora_int8:
max_new_tokens: 256
do_sample: false

# Greedy search
cerebras_int8:
max_new_tokens: 256
do_sample: false

# Top-p sampling
distilgpt2:
do_sample: true
@@ -102,6 +112,11 @@ galactica_lora_int8:
max_new_tokens: 256
do_sample: false

# Greedy search
galactica_int8:
max_new_tokens: 256
do_sample: false

# Greedy search
generic:
max_new_tokens: 256
@@ -146,6 +161,11 @@ gptj_lora_int8:
max_new_tokens: 256
do_sample: false

# Greedy search
gptj_int8:
max_new_tokens: 256
do_sample: false

# Top-p sampling
gpt2:
do_sample: true
@@ -167,6 +187,13 @@ gpt2_lora_int8:
top_p: 0.92
max_new_tokens: 256

# Top-p sampling
gpt2_int8:
do_sample: true
top_k: 0
top_p: 0.92
max_new_tokens: 256

# Contrastive search
llama:
penalty_alpha: 0.6
@@ -186,6 +213,11 @@ llama_lora_int8:
max_new_tokens: 256
do_sample: false

# Greedy search
llama_int8:
max_new_tokens: 256
do_sample: false

# Greedy search
llama_lora_kbit:
max_new_tokens: 256
@@ -238,3 +270,8 @@ opt_lora:
opt_lora_int8:
max_new_tokens: 256
do_sample: false

# Greedy search
opt_int8:
max_new_tokens: 256
do_sample: false
12 changes: 11 additions & 1 deletion src/xturing/engines/__init__.py
Original file line number Diff line number Diff line change
@@ -44,7 +44,13 @@
GPTJLoraEngine,
GPTJLoraInt8Engine,
)
from xturing.engines.llama2_engine import LLama2Engine
from xturing.engines.llama2_engine import (
LLama2Engine,
LLama2Int8Engine,
LLama2LoraEngine,
LLama2LoraInt8Engine,
LLama2LoraKbitEngine,
)
from xturing.engines.llama_engine import (
LLamaEngine,
LLamaInt8Engine,
@@ -97,6 +103,10 @@
BaseEngine.add_to_registry(LlamaLoraInt8Engine.config_name, LlamaLoraInt8Engine)
BaseEngine.add_to_registry(LlamaLoraKbitEngine.config_name, LlamaLoraKbitEngine)
BaseEngine.add_to_registry(LLama2Engine.config_name, LLama2Engine)
BaseEngine.add_to_registry(LLama2Int8Engine.config_name, LLama2Int8Engine)
BaseEngine.add_to_registry(LLama2LoraEngine.config_name, LLama2LoraEngine)
BaseEngine.add_to_registry(LLama2LoraInt8Engine.config_name, LLama2LoraInt8Engine)
BaseEngine.add_to_registry(LLama2LoraKbitEngine.config_name, LLama2LoraKbitEngine)
BaseEngine.add_to_registry(OPTEngine.config_name, OPTEngine)
BaseEngine.add_to_registry(OPTInt8Engine.config_name, OPTInt8Engine)
BaseEngine.add_to_registry(OPTLoraEngine.config_name, OPTLoraEngine)
3 changes: 1 addition & 2 deletions src/xturing/engines/generic_engine.py
Original file line number Diff line number Diff line change
@@ -64,7 +64,7 @@ def __init__(


class GenericLoraKbitEngine(CausalLoraKbitEngine):
config_name: str = "generic+lora_kbit_engine"
config_name: str = "generic_lora_kbit_engine"

def __init__(
self,
@@ -75,7 +75,6 @@ def __init__(
super().__init__(
model_name=model_name,
weights_path=weights_path,
load_4bit=True,
target_modules=target_modules,
)

12 changes: 11 additions & 1 deletion src/xturing/models/__init__.py
Original file line number Diff line number Diff line change
@@ -36,7 +36,13 @@
LlamaLoraInt8,
LlamaLoraKbit,
)
from xturing.models.llama2 import Llama2
from xturing.models.llama2 import (
Llama2,
Llama2Int8,
Llama2Lora,
Llama2LoraInt8,
Llama2LoraKbit,
)
from xturing.models.opt import OPT, OPTInt8, OPTLora, OPTLoraInt8
from xturing.models.stable_diffusion import StableDiffusion

@@ -78,6 +84,10 @@
BaseModel.add_to_registry(LlamaLoraInt8.config_name, LlamaLoraInt8)
BaseModel.add_to_registry(LlamaLoraKbit.config_name, LlamaLoraKbit)
BaseModel.add_to_registry(Llama2.config_name, Llama2)
BaseModel.add_to_registry(Llama2Int8.config_name, Llama2Int8)
BaseModel.add_to_registry(Llama2Lora.config_name, Llama2Lora)
BaseModel.add_to_registry(Llama2LoraInt8.config_name, Llama2LoraInt8)
BaseModel.add_to_registry(Llama2LoraKbit.config_name, Llama2LoraKbit)
BaseModel.add_to_registry(OPT.config_name, OPT)
BaseModel.add_to_registry(OPTInt8.config_name, OPTInt8)
BaseModel.add_to_registry(OPTLora.config_name, OPTLora)
53 changes: 21 additions & 32 deletions src/xturing/models/causal.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import json
from pathlib import Path

from typing import Iterable, List, Optional, Tuple, Type, Union
from typing import Iterable, List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
from pytorch_lightning.loggers import Logger
from torch.utils.data import DataLoader
from tqdm import tqdm
@@ -21,15 +19,7 @@
from xturing.trainers.base import BaseTrainer
from xturing.trainers.lightning_trainer import LightningTrainer
from xturing.utils.logging import configure_logger
from xturing.utils.metrics import get_accuracy
from xturing.utils.prompt import (
OpenAIChatMessage,
OpenAICreateChatPrompt,
OpenAICreatePrompt,
Prompt,
chat_prompt_to_text,
is_chat_prompt,
)
from xturing.utils.prompt import OpenAICreateChatPrompt, OpenAICreatePrompt, Prompt
from xturing.utils.utils import _filter_args, _index_samples

TokenSequence = Union[List[int], torch.LongTensor, torch.Tensor, BatchEncoding]
@@ -44,6 +34,7 @@ def __init__(
weights_path: Optional[str] = None,
model_name: Optional[str] = None,
target_modules: Optional[List[str]] = None,
transfer_to_device: Optional[bool] = True,
**kwargs,
):
arguments = dict(
@@ -82,6 +73,8 @@ def __init__(
logger.debug(f"Finetuning parameters: {self.finetuning_args}")
logger.debug(f"Generation parameters: {self.generation_args}")

self.transfer_to_device = transfer_to_device

def finetuning_config(self):
return self.finetuning_args

@@ -163,7 +156,9 @@ def generate(
batch_size: Optional[int] = 1,
):
self.engine.model.eval()
self.engine.model = self.engine.model.to(DEFAULT_DEVICE)

if self.transfer_to_device:
self.engine.model = self.engine.model.to(DEFAULT_DEVICE)

outputs = []

@@ -239,18 +234,9 @@ def _model_call(
def completion_query(
self, prompt: Union[OpenAICreatePrompt, OpenAICreateChatPrompt, Prompt]
):
# actual_prompt = chat_prompt_to_text(prompt)
actual_prompt = prompt
logger.info(prompt)
text_out = self.generate(texts=[actual_prompt])

# parse results
# result = {
# "text": text_out,
# "tokens": None,
# "logprobs": None,
# }

return text_out, actual_prompt

def check_sampled_text(
@@ -314,8 +300,6 @@ def evaluate(
dataset: Union[TextDataset, InstructionDataset],
batch_size: Optional[int] = 1,
):
# outputs = self.eval_all_samples(dataset)
# return get_accuracy(outputs)
collate_fn = self._make_collate_fn(dataset)
dataloader = DataLoader(
dataset,
@@ -338,7 +322,11 @@ def __init__(
):
assert_not_cpu_int8()
super().__init__(
engine, weights_path=weights_path, model_name=model_name, **kwargs
engine,
weights_path=weights_path,
model_name=model_name,
transfer_to_device=False,
**kwargs,
)


@@ -400,18 +388,19 @@ def __init__(

class CausalLoraKbitModel(CausalLoraModel):
def __init__(
self,
engine: str,
weights_path: Optional[str] = None,
model_name: Optional[str] = None,
target_modules: Optional[List[str]] = None,
**kwargs,
):
self,
engine: str,
weights_path: Optional[str] = None,
model_name: Optional[str] = None,
target_modules: Optional[List[str]] = None,
**kwargs,
):
assert_not_cpu_int8()
super().__init__(
engine,
weights_path=weights_path,
model_name=model_name,
target_modules=target_modules,
transfer_to_device=False,
**kwargs,
)
55 changes: 55 additions & 0 deletions tests/xturing/models/test_generic_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import tempfile
from pathlib import Path

from xturing.models import (
GenericInt8Model,
GenericLoraInt8Model,
GenericLoraKbitModel,
GenericLoraModel,
GenericModel,
)


def test_generic_model():
saving_path = Path(tempfile.gettempdir()) / "test_xturing_generic"
model = GenericModel("distilgpt2")
model.save(str(saving_path))

model2 = GenericModel(str(saving_path))
model2.generate(texts=["Why are the LLM so important?"])


def test_generic_model_int8():
saving_path = Path(tempfile.gettempdir()) / "test_xturing_generic_int8"
model = GenericInt8Model("distilgpt2")
model.save(str(saving_path))

model2 = GenericInt8Model(str(saving_path))
model2.generate(texts=["Why are the LLM so important?"])


def test_generic_model_lora():
saving_path = Path(tempfile.gettempdir()) / "test_xturing_generic_lora"
model = GenericLoraModel("distilgpt2")
model.save(str(saving_path))

model2 = GenericLoraModel(str(saving_path))
model2.generate(texts=["Why are the LLM so important?"])


def test_generic_model_int8_lora():
saving_path = Path(tempfile.gettempdir()) / "test_xturing_lora_int8"
model = GenericLoraInt8Model("distilgpt2")
model.save(str(saving_path))

model2 = GenericLoraInt8Model(str(saving_path))
model2.generate(texts=["Why are the LLM so important?"])


def test_generic_model_lora_kbit():
saving_path = Path(tempfile.gettempdir()) / "test_xturing_lora_kbit"
model = GenericLoraKbitModel("distilgpt2")
model.save(str(saving_path))

model2 = GenericLoraKbitModel(str(saving_path))
model2.generate(texts=["Why are the LLM so important?"])

0 comments on commit fbeea1a

Please sign in to comment.