Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GenericLoraModel load error #218

Closed
xupruvu opened this issue Jun 14, 2023 · 5 comments
Closed

GenericLoraModel load error #218

xupruvu opened this issue Jun 14, 2023 · 5 comments

Comments

@xupruvu
Copy link

xupruvu commented Jun 14, 2023

I train the model using the following code and it works

from xturing.datasets.instruction_dataset import InstructionDataset
from xturing.models import GenericLoraModel

instruction_dataset = InstructionDataset("./alpaca_data")
model = GenericLoraModel("bigscience/bloom-1b1", target_modules=["query_key_value"])
model.finetune(dataset=instruction_dataset)

model.save("model")

But when I want to load the model, an error occurs(BaseModel.load or GenericLoraModel.load)

# from xturing.models.base import BaseModel
from xturing.models import GenericLoraModel
# finetuned_model = BaseModel.load("/home/tl_admin/xturing_test/model")
finetuned_model = GenericLoraModel.load("/home/tl_admin/xturing_test/model")

It gives an error:

/home/tl_admin/.local/lib/python3.8/site-packages/xturing/models/base.py:21 in load              │
│                                                                                                  │
│   18 │   │   if path_weights_dir_or_model_name.is_dir() and exists_xturing_config_file(          │
│   19 │   │   │   path_weights_dir_or_model_name                                                  │
│   20 │   │   ):                                                                                  │
│ ❱ 21 │   │   │   return cls.load_from_local(weights_dir_or_model_name)                           │
│   22 │   │   else:                                                                               │
│   23 │   │   │   print("Loading model from xTuring hub")                                         │
│   24 │   │   │   return cls.load_from_hub(weights_dir_or_model_name)                             │
│                                                                                                  │
│ /home/tl_admin/.local/lib/python3.8/site-packages/xturing/models/base.py:56 in load_from_local   │
│                                                                                                  │
│   53 │   │   │   cls.registry.get(model_name) is not None                                        │
│   54 │   │   ), "The model_name {} is not valid".format(model_name)                              │
│   55 │   │                                                                                       │
│ ❱ 56 │   │   model = cls.create(model_name, weights_path=weights_dir_path)                       │
│   57 │   │                                                                                       │
│   58 │   │   return model                                                                        │
│   59                                                                                             │
│                                                                                                  │
│ /home/tl_admin/.local/lib/python3.8/site-packages/xturing/registry.py:14 in create               │
│                                                                                                  │
│   11 │                                                                                           │
│   12 │   @classmethod                                                                            │
│   13 │   def create(cls, class_key, *args, **kwargs):                                            │
│ ❱ 14 │   │   return cls.registry[class_key](*args, **kwargs)                                     │
│   15 │                                                                                           │
│   16 │   @classmethod                                                                            │
│   17 │   def __getitem__(cls, key):                                                              │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
TypeError: __init__() missing 1 required positional argument: 'model_name'

@tushar2407
Copy link
Contributor

Can you show the full error trace?

@xupruvu
Copy link
Author

xupruvu commented Jun 15, 2023

HI, @tushar2407
I run

# from xturing.models.base import BaseModel
from xturing.models import GenericLoraModel
finetuned_model = BaseModel.load("model")

It gives an error:

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ in <cell line: 3>:3                                                                              │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/xturing/models/base.py:21 in load                        │
│                                                                                                  │
│   18 │   │   if path_weights_dir_or_model_name.is_dir() and exists_xturing_config_file(          │
│   19 │   │   │   path_weights_dir_or_model_name                                                  │
│   20 │   │   ):                                                                                  │
│ ❱ 21 │   │   │   return cls.load_from_local(weights_dir_or_model_name)                           │
│   22 │   │   else:                                                                               │
│   23 │   │   │   print("Loading model from xTuring hub")                                         │
│   24 │   │   │   return cls.load_from_hub(weights_dir_or_model_name)                             │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/xturing/models/base.py:56 in load_from_local             │
│                                                                                                  │
│   53 │   │   │   cls.registry.get(model_name) is not None                                        │
│   54 │   │   ), "The model_name {} is not valid".format(model_name)                              │
│   55 │   │                                                                                       │
│ ❱ 56 │   │   model = cls.create(model_name, weights_path=weights_dir_path)                       │
│   57 │   │                                                                                       │
│   58 │   │   return model                                                                        │
│   59                                                                                             │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/xturing/registry.py:14 in create                         │
│                                                                                                  │
│   11 │                                                                                           │
│   12 │   @classmethod                                                                            │
│   13 │   def create(cls, class_key, *args, **kwargs):                                            │
│ ❱ 14 │   │   return cls.registry[class_key](*args, **kwargs)                                     │
│   15 │                                                                                           │
│   16 │   @classmethod                                                                            │
│   17 │   def __getitem__(cls, key):                                                              │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
TypeError: GenericLoraModel.__init__() missing 1 required positional argument: 'model_name'

only these error messages

@xupruvu
Copy link
Author

xupruvu commented Jun 15, 2023

I found that the model_name in xturing.json in the model file trained by "GenericLoraModel" is different from the one trained by "BaseModel.create("bloom_lora")"

The former is 'generic_lora' and the other is 'bloom_lora'

If I change 'generic_lora' to 'bloom_lora', the model can be loaded

@tushar2407
Copy link
Contributor

Oh right. So the thing is, right now, the GenericModel or its variants, are not taking the kind of model you want to load. They are just taking the path to the model from the HuggingFace library. So because of that, generic_lora and bloom_lora models might have different classes but will work exactly the same if passed the same model path.
We are working the loading the saved model issue. PR is in the pipeline.

@tushar2407
Copy link
Contributor

tushar2407 commented Jun 19, 2023

@xupruvu the issue had been resolved in and merged in dev branch. You should be able to do the usual tasks once it is released.

@xupruvu xupruvu closed this as completed Jun 19, 2023
@xupruvu xupruvu reopened this Jun 19, 2023
@xupruvu xupruvu closed this as completed Jun 19, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants