You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have done my due diligence in trying to find the answer myself.
Topic
The PyTorch implementation
Question
I have question for predicting user stream. I think we can simply further predicting 8 more codes in lm.py:
` def depformer_step(
self,
text_token: torch.Tensor,
transformer_out: torch.Tensor, #[1, 1, 4096]
generate_user: bool = False, #let us add this
) -> torch.Tensor:
(B,) = text_token.shape
prev_token = text_token #so here we can use any random token if not using text
# prev_token = 0 * text_token #here we get rid of text
lm_model = self.lm_model
depformer_tokens: list[torch.Tensor] = []
user_tokens: list[torch.Tensor] = []
assert not lm_model.depformer.is_streaming
# print(f"text_token is {text_token}") #discrete token
# this is audio token generation
# print(f"dep_q is {lm_model.dep_q}") #8
with lm_model.depformer.streaming(B):
assert lm_model.depformer.is_streaming
for cb_index in range(lm_model.dep_q): #8 #autoregressive modeling
input_ = prev_token[:, None, None]
logits = lm_model.forward_depformer(cb_index, input_, transformer_out)
# print(logits.shape) # 2048
next_token = sample_token(
logits.float(),
self.use_sampling,
self.temp,
self.top_k,
)
assert next_token.shape == (B, 1, 1)
next_token = next_token[:, 0, 0] # shape is B
depformer_tokens.append(next_token)
prev_token = next_token
# if generate_user:
# for cb_index in range(lm_model.dep_q, 2 * lm_model.dep_q):
# for cb_index in range(lm_model.dep_q): # we guess it uses modd
# input_ = prev_token[:, None, None]
# logits = lm_model.forward_depformer(cb_index, input_, transformer_out)
# next_token = sample_token(
# logits.float(),
# self.use_sampling,
# self.temp,
# self.top_k,
# )
# assert next_token.shape == (B, 1, 1)
# next_token = next_token[:, 0, 0] # shape is B
# user_tokens.append(next_token)
# prev_token = next_token`
However, there seems to be bugs and the. pretrained model weights only have 8 codes support. I was wondering if there are any tricks I can apply to predict 8 codes for user stream as well to perform offline evaluation.
The text was updated successfully, but these errors were encountered:
BTW, I checked the dep_q is releated model checkpoints is 8 instead of 16. So I think the current model is not able to predict user stream right? Meaning the offline prediction is currently unavailable?
I would also like to know if that is the case. If this functionality isn't available, is it still possible to perform the experiments in section 5.4 by treating the audio as the system stream? Thank you.
Due diligence
Topic
The PyTorch implementation
Question
I have question for predicting user stream. I think we can simply further predicting 8 more codes in lm.py:
` def depformer_step(
self,
text_token: torch.Tensor,
transformer_out: torch.Tensor, #[1, 1, 4096]
generate_user: bool = False, #let us add this
) -> torch.Tensor:
(B,) = text_token.shape
prev_token = text_token #so here we can use any random token if not using text
# prev_token = 0 * text_token #here we get rid of text
lm_model = self.lm_model
depformer_tokens: list[torch.Tensor] = []
user_tokens: list[torch.Tensor] = []
assert not lm_model.depformer.is_streaming
However, there seems to be bugs and the. pretrained model weights only have 8 codes support. I was wondering if there are any tricks I can apply to predict 8 codes for user stream as well to perform offline evaluation.
The text was updated successfully, but these errors were encountered: