Skip to content

Commit

Permalink
more_timestep_sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
gesen2egee committed Oct 4, 2024
1 parent 012e7e6 commit 2001f46
Showing 1 changed file with 45 additions and 6 deletions.
51 changes: 45 additions & 6 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3476,6 +3476,25 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
default=0.1,
help="The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 0.1 / Huber損失のパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1",
)
parser.add_argument(
"--timestep_sampling",
choices=["uniform", "sigmoid", "shift", "flux_shift"],
default="uniform",
help="Method to sample timesteps: uniform random, sigmoid of random normal, shift of sigmoid and FLUX.1 shifting."
" / タイムステップをサンプリングする方法:random uniform、random normalのsigmoid、sigmoidのシフト、FLUX.1のシフト。",
)
parser.add_argument(
"--sigmoid_scale",
type=float,
default=1.0,
help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). / sigmoidタイムステップサンプリングの倍率(timestep-samplingが"sigmoid"の場合のみ有効)。',
)
parser.add_argument(
"--discrete_flow_shift",
type=float,
default=1.0,
help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。",
)

parser.add_argument(
"--lowram",
Expand Down Expand Up @@ -5198,9 +5217,32 @@ def save_sd_model_on_train_end_common(
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True)

def time_shift(mu: float, sigma: float, t: torch.Tensor):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)

def get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, device):
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu")
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
m = (y2 - y1) / (x2 - x1)
b = y1 - m * x1
return lambda x: m * x + b

def get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, latents, device):
# Sample a random timestep for each image
b_size, _, h, w = latents.shape

if args.timestep_sampling != "uniform":
shift = args.discrete_flow_shift
logits_norm = torch.randn(b_size, device="cpu")
logits_norm = logits_norm * args.sigmoid_scale
timesteps = logits_norm.sigmoid()
if args.timestep_sampling == "flux_shift":
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2))
timesteps = time_shift(mu, 1.0, timesteps)
else:
timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps)
t = timesteps.view(-1, 1, 1, 1)
timesteps = min_timestep + (timesteps * (max_timestep - min_timestep))
else:
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu")

if args.loss_type == "huber" or args.loss_type == "smooth_l1":
if args.huber_schedule == "exponential":
Expand All @@ -5223,7 +5265,6 @@ def get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler,
timesteps = timesteps.long().to(device)
return timesteps, huber_c


def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
Expand All @@ -5238,12 +5279,10 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount
)

# Sample a random timestep for each image
b_size = latents.shape[0]
min_timestep = 0 if args.min_timestep is None else args.min_timestep
max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep

timesteps, huber_c = get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, latents.device)
timesteps, huber_c = get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, latents, latents.device)

# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
Expand Down

0 comments on commit 2001f46

Please sign in to comment.