Skip to content

Commit

Permalink
RUFF: forgot to run format after doc updates
Browse files Browse the repository at this point in the history
  • Loading branch information
ntjohnson1 committed Dec 6, 2024
1 parent 66cbd85 commit 990b8c5
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 10 deletions.
6 changes: 3 additions & 3 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@


@pytest.fixture(autouse=True)
def add_packages(doctest_namespace): #noqa: D103
def add_packages(doctest_namespace): # noqa: D103
doctest_namespace["np"] = numpy
doctest_namespace["ttb"] = pyttb


def pytest_addoption(parser): #noqa: D103
def pytest_addoption(parser): # noqa: D103
parser.addoption(
"--packaging",
action="store_true",
Expand All @@ -27,6 +27,6 @@ def pytest_addoption(parser): #noqa: D103
)


def pytest_configure(config): #noqa: D103
def pytest_configure(config): # noqa: D103
if not config.option.packaging:
config.option.markexpr = "not packaging"
1 change: 1 addition & 0 deletions pyttb/cp_apr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1699,6 +1699,7 @@ def calc_grad(
grad_row = (np.ones(phi_row.shape) - phi_row).transpose()
return grad_row, phi_row


# TODO verify what pi is
# Mu helper functions
def calculate_pi(
Expand Down
12 changes: 6 additions & 6 deletions pyttb/gcp/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def solve( # noqa: PLR0913
class SGD(StochasticSolver):
"""General Stochastic Gradient Descent."""

def update_step( #noqa: D102
def update_step( # noqa: D102
self, model: ttb.ktensor, gradient: List[np.ndarray], lower_bound: float
) -> Tuple[List[np.ndarray], float]:
step = self._decay**self._nfails * self._rate
Expand All @@ -254,7 +254,7 @@ def update_step( #noqa: D102
]
return factor_matrices, step

def set_failed_epoch(self): #noqa: D102
def set_failed_epoch(self): # noqa: D102
# No additional internal state for SGD
pass

Expand Down Expand Up @@ -318,14 +318,14 @@ def __init__( # noqa: PLR0913
self._v: List[np.ndarray] = []
self._v_prev: List[np.ndarray] = []

def set_failed_epoch( #noqa: D102
def set_failed_epoch( # noqa: D102
self,
):
self._total_iterations -= self._epoch_iters
self._m = self._m_prev.copy()
self._v = self._v_prev.copy()

def update_step( #noqa: D102
def update_step( # noqa: D102
self, model: ttb.ktensor, gradient: List[np.ndarray], lower_bound: float
) -> Tuple[List[np.ndarray], float]:
if self._total_iterations == 0:
Expand Down Expand Up @@ -387,12 +387,12 @@ def __init__( # noqa: PLR0913
)
self._gnormsum = 0.0

def set_failed_epoch( #noqa: D102
def set_failed_epoch( # noqa: D102
self,
):
self._gnormsum = 0.0

def update_step( #noqa: D102
def update_step( # noqa: D102
self, model: ttb.ktensor, gradient: List[np.ndarray], lower_bound: float
) -> Tuple[List[np.ndarray], float]:
self._gnormsum += np.sum([np.sum(gk**2) for gk in gradient])
Expand Down
2 changes: 1 addition & 1 deletion pyttb/pyttb_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ def tt_subsubsref(obj: np.ndarray, s: Any) -> Union[float, np.ndarray]:
Returns
-------
Still uncertain to this functionality
""" # noqa: D401
""" # noqa: D401
# TODO figure out when subsref yields key of length>1 for now ignore this logic and
# just return
# if len(s) == 1:
Expand Down

0 comments on commit 990b8c5

Please sign in to comment.