Skip to content

Commit

Permalink
Address review comment
Browse files Browse the repository at this point in the history
  • Loading branch information
bwohlberg committed Nov 15, 2023
1 parent 21f5668 commit 124b391
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions scico/functional/_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@

from typing import List, Optional, Union

import jax

import scico
from scico import numpy as snp
from scico.numpy import Array, BlockArray
Expand Down Expand Up @@ -44,7 +42,7 @@ def __repr__(self):
return f"""{type(self)} (has_eval = {self.has_eval}, has_prox = {self.has_prox})"""

def __mul__(self, other: Union[float, int]) -> ScaledFunctional:
if snp.isscalar(other) or isinstance(other, jax.core.Tracer):
if snp.util.is_scalar_equiv(other):
return ScaledFunctional(self, other)
return NotImplemented

Expand Down Expand Up @@ -146,7 +144,7 @@ def __call__(self, x: Union[Array, BlockArray]) -> float:
return self.scale * self.functional(x)

def __mul__(self, other: Union[float, int]) -> ScaledFunctional:
if snp.isscalar(other) or isinstance(other, jax.core.Tracer):
if snp.util.is_scalar_equiv(other):
return ScaledFunctional(self.functional, other * self.scale)
return NotImplemented

Expand Down

0 comments on commit 124b391

Please sign in to comment.