update gradio tools
This commit is contained in:
parent
7a672bc167
commit
d3003b918f
|
|
@ -528,8 +528,6 @@ def generate_lowmem(args):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 限制模型使用的显存为0.6
|
||||
torch.cuda.set_per_process_memory_fraction(0.55)
|
||||
torch.backends.cudnn.enabled = False
|
||||
OmegaConf.register_new_resolver("eval", lambda x: eval(x))
|
||||
OmegaConf.register_new_resolver("concat", lambda *x: [xxx for xx in x for xxx in xx])
|
||||
|
|
|
|||
Binary file not shown.
|
|
@ -6,6 +6,7 @@ import yaml
|
|||
import time
|
||||
import re
|
||||
import os.path as op
|
||||
import torch
|
||||
from levo_inference_lowmem import LeVoInference
|
||||
|
||||
EXAMPLE_LYRICS = """
|
||||
|
|
@ -98,7 +99,7 @@ def generate_song(lyric, description=None, prompt_audio=None, genre=None, cfg_co
|
|||
progress(0.0, "Start Generation")
|
||||
start = time.time()
|
||||
|
||||
audio_data = MODEL(lyric_norm, description, prompt_audio, genre, op.join(APP_DIR, "ckpt/prompt.pt"), gen_type, params).cpu().permute(1, 0).float().numpy()
|
||||
audio_data = MODEL(lyric_norm, description, prompt_audio, genre, op.join(APP_DIR, "tools/new_prompt.pt"), gen_type, params).cpu().permute(1, 0).float().numpy()
|
||||
|
||||
end = time.time()
|
||||
|
||||
|
|
@ -239,4 +240,5 @@ lyrics
|
|||
|
||||
# 启动应用
|
||||
if __name__ == "__main__":
|
||||
torch.set_num_threads(1)
|
||||
demo.launch(server_name="0.0.0.0", server_port=8081)
|
||||
|
|
|
|||
|
|
@ -71,11 +71,7 @@ class LeVoInference(torch.nn.Module):
|
|||
melody_is_wav = True
|
||||
elif genre is not None and auto_prompt_path is not None:
|
||||
auto_prompt = torch.load(auto_prompt_path)
|
||||
merge_prompt = [item for sublist in auto_prompt.values() for item in sublist]
|
||||
if genre == "Auto":
|
||||
prompt_token = merge_prompt[np.random.randint(0, len(merge_prompt))]
|
||||
else:
|
||||
prompt_token = auto_prompt[genre][np.random.randint(0, len(auto_prompt[genre]))]
|
||||
prompt_token = auto_prompt[genre][np.random.randint(0, len(auto_prompt[genre]))]
|
||||
pmt_wav = prompt_token[:,[0],:]
|
||||
vocal_wav = prompt_token[:,[1],:]
|
||||
bgm_wav = prompt_token[:,[2],:]
|
||||
|
|
|
|||
|
|
@ -66,11 +66,7 @@ class LeVoInference(torch.nn.Module):
|
|||
torch.cuda.empty_cache()
|
||||
elif genre is not None and auto_prompt_path is not None:
|
||||
auto_prompt = torch.load(auto_prompt_path)
|
||||
merge_prompt = [item for sublist in auto_prompt.values() for item in sublist]
|
||||
if genre == "Auto":
|
||||
prompt_token = merge_prompt[np.random.randint(0, len(merge_prompt))]
|
||||
else:
|
||||
prompt_token = auto_prompt[genre][np.random.randint(0, len(auto_prompt[genre]))]
|
||||
prompt_token = auto_prompt[genre][np.random.randint(0, len(auto_prompt[genre]))]
|
||||
pmt_wav = prompt_token[:,[0],:]
|
||||
vocal_wav = prompt_token[:,[1],:]
|
||||
bgm_wav = prompt_token[:,[2],:]
|
||||
|
|
|
|||
|
|
@ -1,3 +1,7 @@
|
|||
export OMP_NUM_THREADS=1
|
||||
export MKL_NUM_THREADS=1
|
||||
export CUDA_LAUNCH_BLOCKING=0
|
||||
|
||||
export USER=root
|
||||
export PYTHONDONTWRITEBYTECODE=1
|
||||
export TRANSFORMERS_CACHE="$(pwd)/third_party/hub"
|
||||
|
|
|
|||
Loading…
Reference in New Issue