-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Intermediate time generation #118
base: main
Are you sure you want to change the base?
Conversation
…to intermediate_time_generation
@@ -562,12 +574,17 @@ def generate_samples(self): | |||
assert ( | |||
self.hyper_params.diffusion_sampling_parameters is not None | |||
), "sampling parameters must be provided to create a generator." | |||
with torch.no_grad(): | |||
with ((torch.no_grad())): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is that necessary?
weighted_regularizer_loss = ( | ||
self.regularizer.compute_weighted_regularizer_loss( | ||
score_network=self.axl_network, | ||
augmented_batch=augmented_batch, | ||
current_epoch=self.current_epoch)) | ||
|
||
t2 = time.time() | ||
model_regularization_time = t2 - t1 | ||
logger.info(f" - batch {batch_idx} :: Prediction time = {model_prediction_time:2.1e} s, " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this useful beyond debugging? I'm fine with leaving it in the log, but I'm not sure it adds something.
|
||
def on_train_epoch_start(self) -> None: | ||
"""On train epoch start.""" | ||
logger.info(f"Starting Training Epoch {self.current_epoch}.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
doesn't this overlap with the standard progressbar in the log? We have an option to activate / deactivate it.
|
||
def on_train_epoch_end(self) -> None: | ||
"""On train epoch end.""" | ||
logger.info(f"Ending Training Epoch {self.current_epoch}.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as previous comment
In this PR, I create an experiment to noise a validation data example to an intermediate time and denoise it back again to see when we start having problems. The basic idea is documented in this Notion note: https://www.notion.so/4f446345516e4b499126ad1e97d026cd?v=0618487565864066a0c6c96b4c655d2c&p=17a690083be4809692fef56c24bb97d2&pm=s.
In a nutshell, it lets me create these kinds of images.
Also, I make a few things into
torch.nn.Module
because it is a PITA to constantly have to figure out what device we are on. Some of the auxiliary objects we create (such as a regularizer, for example), are created before the whole PL is shipped off to the GPU, so I can't know at creation time if I'm going to CPU or GPU... Better to work with the framework instead of fighting against it, and let PL handle devices.