Skip to content

Commit

Permalink
fix localmodule args
Browse files Browse the repository at this point in the history
  • Loading branch information
vijaydwivedi75 authored Apr 26, 2024
1 parent d7e7c07 commit d676c2c
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,13 @@ def reset_parameters(self):

def forward(self, seq, x, pos_enc=None, batch_idx=None):
if self.conv_type == "local":
out = self.local_forward(seq, pos_enc)
out = self.local_forward(seq)

elif self.conv_type == "global":
out = self.global_forward(x[: len(batch_idx)], pos_enc, batch_idx)

elif self.conv_type == "full":
out_local = self.local_forward(seq, pos_enc)
out_local = self.local_forward(seq)
out_global = self.global_forward(x[: len(batch_idx)], pos_enc, batch_idx)
out = torch.cat([out_local, out_global], dim=1)

Expand Down Expand Up @@ -147,8 +147,8 @@ def global_forward(self, x, pos_enc, batch_idx):

return out

def local_forward(self, seq, pos_enc):
return self.local_module(seq, pos_enc)
def local_forward(self, seq):
return self.local_module(seq)

def __repr__(self) -> str:
return (
Expand Down

0 comments on commit d676c2c

Please sign in to comment.