diff --git a/skorch/llm/classifier.py b/skorch/llm/classifier.py index 23ea8ad98..ecf5c54f1 100644 --- a/skorch/llm/classifier.py +++ b/skorch/llm/classifier.py @@ -243,7 +243,7 @@ def generate_logits(self, *, label_id, **kwargs): recorded_logits = [] logits_cached = self.get_cache(kwargs) while logits_cached is not None: - if label_id[0] == self.tokenizer.eos_token_id: + if not label_id or label_id[0] == self.tokenizer.eos_token_id: # don't extend with eos_token -- it is already there at the end, # we don't need it twice break diff --git a/skorch/tests/llm/test_llm_classifier.py b/skorch/tests/llm/test_llm_classifier.py index 701a36b82..8ced76134 100644 --- a/skorch/tests/llm/test_llm_classifier.py +++ b/skorch/tests/llm/test_llm_classifier.py @@ -264,6 +264,23 @@ def test_caching_is_faster(self, classifier_cls): # at least 1/3 faster assert cached_time < 0.1 * uncached_time + def test_caching_works_shared_label_prefix_without_eos(self, classifier_cls): + clf = classifier_cls('gpt2') + + # carefully chosen class labels so that one label has the other label as + # its prefix. '11111' = '11' + '111'. For models that tokenize single + # digits indepdentenly this is far more relevant. + X = np.array(["Hey there", "No thank you"]) + y = ['11', '11111'] + + clf.fit(X, y) + + y_pred_1 = clf.predict(X) + y_pred_2 = clf.predict(X) + + # does not raise and gives the same results + np.testing.assert_array_equal(y_pred_1, y_pred_2) + def test_custom_prompt(self, model, tokenizer, classifier_cls, X): prompt = "Please classify my text:\n{text}\n\nLabels: {labels}\n\n" clf = classifier_cls(