Skip to content

Commit

Permalink
Adjust hidden state of training
Browse files Browse the repository at this point in the history
  • Loading branch information
Fabio committed Nov 20, 2024
1 parent 26e6655 commit 9095549
Show file tree
Hide file tree
Showing 11 changed files with 70 additions and 14 deletions.
11 changes: 11 additions & 0 deletions config/A8_swin.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,14 @@ _BASE_: A8.yaml
MODEL:
CLASS: OmniTransformer_v4
OMNIGRU_CHECKPOINT_URL: /home/user/models/A8_swin.pt

DATASET:
CLASS: 'Milly_multifeature_v6'

TRAIN:
OPT: "adamw" #adam adamw sgd rmsprop
LR: 0.0001
LINEAR_WARMUP_EPOCHS: 3 #default None
WEIGHT_DECAY: 0.05
SCHEDULER: 'cos'
EPOCHS: 30
5 changes: 5 additions & 0 deletions config/M2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ DATASET:
OUTPUT:
LOCATION: "/home/user/output"

TRAIN:
OPT: "adamw" #adam adamw sgd rmsprop
LR: 0.001
WEIGHT_DECAY: 0.0001

SKILLS:
- NAME: M2 - Apply Tourniquet
STEPS:
Expand Down
7 changes: 6 additions & 1 deletion config/M3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@ DATASET:

OUTPUT:
LOCATION: "/home/user/output"


TRAIN:
OPT: "adamw" #adam adamw sgd rmsprop
LR: 0.001
WEIGHT_DECAY: 0.001

SKILLS:
- NAME: M3 - Apply pressure dressing
STEPS:
Expand Down
11 changes: 11 additions & 0 deletions config/M4_swin.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,14 @@ _BASE_: M4.yaml
MODEL:
CLASS: OmniTransformer_v3
OMNIGRU_CHECKPOINT_URL: /home/user/models/M4_swin.pt

DATASET:
CLASS: 'Milly_multifeature_v6'

TRAIN:
OPT: "adamw" #adam adamw sgd rmsprop
LR: 0.0001
LINEAR_WARMUP_EPOCHS: 3 #default None
WEIGHT_DECAY: 0.05
SCHEDULER: 'cos'
EPOCHS: 30
5 changes: 5 additions & 0 deletions config/M5.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ DATASET:
OUTPUT:
LOCATION: "/home/user/output"

TRAIN:
OPT: "adamw" #adam adamw sgd rmsprop
LR: 0.001
WEIGHT_DECAY: 0.00001

SKILLS:
- NAME: M5 - X-Stat
STEPS:
Expand Down
11 changes: 11 additions & 0 deletions config/R16_swin.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,14 @@ _BASE_: R16.yaml
MODEL:
CLASS: OmniTransformer_v3
OMNIGRU_CHECKPOINT_URL: /home/user/models/R16_swin.pt

DATASET:
CLASS: 'Milly_multifeature_v6'

TRAIN:
OPT: "adamw" #adam adamw sgd rmsprop
LR: 0.0001
LINEAR_WARMUP_EPOCHS: 3 #default None
WEIGHT_DECAY: 0.05
SCHEDULER: 'cos'
EPOCHS: 30
5 changes: 5 additions & 0 deletions config/R18.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ DATASET:
OUTPUT:
LOCATION: "/home/user/output"

TRAIN:
OPT: "adamw" #adam adamw sgd rmsprop
LR: 0.001
WEIGHT_DECAY: 0.01

SKILLS:
- NAME: R18 - Apply chest seal
STEPS:
Expand Down
10 changes: 10 additions & 0 deletions config/R19_swin.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,13 @@ MODEL:
CLASS: OmniTransformer_v3
OMNIGRU_CHECKPOINT_URL: /home/user/models/R19_swin.pt

DATASET:
CLASS: 'Milly_multifeature_v6'

TRAIN:
OPT: "adamw" #adam adamw sgd rmsprop
LR: 0.0001
LINEAR_WARMUP_EPOCHS: 3 #default None
WEIGHT_DECAY: 0.05
SCHEDULER: 'cos'
EPOCHS: 30
6 changes: 3 additions & 3 deletions scripts/omnimix.sh
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ sbatch <<EOSBATCH
#SBATCH --mem 64GB
#SBATCH --time 2-00:00:00
#SBATCH --gres gpu:1
#SBATCH --job-name step-recog-k$KFOLD_ITER$DESC
#SBATCH --job-name step-k$KFOLD_ITER$DESC
#SBATCH --output logs/%J_step-k$KFOLD_ITER$DESC.out
#SBATCH --mail-type=BEGIN,END,ERROR
#SBATCH --mail-user=$USER@nyu.edu
Expand Down Expand Up @@ -70,8 +70,8 @@ sbatch <<EOSBATCH
#SBATCH --mem 64GB
#SBATCH --time 2-00:00:00
#SBATCH --gres gpu:1
#SBATCH --job-name step$DESC
#SBATCH --output logs/%J_step-recog$DESC.out
#SBATCH --job-name step-$DESC
#SBATCH --output logs/%J_step-$DESC.out
#SBATCH --mail-type=BEGIN,END,ERROR
#SBATCH --mail-user=$USER@nyu.edu
#SBATCH --account=$PROJECT_ACCOUNT
Expand Down
11 changes: 2 additions & 9 deletions step_recog/iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ def build_optimizer(model, cfg):
def train_step_GRU(model, criterion, criterion_t, optimizer, loader, is_training, device, cfg, progress, class_weight):
if is_training:
model.train()
h = model.init_hidden(cfg.TRAIN.BATCH_SIZE)
else:
model.eval()

Expand All @@ -94,13 +93,9 @@ def train_step_GRU(model, criterion, criterion_t, optimizer, loader, is_training
action, obj, frame, audio, label, label_t, mask, _, _ = next(loader_iterator)
label = nn.functional.one_hot(label, model.number_classes)

if not is_training:
h = model.init_hidden(len(label))

h = torch.zeros_like(h)
optimizer.zero_grad()

out, h = model(action.to(device).float(), h, audio.to(device).float(), obj.to(device).float(), frame.to(device).float(), return_last_step = False)
out, _ = model(action=action.to(device).float(), aud=audio.to(device).float(), objs=obj.to(device).float(), frame=frame.to(device).float(), return_last_step = False)

mask = mask.to(out.device)
label = label.to(out.device)
Expand Down Expand Up @@ -468,9 +463,7 @@ def evaluate_GRU(model, data_loader, cfg):
# ipdb.set_trace()

for action, obj, frame, audio, label, _, mask, frame_idx, videos in data_loader:
h = model.init_hidden(len(action))

out, _ = model(action.to(device).float(), h, audio.to(device).float(), obj.to(device).float(), frame.to(device).float(), return_last_step = False)
out, _ = model(action=action.to(device).float(), aud=audio.to(device).float(), objs=obj.to(device).float(), frame=frame.to(device).float(), return_last_step = False)
out = torch.softmax(out[..., :model.number_classes], dim = -1).cpu().detach().numpy()
label = label.cpu().numpy()
frame_idx = frame_idx.cpu().numpy()
Expand Down
2 changes: 1 addition & 1 deletion step_recog/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def forward(self, action, h=None, aud=None, objs=None, frame=None, return_last_s
x.append(obj_in)

x = torch.concat(x, -1) if len(x) > 1 else x[0]
out, h = self.gru(x, h)
out, h = self.gru(x, self.init_hidden(x.shape[0]) if h is None else h)
out = self.relu(out[:, -1]) if return_last_step else self.relu(out)
out = self.fc(out)
return out, h
Expand Down

0 comments on commit 9095549

Please sign in to comment.