Skip to content

Commit

Permalink
fixed the hf_cehrbert_pretrain_runner integration test
Browse files Browse the repository at this point in the history
  • Loading branch information
ChaoPang committed Sep 6, 2024
1 parent 3622db4 commit 5f55329
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@

class CehrBertDataCollator:
def __init__(
self,
tokenizer: CehrBertTokenizer,
max_length: int,
mlm_probability: float = 0.15,
is_pretraining: bool = True,
truncate_type: TruncationType = TruncationType.RANDOM_RIGHT_TRUNCATION,
self,
tokenizer: CehrBertTokenizer,
max_length: int,
mlm_probability: float = 0.15,
is_pretraining: bool = True,
truncate_type: TruncationType = TruncationType.RANDOM_RIGHT_TRUNCATION,
):
self.tokenizer = tokenizer
self.max_length = max_length
Expand All @@ -29,12 +29,12 @@ def __init__(
# Pre-compute these so we can use them later on
# We used VS for the historical data, currently, we use the new [VS] for the newer data
# so we need to check both cases.
self.vs_token_id = tokenizer._convert_token_to_id("VS")
if self.vs_token_id == tokenizer._oov_token_index:
self.vs_token_id = tokenizer._convert_token_to_id("[VS]")
self.ve_token_id = tokenizer._convert_token_to_id("VE")
if self.ve_token_id == tokenizer._oov_token_index:
self.ve_token_id = tokenizer._convert_token_to_id("[VE]")
self.vs_token_id = tokenizer.convert_token_to_id("VS")
if self.vs_token_id == tokenizer.oov_token_index:
self.vs_token_id = tokenizer.convert_token_to_id("[VS]")
self.ve_token_id = tokenizer.convert_token_to_id("VE")
if self.ve_token_id == tokenizer.oov_token_index:
self.ve_token_id = tokenizer.convert_token_to_id("[VE]")

@staticmethod
def _convert_to_tensor(features: Any) -> torch.Tensor:
Expand Down Expand Up @@ -205,9 +205,9 @@ def torch_mask_tokens(self, inputs: torch.Tensor, labels: torch.Tensor) -> Tuple

# 10% of the time, we replace masked input tokens with random word
indices_random = (
torch.bernoulli(torch.full(labels.shape, 0.5)).bool()
& masked_indices
& ~indices_replaced
torch.bernoulli(torch.full(labels.shape, 0.5)).bool()
& masked_indices
& ~indices_replaced
)
random_words = torch.randint(self.tokenizer.vocab_size, labels.shape, dtype=torch.long)
inputs[indices_random] = random_words[indices_random]
Expand All @@ -233,8 +233,8 @@ def generate_start_end_index(self, record: Dict[str, Any]) -> Dict[str, Any]:
start_index = random.randint(0, seq_length - new_max_length)
end_index = min(seq_length, start_index + new_max_length)
elif self.truncate_type in (
TruncationType.RANDOM_RIGHT_TRUNCATION,
TruncationType.RANDOM_COMPLETE,
TruncationType.RANDOM_RIGHT_TRUNCATION,
TruncationType.RANDOM_COMPLETE,
):
# We randomly pick a [VS] token
starting_points = []
Expand Down Expand Up @@ -266,9 +266,9 @@ def generate_start_end_index(self, record: Dict[str, Any]) -> Dict[str, Any]:
new_record = collections.OrderedDict()
for k, v in record.items():
if (
isinstance(v, list)
or isinstance(v, np.ndarray)
or (isinstance(v, torch.Tensor) and v.dim() > 0)
isinstance(v, list)
or isinstance(v, np.ndarray)
or (isinstance(v, torch.Tensor) and v.dim() > 0)
):
if len(v) == seq_length:
new_record[k] = v[start_index:end_index]
Expand Down
42 changes: 22 additions & 20 deletions src/cehrbert/models/hf_models/tokenization_hf_cehrbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
LAB_STATS_FILE_NAME = "cehrgpt_lab_stats.json"


def load_json_file(json_file):
def load_json_file(json_file) -> Union[List[Dict[str, Any]], Dict[str, Any]]:
"""
Loads a JSON file and returns the parsed JSON object.
Expand All @@ -56,10 +56,10 @@ def load_json_file(json_file):
class CehrBertTokenizer(PushToHubMixin):

def __init__(
self,
tokenizer: Tokenizer,
lab_stats: List[Dict[str, Any]],
concept_name_mapping: Dict[str, str],
self,
tokenizer: Tokenizer,
lab_stats: List[Dict[str, Any]],
concept_name_mapping: Dict[str, str],
):
self._tokenizer = tokenizer
self._lab_stats = lab_stats
Expand Down Expand Up @@ -139,10 +139,10 @@ def convert_tokens_to_string(self, tokens):
return out_string

def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
push_to_hub: bool = False,
**kwargs,
self,
save_directory: Union[str, os.PathLike],
push_to_hub: bool = False,
**kwargs,
):
"""
Save the Cehrbert tokenizer.
Expand Down Expand Up @@ -190,9 +190,9 @@ def save_pretrained(

@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Union[str, os.PathLike],
**kwargs,
cls,
pretrained_model_name_or_path: Union[str, os.PathLike],
**kwargs,
):
"""
Load the CehrBert tokenizer.
Expand All @@ -216,21 +216,23 @@ def from_pretrained(
)

if not tokenizer_file:
return None
raise RuntimeError(f"tokenizer_file does not exist: {tokenizer_file}")

tokenizer = Tokenizer.from_file(tokenizer_file)

lab_stats_file = transformers.utils.hub.cached_file(
pretrained_model_name_or_path, LAB_STATS_FILE_NAME, **kwargs
)
if not lab_stats_file:
return None
raise RuntimeError(f"lab_stats_file does not exist: {lab_stats_file}")

concept_name_mapping_file = transformers.utils.hub.cached_file(
pretrained_model_name_or_path, CONCEPT_MAPPING_FILE_NAME, **kwargs
)
if not concept_name_mapping_file:
return None
raise RuntimeError(
f"concept_name_mapping_file does not exist: {concept_name_mapping_file}"
)

lab_stats = load_json_file(lab_stats_file)

Expand All @@ -240,11 +242,11 @@ def from_pretrained(

@classmethod
def train_tokenizer(
cls,
dataset: Union[Dataset, DatasetDict],
feature_names: List[str],
concept_name_mapping: Dict[str, str],
data_args: DataTrainingArguments,
cls,
dataset: Union[Dataset, DatasetDict],
feature_names: List[str],
concept_name_mapping: Dict[str, str],
data_args: DataTrainingArguments,
):
"""
Train a huggingface word level tokenizer.
Expand Down
5 changes: 3 additions & 2 deletions src/cehrbert/runners/hf_cehrbert_pretrain_runner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import os
from typing import Optional, Union

Expand Down Expand Up @@ -62,7 +63,7 @@ def load_and_create_tokenizer(
tokenizer_abspath = os.path.abspath(model_args.tokenizer_name_or_path)
try:
tokenizer = CehrBertTokenizer.from_pretrained(tokenizer_abspath)
except RuntimeError as e:
except (OSError, RuntimeError, FileNotFoundError, json.JSONDecodeError) as e:
LOG.warning(
"Failed to load the tokenizer from %s with the error "
"\n%s\nTried to create the tokenizer, however the dataset is not provided.",
Expand Down Expand Up @@ -104,7 +105,7 @@ def load_and_create_model(
try:
model_abspath = os.path.abspath(model_args.model_name_or_path)
model_config = AutoConfig.from_pretrained(model_abspath)
except RuntimeError as e:
except (OSError, ValueError, FileNotFoundError, json.JSONDecodeError) as e:
LOG.warning(e)
model_config = CehrBertConfig(
vocab_size=tokenizer.vocab_size,
Expand Down

0 comments on commit 5f55329

Please sign in to comment.