diff --git a/.github/workflows/pytvzhen-web.yaml b/.github/workflows/pytvzhen-web.yaml index 11981d2..85c82c6 100644 --- a/.github/workflows/pytvzhen-web.yaml +++ b/.github/workflows/pytvzhen-web.yaml @@ -24,3 +24,5 @@ jobs: run: | python -m pip install --upgrade pip pip install -r requirements.txt + - name: Run PytvzhenAPI integration tests + run: python -m unittest -v test_app.PytvzhenAPITest diff --git a/app.py b/app.py index 9ad6163..156962f 100644 --- a/app.py +++ b/app.py @@ -1,4 +1,5 @@ import os +import json from pytube import YouTube from moviepy.editor import VideoFileClip @@ -15,12 +16,7 @@ from werkzeug.utils import secure_filename app = Flask(__name__) - -app_path = os.path.abspath(__file__) -dir_path = os.path.dirname(app_path) -output_path = os.path.join(dir_path, "output") -model_path = os.path.join(dir_path, "models") -baseline_path = os.path.join(model_path, "baseline.pth") +app.config.from_file("./pytvzhen-config.json", load=json.load) def log_info_return_str(message): @@ -55,6 +51,9 @@ def decorated_func(*args, **kwargs): return decorated_func +log_info_return_str(f"Launching Pytvzhen config: \n\t{app.config}") + + @app.route('/', methods=['GET']) def index(): return render_template('index.html') @@ -79,6 +78,8 @@ def unique_video_fn_with_extension(extension): @app.route('/video_upload', methods=['POST']) def video_upload(): + output_path = app.config['OUTPUT_PATH'] + # check if the post request has the file part if 'file' not in request.files: return jsonify(error='No file part in the POST request'), 400 @@ -104,6 +105,8 @@ def video_upload(): @app.route('/yt_download', methods=['POST']) @require_video_id_from_post_request def yt_download(video_id): + output_path = app.config['OUTPUT_PATH'] + video_fn = f"{video_id}.mp4" video_fhd = f"{video_id}_fhd.mp4" video_save_path = os.path.join(output_path, video_fn) @@ -138,6 +141,7 @@ def yt_download(video_id): @app.route('/yt/', methods=['GET']) def yt_serve(video_id): + output_path = app.config['OUTPUT_PATH'] video_fn = f'{video_id}.mp4' if os.path.exists(os.path.join(output_path, video_fn)): @@ -149,6 +153,8 @@ def yt_serve(video_id): @app.route('/extra_audio', methods=['POST']) @require_video_id_from_post_request def extra_audio(video_id): + output_path = app.config['OUTPUT_PATH'] + video_fn = f'{video_id}.mp4' audio_fn = f'{video_id}.wav' @@ -177,6 +183,8 @@ def extra_audio(video_id): @app.route('/audio/', methods=['GET']) def audio_serve(video_id): + output_path = app.config['OUTPUT_PATH'] + video_fn = f'{video_id}.mp4' audio_fn = f'{video_id}.wav' @@ -191,6 +199,8 @@ def audio_serve(video_id): @app.route('/remove_audio_bg', methods=['POST']) @require_video_id_from_post_request def remove_audio_bg(video_id): + output_path = app.config['OUTPUT_PATH'] + video_fn = f'{video_id}.mp4' audio_fn = f'{video_id}.wav' audio_no_bg_fn, audio_bg_fn = f'{video_id}_no_bg.wav', f'{video_id}_bg.wav' @@ -211,7 +221,9 @@ def remove_audio_bg(video_id): f'not found at {output_path}, please extract it first')}), 404 try: - audio_remove(audio_path, audio_no_bg_path, audio_bg_fn_path, baseline_path) + baseline_path = app.config['REMOVE_BACKGROUND_MUSIC_BASELINE_MODEL_PATH'] + audio_remove(audio_path, audio_no_bg_path, audio_bg_fn_path, baseline_path, + app.config['REMOVE_BACKGROUND_MUSIC_TORCH_DEVICE']) return jsonify({"message": log_info_return_str( f"Remove remove background music for {audio_fn} as {audio_no_bg_fn} and {audio_bg_fn_path} successfully."), "video_id": video_id}), 200 @@ -225,6 +237,8 @@ def remove_audio_bg(video_id): @app.route('/audio_no_bg/', methods=['GET']) def audio_no_bg_serve(video_id): + output_path = app.config['OUTPUT_PATH'] + audio_no_bg_fn = f'{video_id}_no_bg.wav' audio_no_bg_path = os.path.join(output_path, audio_no_bg_fn) @@ -237,6 +251,8 @@ def audio_no_bg_serve(video_id): @app.route('/audio_bg/', methods=['GET']) def audio_bg_serve(video_id): + output_path = app.config['OUTPUT_PATH'] + audio_bg_fn = f'{video_id}_bg.wav' audio_bg_path = os.path.join(output_path, audio_bg_fn) @@ -250,6 +266,8 @@ def audio_bg_serve(video_id): @app.route('/transcribe', methods=['POST']) @require_video_id_from_post_request def transcribe(video_id): + output_path = app.config['OUTPUT_PATH'] + transcribe_model = "medium" en_srt_fn, en_srt_merged_fn, audio_no_bg_fn = f'{video_id}_en.srt', f'{video_id}_en_merged.srt', f'{video_id}_no_bg.wav' @@ -285,6 +303,8 @@ def transcribe(video_id): @app.route('/srt_en/', methods=['GET']) def srt_en_serve(video_id): + output_path = app.config['OUTPUT_PATH'] + en_srt_fn = f'{video_id}_en.srt' en_srt_path = os.path.join(output_path, en_srt_fn) @@ -297,6 +317,8 @@ def srt_en_serve(video_id): @app.route('/srt_en_merged/', methods=['GET']) def srt_en_merged_serve(video_id): + output_path = app.config['OUTPUT_PATH'] + en_srt_merged_fn = f'{video_id}_en_merged.srt' en_srt_merged_path = os.path.join(output_path, en_srt_merged_fn) @@ -310,6 +332,8 @@ def srt_en_merged_serve(video_id): @app.route('/translate_to_zh', methods=['POST']) @require_video_id_from_post_request def transhlate_to_zh(video_id): + output_path = app.config['OUTPUT_PATH'] + data = request.get_json() video_id = data['video_id'] translateVendor = data['translate_vendor'] @@ -375,6 +399,8 @@ def transhlate_to_zh(video_id): @app.route('/srt_zh_merged/', methods=['GET']) def srt_zh_merged_serve(video_id): + output_path = app.config['OUTPUT_PATH'] + zh_srt_merged_fn = f'{video_id}_zh_merged.srt' zh_srt_merged_path = os.path.join(output_path, zh_srt_merged_fn) @@ -388,6 +414,8 @@ def srt_zh_merged_serve(video_id): @app.route('/voice_connect', methods=['POST']) @require_video_id_from_post_request def voice_connect(video_id): + output_path = app.config['OUTPUT_PATH'] + data = request.get_json() video_id = data['video_id'] voiceDir = os.path.join(output_path, video_id + "_zh_source") @@ -411,6 +439,8 @@ def voice_connect(video_id): @app.route('/voice_connect_log/', methods=['GET']) def voice_connect_log_serve(video_id): + output_path = app.config['OUTPUT_PATH'] + warning_log_fn = video_id + "_connect_warning.log" warning_log_path = os.path.join(output_path, warning_log_fn) @@ -423,6 +453,8 @@ def voice_connect_log_serve(video_id): @app.route('/voice_connect/', methods=['GET']) def voice_connect_serve(video_id): + output_path = app.config['OUTPUT_PATH'] + voice_connect_fn = video_id + "_zh.wav" voice_connect_path = os.path.join(output_path, voice_connect_fn) @@ -436,6 +468,8 @@ def voice_connect_serve(video_id): @app.route('/tts', methods=['POST']) @require_video_id_from_post_request def tts(video_id): + output_path = app.config['OUTPUT_PATH'] + data = request.get_json() video_id = data['video_id'] srt_fn = f'{video_id}_zh_merged.srt' @@ -467,6 +501,8 @@ def tts(video_id): @app.route('/tts/', methods=['GET']) def tts_serve(video_id): + output_path = app.config['OUTPUT_PATH'] + tts_dir = os.path.join(output_path, video_id + "_zh_source") tts_zip_fn = video_id + "_zh_source.zip" tts_zip_path = os.path.join(output_path, tts_zip_fn) @@ -491,6 +527,8 @@ def tts_serve(video_id): @app.route('/video_preview', methods=['POST']) @require_video_id_from_post_request def video_preview(video_id): + output_path = app.config['OUTPUT_PATH'] + data = request.get_json() video_id = data['video_id'] voice_connect_fn = video_id + "_zh.wav" @@ -533,6 +571,8 @@ def video_preview(video_id): @app.route('/video_preview/', methods=['GET']) def video_preview_serve(video_id): + output_path = app.config['OUTPUT_PATH'] + video_preview_fn = f"{video_id}_preview.mp4" video_preview_path = os.path.join(output_path, video_preview_fn) diff --git a/pytvzhen-config.json b/pytvzhen-config.json new file mode 100644 index 0000000..ab56ee5 --- /dev/null +++ b/pytvzhen-config.json @@ -0,0 +1,5 @@ +{ + "OUTPUT_PATH": "./output", + "REMOVE_BACKGROUND_MUSIC_TORCH_DEVICE": "cpu", + "REMOVE_BACKGROUND_MUSIC_BASELINE_MODEL_PATH": "./models/baseline.pth" +} \ No newline at end of file diff --git a/test_app.py b/test_app.py new file mode 100644 index 0000000..6b455f9 --- /dev/null +++ b/test_app.py @@ -0,0 +1,34 @@ +import os +import unittest +import tempfile + +from app import app + + +class PytvzhenAPITest(unittest.TestCase): + + def setUp(self): + self.test_dir = tempfile.mkdtemp() + + app.config['OUTPUT_PATH'] = self.test_dir + app.config['DEBUG'] = True + + self.app = app.test_client() + self.app.testing = True + + def test_download_yt_video_with_valid_video_id(self): + print("download to " + self.test_dir) + response = self.app.post("/yt_download", json={'video_id': 'VwhT-P3pLJs'}) + assert response.status_code == 200 + assert os.path.isfile(os.path.join(self.test_dir, 'VwhT-P3pLJs.mp4')) + + def tearDown(self): + for root, dirs, files in os.walk(self.test_dir, topdown=False): + for name in files: + os.remove(os.path.join(root, name)) + for name in dirs: + os.rmdir(os.path.join(root, name)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tools/audio_remove.py b/tools/audio_remove.py index 5ab8a9e..9f5054d 100644 --- a/tools/audio_remove.py +++ b/tools/audio_remove.py @@ -102,15 +102,14 @@ def separate_tta(self, X_spec): return y_spec, v_spec -def audio_remove(audioFileNameAndPath, voiceFileNameAndPath, instrumentFileNameAndPath, modelNameAndPath): - if AUDIO_REMOVE_DEVICE == "cpu": - device = torch.device('cpu') - elif AUDIO_REMOVE_DEVICE == "gpu": - device = device = torch.device('cuda:0') - else: - raise ValueError("Invalid device: {}".format(AUDIO_REMOVE_DEVICE)) - - print("Loading model " + AUDIO_REMOVE_DEVICE) +def audio_remove(audioFileNameAndPath, voiceFileNameAndPath, instrumentFileNameAndPath, modelNameAndPath, + pytorchDevice): + if pytorchDevice not in ["cpu", "cuda:0"]: + raise ValueError("Invalid device: {}, valid choices are cpu or cuda:0. ".format(AUDIO_REMOVE_DEVICE)) + + device = torch.device(pytorchDevice) + + print("Loading model " + pytorchDevice) model = nets.CascadedNet(AUDIO_REMOVE_FFT_SIZE, AUDIO_REMOVE_HOP_SIZE, 32, 128) # 模型参数 model.load_state_dict(torch.load(modelNameAndPath, map_location='cpu')) model.to(device) diff --git a/work_space.py b/work_space.py index 3ecc158..a9722f8 100644 --- a/work_space.py +++ b/work_space.py @@ -606,7 +606,8 @@ def voiceConnect(logger, sourceDir, outputAndPath, warningFilePath): audioEndPosition = audioPosition + audio.duration_seconds * 1000 + MIN_GAP_DURATION * 1000 audioNextPosition = voiceMapSrt[i + 1].start.total_seconds() * 1000 if audioNextPosition < audioEndPosition: - speedUp = (audio.duration_seconds * 1000 + MIN_GAP_DURATION * 1000) / (audioNextPosition - audioPosition) + speedUp = (audio.duration_seconds * 1000 + MIN_GAP_DURATION * 1000) / ( + audioNextPosition - audioPosition) seconds = audioPosition / 1000.0 timeStr = str(datetime.timedelta(seconds=seconds)) if speedUp > MAX_SPEED_UP: @@ -765,7 +766,8 @@ def voiceConnect(logger, sourceDir, outputAndPath, warningFilePath): if paramDict["audio remove"]: print(f"Removing music from {audioFileNameAndPath} to {voiceNameAndPath} and {insturmentNameAndPath}") try: - audio_remove(audioFileNameAndPath, voiceNameAndPath, insturmentNameAndPath, audioRemoveModelNameAndPath) + audio_remove(audioFileNameAndPath, voiceNameAndPath, insturmentNameAndPath, audioRemoveModelNameAndPath, + "cuda:0") executeLog.write( f"[WORK o] Remove music from {audioFileNameAndPath} to {voiceNameAndPath} and {insturmentNameAndPath} successfully.") except Exception as e: