Skip to content

Commit

Permalink
added preload of z
Browse files Browse the repository at this point in the history
  • Loading branch information
NilsDem committed Dec 18, 2024
1 parent f0a4579 commit 8fc6b2b
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 19 deletions.
2 changes: 1 addition & 1 deletion diffusion/configs/main.gin
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ X_LENGTH = 131072
AE_EMBSIZE = 32
AE_FACTOR = 1024
SR = 24000
ZT_CHANNELS = 32
ZT_CHANNELS = 16
ZS_CHANNELS = 16

N_MELS = 256
Expand Down
88 changes: 70 additions & 18 deletions train_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def main(args):
)

if args.restart > 0:
config_path = "./runs/" + args.name + "/config.gin"
config_path = "./diffusion/runs/" + args.name + "/config.gin"
with gin.unlock_config():
gin.parse_config_files_and_bindings([config_path], [])

Expand All @@ -60,31 +60,70 @@ def main(args):
"cpu") #if args.use_accelerator else Accelerator()
model = model.to(model.accelerator.device)

### GET AE RATIO ###
dummy_x = torch.randn(1, 1, 4096).to(model.accelerator.device)
z = model.emb_model.encode(dummy_x)
ae_ratio = dummy_x.shape[-1] // z.shape[-1]

######### GET THE DATASET #########

if args.dataset_type == "waveform":
dataset = SimpleDataset(path=args.db_path, keys=["waveform"])
dataset = SimpleDataset(path=args.db_path, keys=["waveform", "z"])

try:
dataset[0]["z"]
z_precomputed = True
except:
z_precomputed = False
dataset.buffer_keys = ["waveform"]
print(
"Using on the fly AE encoding, training will be slow. Use split_to_lmdb.py with emb_model arg to precompute z"
)

dataset, valset = torch.utils.data.random_split(
dataset,
(len(dataset) - int(0.95 * len(dataset)), int(
0.95 * len(dataset))))

x_length = gin.query_parameter("%X_LENGTH")
z_length = x_length // ae_ratio

def collate_fn(L):

x = np.stack([l["waveform"] for l in L])
x = torch.from_numpy(x).float().reshape((x.shape[0], 1, -1))

i0 = np.random.randint(0, x.shape[-1] - x_length, x.shape[0])
x_diff = torch.stack(
[xc[..., i:i + x_length] for i, xc in zip(i0, x)])
if z_precomputed:
z = np.stack([l["z"] for l in L])
z = torch.from_numpy(z).float()

i1 = np.random.randint(0, x.shape[-1] - x_length, x.shape[0])
x_toz = torch.stack(
[xc[..., i:i + x_length] for i, xc in zip(i1, x)])
i0 = np.random.randint(0, x.shape[-1] // ae_ratio - z_length,
x.shape[0])

i1 = np.random.randint(0, x.shape[-1] // ae_ratio - z_length,
x.shape[0])

x_diff = torch.stack([
xc[..., i * ae_ratio:i * ae_ratio + x_length]
for i, xc in zip(i0, x)
])

if z_precomputed:
z_diff = torch.stack(
[xc[..., i:i + z_length] for i, xc in zip(i0, z)])

x_toz = torch.stack(
[xc[..., i:i + z_length] for i, xc in zip(i1, z)])

else:
z_diff = x_diff
x_toz = torch.stack([
xc[..., i * ae_ratio:i * ae_ratio + x_length]
for i, xc in zip(i1, x)
])

return {
"x": x_diff,
"x": z_diff,
"x_time_cond": x_diff,
"x_toz": x_toz,
}
Expand All @@ -97,11 +136,17 @@ def collate_fn(L):
0.95 * len(dataset))))

x_length = gin.query_parameter("%X_LENGTH")
ae_ratio = gin.query_parameter("%AE_FACTOR")

def collate_fn(L):
x = np.stack([l["waveform"] for l in L])
x = torch.from_numpy(x).float().reshape((x.shape[0], 1, -1))
if z_precomputed:
x = np.stack([l["z"] for l in L])
x = torch.from_numpy(x).float()
length = x_length // ae_ratio

else:
x = np.stack([l["waveform"] for l in L])
x = torch.from_numpy(x).float().reshape((x.shape[0], 1, -1))
length = x_length

pr = [l["pr"] for l in L]
pr = map(normalize, pr)
Expand All @@ -110,16 +155,23 @@ def collate_fn(L):

i0 = np.random.randint(0, pr.shape[-1] - x_length // ae_ratio,
x.shape[0])
x_diff = torch.stack([
xc[..., i * ae_ratio:i * ae_ratio + x_length]
for i, xc in zip(i0, x)
])

if z_precomputed:
x_diff = torch.stack(
[xc[..., i:i + length] for i, xc in zip(i0, x)])

else:
x_diff = torch.stack([
xc[..., i * ae_ratio:i * ae_ratio + x_length]
for i, xc in zip(i0, x)
])

pr = torch.stack(
[xc[..., i:i + x_length // ae_ratio] for i, xc in zip(i0, pr)])

i1 = np.random.randint(0, x.shape[-1] - x_length, x.shape[0])
i1 = np.random.randint(0, x.shape[-1] - length, x.shape[0])
x_toz = torch.stack(
[xc[..., i:i + x_length] for i, xc in zip(i1, x)])
[xc[..., i:i + length] for i, xc in zip(i1, x)])

return {
"x": x_diff,
Expand Down

0 comments on commit 8fc6b2b

Please sign in to comment.