Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolve #468 #470

Merged
merged 8 commits into from
Nov 15, 2023
Merged

Resolve #468 #470

merged 8 commits into from
Nov 15, 2023

Conversation

bwohlberg
Copy link
Collaborator

@bwohlberg bwohlberg commented Nov 14, 2023

Resolve #468.

Also:

  • Improve scaling of ScaledFunctional, so that e.g. 5.0 * (2.0 * functional.L2Norm()) is a ScaledFunctional with a scale of 10 of an L2Norm rather than a ScaledFunctional with a scale of 5 of a ScaledFunctional.
  • Some cleaning up and code re-organization.
  • Update change summary CHANGES.rst.

Copy link

codecov bot commented Nov 14, 2023

Codecov Report

Attention: 2 lines in your changes are missing coverage. Please review.

Comparison is base (5ffd1f9) 94.60% compared to head (124b391) 94.57%.

Files Patch % Lines
scico/functional/_functional.py 95.24% 1 Missing ⚠️
scico/loss.py 66.67% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #470      +/-   ##
==========================================
- Coverage   94.60%   94.57%   -0.03%     
==========================================
  Files          90       90              
  Lines        5578     5585       +7     
==========================================
+ Hits         5277     5282       +5     
- Misses        301      303       +2     
Flag Coverage Δ
unittests 94.57% <91.67%> (-0.03%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@@ -38,15 +43,15 @@ def __init__(self):
def __repr__(self):
return f"""{type(self)} (has_eval = {self.has_eval}, has_prox = {self.has_prox})"""

def __mul__(self, other):
def __mul__(self, other: Union[float, int]) -> ScaledFunctional:
if snp.isscalar(other) or isinstance(other, jax.core.Tracer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(1) I suggest snp.util.is_scalar_equiv here, remember snp.isscalar(jnp.sum(1.0)) == False
(2) I suggest removing the jax.core.Tracer check. That would return true for traced arrays as well as scalars, which seems wrong. And snp.util.is_scalar_equiv should be true for scalar tracers (if not, let's fix it)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Done.

@bwohlberg bwohlberg merged commit 4efcba4 into main Nov 15, 2023
@bwohlberg bwohlberg deleted the brendt/issue468 branch November 15, 2023 18:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Incorrect handling of scale in Loss.grad
2 participants