Skip to content

Commit

Permalink
add karras sampling to the Scheduler abstract class, default is quadr…
Browse files Browse the repository at this point in the history
…atic
  • Loading branch information
limiteinductive committed Dec 3, 2023
1 parent 4176868 commit 39433e2
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 13 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from torch import Tensor, device as Device, dtype as Dtype, arange, sqrt, float32, tensor
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler


class DDIM(Scheduler):
Expand All @@ -9,6 +9,7 @@ def __init__(
num_train_timesteps: int = 1_000,
initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2,
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
device: Device | str = "cpu",
dtype: Dtype = float32,
) -> None:
Expand All @@ -17,6 +18,7 @@ def __init__(
num_train_timesteps,
initial_diffusion_rate,
final_diffusion_rate,
noise_schedule=noise_schedule,
device=device,
dtype=dtype,
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler
import numpy as np
from torch import Tensor, device as Device, tensor, exp, float32, dtype as Dtype
from collections import deque
Expand All @@ -16,6 +16,7 @@ def __init__(
num_train_timesteps: int = 1_000,
initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2,
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
device: Device | str = "cpu",
dtype: Dtype = float32,
):
Expand All @@ -24,6 +25,7 @@ def __init__(
num_train_timesteps=num_train_timesteps,
initial_diffusion_rate=initial_diffusion_rate,
final_diffusion_rate=final_diffusion_rate,
noise_schedule=noise_schedule,
device=device,
dtype=dtype,
)
Expand Down
44 changes: 33 additions & 11 deletions src/refiners/foundationals/latent_diffusion/schedulers/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
from abc import ABC, abstractmethod
from enum import Enum
from torch import Tensor, device as Device, dtype as DType, linspace, float32, sqrt, log
from typing import TypeVar

T = TypeVar("T", bound="Scheduler")


class NoiseSchedule(str, Enum):
UNIFORM = "uniform"
QUADRATIC = "quadratic"
KARRAS = "karras"


class Scheduler(ABC):
"""
A base class for creating a diffusion model scheduler.
Expand All @@ -24,6 +31,7 @@ def __init__(
num_train_timesteps: int = 1_000,
initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2,
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
device: Device | str = "cpu",
dtype: DType = float32,
):
Expand All @@ -33,17 +41,8 @@ def __init__(
self.num_train_timesteps = num_train_timesteps
self.initial_diffusion_rate = initial_diffusion_rate
self.final_diffusion_rate = final_diffusion_rate
self.scale_factors = (
1.0
- linspace(
start=initial_diffusion_rate**0.5,
end=final_diffusion_rate**0.5,
steps=num_train_timesteps,
device=device,
dtype=dtype,
)
** 2
)
self.noise_schedule = noise_schedule
self.scale_factors = self.sample_noise_schedule()
self.cumulative_scale_factors = sqrt(self.scale_factors.cumprod(dim=0))
self.noise_std = sqrt(1.0 - self.scale_factors.cumprod(dim=0))
self.signal_to_noise_ratios = log(self.cumulative_scale_factors) - log(self.noise_std)
Expand Down Expand Up @@ -71,6 +70,29 @@ def _generate_timesteps(self) -> Tensor:
def steps(self) -> list[int]:
return list(range(self.num_inference_steps))

def sample_power_distribution(self, power: float = 2, /) -> Tensor:
return (
linspace(
start=self.initial_diffusion_rate ** (1 / power),
end=self.final_diffusion_rate ** (1 / power),
steps=self.num_train_timesteps,
device=self.device,
dtype=self.dtype,
)
** power
)

def sample_noise_schedule(self) -> Tensor:
match self.noise_schedule:
case "uniform":
return 1 - self.sample_power_distribution(1)
case "quadratic":
return 1 - self.sample_power_distribution(2)
case "karras":
return 1 - self.sample_power_distribution(7)
case _:
raise ValueError(f"Unknown noise schedule: {self.noise_schedule}")

def add_noise(
self,
x: Tensor,
Expand Down

0 comments on commit 39433e2

Please sign in to comment.