Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix datapath #14

2 changes: 1 addition & 1 deletion llm/alignment/ppo/data/alpaca.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class AlpacaDataset(RawDataset):
NAME: str = "alpaca"
ALIASES: tuple[str, ...] = ("stanford-alpaca",)

def __init__(self, path: str | None = None) -> None:
def __init__(self, path: str | None = None, *args, **kwargs) -> None:
self.data = load_dataset(path or "tatsu-lab/alpaca", split="train")

def __getitem__(self, index: int) -> RawSample:
Expand Down
2 changes: 2 additions & 0 deletions llm/alignment/ppo/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,8 @@ def __init__( # pylint: disable=too-many-branches
raise TypeError(
f"Dataset `{name}` attributes should be a float or a dict, " f"got {type(attributes).__name__}.",
)
kwargs["use_rm_server"] = use_rm_server

proportion = kwargs.pop("proportion", 1.0)
if isinstance(proportion, Fraction):
if not (proportion < 0 and proportion.denominator == 1):
Expand Down
8 changes: 6 additions & 2 deletions llm/alignment/ppo/data/jsondata.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,16 @@
class JsonDataset(RawDataset):
NAME: str = "Jsonfile"

def __init__(self, path: str | None = None) -> None:
def __init__(self, path: str | None = None, *args, **kwargs) -> None:
self.data = load_dataset("json", data_files=path, split="train")
self.use_rm_server = kwargs.pop("use_rm_server", False)
assert "src" in self.data.column_names, "'src' should be included in jsonfile"
if self.use_rm_server:
assert "tgt" in self.data.column_names, "'tgt' should be included in jsonfile when using rm server"

def __getitem__(self, index: int) -> RawSample:
data = self.data[index]
if "tgt" in data:
if self.use_rm_server:
rawdata = RawSample(
input=data["src"],
answer=data["tgt"],
Expand Down
2 changes: 1 addition & 1 deletion llm/alignment/ppo/data/safe_rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class SafeRLHFDataset(RawDataset):
SPLIT: ClassVar[str]
PATH: ClassVar[str]

def __init__(self, path: str | None = None) -> None:
def __init__(self, path: str | None = None, *args, **kwargs) -> None:
self.data = load_dataset(path or self.PATH, split=self.SPLIT)

def __getitem__(self, index: int) -> RawSample:
Expand Down
2 changes: 1 addition & 1 deletion llm/alignment/ppo/run_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def main():
reward_tokenizer,
reward_critic_tokenizer if training_args.rl_algorithm == "ppo" else None,
]:
if tokenizer.pad_token_id is None:
if tokenizer and tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id

if training_args.should_load_dataset:
Expand Down
Loading