Skip to content

Commit

Permalink
changed the private methods (convert_token_to_id, convert_id_to_token…
Browse files Browse the repository at this point in the history
…) to public for CehrBertTokenizer
  • Loading branch information
ChaoPang committed Sep 6, 2024
1 parent 68ffb58 commit 9aba6b2
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 25 deletions.
60 changes: 36 additions & 24 deletions src/cehrbert/models/hf_models/tokenization_hf_cehrbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from transformers.tokenization_utils_base import PushToHubMixin

from cehrbert.models.hf_models.tokenization_utils import (
_agg_helper,
agg_helper,
agg_statistics,
map_statistics,
)
Expand All @@ -32,22 +32,34 @@


def load_json_file(json_file):
"""
Loads a JSON file and returns the parsed JSON object.
Args:
json_file (str): The path to the JSON file.
Returns:
dict: The parsed JSON object.
Raises:
RuntimeError: If the JSON file cannot be read.
"""
try:
with open(json_file, "r", encoding="utf-8") as reader:
file_contents = reader.read()
parsed_json = json.loads(file_contents)
return parsed_json
except Exception as e:
raise RuntimeError(f"Can't load the json file at {json_file} due to {e}")
except RuntimeError as e:
raise RuntimeError(f"Can't load the json file at {json_file}") from e


class CehrBertTokenizer(PushToHubMixin):

def __init__(
self,
tokenizer: Tokenizer,
lab_stats: List[Dict[str, Any]],
concept_name_mapping: Dict[str, str],
self,
tokenizer: Tokenizer,
lab_stats: List[Dict[str, Any]],
concept_name_mapping: Dict[str, str],
):
self._tokenizer = tokenizer
self._lab_stats = lab_stats
Expand Down Expand Up @@ -107,12 +119,12 @@ def encode(self, concept_ids: Sequence[str]) -> Sequence[int]:
def decode(self, concept_token_ids: List[int]) -> List[str]:
return self._tokenizer.decode(concept_token_ids).split(" ")

def _convert_token_to_id(self, token):
def convert_token_to_id(self, token):
"""Converts a token (str) in an id using the vocab."""
token_id = self._tokenizer.token_to_id(token)
return token_id if token_id else self._oov_token_index

def _convert_id_to_token(self, index):
def convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
token = self._tokenizer.id_to_token(index)
return token if token else OUT_OF_VOCABULARY_TOKEN
Expand All @@ -123,10 +135,10 @@ def convert_tokens_to_string(self, tokens):
return out_string

def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
push_to_hub: bool = False,
**kwargs,
self,
save_directory: Union[str, os.PathLike],
push_to_hub: bool = False,
**kwargs,
):
"""
Save the Cehrbert tokenizer.
Expand All @@ -137,7 +149,7 @@ def save_pretrained(
Args:
save_directory (`str` or `os.PathLike`): The path to a directory where the tokenizer will be saved.
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
Whether to push your model to the Hugging Face model hub after saving it. You can specify the
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
namespace).
kwargs (`Dict[str, Any]`, *optional*):
Expand Down Expand Up @@ -174,9 +186,9 @@ def save_pretrained(

@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Union[str, os.PathLike],
**kwargs,
cls,
pretrained_model_name_or_path: Union[str, os.PathLike],
**kwargs,
):
"""
Load the CehrBert tokenizer.
Expand Down Expand Up @@ -224,11 +236,11 @@ def from_pretrained(

@classmethod
def train_tokenizer(
cls,
dataset: Union[Dataset, DatasetDict],
feature_names: List[str],
concept_name_mapping: Dict[str, str],
data_args: DataTrainingArguments,
cls,
dataset: Union[Dataset, DatasetDict],
feature_names: List[str],
concept_name_mapping: Dict[str, str],
data_args: DataTrainingArguments,
):
"""
Train a huggingface word level tokenizer.
Expand Down Expand Up @@ -292,14 +304,14 @@ def batched_generator():

if data_args.streaming:
parts = dataset.map(
partial(_agg_helper, map_func=map_statistics),
partial(agg_helper, map_func=map_statistics),
batched=True,
batch_size=data_args.preprocessing_batch_size,
remove_columns=dataset.column_names,
)
else:
parts = dataset.map(
partial(_agg_helper, map_func=map_statistics),
partial(agg_helper, map_func=map_statistics),
batched=True,
batch_size=data_args.preprocessing_batch_size,
remove_columns=dataset.column_names,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_oov_token(self):

def test_convert_id_to_token_oov(self):
# Test decoding an out-of-vocabulary token ID
decoded = self.tokenizer._convert_id_to_token(99) # Assuming 99 is not in the index
decoded = self.tokenizer.convert_id_to_token(99) # Assuming 99 is not in the index
self.assertEqual(decoded, OUT_OF_VOCABULARY_TOKEN)


Expand Down

0 comments on commit 9aba6b2

Please sign in to comment.