diff --git a/assets/scripts/train.sh b/assets/scripts/train.sh index ea74d99..4d00770 100755 --- a/assets/scripts/train.sh +++ b/assets/scripts/train.sh @@ -31,6 +31,7 @@ CUDA_VISIBLE_DEVICES=0 mpirun -n 1 python entry.py train \ MODEL.DECODER.SPATIAL.ENABLED True \ MODEL.DECODER.GROUNDING.ENABLED True \ LOADER.SAMPLE_PROB prop \ + BioMed.INPUT.RANDOM_ROTATE True \ FIND_UNUSED_PARAMETERS True \ ATTENTION_ARCH.SPATIAL_MEMORIES 32 \ MODEL.DECODER.SPATIAL.MAX_ITER 0 \ diff --git a/configs/biomed_seg_lang_v1.yaml b/configs/biomed_seg_lang_v1.yaml index 5d1fb78..f790fb4 100755 --- a/configs/biomed_seg_lang_v1.yaml +++ b/configs/biomed_seg_lang_v1.yaml @@ -289,6 +289,7 @@ BioMed: COLOR_AUG_SSD: False SIZE_DIVISIBILITY: 32 RANDOM_FLIP: "none" + RANDOM_ROTATE: False MASK_FORMAT: "polygon" MIN_AREA: 30 FORMAT: "RGB" diff --git a/datasets/dataset_mappers/biomed_dataset_mapper.py b/datasets/dataset_mappers/biomed_dataset_mapper.py index 36b4d44..9f101cc 100755 --- a/datasets/dataset_mappers/biomed_dataset_mapper.py +++ b/datasets/dataset_mappers/biomed_dataset_mapper.py @@ -121,6 +121,7 @@ def __init__( max_token_num, tokenizer, binary_classes: bool, + rotate: bool, ): """ NOTE: this interface is experimental. @@ -153,6 +154,7 @@ def __init__( self.max_token_num = max_token_num self.binary_classes = binary_classes + self.rotate = rotate @classmethod def from_config(cls, cfg, is_train=True): @@ -188,7 +190,8 @@ def from_config(cls, cfg, is_train=True): "retrieval": retrieval, "max_token_num": max_token_num, "tokenizer": tokenizer, - "binary_classes": cfg['MODEL']['ENCODER']['BINARY_CLASSES'] + "binary_classes": cfg['MODEL']['ENCODER']['BINARY_CLASSES'], + "rotate": cfg['INPUT']['RANDOM_ROTATE'], } return ret @@ -213,6 +216,12 @@ def __call__(self, dataset_dict): image, transforms = T.apply_transform_gens(self.tfm_gens, image) image_shape = image.shape[:2] # h, w + rotate_time = 0 + if self.is_train and self.rotate and random.random() < 0.5: + rotate_time = random.randint(1, 3) + if rotate_time > 0: + image = np.rot90(image, rotate_time) + # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, # but not efficient on large generic data structures due to the use of pickle & mp.Queue. # Therefore it's important to use torch.Tensor. @@ -252,6 +261,8 @@ def __call__(self, dataset_dict): m = 1 * (m > 0) m = m.astype(np.uint8) # convert to np.uint8 m = transforms.apply_segmentation(255*m[:,:,None])[:,:,0] + if rotate_time > 0: + m = np.rot90(m, rotate_time) masks_grd += [m] rand_id = random.randint(0, len(ann['sentences'])-1) texts_grd.append(ann['sentences'][rand_id]['raw'].lower()) @@ -320,6 +331,8 @@ def __call__(self, dataset_dict): m = m.astype(np.uint8) # convert to np.uint8 m = transforms.apply_segmentation(m[:,:,None])[:,:,0] + if rotate_time > 0: + m = np.rot90(m, rotate_time) masks_grd += [m] # random select a sentence of a single annotation. rand_index = random.randint(0, len(ann['sentences'])-1)