Skip to content

Commit

Permalink
Extend support to transformers v4.31 (#395)
Browse files Browse the repository at this point in the history
* Extend support to transformers v4.31

Fall back with warning to `load_state_dict(strict=False)` for backwards
compatibility.

* Fix strict=False for load_state_dict

* Skip attempt at improved backwards compatibility

* Add a more informative warning before reraising error

* Format

* CI: Add future note for restoring requirements
  • Loading branch information
adrianeboyd authored Jul 31, 2023
1 parent 6b36857 commit e87c4e3
Showing 4 changed files with 28 additions and 7 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
@@ -108,5 +108,9 @@ jobs:
- name: Test backwards compatibility for v1.1 models
if: matrix.python_version == '3.9'
run: |
python -m pip install "transformers<4.31"
python -m pip install https://github.com/explosion/spacy-models/releases/download/en_core_web_trf-3.4.0/en_core_web_trf-3.4.0-py3-none-any.whl --no-deps
python -c "import spacy; nlp = spacy.load('en_core_web_trf'); doc = nlp('test')"
# NOTE: update requirements at the end of this step if any following
# steps are added in the future
# python -m pip install -U -r requirements.txt
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
spacy>=3.5.0,<4.0.0
numpy>=1.15.0
transformers>=3.4.0,<4.31.0
transformers>=3.4.0,<4.32.0
torch>=1.8.0
srsly>=2.4.0,<3.0.0
dataclasses>=0.6,<1.0; python_version < "3.7"
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -34,7 +34,7 @@ install_requires =
spacy>=3.5.0,<4.0.0
numpy>=1.15.0; python_version < "3.9"
numpy>=1.19.0; python_version >= "3.9"
transformers>=3.4.0,<4.31.0
transformers>=3.4.0,<4.32.0
torch>=1.8.0
srsly>=2.4.0,<3.0.0
dataclasses>=0.6,<1.0; python_version < "3.7"
27 changes: 22 additions & 5 deletions spacy_transformers/layers/hf_shim.py
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@
from pathlib import Path
import srsly
import torch
import warnings
from thinc.api import get_torch_default_device
from spacy.util import SimpleFrozenDict

@@ -24,9 +25,9 @@ def __init__(
optimizer: Any = None,
mixed_precision: bool = False,
grad_scaler_config: dict = {},
config_cls = AutoConfig,
model_cls = AutoModel,
tokenizer_cls = AutoTokenizer,
config_cls=AutoConfig,
model_cls=AutoModel,
tokenizer_cls=AutoTokenizer,
):
self._hfmodel = model
self.config_cls = config_cls
@@ -97,7 +98,9 @@ def from_bytes(self, bytes_data):
tok_kwargs = tok_dict.pop("kwargs", {})
for x, x_bytes in tok_dict.items():
Path(temp_dir / x).write_bytes(x_bytes)
tokenizer = self.tokenizer_cls.from_pretrained(str(temp_dir.absolute()), **tok_kwargs)
tokenizer = self.tokenizer_cls.from_pretrained(
str(temp_dir.absolute()), **tok_kwargs
)
vocab_file_contents = None
if hasattr(tokenizer, "vocab_file"):
vocab_file_name = tokenizer.vocab_files_names["vocab_file"]
@@ -117,7 +120,21 @@ def from_bytes(self, bytes_data):
filelike = BytesIO(msg["state"])
filelike.seek(0)
device = get_torch_default_device()
self._model.load_state_dict(torch.load(filelike, map_location=device))
try:
self._model.load_state_dict(torch.load(filelike, map_location=device))
except RuntimeError as ex:
warn_msg = (
"Error loading saved torch model. If the error is related "
"to unexpected key(s) in state_dict, a possible workaround "
"is to load this model with 'transformers<4.31'. "
"Alternatively, download a newer compatible model or "
"retrain your custom model with the current "
"transformers and spacy-transformers versions. For more "
"details and available updates, run: python -m spacy "
"validate"
)
warnings.warn(warn_msg)
raise ex
self._model.to(device)
else:
self._hfmodel = HFObjects(

0 comments on commit e87c4e3

Please sign in to comment.