Skip to content

Commit

Permalink
fix: gpt2 generation, datasets
Browse files Browse the repository at this point in the history
Generation unified added dataset and gpt2 fixes
  • Loading branch information
StochasticRomanAgeev committed Mar 21, 2023
1 parent c109041 commit 9a04aea
Show file tree
Hide file tree
Showing 11 changed files with 78 additions and 64 deletions.
3 changes: 1 addition & 2 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
pre-commit
torch
transformers
pytest
1 change: 1 addition & 0 deletions src/turing/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .base import InstructionDataset, Text2ImageDataset, TextDataset
12 changes: 9 additions & 3 deletions src/turing/datasets/instruction_dataset.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
from pathlib import Path
from typing import Union

from datasets import Dataset as HFDataset
from datasets import load_from_disk
from torch.utils.data import Dataset


class InstructionDataset(Dataset):
config_name: str = "instruction_dataset"

def __init__(self, path: Union[str, Path]):
assert Path(path).exists(), "path does not exist"
self.data = load_from_disk(path)
def __init__(self, path: Union[str, Path, HFDataset, dict]):
if isinstance(path, HFDataset):
self.data = path
elif isinstance(path, dict):
self.data = {"train": HFDataset.from_dict(path)}
else:
assert Path(path).exists(), "path does not exist"
self.data = load_from_disk(path)
self._validate()

def _validate(self):
Expand Down
12 changes: 9 additions & 3 deletions src/turing/datasets/text_dataset.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
from pathlib import Path
from typing import Union

from datasets import Dataset as HFDataset
from datasets import load_from_disk
from torch.utils.data import Dataset


class TextDataset(Dataset):
config_name: str = "text_dataset"

def __init__(self, path: Union[str, Path]):
assert Path(path).exists(), "path does not exist"
self.data = load_from_disk(path)
def __init__(self, path: Union[str, Path, HFDataset, dict]):
if isinstance(path, HFDataset):
self.data = path
elif isinstance(path, dict):
self.data = {"train": HFDataset.from_dict(path)}
else:
assert Path(path).exists(), "path does not exist"
self.data = load_from_disk(path)
self._validate()

def _validate(self):
Expand Down
6 changes: 3 additions & 3 deletions src/turing/engines/llama_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import evaluate
import torch
import torch.nn as nn
from transformers import AutoTokenizer, LlamaForCausalLM
from transformers import AutoModelForCausalLM, AutoTokenizer

from turing.config import DEFAULT_DTYPE

Expand All @@ -14,7 +14,7 @@ class LLamaEngine:

def __init__(self, weights_path: Optional[Union[str, Path]] = None):
if weights_path is None:
self.model = LlamaForCausalLM.from_pretrained(
self.model = AutoModelForCausalLM.from_pretrained(
"decapoda-research/llama-7b-hf", torch_dtype=DEFAULT_DTYPE
)
self.tokenizer = AutoTokenizer.from_pretrained(
Expand All @@ -24,7 +24,7 @@ def __init__(self, weights_path: Optional[Union[str, Path]] = None):
assert Path(
weights_path
).is_dir(), "The weights path should be a existing directory"
self.model = LlamaForCausalLM.from_pretrained(
self.model = AutoModelForCausalLM.from_pretrained(
weights_path, torch_dtype=DEFAULT_DTYPE
)
self.tokenizer = AutoTokenizer.from_pretrained(weights_path)
Expand Down
1 change: 1 addition & 0 deletions src/turing/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .base import GPT2, GPTJ, Llama, StableDiffusion
1 change: 1 addition & 0 deletions src/turing/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ class BaseModel(BaseParent):
GPTJ.config_name: GPTJ,
Llama.config_name: Llama,
StableDiffusion.config_name: StableDiffusion,
GPT2.config_name: GPT2,
}
98 changes: 49 additions & 49 deletions src/turing/models/gpt2.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from pathlib import Path
from typing import List, Optional, Union
from typing import Iterable, List, Optional, Union

import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

from turing.config import DEFAULT_DEVICE
from turing.datasets.instruction_dataset import InstructionDataset
from turing.datasets.text_dataset import TextDataset
from turing.engines.base import BaseEngine
Expand All @@ -13,62 +14,67 @@


class GPT2:
config_name = "gpt2"

def __init__(self, weights_path: Optional[str] = None):
self.engine = BaseEngine.create("gpt2_engine", weights_path)

self.collate_fn = None
self.trainer = None
def _make_collate_fn(self, dataset: Union[TextDataset, InstructionDataset]):
return BasePreprocessor.create(dataset.config_name, self.engine.tokenizer, 512)

def finetune(self, dataset: Union[TextDataset, InstructionDataset]):
assert dataset.config_name in [
"text_dataset",
"instruction_dataset",
], "Please make sure the dataset_type is text_dataset or instruction_dataset"
self.collate_fn = BasePreprocessor.create(
dataset.config_name, self.engine.tokenizer, 512
)
self.trainer = BaseTrainer.create(
"lightning_trainer", self.engine, dataset, self.collate_fn
collate_fn = self._make_collate_fn(dataset)
trainer = BaseTrainer.create(
"lightning_trainer", self.engine, dataset, collate_fn
)
self.trainer.fit()
trainer.fit()

def evaluate(self, dataset: Union[TextDataset, InstructionDataset]):
pass

def _generate_from_iterable(self, data_iterator: Iterable, do_tokenization=False):
outputs = []

for i, batch in enumerate(tqdm(data_iterator)):
if do_tokenization:
inputs = self.engine.tokenizer(batch, return_tensors="pt")
input_ids = inputs.input_ids.to(DEFAULT_DEVICE)
else:
input_ids = batch["input_ids"].to(DEFAULT_DEVICE)
with torch.no_grad():
with torch.autocast("cuda"):
output = self.engine.model.generate(
input_ids=input_ids, do_sample=False, max_new_tokens=300
)

output = self.engine.tokenizer.decode(output[0], skip_special_tokens=False)
outputs.append(output)

return outputs

def generate(
self,
*,
texts: Optional[Union[List[str], str]] = None,
dataset: Optional[Union[TextDataset, InstructionDataset]] = None,
):
device = (
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
)
self.engine.model.eval()

outputs = []

if texts is not None:
texts = [texts] if isinstance(texts) == str else texts

outputs = []
for text in tqdm(texts):
inputs = self.engine.tokenizer(text, return_tensors="pt")
input_ids = inputs.input_ids.to(device)
with torch.no_grad():
with torch.autocast("cuda"):
output = self.engine.model.generate(
input_ids=input_ids, do_sample=False, max_new_tokens=300
)

output = self.engine.tokenizer.decode(
output[0], skip_special_tokens=False
)
outputs.append(output)

elif dataset is not None:
collate_fn = (
BasePreprocessor("text_dataset")(self.engine.tokenizer, 512)
if isinstance(dataset) == TextDataset
else BasePreprocessor("instruction_dataset")(self.engine.tokenizer, 512)
flattened_texts = [texts] if isinstance(texts, str) else texts

outputs.extend(
self._generate_from_iterable(flattened_texts, do_tokenization=True)
)

if dataset is not None:
collate_fn = self._make_collate_fn(dataset)
dataloader = DataLoader(
dataset,
batch_size=1,
Expand All @@ -77,21 +83,15 @@ def generate(
collate_fn=collate_fn,
)

outputs = []
for i, batch in enumerate(tqdm(dataloader)):
input_ids = batch["input_ids"].to(self.device)
with torch.no_grad():
with torch.autocast("cuda"):
output = self.engine.model.generate(
input_ids=input_ids, do_sample=False, max_new_tokens=300
)

output = self.engine.tokenizer.decode(
output[0], skip_special_tokens=False
)
outputs.append(output)
else:
raise ("Make sure texts or dataset is not None")
outputs.extend(
self._generate_from_iterable(dataloader, do_tokenization=False)
)

if texts is None and dataset is None:
assert False, "Make sure texts or dataset is not None"

if isinstance(texts, str) and dataset is None:
return outputs[0]

return outputs

Expand Down
2 changes: 1 addition & 1 deletion src/turing/models/gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def generate(
self.engine.model.eval()

if texts is not None:
texts = [texts] if isinstance(texts) == str else texts
texts = [texts] if isinstance(texts, str) else texts

outputs = []
for text in tqdm(texts):
Expand Down
4 changes: 2 additions & 2 deletions src/turing/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def generate(
self.engine.model.eval()

if texts is not None:
texts = [texts] if isinstance(texts) == str else texts
texts = [texts] if isinstance(texts, str) else texts

outputs = []
for text in tqdm(texts):
Expand All @@ -68,7 +68,7 @@ def generate(

elif dataset is not None:
collate_fn = (
BasePreprocessor("text_dataset")(self.engine.tokenizer, 512)
BasePreprocessor.create(dataset.con, self.engine.tokenizer, 512)
if isinstance(dataset) == TextDataset
else BasePreprocessor("instruction_dataset")(self.engine.tokenizer, 512)
)
Expand Down
2 changes: 1 addition & 1 deletion src/turing/trainers/lightning_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def validation_step(self, batch, batch_idx):


class LightningTrainer:
config_name: TuringLightningModule
config_name = "lightning_trainer"

def __init__(
self,
Expand Down

0 comments on commit 9a04aea

Please sign in to comment.