Skip to content

Commit

Permalink
VocabParallelEmbedding
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu committed Jul 16, 2024
1 parent a632185 commit 3e127db
Showing 1 changed file with 2 additions and 5 deletions.
7 changes: 2 additions & 5 deletions vllm/model_executor/models/ttslm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead
from vllm.model_executor.layers.vocab_parallel_embedding import DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
Expand Down Expand Up @@ -46,8 +46,6 @@ def __init__(self,
super().__init__()

# static parameters, put them in config later
self.spk_emb_dim = 192
self.spk_KL = 8
self.num_audio_tokens = 626
self.num_text_tokens = 21178
self.num_vq = 4
Expand All @@ -56,10 +54,9 @@ def __init__(self,
self.gpt = LlamaModel(config)
self.model_dim = self.gpt.config.hidden_size
self.emb_all = nn.ModuleList([
nn.Embedding(self.num_audio_tokens + self.num_text_tokens, self.model_dim) for _ in range(self.num_vq)
VocabParallelEmbedding(self.num_audio_tokens + self.num_text_tokens, self.model_dim) for _ in range(self.num_vq)
])

self.head_text = weight_norm(nn.Linear(self.model_dim, self.num_text_tokens, bias=False), name='weight')
self.lm_head = nn.ModuleList([
nn.Linear(self.model_dim, self.num_audio_tokens, bias=False) for _ in range(self.num_vq)
])
Expand Down

0 comments on commit 3e127db

Please sign in to comment.