Skip to content

Commit

Permalink
merge main branch
Browse files Browse the repository at this point in the history
  • Loading branch information
KepingYan committed Jan 15, 2024
2 parents 6cedffa + 352b64e commit 88e66c4
Show file tree
Hide file tree
Showing 86 changed files with 2,502 additions and 1,513 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def update_finetune_config(base_model):
# avaiable base models are:
#
# Mistral-7B-v0.1
# Llama-2-7b
# Llama-2-7b
# pythia-1.4b
# pythia-2.8b
# pythia-70m
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/config/update_inference_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ def get_parser():
parser = argparse.ArgumentParser(description="Adjust Inference Config File")
parser.add_argument("--config_file", type=str, required=True)
parser.add_argument("--output_file", type=str, required=True)
parser.add_argument("--deepspeed", action='store_true')
parser.add_argument("--ipex", action='store_true')
parser.add_argument("--deepspeed", action="store_true")
parser.add_argument("--ipex", action="store_true")
return parser


Expand Down
28 changes: 28 additions & 0 deletions .github/workflows/workflow_lint.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
name: Lint

on:
workflow_call:
inputs:
ci_type:
type: string
default: 'pr'

concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-lt
cancel-in-progress: true

jobs:
lint:
name: lint check
runs-on: ubuntu-latest

defaults:
run:
shell: bash

steps:
- name: Checkout
uses: actions/checkout@v2

- name: Run Lint
run: ./format.sh -a
2 changes: 2 additions & 0 deletions .github/workflows/workflow_orders_on_merge.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ on:
- 'pyproject.toml'

jobs:
call-lint:
uses: ./.github/workflows/workflow_lint.yml

call-inference:
uses: ./.github/workflows/workflow_inference.yml
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/workflow_orders_on_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,13 @@ on:
- 'pyproject.toml'

jobs:
call-lint:
uses: ./.github/workflows/workflow_lint.yml

call-inference:
needs: call-lint
uses: ./.github/workflows/workflow_inference.yml

call-finetune:
needs: call-lint
uses: ./.github/workflows/workflow_finetune.yml
30 changes: 30 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
ci:
autoupdate_schedule: monthly

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.289
hooks:
- id: ruff
args: [ --fix, --exit-non-zero-on-fix, --ignore=E402, --ignore=E501, --ignore=E731]

# Black needs to be ran after ruff with --fix
- repo: https://github.com/psf/black
rev: 23.3.0
hooks:
- id: black

- repo: https://github.com/pre-commit/mirrors-mypy
rev: "v0.950"
hooks:
- id: mypy
exclude: tests
additional_dependencies:
- mypy-extensions
- pydantic==1.10.0
- types-cachetools
- types-filelock
- types-PyYAML
- types-redis
- types-requests
- types-paramiko
14 changes: 7 additions & 7 deletions common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from .config import Config, parse_args, parse_config
from .torch_config import TorchConfig
from .logging import get_logger, logger
from .load import *
from .init import init
from .common import import_all_module
from .logging import logger
from .load import * # noqa: F403 # unable to detect undefined names
from . import agentenv
from .torch_config import TorchConfig # noqa: F401
from typing import Dict, Any
import sys

@load_check_decorator

@load_check_decorator # noqa: F405 # may be undefined, or defined from star imports
def get_agentenv(config: Dict[str, Any]):
logger.info(f"{sys._getframe().f_code.co_name} config: {config}")
agentenv_type = config.get("type", None)
Expand Down
2 changes: 1 addition & 1 deletion common/agentenv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@
basedir = os.path.dirname(realpath)
import_all_module(basedir, "common.agentenv")

__all__ = ["AgentEnv"]
__all__ = ["AgentEnv"]
3 changes: 2 additions & 1 deletion common/agentenv/agentenv.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
class Meta(type):
def __init__(cls, name, bases, namespace, **kwargs):
super().__init__(name, bases, namespace, **kwargs)
if not hasattr(cls, 'registory'):
if not hasattr(cls, "registory"):
# this is the base class
cls.registory = {}
else:
# this is the subclass
cls.registory[name] = cls


class AgentEnv(metaclass=Meta):
def __init__(self, config):
self.config = config
82 changes: 40 additions & 42 deletions common/agentenv/rlhf_env.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from typing import Any
import gymnasium as gym

import numpy as np
Expand All @@ -13,15 +12,15 @@


def generate_response(
model: torch.nn.Module,
*,
input_ids: torch.tensor,
max_length:int,
eos_token_id: int
model: torch.nn.Module,
*,
input_ids: torch.tensor,
max_length: int,
eos_token_id: int,
):
"""Generate a response using the model."""
generated_sequence = []
probs_list = []
# probs_list = []
model_in = torch.clone(input_ids)
with torch.no_grad():
for i in range(max_length):
Expand Down Expand Up @@ -66,6 +65,7 @@ def generate_response(
"n_generated_tokens": generated_tokens.shape[-1],
}


def compute_approx_kl(
logits: torch.Tensor,
logits_base: torch.Tensor,
Expand All @@ -79,53 +79,52 @@ def compute_approx_kl(


class RLHFEnv(gym.Env, AgentEnv):

def __init__(self, config):

self.config = config
agentenv_config = config.get("config")

# Prompt dataset
self.prompt_dataset = load_dataset(agentenv_config.get("datasets"))
self.prompt_dataset = load_dataset(agentenv_config.get("datasets"))
self.dsize = len(self.prompt_dataset)

# base tokenizer
self.tokenizer = load_tokenizer(agentenv_config.get("tokenizer"))
vocab_size = self.tokenizer.vocab_size
model_max_length = min(agentenv_config['model_max_length'], self.tokenizer.model_max_length)
model_max_length = min(agentenv_config["model_max_length"], self.tokenizer.model_max_length)

# reward and sft model
self.reward_model = load_model(agentenv_config.get("reward_model"))
self.sft_model = load_model(agentenv_config.get("sft_model"))

# the KL coefficient
self.kl_coeff = agentenv_config["kl_coeff"]
# The maximum length of the generated text
self.max_generation_length = agentenv_config["max_generation_length"]

# action space
self.action_space = sp.Dict({
"sequence": Repeated(sp.Discrete(vocab_size), max_len=model_max_length),
"attention_mask": Repeated(sp.Discrete(2), max_len=model_max_length),
"response_mask": Repeated(sp.Discrete(2), max_len=model_max_length),
"logits": Repeated(
sp.Box(0, 1, shape=(vocab_size,)), max_len=model_max_length
),
})
self.action_space = sp.Dict(
{
"sequence": Repeated(sp.Discrete(vocab_size), max_len=model_max_length),
"attention_mask": Repeated(sp.Discrete(2), max_len=model_max_length),
"response_mask": Repeated(sp.Discrete(2), max_len=model_max_length),
"logits": Repeated(sp.Box(0, 1, shape=(vocab_size,)), max_len=model_max_length),
}
)

# observation space
self.observation_space = sp.Dict({
"input_ids": Repeated(sp.Discrete(vocab_size), max_len=model_max_length),
"attention_mask": Repeated(sp.Discrete(2), max_len=model_max_length),
})
self.observation_space = sp.Dict(
{
"input_ids": Repeated(sp.Discrete(vocab_size), max_len=model_max_length),
"attention_mask": Repeated(sp.Discrete(2), max_len=model_max_length),
}
)

def reset(self, *, seed=None, options=None):

if seed:
np.random.seed(seed)
index = np.random.randint(self.dsize)
if 'train' in self.prompt_dataset:
self.prompt_dataset = self.prompt_dataset['train']
if "train" in self.prompt_dataset:
self.prompt_dataset = self.prompt_dataset["train"]
prompt = self.prompt_dataset[index]["prompt"]
prompt_tokens = self.tokenizer(prompt, return_tensors="np")
# remove the batch dimension since we can only do one sentence generation at a time
Expand All @@ -134,7 +133,6 @@ def reset(self, *, seed=None, options=None):
return prompt_tokens, {}

def step(self, action):

sequence = action["sequence"]
response_mask = action["response_mask"]
attention_mask = action["attention_mask"]
Expand All @@ -147,28 +145,28 @@ def step(self, action):
r_align = r_align[-1].item()

# Compute the probs from the sft model for the same number of tokens
sequence = torch.tensor(sequence, dtype=torch.long)[None] # add batch dim
sequence = torch.tensor(sequence, dtype=torch.long)[None] # add batch dim
sft_output = generate_response(
self.sft_model,
input_ids=sequence,
max_length=n_response_tokens,
eos_token_id=self.tokenizer.eos_token_id
self.sft_model,
input_ids=sequence,
max_length=n_response_tokens,
eos_token_id=self.tokenizer.eos_token_id,
)
logits = torch.tensor(logits, dtype=torch.float32)[None] # add batch dim

logits = torch.tensor(logits, dtype=torch.float32)[None] # add batch dim
# only compute kl on the response tokens
r_kl = compute_approx_kl(
logits[:, -n_response_tokens:], # the inner term
sft_output["logits"][:, -n_response_tokens:] # the outer term
logits[:, -n_response_tokens:], # the inner term
sft_output["logits"][:, -n_response_tokens:], # the outer term
).item()

reward = r_align - self.kl_coeff * r_kl

info = {
"r_align": r_align,
"r_kl": r_kl,
"n_response_tokens": n_response_tokens
"r_align": r_align,
"r_kl": r_kl,
"n_response_tokens": n_response_tokens,
}

# Produce a random reward when we reach the goal.
return self.observation_space.sample(), reward, True, False, info
return self.observation_space.sample(), reward, True, False, info
7 changes: 4 additions & 3 deletions common/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@

from .logging import logger

def import_all_module(basedir, prefix = None):
all_py_files = glob.glob(basedir+"/*.py")

def import_all_module(basedir, prefix=None):
all_py_files = glob.glob(basedir + "/*.py")
modules = [os.path.basename(f) for f in all_py_files]

for module in modules:
Expand All @@ -17,5 +18,5 @@ def import_all_module(basedir, prefix = None):
module_name = f"{prefix}.{module}"
try:
importlib.import_module(module_name)
except Exception as e:
except Exception:
logger.warning(f"import {module_name} erro", exc_info=True)
Loading

0 comments on commit 88e66c4

Please sign in to comment.