Skip to content

Commit

Permalink
feat: Add generic Wandb integration
Browse files Browse the repository at this point in the history
  • Loading branch information
cnbeining committed May 2, 2023
1 parent 142954b commit 5bf9d33
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 26 deletions.
38 changes: 21 additions & 17 deletions examples/int4_finetuning/LLaMA_lora_int4.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@
"source": [
"from xturing.datasets.instruction_dataset import InstructionDataset\n",
"from xturing.models import BaseModel\n",
"from pytorch_lightning.loggers import WandbLogger\n",
"\n",
"# Initializes WandB integration \n",
"wandb_logger = WandbLogger()\n",
"\n",
"instruction_dataset = InstructionDataset(\"../llama/alpaca_data\")\n",
"# Initializes the model\n",
Expand All @@ -65,46 +69,46 @@
},
{
"cell_type": "markdown",
"source": [
"## 3. Start the finetuning"
],
"metadata": {
"collapsed": false
}
},
"source": [
"## 3. Start the finetuning"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# Finetuned the model\n",
"model.finetune(dataset=instruction_dataset)"
],
"metadata": {
"collapsed": false
}
"model.finetune(dataset=instruction_dataset, logger=wandb_logger)"
]
},
{
"cell_type": "markdown",
"source": [
"## 4. Generate an output text with the fine-tuned model"
],
"metadata": {
"collapsed": false
}
},
"source": [
"## 4. Generate an output text with the fine-tuned model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# Once the model has been finetuned, you can start doing inferences\n",
"output = model.generate(texts=[\"Why LLM models are becoming so important?\"])\n",
"print(\"Generated output by the model: {}\".format(output))"
],
"metadata": {
"collapsed": false
}
]
}
],
"metadata": {
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ dependencies = [
"pydantic >= 1.10.0",
"rouge-score >= 0.1.2",
"accelerate",
"wandb",
]

[project.scripts]
Expand Down
16 changes: 11 additions & 5 deletions src/xturing/models/causal.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from pathlib import Path
from typing import Iterable, List, Optional, Union
from typing import Iterable, List, Optional, Union, Type

import torch
from torch.utils.data import DataLoader
Expand All @@ -18,6 +18,7 @@
from xturing.trainers.base import BaseTrainer
from xturing.trainers.lightning_trainer import LightningTrainer
from xturing.utils.logging import configure_logger
from pytorch_lightning.loggers import Logger

logger = configure_logger(__name__)

Expand Down Expand Up @@ -63,7 +64,8 @@ def _make_collate_fn(self, dataset: Union[TextDataset, InstructionDataset]):
dataset.meta,
)

def _make_trainer(self, dataset: Union[TextDataset, InstructionDataset]):
def _make_trainer(self, dataset: Union[TextDataset, InstructionDataset],
logger: Union[Logger, Iterable[Logger], bool] = True):
return BaseTrainer.create(
LightningTrainer.config_name,
self.engine,
Expand All @@ -73,14 +75,16 @@ def _make_trainer(self, dataset: Union[TextDataset, InstructionDataset]):
int(self.finetuning_args.batch_size),
float(self.finetuning_args.learning_rate),
self.finetuning_args.optimizer_name,
logger=logger,
)

def finetune(self, dataset: Union[TextDataset, InstructionDataset]):
def finetune(self, dataset: Union[TextDataset, InstructionDataset],
logger: Union[Logger, Iterable[Logger], bool] = True):
assert dataset.config_name in [
"text_dataset",
"instruction_dataset",
], "Please make sure the dataset_type is text_dataset or instruction_dataset"
trainer = self._make_trainer(dataset)
trainer = self._make_trainer(dataset, logger)
trainer.fit()

def evaluate(self, dataset: Union[TextDataset, InstructionDataset]):
Expand Down Expand Up @@ -188,7 +192,8 @@ class CausalLoraModel(CausalModel):
def __init__(self, engine: str, weights_path: Optional[str] = None):
super().__init__(engine, weights_path)

def _make_trainer(self, dataset: Union[TextDataset, InstructionDataset]):
def _make_trainer(self, dataset: Union[TextDataset, InstructionDataset],
logger: Union[Logger, Iterable[Logger], bool] = True):
return BaseTrainer.create(
LightningTrainer.config_name,
self.engine,
Expand All @@ -200,6 +205,7 @@ def _make_trainer(self, dataset: Union[TextDataset, InstructionDataset]):
self.finetuning_args.optimizer_name,
True,
True,
logger=logger,
)


Expand Down
7 changes: 5 additions & 2 deletions src/xturing/models/llama.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List, Optional, Union
from typing import Iterable, List, Optional, Union
from pytorch_lightning.loggers import Logger

from xturing.engines.llama_engine import (
LLamaEngine,
Expand Down Expand Up @@ -50,7 +51,8 @@ def __init__(self, weights_path: Optional[str] = None):
class LlamaLoraInt4(CausalLoraInt8Model):
config_name: str = "llama_lora_int4"

def _make_trainer(self, dataset: Union[TextDataset, InstructionDataset]):
def _make_trainer(self, dataset: Union[TextDataset, InstructionDataset],
logger: Union[Logger, Iterable[Logger], bool] = True):
return BaseTrainer.create(
LightningTrainer.config_name,
self.engine,
Expand All @@ -63,6 +65,7 @@ def _make_trainer(self, dataset: Union[TextDataset, InstructionDataset]):
True,
True,
lora_type=32,
logger=logger,
)

def __init__(self, weights_path: Optional[str] = None):
Expand Down
7 changes: 6 additions & 1 deletion src/xturing/models/stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@ class StableDiffusion:
def __init__(self, weights_path: str):
pass

def finetune(self, dataset: Text2ImageDataset):
def finetune(self, dataset: Text2ImageDataset, logger=True):
"""Finetune Stable Diffusion model on a given dataset.
Args:
dataset (Text2ImageDataset): Dataset to finetune on.
logger (bool, optional): To be setup with a Pytorch Lightning logger when implemented."""
pass

def generate(
Expand Down
7 changes: 6 additions & 1 deletion src/xturing/trainers/lightning_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
import tempfile
import uuid
from pathlib import Path
from typing import Optional, Union
from typing import Iterable, Optional, Union, Type

import pytorch_lightning as pl
import torch
from deepspeed.ops.adam import DeepSpeedCPUAdam
from pytorch_lightning import callbacks
from pytorch_lightning.trainer.trainer import Trainer
from pytorch_lightning.loggers import Logger

from xturing.config import DEFAULT_DEVICE, IS_INTERACTIVE
from xturing.datasets.base import BaseDataset
Expand Down Expand Up @@ -101,6 +102,7 @@ def __init__(
use_deepspeed: bool = False,
max_training_time_in_secs: Optional[int] = None,
lora_type: int = 16,
logger: Union[Logger, Iterable[Logger], bool] = True,
):
self.lightning_model = TuringLightningModule(
model_engine=model_engine,
Expand Down Expand Up @@ -145,6 +147,7 @@ def __init__(
callbacks=training_callbacks,
enable_checkpointing=False,
log_every_n_steps=50,
logger=logger,
)
elif not use_lora and not use_deepspeed:
self.trainer = Trainer(
Expand All @@ -154,6 +157,7 @@ def __init__(
callbacks=training_callbacks,
enable_checkpointing=True,
log_every_n_steps=50,
logger=logger,
)
else:
training_callbacks = [
Expand All @@ -179,6 +183,7 @@ def __init__(
callbacks=training_callbacks,
enable_checkpointing=True,
log_every_n_steps=50,
logger=logger,
)

def fit(self):
Expand Down

0 comments on commit 5bf9d33

Please sign in to comment.