Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Intermediate time generation #118

Open
wants to merge 33 commits into
base: main
Choose a base branch
from
Open

Conversation

rousseab
Copy link
Collaborator

In this PR, I create an experiment to noise a validation data example to an intermediate time and denoise it back again to see when we start having problems. The basic idea is documented in this Notion note: https://www.notion.so/4f446345516e4b499126ad1e97d026cd?v=0618487565864066a0c6c96b4c655d2c&p=17a690083be4809692fef56c24bb97d2&pm=s.

In a nutshell, it lets me create these kinds of images.
Screenshot 2025-01-16 at 5 28 37 PM

Also, I make a few things into torch.nn.Module because it is a PITA to constantly have to figure out what device we are on. Some of the auxiliary objects we create (such as a regularizer, for example), are created before the whole PL is shipped off to the GPU, so I can't know at creation time if I'm going to CPU or GPU... Better to work with the framework instead of fighting against it, and let PL handle devices.

@rousseab rousseab requested a review from sblackburn86 January 22, 2025 20:41
@@ -562,12 +574,17 @@ def generate_samples(self):
assert (
self.hyper_params.diffusion_sampling_parameters is not None
), "sampling parameters must be provided to create a generator."
with torch.no_grad():
with ((torch.no_grad())):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is that necessary?

weighted_regularizer_loss = (
self.regularizer.compute_weighted_regularizer_loss(
score_network=self.axl_network,
augmented_batch=augmented_batch,
current_epoch=self.current_epoch))

t2 = time.time()
model_regularization_time = t2 - t1
logger.info(f" - batch {batch_idx} :: Prediction time = {model_prediction_time:2.1e} s, "
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this useful beyond debugging? I'm fine with leaving it in the log, but I'm not sure it adds something.


def on_train_epoch_start(self) -> None:
"""On train epoch start."""
logger.info(f"Starting Training Epoch {self.current_epoch}.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doesn't this overlap with the standard progressbar in the log? We have an option to activate / deactivate it.


def on_train_epoch_end(self) -> None:
"""On train epoch end."""
logger.info(f"Ending Training Epoch {self.current_epoch}.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as previous comment

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants