From 9095549b46bf0f371340b040bda172fbc2998014 Mon Sep 17 00:00:00 2001 From: Fabio Date: Wed, 20 Nov 2024 14:23:29 -0500 Subject: [PATCH] Adjust hidden state of training --- config/A8_swin.yaml | 11 +++++++++++ config/M2.yaml | 5 +++++ config/M3.yaml | 7 ++++++- config/M4_swin.yaml | 11 +++++++++++ config/M5.yaml | 5 +++++ config/R16_swin.yaml | 11 +++++++++++ config/R18.yaml | 5 +++++ config/R19_swin.yaml | 10 ++++++++++ scripts/omnimix.sh | 6 +++--- step_recog/iterators.py | 11 ++--------- step_recog/models.py | 2 +- 11 files changed, 70 insertions(+), 14 deletions(-) diff --git a/config/A8_swin.yaml b/config/A8_swin.yaml index 4ef02bb..581b025 100644 --- a/config/A8_swin.yaml +++ b/config/A8_swin.yaml @@ -2,3 +2,14 @@ _BASE_: A8.yaml MODEL: CLASS: OmniTransformer_v4 OMNIGRU_CHECKPOINT_URL: /home/user/models/A8_swin.pt + +DATASET: + CLASS: 'Milly_multifeature_v6' + +TRAIN: + OPT: "adamw" #adam adamw sgd rmsprop + LR: 0.0001 + LINEAR_WARMUP_EPOCHS: 3 #default None + WEIGHT_DECAY: 0.05 + SCHEDULER: 'cos' + EPOCHS: 30 \ No newline at end of file diff --git a/config/M2.yaml b/config/M2.yaml index cb67856..021a6b9 100644 --- a/config/M2.yaml +++ b/config/M2.yaml @@ -12,6 +12,11 @@ DATASET: OUTPUT: LOCATION: "/home/user/output" +TRAIN: + OPT: "adamw" #adam adamw sgd rmsprop + LR: 0.001 + WEIGHT_DECAY: 0.0001 + SKILLS: - NAME: M2 - Apply Tourniquet STEPS: diff --git a/config/M3.yaml b/config/M3.yaml index 3662375..6a62f74 100644 --- a/config/M3.yaml +++ b/config/M3.yaml @@ -11,7 +11,12 @@ DATASET: OUTPUT: LOCATION: "/home/user/output" - + +TRAIN: + OPT: "adamw" #adam adamw sgd rmsprop + LR: 0.001 + WEIGHT_DECAY: 0.001 + SKILLS: - NAME: M3 - Apply pressure dressing STEPS: diff --git a/config/M4_swin.yaml b/config/M4_swin.yaml index fc50524..be78ab3 100644 --- a/config/M4_swin.yaml +++ b/config/M4_swin.yaml @@ -2,3 +2,14 @@ _BASE_: M4.yaml MODEL: CLASS: OmniTransformer_v3 OMNIGRU_CHECKPOINT_URL: /home/user/models/M4_swin.pt + +DATASET: + CLASS: 'Milly_multifeature_v6' + +TRAIN: + OPT: "adamw" #adam adamw sgd rmsprop + LR: 0.0001 + LINEAR_WARMUP_EPOCHS: 3 #default None + WEIGHT_DECAY: 0.05 + SCHEDULER: 'cos' + EPOCHS: 30 \ No newline at end of file diff --git a/config/M5.yaml b/config/M5.yaml index db2d7fe..eab3e07 100644 --- a/config/M5.yaml +++ b/config/M5.yaml @@ -12,6 +12,11 @@ DATASET: OUTPUT: LOCATION: "/home/user/output" +TRAIN: + OPT: "adamw" #adam adamw sgd rmsprop + LR: 0.001 + WEIGHT_DECAY: 0.00001 + SKILLS: - NAME: M5 - X-Stat STEPS: diff --git a/config/R16_swin.yaml b/config/R16_swin.yaml index f0c681a..87272dc 100644 --- a/config/R16_swin.yaml +++ b/config/R16_swin.yaml @@ -2,3 +2,14 @@ _BASE_: R16.yaml MODEL: CLASS: OmniTransformer_v3 OMNIGRU_CHECKPOINT_URL: /home/user/models/R16_swin.pt + +DATASET: + CLASS: 'Milly_multifeature_v6' + +TRAIN: + OPT: "adamw" #adam adamw sgd rmsprop + LR: 0.0001 + LINEAR_WARMUP_EPOCHS: 3 #default None + WEIGHT_DECAY: 0.05 + SCHEDULER: 'cos' + EPOCHS: 30 \ No newline at end of file diff --git a/config/R18.yaml b/config/R18.yaml index e2357c3..8b30030 100644 --- a/config/R18.yaml +++ b/config/R18.yaml @@ -12,6 +12,11 @@ DATASET: OUTPUT: LOCATION: "/home/user/output" +TRAIN: + OPT: "adamw" #adam adamw sgd rmsprop + LR: 0.001 + WEIGHT_DECAY: 0.01 + SKILLS: - NAME: R18 - Apply chest seal STEPS: diff --git a/config/R19_swin.yaml b/config/R19_swin.yaml index 08ae5e2..7151017 100644 --- a/config/R19_swin.yaml +++ b/config/R19_swin.yaml @@ -3,3 +3,13 @@ MODEL: CLASS: OmniTransformer_v3 OMNIGRU_CHECKPOINT_URL: /home/user/models/R19_swin.pt +DATASET: + CLASS: 'Milly_multifeature_v6' + +TRAIN: + OPT: "adamw" #adam adamw sgd rmsprop + LR: 0.0001 + LINEAR_WARMUP_EPOCHS: 3 #default None + WEIGHT_DECAY: 0.05 + SCHEDULER: 'cos' + EPOCHS: 30 \ No newline at end of file diff --git a/scripts/omnimix.sh b/scripts/omnimix.sh index fcae475..6da5ee4 100644 --- a/scripts/omnimix.sh +++ b/scripts/omnimix.sh @@ -41,7 +41,7 @@ sbatch < 1 else x[0] - out, h = self.gru(x, h) + out, h = self.gru(x, self.init_hidden(x.shape[0]) if h is None else h) out = self.relu(out[:, -1]) if return_last_step else self.relu(out) out = self.fc(out) return out, h