Skip to content

Commit

Permalink
Update train_network.py
Browse files Browse the repository at this point in the history
  • Loading branch information
gesen2egee authored Aug 4, 2024
1 parent a593e83 commit f6dbf7c
Showing 1 changed file with 19 additions and 19 deletions.
38 changes: 19 additions & 19 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -1034,26 +1034,26 @@ def remove_model(old_ckpt_name):
logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm)
accelerator.log(logs, step=global_step)

if len(val_dataloader) > 0:
if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps:
accelerator.print("Validating バリデーション処理...")
total_loss = 0.0
with torch.no_grad():
validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader)
for val_step in tqdm(range(validation_steps), desc='Validation Steps'):
batch = next(cyclic_val_dataloader)
loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args)
total_loss += loss.detach().item()
current_loss = total_loss / validation_steps
val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss)
if len(val_dataloader) > 0:
if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps:
accelerator.print("Validating バリデーション処理...")
total_loss = 0.0
with torch.no_grad():
validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader)
for val_step in tqdm(range(validation_steps), desc='Validation Steps'):
batch = next(cyclic_val_dataloader)
loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args)
total_loss += loss.detach().item()
current_loss = total_loss / validation_steps
val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss)

if args.logging_dir is not None:
logs = {"loss/current_val_loss": current_loss}
accelerator.log(logs, step=global_step)
avr_loss: float = val_loss_recorder.moving_average
logs = {"loss/average_val_loss": avr_loss}
accelerator.log(logs, step=global_step)

if args.logging_dir is not None:
logs = {"loss/current_val_loss": current_loss}
accelerator.log(logs, step=global_step)
avr_loss: float = val_loss_recorder.moving_average
logs = {"loss/average_val_loss": avr_loss}
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
break

Expand Down

0 comments on commit f6dbf7c

Please sign in to comment.