Skip to content

Commit

Permalink
Fixing dependencies (#105)
Browse files Browse the repository at this point in the history
* examples readme, numpy version specs, tqdm optional

* also fixed riemannian tqdm

* moved unrealted change

* revert numpy version fix

* shortened exception message

---------

Co-authored-by: Marton Havasi <[email protected]>
  • Loading branch information
mhavasi and mhavasi committed Jan 2, 2025
1 parent b837802 commit 47c4396
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 2 deletions.
15 changes: 14 additions & 1 deletion flow_matching/solver/discrete_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,20 @@
from torch import Tensor

from torch.nn import functional as F
from tqdm import tqdm

from flow_matching.path import MixtureDiscreteProbPath

from flow_matching.solver.solver import Solver
from flow_matching.utils import categorical, ModelWrapper
from .utils import get_nearest_times

try:
from tqdm import tqdm

TQDM_AVAILABLE = True
except ImportError:
TQDM_AVAILABLE = False


class MixtureDiscreteEulerSolver(Solver):
r"""Solver that simulates the CTMC process :math:`(X_t)_{t_{\text{init}}\leq t\leq t_{\text{final}}}` defined by :math:`p_t` the marginal probability path of ``path``.
Expand Down Expand Up @@ -130,6 +136,9 @@ def forward(self, x: torch.Tensor, t: torch.Tensor, **extras) -> torch.Tensor:
Returns:
Tensor: The sampled sequence of discrete values.
Raises:
ImportError: To run in verbose mode, tqdm must be installed.
"""
if not div_free == 0.0:
assert (
Expand Down Expand Up @@ -173,6 +182,10 @@ def forward(self, x: torch.Tensor, t: torch.Tensor, **extras) -> torch.Tensor:
res = [x_init.clone()]

if verbose:
if not TQDM_AVAILABLE:
raise ImportError(
"tqdm is required for verbose mode. Please install it."
)
ctx = tqdm(total=t_final, desc=f"NFE: {steps_counter}")
else:
ctx = nullcontext()
Expand Down
15 changes: 14 additions & 1 deletion flow_matching/solver/riemannian_ode_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,18 @@

import torch
from torch import Tensor
from tqdm import tqdm

from flow_matching.solver.solver import Solver
from flow_matching.utils import ModelWrapper
from flow_matching.utils.manifolds import geodesic, Manifold

try:
from tqdm import tqdm

TQDM_AVAILABLE = True
except ImportError:
TQDM_AVAILABLE = False


class RiemannianODESolver(Solver):
r"""Riemannian ODE solver
Expand Down Expand Up @@ -60,6 +66,9 @@ def sample(
Returns:
Tensor: The sampled sequence. Defaults to returning samples at :math:`t=1`.
Raises:
ImportError: To run in verbose mode, tqdm must be installed.
"""
step_fns = {
"euler": _euler_step,
Expand Down Expand Up @@ -96,6 +105,10 @@ def velocity_func(x, t):
t0s = t_discretization[:-1]

if verbose:
if not TQDM_AVAILABLE:
raise ImportError(
"tqdm is required for verbose mode. Please install it."
)
t0s = tqdm(t0s)

if return_intermediates:
Expand Down

0 comments on commit 47c4396

Please sign in to comment.