Skip to content

Commit

Permalink
Fixed Sequential to not have extra arguments. If extra arguments ar…
Browse files Browse the repository at this point in the history
…e necessary, they should be packaged into the data
  • Loading branch information
ArvinSKushwaha committed Jul 26, 2024
1 parent d990d42 commit b806827
Showing 1 changed file with 3 additions and 7 deletions.
10 changes: 3 additions & 7 deletions src/jaximal/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,9 @@ def init(key: PRNGKeyArray) -> Self:

return init

def __call__(self, data: Any, *args: dict[str, Any]) -> Any:
assert len(args) == len(self.modules), (
'Expected `self.modules` and `args` to have the same length '
f'but got {len(self.modules)} and {len(args)}, respectively.'
)
for kwargs, modules in zip(args, self.modules):
data = modules(data, **kwargs)
def __call__(self, data: Any) -> Any:
for modules in self.modules:
data = modules(data)

return data

Expand Down

0 comments on commit b806827

Please sign in to comment.