Skip to content

Commit

Permalink
Add integration tests for Pytvzhen APIs.
Browse files Browse the repository at this point in the history
  • Loading branch information
HanFa committed May 28, 2024
1 parent 1a91030 commit bd4b836
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 18 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/pytvzhen-web.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,5 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
- name: Run PytvzhenAPI integration tests
run: python -m unittest -v test_app.PytvzhenAPITest
54 changes: 47 additions & 7 deletions app.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import json

from pytube import YouTube
from moviepy.editor import VideoFileClip
Expand All @@ -15,12 +16,7 @@
from werkzeug.utils import secure_filename

app = Flask(__name__)

app_path = os.path.abspath(__file__)
dir_path = os.path.dirname(app_path)
output_path = os.path.join(dir_path, "output")
model_path = os.path.join(dir_path, "models")
baseline_path = os.path.join(model_path, "baseline.pth")
app.config.from_file("./pytvzhen-config.json", load=json.load)


def log_info_return_str(message):
Expand Down Expand Up @@ -55,6 +51,9 @@ def decorated_func(*args, **kwargs):
return decorated_func


log_info_return_str(f"Launching Pytvzhen config: \n\t{app.config}")


@app.route('/', methods=['GET'])
def index():
return render_template('index.html')
Expand All @@ -79,6 +78,8 @@ def unique_video_fn_with_extension(extension):

@app.route('/video_upload', methods=['POST'])
def video_upload():
output_path = app.config['OUTPUT_PATH']

# check if the post request has the file part
if 'file' not in request.files:
return jsonify(error='No file part in the POST request'), 400
Expand All @@ -104,6 +105,8 @@ def video_upload():
@app.route('/yt_download', methods=['POST'])
@require_video_id_from_post_request
def yt_download(video_id):
output_path = app.config['OUTPUT_PATH']

video_fn = f"{video_id}.mp4"
video_fhd = f"{video_id}_fhd.mp4"
video_save_path = os.path.join(output_path, video_fn)
Expand Down Expand Up @@ -138,6 +141,7 @@ def yt_download(video_id):

@app.route('/yt/<video_id>', methods=['GET'])
def yt_serve(video_id):
output_path = app.config['OUTPUT_PATH']
video_fn = f'{video_id}.mp4'

if os.path.exists(os.path.join(output_path, video_fn)):
Expand All @@ -149,6 +153,8 @@ def yt_serve(video_id):
@app.route('/extra_audio', methods=['POST'])
@require_video_id_from_post_request
def extra_audio(video_id):
output_path = app.config['OUTPUT_PATH']

video_fn = f'{video_id}.mp4'
audio_fn = f'{video_id}.wav'

Expand Down Expand Up @@ -177,6 +183,8 @@ def extra_audio(video_id):

@app.route('/audio/<video_id>', methods=['GET'])
def audio_serve(video_id):
output_path = app.config['OUTPUT_PATH']

video_fn = f'{video_id}.mp4'
audio_fn = f'{video_id}.wav'

Expand All @@ -191,6 +199,8 @@ def audio_serve(video_id):
@app.route('/remove_audio_bg', methods=['POST'])
@require_video_id_from_post_request
def remove_audio_bg(video_id):
output_path = app.config['OUTPUT_PATH']

video_fn = f'{video_id}.mp4'
audio_fn = f'{video_id}.wav'
audio_no_bg_fn, audio_bg_fn = f'{video_id}_no_bg.wav', f'{video_id}_bg.wav'
Expand All @@ -211,7 +221,9 @@ def remove_audio_bg(video_id):
f'not found at {output_path}, please extract it first')}), 404

try:
audio_remove(audio_path, audio_no_bg_path, audio_bg_fn_path, baseline_path)
baseline_path = app.config['REMOVE_BACKGROUND_MUSIC_BASELINE_MODEL_PATH']
audio_remove(audio_path, audio_no_bg_path, audio_bg_fn_path, baseline_path,
app.config['REMOVE_BACKGROUND_MUSIC_TORCH_DEVICE'])
return jsonify({"message": log_info_return_str(
f"Remove remove background music for {audio_fn} as {audio_no_bg_fn} and {audio_bg_fn_path} successfully."),
"video_id": video_id}), 200
Expand All @@ -225,6 +237,8 @@ def remove_audio_bg(video_id):

@app.route('/audio_no_bg/<video_id>', methods=['GET'])
def audio_no_bg_serve(video_id):
output_path = app.config['OUTPUT_PATH']

audio_no_bg_fn = f'{video_id}_no_bg.wav'
audio_no_bg_path = os.path.join(output_path, audio_no_bg_fn)

Expand All @@ -237,6 +251,8 @@ def audio_no_bg_serve(video_id):

@app.route('/audio_bg/<video_id>', methods=['GET'])
def audio_bg_serve(video_id):
output_path = app.config['OUTPUT_PATH']

audio_bg_fn = f'{video_id}_bg.wav'
audio_bg_path = os.path.join(output_path, audio_bg_fn)

Expand All @@ -250,6 +266,8 @@ def audio_bg_serve(video_id):
@app.route('/transcribe', methods=['POST'])
@require_video_id_from_post_request
def transcribe(video_id):
output_path = app.config['OUTPUT_PATH']

transcribe_model = "medium"
en_srt_fn, en_srt_merged_fn, audio_no_bg_fn = f'{video_id}_en.srt', f'{video_id}_en_merged.srt', f'{video_id}_no_bg.wav'

Expand Down Expand Up @@ -285,6 +303,8 @@ def transcribe(video_id):

@app.route('/srt_en/<video_id>', methods=['GET'])
def srt_en_serve(video_id):
output_path = app.config['OUTPUT_PATH']

en_srt_fn = f'{video_id}_en.srt'
en_srt_path = os.path.join(output_path, en_srt_fn)

Expand All @@ -297,6 +317,8 @@ def srt_en_serve(video_id):

@app.route('/srt_en_merged/<video_id>', methods=['GET'])
def srt_en_merged_serve(video_id):
output_path = app.config['OUTPUT_PATH']

en_srt_merged_fn = f'{video_id}_en_merged.srt'
en_srt_merged_path = os.path.join(output_path, en_srt_merged_fn)

Expand All @@ -310,6 +332,8 @@ def srt_en_merged_serve(video_id):
@app.route('/translate_to_zh', methods=['POST'])
@require_video_id_from_post_request
def transhlate_to_zh(video_id):
output_path = app.config['OUTPUT_PATH']

data = request.get_json()
video_id = data['video_id']
translateVendor = data['translate_vendor']
Expand Down Expand Up @@ -375,6 +399,8 @@ def transhlate_to_zh(video_id):

@app.route('/srt_zh_merged/<video_id>', methods=['GET'])
def srt_zh_merged_serve(video_id):
output_path = app.config['OUTPUT_PATH']

zh_srt_merged_fn = f'{video_id}_zh_merged.srt'
zh_srt_merged_path = os.path.join(output_path, zh_srt_merged_fn)

Expand All @@ -388,6 +414,8 @@ def srt_zh_merged_serve(video_id):
@app.route('/voice_connect', methods=['POST'])
@require_video_id_from_post_request
def voice_connect(video_id):
output_path = app.config['OUTPUT_PATH']

data = request.get_json()
video_id = data['video_id']
voiceDir = os.path.join(output_path, video_id + "_zh_source")
Expand All @@ -411,6 +439,8 @@ def voice_connect(video_id):

@app.route('/voice_connect_log/<video_id>', methods=['GET'])
def voice_connect_log_serve(video_id):
output_path = app.config['OUTPUT_PATH']

warning_log_fn = video_id + "_connect_warning.log"
warning_log_path = os.path.join(output_path, warning_log_fn)

Expand All @@ -423,6 +453,8 @@ def voice_connect_log_serve(video_id):

@app.route('/voice_connect/<video_id>', methods=['GET'])
def voice_connect_serve(video_id):
output_path = app.config['OUTPUT_PATH']

voice_connect_fn = video_id + "_zh.wav"
voice_connect_path = os.path.join(output_path, voice_connect_fn)

Expand All @@ -436,6 +468,8 @@ def voice_connect_serve(video_id):
@app.route('/tts', methods=['POST'])
@require_video_id_from_post_request
def tts(video_id):
output_path = app.config['OUTPUT_PATH']

data = request.get_json()
video_id = data['video_id']
srt_fn = f'{video_id}_zh_merged.srt'
Expand Down Expand Up @@ -467,6 +501,8 @@ def tts(video_id):

@app.route('/tts/<video_id>', methods=['GET'])
def tts_serve(video_id):
output_path = app.config['OUTPUT_PATH']

tts_dir = os.path.join(output_path, video_id + "_zh_source")
tts_zip_fn = video_id + "_zh_source.zip"
tts_zip_path = os.path.join(output_path, tts_zip_fn)
Expand All @@ -491,6 +527,8 @@ def tts_serve(video_id):
@app.route('/video_preview', methods=['POST'])
@require_video_id_from_post_request
def video_preview(video_id):
output_path = app.config['OUTPUT_PATH']

data = request.get_json()
video_id = data['video_id']
voice_connect_fn = video_id + "_zh.wav"
Expand Down Expand Up @@ -533,6 +571,8 @@ def video_preview(video_id):

@app.route('/video_preview/<video_id>', methods=['GET'])
def video_preview_serve(video_id):
output_path = app.config['OUTPUT_PATH']

video_preview_fn = f"{video_id}_preview.mp4"
video_preview_path = os.path.join(output_path, video_preview_fn)

Expand Down
5 changes: 5 additions & 0 deletions pytvzhen-config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"OUTPUT_PATH": "./output",
"REMOVE_BACKGROUND_MUSIC_TORCH_DEVICE": "cpu",
"REMOVE_BACKGROUND_MUSIC_BASELINE_MODEL_PATH": "./models/baseline.pth"
}
34 changes: 34 additions & 0 deletions test_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import os
import unittest
import tempfile

from app import app


class PytvzhenAPITest(unittest.TestCase):

def setUp(self):
self.test_dir = tempfile.mkdtemp()

app.config['OUTPUT_PATH'] = self.test_dir
app.config['DEBUG'] = True

self.app = app.test_client()
self.app.testing = True

def test_download_yt_video_with_valid_video_id(self):
print("download to " + self.test_dir)
response = self.app.post("/yt_download", json={'video_id': 'VwhT-P3pLJs'})
assert response.status_code == 200
assert os.path.isfile(os.path.join(self.test_dir, 'VwhT-P3pLJs.mp4'))

def tearDown(self):
for root, dirs, files in os.walk(self.test_dir, topdown=False):
for name in files:
os.remove(os.path.join(root, name))
for name in dirs:
os.rmdir(os.path.join(root, name))


if __name__ == '__main__':
unittest.main()
17 changes: 8 additions & 9 deletions tools/audio_remove.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,15 +102,14 @@ def separate_tta(self, X_spec):
return y_spec, v_spec


def audio_remove(audioFileNameAndPath, voiceFileNameAndPath, instrumentFileNameAndPath, modelNameAndPath):
if AUDIO_REMOVE_DEVICE == "cpu":
device = torch.device('cpu')
elif AUDIO_REMOVE_DEVICE == "gpu":
device = device = torch.device('cuda:0')
else:
raise ValueError("Invalid device: {}".format(AUDIO_REMOVE_DEVICE))

print("Loading model " + AUDIO_REMOVE_DEVICE)
def audio_remove(audioFileNameAndPath, voiceFileNameAndPath, instrumentFileNameAndPath, modelNameAndPath,
pytorchDevice):
if pytorchDevice not in ["cpu", "cuda:0"]:
raise ValueError("Invalid device: {}, valid choices are cpu or cuda:0. ".format(AUDIO_REMOVE_DEVICE))

device = torch.device(pytorchDevice)

print("Loading model " + pytorchDevice)
model = nets.CascadedNet(AUDIO_REMOVE_FFT_SIZE, AUDIO_REMOVE_HOP_SIZE, 32, 128) # 模型参数
model.load_state_dict(torch.load(modelNameAndPath, map_location='cpu'))
model.to(device)
Expand Down
6 changes: 4 additions & 2 deletions work_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,8 @@ def voiceConnect(logger, sourceDir, outputAndPath, warningFilePath):
audioEndPosition = audioPosition + audio.duration_seconds * 1000 + MIN_GAP_DURATION * 1000
audioNextPosition = voiceMapSrt[i + 1].start.total_seconds() * 1000
if audioNextPosition < audioEndPosition:
speedUp = (audio.duration_seconds * 1000 + MIN_GAP_DURATION * 1000) / (audioNextPosition - audioPosition)
speedUp = (audio.duration_seconds * 1000 + MIN_GAP_DURATION * 1000) / (
audioNextPosition - audioPosition)
seconds = audioPosition / 1000.0
timeStr = str(datetime.timedelta(seconds=seconds))
if speedUp > MAX_SPEED_UP:
Expand Down Expand Up @@ -765,7 +766,8 @@ def voiceConnect(logger, sourceDir, outputAndPath, warningFilePath):
if paramDict["audio remove"]:
print(f"Removing music from {audioFileNameAndPath} to {voiceNameAndPath} and {insturmentNameAndPath}")
try:
audio_remove(audioFileNameAndPath, voiceNameAndPath, insturmentNameAndPath, audioRemoveModelNameAndPath)
audio_remove(audioFileNameAndPath, voiceNameAndPath, insturmentNameAndPath, audioRemoveModelNameAndPath,
"cuda:0")
executeLog.write(
f"[WORK o] Remove music from {audioFileNameAndPath} to {voiceNameAndPath} and {insturmentNameAndPath} successfully.")
except Exception as e:
Expand Down

0 comments on commit bd4b836

Please sign in to comment.