Skip to content

Commit

Permalink
Refactor positional encoding and document.
Browse files Browse the repository at this point in the history
Refactor mag_min, mag_max to amp_min, amp_max.
  • Loading branch information
tibuch committed Mar 30, 2021
1 parent 88f1982 commit c1afa98
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 224 deletions.
13 changes: 6 additions & 7 deletions fit/modules/SResTransformerModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

class SResTransformerModule(LightningModule):
def __init__(self, d_model, img_shape,
x_coords_img, y_coords_img, dst_flatten_order, dst_order,
coords, dst_flatten_order, dst_order,
loss='prod',
lr=0.0001,
weight_decay=0.01,
Expand All @@ -34,8 +34,7 @@ def __init__(self, d_model, img_shape,
"dropout",
"attention_dropout")

self.x_coords_img = x_coords_img
self.y_coords_img = y_coords_img
self.coords = coords
self.dst_flatten_order = dst_flatten_order
self.dst_order = dst_order
self.dft_shape = (img_shape, img_shape // 2 + 1)
Expand All @@ -46,7 +45,7 @@ def __init__(self, d_model, img_shape,
self.loss = _fc_sum_loss

self.sres = SResTransformerTrain(d_model=self.hparams.d_model,
y_coords_img=self.y_coords_img, x_coords_img=self.x_coords_img,
coords=self.coords,
flatten_order=self.dst_flatten_order,
attention_type='causal-linear',
n_layers=self.hparams.n_layers,
Expand All @@ -73,8 +72,8 @@ def configure_optimizers(self):
}

def criterion(self, pred_fc, target_fc, mag_min, mag_max):
fc_loss, amp_loss, phi_loss = self.loss(pred_fc=pred_fc, target_fc=target_fc, mag_min=mag_min,
mag_max=mag_max)
fc_loss, amp_loss, phi_loss = self.loss(pred_fc=pred_fc, target_fc=target_fc, amp_min=mag_min,
amp_max=mag_max)
return fc_loss, amp_loss, phi_loss

def training_step(self, batch, batch_idx):
Expand Down Expand Up @@ -135,7 +134,7 @@ def validation_epoch_end(self, outputs):

def load_test_model(self, path):
self.sres_pred = SResTransformerPredict(self.hparams.d_model,
y_coords_img=self.y_coords_img, x_coords_img=self.x_coords_img,
coords=self.coords,
flatten_order=self.dst_flatten_order,
attention_type='causal-linear',
n_layers=self.hparams.n_layers,
Expand Down
Loading

0 comments on commit c1afa98

Please sign in to comment.