Skip to content

Commit

Permalink
Fixed GPU training + added extra stat print for train set.
Browse files Browse the repository at this point in the history
  • Loading branch information
vladd-bit committed Feb 7, 2024
1 parent 89d9128 commit e6e99cb
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 17 deletions.
12 changes: 8 additions & 4 deletions medcat/rel_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,12 +221,12 @@ def train(self, export_data_path="", train_csv_path="", test_csv_path="", checkp
train_dataset_size = len(train_rel_data)
batch_size = train_dataset_size if train_dataset_size < self.batch_size else self.batch_size
train_dataloader = DataLoader(train_rel_data, batch_size=batch_size, shuffle=self.config.train.shuffle_data,
num_workers=0, collate_fn=self.padding_seq, pin_memory=self.config.general.pin_memory)
num_workers=0, collate_fn=self.padding_seq, pin_memory=self.config.general.pin_memory, pin_memory_device=self.device)

test_dataset_size = len(test_rel_data)
test_batch_size = test_dataset_size if test_dataset_size < self.batch_size else self.batch_size
test_dataloader = DataLoader(test_rel_data, batch_size=test_batch_size, shuffle=self.config.train.shuffle_data,
num_workers=0, collate_fn=self.padding_seq, pin_memory=self.config.general.pin_memory)
num_workers=0, collate_fn=self.padding_seq, pin_memory=self.config.general.pin_memory, pin_memory_device=self.device)

criterion = nn.CrossEntropyLoss(ignore_index=-1)

Expand Down Expand Up @@ -320,6 +320,10 @@ def train(self, export_data_path="", train_csv_path="", test_csv_path="", checkp

end_time = datetime.now().time()

print("======================== TRAIN SET TEST RESULTS ========================")
train_results = self.evaluate_results(train_dataloader, self.pad_id)

print("======================== TEST SET TEST RESULTS ========================")
results = self.evaluate_results(test_dataloader, self.pad_id)

f1_per_epoch.append(results['f1'])
Expand Down Expand Up @@ -419,8 +423,8 @@ def evaluate_results(self, data_loader, pad_id):
for i, data in enumerate(data_loader):
with torch.no_grad():
token_ids, e1_e2_start, labels, _, _, _ = data
attention_mask = (token_ids != pad_id).float()
token_type_ids = torch.zeros((token_ids.shape[0], token_ids.shape[1])).long()
attention_mask = (token_ids != pad_id).float().to(self.device)
token_type_ids = torch.zeros((token_ids.shape[0], token_ids.shape[1])).long().to(self.device)

labels = labels.to(self.device)

Expand Down
22 changes: 10 additions & 12 deletions medcat/utils/relation_extraction/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def output2logits(self, pooled_output, sequence_output, input_ids, e1_e2_start):
new_pooled_output = torch.cat((pooled_output, *seq_tags), dim=1)
new_pooled_output = torch.squeeze(new_pooled_output, dim=1)
else:
e1e2_output =[]
e1e2_output = []
temp_e1 = []
temp_e2 = []

Expand All @@ -96,7 +96,7 @@ def output2logits(self, pooled_output, sequence_output, input_ids, e1_e2_start):

classification_logits = self.classification_layer(self.drop_out(new_pooled_output))

return classification_logits
return classification_logits.to(self.relcat_config.general.device)

def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None,
head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None,
Expand All @@ -106,20 +106,18 @@ def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, posi
else:
raise ValueError("You have to specify input_ids")

device = input_ids.device

if attention_mask is None:
attention_mask = torch.ones(input_shape, device=device)
attention_mask = torch.ones(input_shape, device=self.relcat_config.general.device)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(input_shape, device=device)
encoder_attention_mask = torch.ones(input_shape, device=self.relcat_config.general.device)
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.relcat_config.general.device)

input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
encoder_attention_mask = encoder_attention_mask.to(device)
input_ids = input_ids.to(self.relcat_config.general.device)
attention_mask = attention_mask.to(self.relcat_config.general.device)
encoder_attention_mask = encoder_attention_mask.to(self.relcat_config.general.device)

self.bert_model = self.bert_model.to(device)
self.bert_model = self.bert_model.to(self.relcat_config.general.device)

model_output = self.bert_model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids,
encoder_hidden_states=encoder_hidden_states,
Expand All @@ -131,4 +129,4 @@ def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, posi

classification_logits = self.output2logits(pooled_output, sequence_output, input_ids, e1_e2_start)

return model_output, classification_logits.to(device)
return model_output, classification_logits.to(self.relcat_config.general.device)
2 changes: 1 addition & 1 deletion medcat/utils/relation_extraction/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def save_state(model, optimizer, scheduler, epoch, best_f1, path="./", model_nam
def load_state(model, optimizer, scheduler, path="./", model_name="BERT", file_prefix="train", load_best=False, device=torch.device("cpu"), config: ConfigRelCAT = ConfigRelCAT()):

model_name = model_name.replace("/", "_")
print("Attempting to load RelCAT model on device: ", device.type)
print("Attempting to load RelCAT model on device: ", device)
checkpoint_path = os.path.join(path, file_prefix + "_checkpoint_%s.dat" % model_name)
best_path = os.path.join(path, file_prefix + "_model_best_%s.dat" % model_name)
start_epoch, best_f1, checkpoint = 0, 0, None
Expand Down

0 comments on commit e6e99cb

Please sign in to comment.