diff --git a/flow_matching/solver/discrete_solver.py b/flow_matching/solver/discrete_solver.py index 282c2a0..8ca84cd 100644 --- a/flow_matching/solver/discrete_solver.py +++ b/flow_matching/solver/discrete_solver.py @@ -12,7 +12,6 @@ from torch import Tensor from torch.nn import functional as F -from tqdm import tqdm from flow_matching.path import MixtureDiscreteProbPath @@ -20,6 +19,13 @@ from flow_matching.utils import categorical, ModelWrapper from .utils import get_nearest_times +try: + from tqdm import tqdm + + TQDM_AVAILABLE = True +except ImportError: + TQDM_AVAILABLE = False + class MixtureDiscreteEulerSolver(Solver): r"""Solver that simulates the CTMC process :math:`(X_t)_{t_{\text{init}}\leq t\leq t_{\text{final}}}` defined by :math:`p_t` the marginal probability path of ``path``. @@ -130,6 +136,9 @@ def forward(self, x: torch.Tensor, t: torch.Tensor, **extras) -> torch.Tensor: Returns: Tensor: The sampled sequence of discrete values. + + Raises: + ImportError: To run in verbose mode, tqdm must be installed. """ if not div_free == 0.0: assert ( @@ -173,6 +182,10 @@ def forward(self, x: torch.Tensor, t: torch.Tensor, **extras) -> torch.Tensor: res = [x_init.clone()] if verbose: + if not TQDM_AVAILABLE: + raise ImportError( + "tqdm is required for verbose mode. Please install it." + ) ctx = tqdm(total=t_final, desc=f"NFE: {steps_counter}") else: ctx = nullcontext() diff --git a/flow_matching/solver/riemannian_ode_solver.py b/flow_matching/solver/riemannian_ode_solver.py index d851e8f..a6ff0fd 100644 --- a/flow_matching/solver/riemannian_ode_solver.py +++ b/flow_matching/solver/riemannian_ode_solver.py @@ -9,12 +9,18 @@ import torch from torch import Tensor -from tqdm import tqdm from flow_matching.solver.solver import Solver from flow_matching.utils import ModelWrapper from flow_matching.utils.manifolds import geodesic, Manifold +try: + from tqdm import tqdm + + TQDM_AVAILABLE = True +except ImportError: + TQDM_AVAILABLE = False + class RiemannianODESolver(Solver): r"""Riemannian ODE solver @@ -60,6 +66,9 @@ def sample( Returns: Tensor: The sampled sequence. Defaults to returning samples at :math:`t=1`. + + Raises: + ImportError: To run in verbose mode, tqdm must be installed. """ step_fns = { "euler": _euler_step, @@ -96,6 +105,10 @@ def velocity_func(x, t): t0s = t_discretization[:-1] if verbose: + if not TQDM_AVAILABLE: + raise ImportError( + "tqdm is required for verbose mode. Please install it." + ) t0s = tqdm(t0s) if return_intermediates: