Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Predict Moshi User Stream #188

Open
1 task done
jlian2 opened this issue Jan 20, 2025 · 2 comments
Open
1 task done

Predict Moshi User Stream #188

jlian2 opened this issue Jan 20, 2025 · 2 comments
Labels
question Further information is requested

Comments

@jlian2
Copy link

jlian2 commented Jan 20, 2025

Due diligence

  • I have done my due diligence in trying to find the answer myself.

Topic

The PyTorch implementation

Question

Image

I have question for predicting user stream. I think we can simply further predicting 8 more codes in lm.py:

` def depformer_step(
self,
text_token: torch.Tensor,
transformer_out: torch.Tensor, #[1, 1, 4096]
generate_user: bool = False, #let us add this
) -> torch.Tensor:
(B,) = text_token.shape
prev_token = text_token #so here we can use any random token if not using text
# prev_token = 0 * text_token #here we get rid of text
lm_model = self.lm_model
depformer_tokens: list[torch.Tensor] = []
user_tokens: list[torch.Tensor] = []
assert not lm_model.depformer.is_streaming

    # print(f"text_token is {text_token}") #discrete token
    # this is audio token generation
    # print(f"dep_q is {lm_model.dep_q}") #8
    
    with lm_model.depformer.streaming(B):
        assert lm_model.depformer.is_streaming
        for cb_index in range(lm_model.dep_q): #8 #autoregressive modeling
            input_ = prev_token[:, None, None]
            logits = lm_model.forward_depformer(cb_index, input_, transformer_out)
            # print(logits.shape) # 2048
            next_token = sample_token(
                logits.float(),
                self.use_sampling,
                self.temp,
                self.top_k,
            )
            assert next_token.shape == (B, 1, 1)
            next_token = next_token[:, 0, 0]  # shape is B
            depformer_tokens.append(next_token)
            prev_token = next_token
            
        # if generate_user:
            # for cb_index in range(lm_model.dep_q, 2 * lm_model.dep_q):
            # for cb_index in range(lm_model.dep_q): # we guess it uses modd
            #     input_ = prev_token[:, None, None]
            #     logits = lm_model.forward_depformer(cb_index, input_, transformer_out)
            #     next_token = sample_token(
            #         logits.float(),
            #         self.use_sampling,
            #         self.temp,
            #         self.top_k,
            #     )
            #     assert next_token.shape == (B, 1, 1)
            #     next_token = next_token[:, 0, 0]  # shape is B
            #     user_tokens.append(next_token)
            #     prev_token = next_token`

However, there seems to be bugs and the. pretrained model weights only have 8 codes support. I was wondering if there are any tricks I can apply to predict 8 codes for user stream as well to perform offline evaluation.

@jlian2 jlian2 added the question Further information is requested label Jan 20, 2025
@jlian2
Copy link
Author

jlian2 commented Jan 20, 2025

BTW, I checked the dep_q is releated model checkpoints is 8 instead of 16. So I think the current model is not able to predict user stream right? Meaning the offline prediction is currently unavailable?

@slavivo
Copy link

slavivo commented Feb 23, 2025

I would also like to know if that is the case. If this functionality isn't available, is it still possible to perform the experiments in section 5.4 by treating the audio as the system stream? Thank you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants