Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KeyError: 'llama_lora_int4' #233

Closed
adastra9257 opened this issue Jul 19, 2023 · 3 comments
Closed

KeyError: 'llama_lora_int4' #233

adastra9257 opened this issue Jul 19, 2023 · 3 comments

Comments

@adastra9257
Copy link

I am learning to fine-tune LLaMA in INT4 with xTuring. I am using the LLaMA_lora_int4.ipynb file in the example folder. I encountered the following error during runtime:
KeyError: 'llama_lora_int4'
I have no idea why this error occurred. Can anyone help me? Thank you!

OS: Ubuntu 22.04

This code was executed in JupyterLab:

from xturing.datasets.instruction_dataset import InstructionDataset
from xturing.models import BaseModel
# from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.loggers import TensorBoardLogger

# Initializes WandB integration 
# wandb_logger = WandbLogger()
tensorboard_logger = TensorBoardLogger(save_dir="logs/")

instruction_dataset = InstructionDataset("./alpaca_data")
# Initializes the model
model = BaseModel.create("llama_lora_int4")

This is the error log:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
Cell In[2], line 12
     10 instruction_dataset = InstructionDataset("./alpaca_data")
     11 # Initializes the model
---> 12 model = BaseModel.create("llama_lora_int4")

File ~/miniconda3/envs/finetune/lib/python3.10/site-packages/xturing/registry.py:14, in BaseParent.create(cls, class_key, *args, **kwargs)
     12 @classmethod
     13 def create(cls, class_key, *args, **kwargs):
---> 14     return cls.registry[class_key](*args, **kwargs)

KeyError: 'llama_lora_int4'

These are the dependencies installed in the environment listed by pip freeze:

absl-py==1.4.0
accelerate==0.21.0
ai21==1.2.1
aiofiles==23.1.0
aiohttp==3.8.4
aiosignal==1.3.1
altair==5.0.1
anyio @ file:///home/conda/feedstock_root/build_artifacts/anyio_1688651106312/work/dist
appdirs==1.4.4
argon2-cffi @ file:///home/conda/feedstock_root/build_artifacts/argon2-cffi_1640817743617/work
argon2-cffi-bindings @ file:///home/conda/feedstock_root/build_artifacts/argon2-cffi-bindings_1649500328244/work
asttokens @ file:///home/conda/feedstock_root/build_artifacts/asttokens_1670263926556/work
async-lru @ file:///home/conda/feedstock_root/build_artifacts/async-lru_1688997201545/work
async-timeout==4.0.2
attrs @ file:///home/conda/feedstock_root/build_artifacts/attrs_1683424013410/work
Babel @ file:///home/conda/feedstock_root/build_artifacts/babel_1677767029043/work
backcall @ file:///home/conda/feedstock_root/build_artifacts/backcall_1592338393461/work
backoff==2.2.1
backports.functools-lru-cache @ file:///home/conda/feedstock_root/build_artifacts/backports.functools_lru_cache_1687772187254/work
beautifulsoup4 @ file:///home/conda/feedstock_root/build_artifacts/beautifulsoup4_1680888073205/work
bitsandbytes==0.37.2
bleach @ file:///home/conda/feedstock_root/build_artifacts/bleach_1674535352125/work
Brotli @ file:///home/conda/feedstock_root/build_artifacts/brotli-split_1648883617327/work
cachetools==5.3.1
certifi==2023.5.7
cffi @ file:///home/conda/feedstock_root/build_artifacts/cffi_1636046050867/work
charset-normalizer @ file:///home/conda/feedstock_root/build_artifacts/charset-normalizer_1688813409104/work
click==8.1.6
cmake==3.26.4
cohere==4.16.0
comm==0.1.3
contourpy==1.1.0
cycler==0.11.0
datasets==2.13.1
debugpy @ file:///home/builder/ci_310/debugpy_1640789504635/work
decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1641555617451/work
deepspeed==0.10.0
defusedxml @ file:///home/conda/feedstock_root/build_artifacts/defusedxml_1615232257335/work
dill==0.3.6
docker-pycreds==0.4.0
entrypoints @ file:///home/conda/feedstock_root/build_artifacts/entrypoints_1643888246732/work
evaluate==0.4.0
exceptiongroup @ file:///home/conda/feedstock_root/build_artifacts/exceptiongroup_1688381075899/work
executing @ file:///home/conda/feedstock_root/build_artifacts/executing_1667317341051/work
fastapi==0.100.0
fastjsonschema @ file:///home/conda/feedstock_root/build_artifacts/python-fastjsonschema_1684761244589/work/dist
ffmpy==0.3.1
filelock==3.12.2
flit_core @ file:///home/conda/feedstock_root/build_artifacts/flit-core_1684084314667/work/source/flit_core
fonttools==4.41.0
frozenlist==1.4.0
fsspec==2023.6.0
gitdb==4.0.10
GitPython==3.1.32
google-auth==2.22.0
google-auth-oauthlib==1.0.0
gradio==3.37.0
gradio_client==0.2.10
grpcio==1.56.0
h11==0.14.0
hjson==3.1.0
httpcore==0.17.3
httpx==0.24.1
huggingface-hub==0.16.4
idna @ file:///home/conda/feedstock_root/build_artifacts/idna_1663625384323/work
importlib-metadata @ file:///home/conda/feedstock_root/build_artifacts/importlib-metadata_1688754491823/work
importlib-resources @ file:///home/conda/feedstock_root/build_artifacts/importlib_resources_1689017639396/work
ipykernel @ file:///home/conda/feedstock_root/build_artifacts/ipykernel_1655369107642/work
ipython @ file:///home/conda/feedstock_root/build_artifacts/ipython_1685727741709/work
ipywidgets==8.0.7
jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1669134318875/work
Jinja2 @ file:///home/conda/feedstock_root/build_artifacts/jinja2_1654302431367/work
joblib==1.3.1
json5 @ file:///home/conda/feedstock_root/build_artifacts/json5_1688248289187/work
jsonschema @ file:///home/conda/feedstock_root/build_artifacts/jsonschema-meta_1669810440410/work
jsonschema-specifications==2023.7.1
jupyter-events @ file:///home/conda/feedstock_root/build_artifacts/jupyter_events_1673559782596/work
jupyter-lsp @ file:///home/conda/feedstock_root/build_artifacts/jupyter-lsp-meta_1685453365113/work/jupyter-lsp
jupyter_client @ file:///home/conda/feedstock_root/build_artifacts/jupyter_client_1687700988094/work
jupyter_core @ file:///home/conda/feedstock_root/build_artifacts/jupyter_core_1686775611663/work
jupyter_server @ file:///home/conda/feedstock_root/build_artifacts/jupyter_server_1687869799272/work
jupyter_server_terminals @ file:///home/conda/feedstock_root/build_artifacts/jupyter_server_terminals_1673491454549/work
jupyterlab @ file:///home/conda/feedstock_root/build_artifacts/jupyterlab_1689253413907/work
jupyterlab-pygments @ file:///home/conda/feedstock_root/build_artifacts/jupyterlab_pygments_1649936611996/work
jupyterlab-widgets==3.0.8
jupyterlab_server @ file:///home/conda/feedstock_root/build_artifacts/jupyterlab_server_1686659921555/work
kiwisolver==1.4.4
lightning-utilities==0.9.0
linkify-it-py==2.0.2
lit==16.0.6
Markdown==3.4.3
markdown-it-py==2.2.0
MarkupSafe @ file:///opt/conda/conda-bld/markupsafe_1654597864307/work
matplotlib==3.7.2
matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1660814786464/work
mdit-py-plugins==0.3.3
mdurl==0.1.2
mistune @ file:///home/conda/feedstock_root/build_artifacts/mistune_1686313613819/work/dist
mpmath==1.3.0
multidict==6.0.4
multiprocess==0.70.14
nbclient @ file:///home/conda/feedstock_root/build_artifacts/nbclient_1684790896106/work
nbconvert @ file:///home/conda/feedstock_root/build_artifacts/nbconvert-meta_1689733131629/work
nbformat @ file:///home/conda/feedstock_root/build_artifacts/nbformat_1688996247388/work
nest-asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1664684991461/work
networkx==3.1
ninja==1.11.1
nltk==3.8.1
notebook_shim @ file:///home/conda/feedstock_root/build_artifacts/notebook-shim_1682360583588/work
numpy==1.25.1
nvidia-cublas-cu11==11.10.3.66
nvidia-cuda-cupti-cu11==11.7.101
nvidia-cuda-nvrtc-cu11==11.7.99
nvidia-cuda-runtime-cu11==11.7.99
nvidia-cudnn-cu11==8.5.0.96
nvidia-cufft-cu11==10.9.0.58
nvidia-curand-cu11==10.2.10.91
nvidia-cusolver-cu11==11.4.0.1
nvidia-cusparse-cu11==11.7.4.91
nvidia-nccl-cu11==2.14.3
nvidia-nvtx-cu11==11.7.91
oauthlib==3.2.2
openai==0.27.8
orjson==3.9.2
overrides @ file:///home/conda/feedstock_root/build_artifacts/overrides_1666057828264/work
packaging @ file:///home/conda/feedstock_root/build_artifacts/packaging_1681337016113/work
pandas==2.0.3
pandocfilters @ file:///home/conda/feedstock_root/build_artifacts/pandocfilters_1631603243851/work
parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1638334955874/work
pathtools==0.1.2
pexpect @ file:///home/conda/feedstock_root/build_artifacts/pexpect_1667297516076/work
pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1602536217715/work
Pillow==10.0.0
pkgutil_resolve_name @ file:///home/conda/feedstock_root/build_artifacts/pkgutil-resolve-name_1633981968097/work
platformdirs @ file:///home/conda/feedstock_root/build_artifacts/platformdirs_1689538620473/work
prometheus-client @ file:///home/conda/feedstock_root/build_artifacts/prometheus_client_1689032443210/work
prompt-toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1688565951714/work
protobuf==4.23.4
psutil @ file:///opt/conda/conda-bld/psutil_1656431268089/work
ptyprocess @ file:///home/conda/feedstock_root/build_artifacts/ptyprocess_1609419310487/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl
pure-eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1642875951954/work
py-cpuinfo==9.0.0
pyarrow==12.0.1
pyasn1==0.5.0
pyasn1-modules==0.3.0
pycparser @ file:///home/conda/feedstock_root/build_artifacts/pycparser_1636257122734/work
pydantic==1.10.11
pydub==0.25.1
Pygments @ file:///home/conda/feedstock_root/build_artifacts/pygments_1681904169130/work
pyparsing==3.0.9
pyrsistent @ file:///home/builder/ci_310/pyrsistent_1640807196327/work
PySocks @ file:///home/conda/feedstock_root/build_artifacts/pysocks_1661604839144/work
python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/python-dateutil_1626286286081/work
python-json-logger @ file:///home/conda/feedstock_root/build_artifacts/python-json-logger_1677079630776/work
python-multipart==0.0.6
pytorch-lightning==2.0.5
pytz @ file:///home/conda/feedstock_root/build_artifacts/pytz_1680088766131/work
PyYAML==6.0.1
pyzmq @ file:///croot/pyzmq_1686601365461/work
referencing==0.30.0
regex==2023.6.3
requests @ file:///home/conda/feedstock_root/build_artifacts/requests_1684774241324/work
requests-oauthlib==1.3.1
responses==0.18.0
rfc3339-validator @ file:///home/conda/feedstock_root/build_artifacts/rfc3339-validator_1638811747357/work
rfc3986-validator @ file:///home/conda/feedstock_root/build_artifacts/rfc3986-validator_1598024191506/work
rouge-score==0.1.2
rpds-py==0.9.2
rsa==4.9
semantic-version==2.10.0
Send2Trash @ file:///home/conda/feedstock_root/build_artifacts/send2trash_1682601222253/work
sentencepiece==0.1.99
sentry-sdk==1.28.1
setproctitle==1.3.2
six @ file:///home/conda/feedstock_root/build_artifacts/six_1620240208055/work
smmap==5.0.0
sniffio @ file:///home/conda/feedstock_root/build_artifacts/sniffio_1662051266223/work
soupsieve @ file:///home/conda/feedstock_root/build_artifacts/soupsieve_1658207591808/work
stack-data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1669632077133/work
starlette==0.27.0
sympy==1.12
tensorboard==2.13.0
tensorboard-data-server==0.7.1
terminado @ file:///home/conda/feedstock_root/build_artifacts/terminado_1670253674810/work
tinycss2 @ file:///home/conda/feedstock_root/build_artifacts/tinycss2_1666100256010/work
tokenizers==0.13.3
tomli @ file:///home/conda/feedstock_root/build_artifacts/tomli_1644342247877/work
toolz==0.12.0
torch==2.0.1
torchmetrics==1.0.1
tornado @ file:///opt/conda/conda-bld/tornado_1662061693373/work
tqdm==4.65.0
traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1675110562325/work
transformers==4.28.1
triton==2.0.0
typing-utils @ file:///home/conda/feedstock_root/build_artifacts/typing_utils_1622899189314/work
typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/typing_extensions_1688315532570/work
tzdata==2023.3
uc-micro-py==1.0.2
urllib3==1.26.16
uvicorn==0.23.1
wandb==0.15.5
wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1673864653149/work
webencodings==0.5.1
websocket-client @ file:///home/conda/feedstock_root/build_artifacts/websocket-client_1687789148259/work
websockets==11.0.3
Werkzeug==2.3.6
wget==3.2
widgetsnbextension==4.0.8
xturing==0.1.6
xxhash==3.2.0
yarl==1.9.2
zipp @ file:///home/conda/feedstock_root/build_artifacts/zipp_1689374466814/work
@tushar2407
Copy link
Contributor

Hey @adastra9257, you would need to use the Kbit Model instead of this.
GenericModelKbitModel and send the Llama model path to it.

@adastra9257
Copy link
Author

Hello @tushar2407, thank you for your reply! I couldn't find any documentation regarding GenericModelKbitModel. I tried to modify my code myself, but it still failed to run. Could you please elaborate on how to modify my code so that it can work properly?

@tushar2407
Copy link
Contributor

tushar2407 commented Jul 25, 2023

Hey! Sure. Below is the working code.

# Make the necessary imports
from xturing.datasets.instruction_dataset import InstructionDataset
from xturing.models import GenericLoraKbitModel, LlamaLoraKbit
from pytorch_lightning.loggers import WandbLogger

# Initializes WandB integration 
wandb_logger = WandbLogger()

# Load your desired dataset
instruction_dataset = InstructionDataset("../llama/alpaca_data")

# Initialize the model
model = GenericLoraKbitModel('aleksickx/llama-7b-hf')

# OR
model = LLamaLoraKbit()

# Fine-tune the model on your desired dataset
model.finetune(dataset=instruction_dataset, logger=wandb_logger)

# Save the finetuned model
model.save('./finetuned_model')

Hope this helps!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants