diff --git a/generate.py b/generate.py index 9a47d4f..876c34a 100755 --- a/generate.py +++ b/generate.py @@ -528,8 +528,6 @@ def generate_lowmem(args): if __name__ == "__main__": - # 限制模型使用的显存为0.6 - torch.cuda.set_per_process_memory_fraction(0.55) torch.backends.cudnn.enabled = False OmegaConf.register_new_resolver("eval", lambda x: eval(x)) OmegaConf.register_new_resolver("concat", lambda *x: [xxx for xx in x for xxx in xx]) diff --git a/tmp/sample_prompt_audio_vocal.flac b/tmp/sample_prompt_audio_vocal.flac deleted file mode 100644 index 6630678..0000000 Binary files a/tmp/sample_prompt_audio_vocal.flac and /dev/null differ diff --git a/tools/gradio/app.py b/tools/gradio/app.py index 76b976b..dc8eedc 100644 --- a/tools/gradio/app.py +++ b/tools/gradio/app.py @@ -6,6 +6,7 @@ import yaml import time import re import os.path as op +import torch from levo_inference_lowmem import LeVoInference EXAMPLE_LYRICS = """ @@ -98,7 +99,7 @@ def generate_song(lyric, description=None, prompt_audio=None, genre=None, cfg_co progress(0.0, "Start Generation") start = time.time() - audio_data = MODEL(lyric_norm, description, prompt_audio, genre, op.join(APP_DIR, "ckpt/prompt.pt"), gen_type, params).cpu().permute(1, 0).float().numpy() + audio_data = MODEL(lyric_norm, description, prompt_audio, genre, op.join(APP_DIR, "tools/new_prompt.pt"), gen_type, params).cpu().permute(1, 0).float().numpy() end = time.time() @@ -239,4 +240,5 @@ lyrics # 启动应用 if __name__ == "__main__": + torch.set_num_threads(1) demo.launch(server_name="0.0.0.0", server_port=8081) diff --git a/tools/gradio/levo_inference.py b/tools/gradio/levo_inference.py index 87d5241..4699cb0 100644 --- a/tools/gradio/levo_inference.py +++ b/tools/gradio/levo_inference.py @@ -71,11 +71,7 @@ class LeVoInference(torch.nn.Module): melody_is_wav = True elif genre is not None and auto_prompt_path is not None: auto_prompt = torch.load(auto_prompt_path) - merge_prompt = [item for sublist in auto_prompt.values() for item in sublist] - if genre == "Auto": - prompt_token = merge_prompt[np.random.randint(0, len(merge_prompt))] - else: - prompt_token = auto_prompt[genre][np.random.randint(0, len(auto_prompt[genre]))] + prompt_token = auto_prompt[genre][np.random.randint(0, len(auto_prompt[genre]))] pmt_wav = prompt_token[:,[0],:] vocal_wav = prompt_token[:,[1],:] bgm_wav = prompt_token[:,[2],:] diff --git a/tools/gradio/levo_inference_lowmem.py b/tools/gradio/levo_inference_lowmem.py index e97f131..e59148b 100644 --- a/tools/gradio/levo_inference_lowmem.py +++ b/tools/gradio/levo_inference_lowmem.py @@ -66,11 +66,7 @@ class LeVoInference(torch.nn.Module): torch.cuda.empty_cache() elif genre is not None and auto_prompt_path is not None: auto_prompt = torch.load(auto_prompt_path) - merge_prompt = [item for sublist in auto_prompt.values() for item in sublist] - if genre == "Auto": - prompt_token = merge_prompt[np.random.randint(0, len(merge_prompt))] - else: - prompt_token = auto_prompt[genre][np.random.randint(0, len(auto_prompt[genre]))] + prompt_token = auto_prompt[genre][np.random.randint(0, len(auto_prompt[genre]))] pmt_wav = prompt_token[:,[0],:] vocal_wav = prompt_token[:,[1],:] bgm_wav = prompt_token[:,[2],:] diff --git a/tools/gradio/run.sh b/tools/gradio/run.sh index 6961792..a90e62e 100644 --- a/tools/gradio/run.sh +++ b/tools/gradio/run.sh @@ -1,3 +1,7 @@ +export OMP_NUM_THREADS=1 +export MKL_NUM_THREADS=1 +export CUDA_LAUNCH_BLOCKING=0 + export USER=root export PYTHONDONTWRITEBYTECODE=1 export TRANSFORMERS_CACHE="$(pwd)/third_party/hub"