Skip to content

Commit

Permalink
clean code
Browse files Browse the repository at this point in the history
  • Loading branch information
niuzheng168 committed Sep 19, 2024
1 parent bac660b commit 9017387
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 137 deletions.
15 changes: 1 addition & 14 deletions testllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,4 @@
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
texts.append(generated_text)
end = time.time()
print(f"Time taken: {end - start:.2f}s")
# for text in texts:
# print(text)

# for i in range(5):
# prompts.append(prompts[0])
# prompts.append(prompts[1])

# sampling_params = SamplingParams(temperature=1, top_k=1, max_tokens=100)
# outputs = llm.generate(prompts, sampling_params)
# for output in outputs:
# prompt = output.prompt
# generated_text = output.outputs[0].text
# print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
print(f"Time taken: {end - start:.2f}s")
216 changes: 128 additions & 88 deletions tts_fish.py
Original file line number Diff line number Diff line change
@@ -1,83 +1,90 @@
import asyncio
from vllm import LLM, SamplingParams
from tokenizers import Tokenizer
import pypinyin
import torch

from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
torch.random.manual_seed(999)
# tts1 = torch.load('/home/zhn/ttslm_dev/GPT_merged_emb_nonorm.pt')
# tts2 = torch.load('/home/zhn/fishtts/checkpoint-1400000.bak')

# layer = 24
# dim = 1536
# num_audio_tokens = 1026
# num_text_tokens = 7002
# llama = tts2['model']['llama']

# llama.pop('freqs_cis')
# llama.pop('causal_mask')

# text_emb = llama['text_embeddings.weight']
# for i in range(100):
# text_emb = torch.cat([text_emb, torch.zeros((1,dim), device=text_emb.device)], 0)
# llama['emb_text.weight'] = text_emb
# llama.pop('text_embeddings.weight')

# llama['emb_code.0.weight'] = llama['code_embeddings.weight'][0:num_audio_tokens]
# llama['emb_code.1.weight'] = llama['code_embeddings.weight'][num_audio_tokens-2:num_audio_tokens - 2 + num_audio_tokens]
# llama.pop('code_embeddings.weight')

# for i in range(layer):
# qkv_name = f'layers.{i}.attention.wqkv.weight'
# q = llama[qkv_name][0:dim]
# k = llama[qkv_name][dim:2*dim]
# v = llama[qkv_name][2*dim:]
# llama[f'gpt.layers.{i}.self_attn.q_proj.weight'] = q
# llama[f'gpt.layers.{i}.self_attn.k_proj.weight'] = k
# llama[f'gpt.layers.{i}.self_attn.v_proj.weight'] = v
# llama.pop(qkv_name)

# wo_name = f'layers.{i}.attention.wo.weight'
# wo = llama[wo_name]
# llama[f'gpt.layers.{i}.self_attn.o_proj.weight'] = wo
# llama.pop(wo_name)

# gate_proj_name = f'layers.{i}.feed_forward.w1.weight'
# w_gate = llama[gate_proj_name]
# llama[f'gpt.layers.{i}.mlp.gate_proj.weight'] = w_gate
# llama.pop(gate_proj_name)

# gate_up_proj_name = f'layers.{i}.feed_forward.w3.weight'
# w_gate_up = llama[gate_up_proj_name]
# llama[f'gpt.layers.{i}.mlp.up_proj.weight'] = w_gate_up
# llama.pop(gate_up_proj_name)

# gate_down_proj_name = f'layers.{i}.feed_forward.w2.weight'
# w_gate_down = llama[gate_down_proj_name]
# llama[f'gpt.layers.{i}.mlp.down_proj.weight'] = w_gate_down
# llama.pop(gate_down_proj_name)

# attn_norm_name = f'layers.{i}.attention_norm.weight'
# w_attn_norm = llama[attn_norm_name]
# llama[f'gpt.layers.{i}.input_layernorm.weight'] = w_attn_norm
# llama.pop(attn_norm_name)

# ffn_norm_name = f'layers.{i}.ffn_norm.weight'
# w_ffn_norm = llama[ffn_norm_name]
# llama[f'gpt.layers.{i}.post_attention_layernorm.weight'] = w_ffn_norm
# llama.pop(ffn_norm_name)


# norm_name = 'norm.weight'
# w_norm = llama[norm_name]
# llama['gpt.norm.weight'] = w_norm
# llama.pop(norm_name)

# output_name = 'output.weight'
# w_output = llama[output_name]
# llama['lm_head.0.weight'] = w_output[num_text_tokens:num_text_tokens+num_audio_tokens]
# llama['lm_head.1.weight'] = w_output[num_text_tokens+num_audio_tokens:num_text_tokens+num_audio_tokens*2]
# llama.pop(output_name)

# torch.save(llama, '/home/zhn/fishtts/llama.pt')

def convert_model():
tts2 = torch.load('/home/zhn/fishtts/checkpoint-1400000.bak')

layer = 24
dim = 1536
num_audio_tokens = 1026
num_text_tokens = 7002
llama = tts2['model']['llama']

llama.pop('freqs_cis')
llama.pop('causal_mask')

text_emb = llama['text_embeddings.weight']
for i in range(100):
text_emb = torch.cat([text_emb, torch.zeros((1,dim), device=text_emb.device)], 0)
llama['emb_text.weight'] = text_emb
llama.pop('text_embeddings.weight')

llama['emb_code.0.weight'] = llama['code_embeddings.weight'][0:num_audio_tokens]
llama['emb_code.1.weight'] = llama['code_embeddings.weight'][num_audio_tokens-2:num_audio_tokens - 2 + num_audio_tokens]
llama.pop('code_embeddings.weight')

for i in range(layer):
qkv_name = f'layers.{i}.attention.wqkv.weight'
q = llama[qkv_name][0:dim]
k = llama[qkv_name][dim:2*dim]
v = llama[qkv_name][2*dim:]
llama[f'gpt.layers.{i}.self_attn.q_proj.weight'] = q
llama[f'gpt.layers.{i}.self_attn.k_proj.weight'] = k
llama[f'gpt.layers.{i}.self_attn.v_proj.weight'] = v
llama.pop(qkv_name)

wo_name = f'layers.{i}.attention.wo.weight'
wo = llama[wo_name]
llama[f'gpt.layers.{i}.self_attn.o_proj.weight'] = wo
llama.pop(wo_name)

gate_proj_name = f'layers.{i}.feed_forward.w1.weight'
w_gate = llama[gate_proj_name]
llama[f'gpt.layers.{i}.mlp.gate_proj.weight'] = w_gate
llama.pop(gate_proj_name)

gate_up_proj_name = f'layers.{i}.feed_forward.w3.weight'
w_gate_up = llama[gate_up_proj_name]
llama[f'gpt.layers.{i}.mlp.up_proj.weight'] = w_gate_up
llama.pop(gate_up_proj_name)

gate_down_proj_name = f'layers.{i}.feed_forward.w2.weight'
w_gate_down = llama[gate_down_proj_name]
llama[f'gpt.layers.{i}.mlp.down_proj.weight'] = w_gate_down
llama.pop(gate_down_proj_name)

attn_norm_name = f'layers.{i}.attention_norm.weight'
w_attn_norm = llama[attn_norm_name]
llama[f'gpt.layers.{i}.input_layernorm.weight'] = w_attn_norm
llama.pop(attn_norm_name)

ffn_norm_name = f'layers.{i}.ffn_norm.weight'
w_ffn_norm = llama[ffn_norm_name]
llama[f'gpt.layers.{i}.post_attention_layernorm.weight'] = w_ffn_norm
llama.pop(ffn_norm_name)


norm_name = 'norm.weight'
w_norm = llama[norm_name]
llama['gpt.norm.weight'] = w_norm
llama.pop(norm_name)

output_name = 'output.weight'
w_output = llama[output_name]
llama['lm_head.0.weight'] = w_output[num_text_tokens:num_text_tokens+num_audio_tokens]
llama['lm_head.1.weight'] = w_output[num_text_tokens+num_audio_tokens:num_text_tokens+num_audio_tokens*2]
llama.pop(output_name)

torch.save(llama, '/home/zhn/fishtts/llama.pt')

streaming=True

texts = [
'城市霓虹,夜幕低垂,梦想之光,闪烁不已。心向未来,勇往直前,在星空下,奋斗的旋律。',
Expand All @@ -96,16 +103,49 @@
token_ids.append(7003)
llm_inputs.append(token_ids)

llm = LLM(model='/home/zhn/fishtts', gpu_memory_utilization=0.5, dtype=torch.float32, skip_tokenizer_init=True)
prompts = [
{"prompt_token_ids": llm_input} for llm_input in llm_inputs
]

sampling_params = SamplingParams(temperature=1, detokenize=False, stop_token_ids=[1025], ignore_eos=True, max_tokens=2048, top_k=1, repetition_penalty=1.5, repetition_window=16)
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
print(output.prompt)
token_ids = output.outputs[0].token_ids
for token_id in token_ids:
print([x - 0 for x in token_id])
print(len(token_ids))
if not streaming:
llm = LLM(model='/home/zhn/fishtts', gpu_memory_utilization=0.5, dtype=torch.float32, skip_tokenizer_init=True)
prompts = [
{"prompt_token_ids": llm_input} for llm_input in llm_inputs
]

sampling_params = SamplingParams(temperature=1, detokenize=False, stop_token_ids=[1025], ignore_eos=True, max_tokens=2048, top_k=1, repetition_penalty=1.5, repetition_window=16)
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
print(output.prompt)
token_ids = output.outputs[0].token_ids
for token_id in token_ids:
print([x - 0 for x in token_id])
print(len(token_ids))

else:
engine_args = AsyncEngineArgs(model='/home/zhn/fishtts', gpu_memory_utilization=0.5, dtype=torch.float32, skip_tokenizer_init=True)
model = AsyncLLMEngine.from_engine_args(engine_args)
sampling_params = SamplingParams(temperature=1, detokenize=False, stop_token_ids=[1025], max_tokens=2048, top_k=1, repetition_penalty=1.5, repetition_window=16)
prompts = [
{"prompt_token_ids": llm_input} for llm_input in llm_inputs
]

async def generate_streaming(prompt, id):
results_generator = model.generate(prompt, sampling_params, request_id=id)
count=0
tokens = []
async for request_output in results_generator:
token_ids = request_output.outputs[0].token_ids
print(f'{id} {[x - 0 for x in token_ids[-1]]}')
tokens.append([x - 0 for x in token_ids[-1]])
count+=1

print(id)
print(len(tokens))
for token in tokens:
print(token)

async def generate():
tasks = []
for i in range(2):
t = generate_streaming(prompts[i%2], i)
tasks.append(t)
await asyncio.gather(*tasks)

asyncio.run(generate())
2 changes: 1 addition & 1 deletion vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
"JambaForCausalLM": ("jamba", "JambaForCausalLM"),
"GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
"ChatTtsLlm": ("ttslm", "ChatTtsLlm")
"FishTtsLlm": ("fishtts", "FishTtsLlm")
}

_EMBEDDING_MODELS = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def get_max_speech_tokens(ctx: InputContext):
@MULTIMODAL_REGISTRY.register_speech_input_mapper()
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_ttsllm)
@MULTIMODAL_REGISTRY.register_max_speech_tokens(get_max_speech_tokens)
class ChatTtsLlm(nn.Module):
class FishTtsLlm(nn.Module):
def __init__(self,
config: LlamaConfig,
cache_config: Optional[CacheConfig] = None,
Expand All @@ -72,7 +72,6 @@ def __init__(self,
])
self.logits_processor = LogitsProcessor(self.num_audio_tokens)
self.sampler = MultiheadSampler()
# self.samplers = [Sampler(head_idx) for head_idx in range(self.num_output_head)]

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
Expand Down Expand Up @@ -149,12 +148,6 @@ def sample(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
# head_logits = logits.permute(1, 0, 2)
# next_tokens = self.samplers[0](head_logits[0], sampling_metadata)
# for i in range(self.num_output_head - 1):
# output = self.samplers[i](head_logits[i + 1], sampling_metadata)
# self.merge_sample_results(next_tokens, output)

next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens

Expand All @@ -173,8 +166,6 @@ def forward(
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids, is_prompt)
# spk_emb = kwargs.get("speech", None)
# self.apply_spk_emb(hidden_states, spk_emb, attn_metadata, input_ids)
model_output = self.gpt(
input_ids=input_ids,
inputs_embeds=hidden_states,
Expand Down
24 changes: 0 additions & 24 deletions vllm/multimodal/speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,30 +18,6 @@
import base64
import pickle

class FishSpeechPlugin(MultiModalPlugin):

def get_data_key(self) -> str:
return "audio"

def _default_input_mapper(self, ctx: InputContext,
data: object) -> MultiModalInputs:
if isinstance(data, str):
base64_decoded = base64.b64decode(data)
deserialized_data = pickle.loads(base64_decoded)
tensor = torch.from_numpy(deserialized_data)
return MultiModalInputs({"audio": tensor})
elif isinstance(data, torch.Tensor):
raise NotImplementedError("Embeddings input is not supported yet")

raise TypeError(f"Invalid image type: {type(data)}")

def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
return 16

@staticmethod
def get_default_audio():
return 'a'

class SpeechPlugin(MultiModalPlugin):

def get_data_key(self) -> str:
Expand Down

0 comments on commit 9017387

Please sign in to comment.