Skip to content

Commit

Permalink
Added random rotation training feature in train.sh
Browse files Browse the repository at this point in the history
  • Loading branch information
Theodore Zhao committed Dec 12, 2024
1 parent 400d65e commit e6036b3
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 1 deletion.
1 change: 1 addition & 0 deletions assets/scripts/train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ CUDA_VISIBLE_DEVICES=0 mpirun -n 1 python entry.py train \
MODEL.DECODER.SPATIAL.ENABLED True \
MODEL.DECODER.GROUNDING.ENABLED True \
LOADER.SAMPLE_PROB prop \
BioMed.INPUT.RANDOM_ROTATE True \
FIND_UNUSED_PARAMETERS True \
ATTENTION_ARCH.SPATIAL_MEMORIES 32 \
MODEL.DECODER.SPATIAL.MAX_ITER 0 \
Expand Down
1 change: 1 addition & 0 deletions configs/biomed_seg_lang_v1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ BioMed:
COLOR_AUG_SSD: False
SIZE_DIVISIBILITY: 32
RANDOM_FLIP: "none"
RANDOM_ROTATE: False
MASK_FORMAT: "polygon"
MIN_AREA: 30
FORMAT: "RGB"
Expand Down
15 changes: 14 additions & 1 deletion datasets/dataset_mappers/biomed_dataset_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def __init__(
max_token_num,
tokenizer,
binary_classes: bool,
rotate: bool,
):
"""
NOTE: this interface is experimental.
Expand Down Expand Up @@ -153,6 +154,7 @@ def __init__(
self.max_token_num = max_token_num

self.binary_classes = binary_classes
self.rotate = rotate

@classmethod
def from_config(cls, cfg, is_train=True):
Expand Down Expand Up @@ -188,7 +190,8 @@ def from_config(cls, cfg, is_train=True):
"retrieval": retrieval,
"max_token_num": max_token_num,
"tokenizer": tokenizer,
"binary_classes": cfg['MODEL']['ENCODER']['BINARY_CLASSES']
"binary_classes": cfg['MODEL']['ENCODER']['BINARY_CLASSES'],
"rotate": cfg['INPUT']['RANDOM_ROTATE'],
}
return ret

Expand All @@ -213,6 +216,12 @@ def __call__(self, dataset_dict):
image, transforms = T.apply_transform_gens(self.tfm_gens, image)
image_shape = image.shape[:2] # h, w

rotate_time = 0
if self.is_train and self.rotate and random.random() < 0.5:
rotate_time = random.randint(1, 3)
if rotate_time > 0:
image = np.rot90(image, rotate_time)

# Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
# but not efficient on large generic data structures due to the use of pickle & mp.Queue.
# Therefore it's important to use torch.Tensor.
Expand Down Expand Up @@ -252,6 +261,8 @@ def __call__(self, dataset_dict):
m = 1 * (m > 0)
m = m.astype(np.uint8) # convert to np.uint8
m = transforms.apply_segmentation(255*m[:,:,None])[:,:,0]
if rotate_time > 0:
m = np.rot90(m, rotate_time)
masks_grd += [m]
rand_id = random.randint(0, len(ann['sentences'])-1)
texts_grd.append(ann['sentences'][rand_id]['raw'].lower())
Expand Down Expand Up @@ -320,6 +331,8 @@ def __call__(self, dataset_dict):

m = m.astype(np.uint8) # convert to np.uint8
m = transforms.apply_segmentation(m[:,:,None])[:,:,0]
if rotate_time > 0:
m = np.rot90(m, rotate_time)
masks_grd += [m]
# random select a sentence of a single annotation.
rand_index = random.randint(0, len(ann['sentences'])-1)
Expand Down

0 comments on commit e6036b3

Please sign in to comment.