Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Test] Improve download test #29944

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
Open
21 changes: 9 additions & 12 deletions test/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from youtube_dl.utils import (
IDENTITY,
preferredencoding,
variadic,
write_string,
)

Expand Down Expand Up @@ -66,7 +67,7 @@ def report_warning(message):
class FakeYDL(YoutubeDL):
def __init__(self, override=None):
# Different instances of the downloader can't share the same dictionary
# some test set the "sublang" parameter, which would break the md5 checks.
# some tests set the "sublang" parameter, which would break the md5 checks.
params = get_params(override=override)
super(FakeYDL, self).__init__(params, auto_init=False)
self.result = []
Expand All @@ -83,13 +84,7 @@ def download(self, x):

def expect_warning(self, regex):
# Silence an expected warning matching a regex
old_report_warning = self.report_warning

def report_warning(self, message):
if re.match(regex, message):
return
old_report_warning(message)
self.report_warning = types.MethodType(report_warning, self)
expect_warnings(self, regex)


class FakeLogger(object):
Expand Down Expand Up @@ -285,12 +280,14 @@ def assertEqual(self, got, expected, msg=None):

def expect_warnings(ydl, warnings_re):
real_warning = ydl.report_warning
# to facilitate matching, don't prettify messages
ydl.params['no_color'] = True

def _report_warning(w):
if not any(re.search(w_re, w) for w_re in warnings_re):
real_warning(w)
def _report_warning(self, w, *args, **kwargs):
if not any(re.search(w_re, w) for w_re in variadic(warnings_re)):
real_warning(w, *args, **kwargs)

ydl.report_warning = _report_warning
ydl.report_warning = types.MethodType(_report_warning, ydl)


def http_server_port(httpd):
Expand Down
82 changes: 67 additions & 15 deletions test/test_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,17 @@


import hashlib
import itertools
import json
import socket
import re

import youtube_dl.YoutubeDL
from youtube_dl.compat import (
dirkf marked this conversation as resolved.
Show resolved Hide resolved
compat_filter as filter,
compat_http_client,
compat_HTTPError,
compat_map as map,
compat_open as open,
compat_urllib_error,
)
Expand All @@ -35,9 +39,11 @@
ExtractorError,
error_to_compat_str,
format_bytes,
std_headers,
UnavailableVideoError,
)
from youtube_dl.extractor import get_info_extractor
from youtube_dl.downloader.common import FileDownloader

RETRIES = 3

Expand All @@ -48,7 +54,7 @@ def __init__(self, *args, **kwargs):
self.processed_info_dicts = []
super(YoutubeDL, self).__init__(*args, **kwargs)

def report_warning(self, message):
def report_warning(self, message, *args, **kwargs):
# Don't accept warnings during tests
raise ExtractorError(message)

Expand All @@ -57,9 +63,10 @@ def process_info(self, info_dict):
return super(YoutubeDL, self).process_info(info_dict)


def _file_md5(fn):
def _file_md5(fn, length=None):
with open(fn, 'rb') as f:
return hashlib.md5(f.read()).hexdigest()
return hashlib.md5(
f.read() if length is None else f.read(length)).hexdigest()


defs = gettestcases()
Expand All @@ -84,6 +91,13 @@ def strclass(cls):
strclass(self.__class__),
' [%s]' % add_ie if add_ie else '')

@classmethod
def addTest(cls, test_method, test_method_name, add_ie):
test_method.__name__ = str(test_method_name)
test_method.add_ie = add_ie
setattr(cls, test_method.__name__, test_method)
del test_method

def setUp(self):
self.defs = defs

Expand Down Expand Up @@ -125,6 +139,17 @@ def print_skipping(reason):
params.setdefault('playlistend', test_case.get('playlist_mincount'))
params.setdefault('skip_download', True)

if 'user_agent' in params:
std_headers['User-Agent'] = params['user_agent']

if 'referer' in params:
std_headers['Referer'] = params['referer']

for h in params.get('headers', []):
h = h.split(':', 1)
if len(h) > 1:
std_headers[h[0]] = h[1]

ydl = YoutubeDL(params, auto_init=False)
ydl.add_default_info_extractors()
finished_hook_called = set()
Expand All @@ -151,8 +176,7 @@ def try_rm_tcs_files(tcs=None):

try_rm_tcs_files()
try:
try_num = 1
while True:
for try_num in itertools.count(1):
try:
# We're not using .download here since that is just a shim
# for outside error handling, and returns the exit code
Expand All @@ -161,7 +185,7 @@ def try_rm_tcs_files(tcs=None):
test_case['url'],
force_generic_extractor=params.get('force_generic_extractor', False))
except (DownloadError, ExtractorError) as err:
# Check if the exception is not a network related one
# Retry, or raise if the exception is not network-related
if not err.exc_info[0] in (compat_urllib_error.URLError, socket.timeout, UnavailableVideoError, compat_http_client.BadStatusLine) or (err.exc_info[0] == compat_HTTPError and err.exc_info[1].code == 503):
msg = getattr(err, 'msg', error_to_compat_str(err))
err.msg = '%s (%s)' % (msg, tname, )
Expand All @@ -172,8 +196,6 @@ def try_rm_tcs_files(tcs=None):
return

print('Retrying: {0} failed tries\n\n##########\n\n'.format(try_num))

try_num += 1
else:
break

Expand Down Expand Up @@ -237,7 +259,7 @@ def try_rm_tcs_files(tcs=None):
(tc_filename, format_bytes(expected_minsize),
format_bytes(got_fsize)))
if 'md5' in tc:
md5_for_file = _file_md5(tc_filename)
md5_for_file = _file_md5(tc_filename) if not params.get('test') else _file_md5(tc_filename, FileDownloader._TEST_FILE_SIZE)
dirkf marked this conversation as resolved.
Show resolved Hide resolved
self.assertEqual(tc['md5'], md5_for_file)
# Finally, check test cases' data again but this time against
# extracted data from info JSON file written during processing
Expand Down Expand Up @@ -267,12 +289,42 @@ def try_rm_tcs_files(tcs=None):
tname = 'test_%s_%d' % (test_case['name'], i)
i += 1
test_method = generator(test_case, tname)
test_method.__name__ = str(tname)
ie_list = test_case.get('add_ie')
test_method.add_ie = ie_list and ','.join(ie_list)
setattr(TestDownload, test_method.__name__, test_method)
del test_method

ie_list = ','.join(test_case.get('add_ie', []))
TestDownload.addTest(test_method, tname, ie_list)


def tests_for_ie(ie_key):
return filter(
lambda a: callable(getattr(TestDownload, a, None)),
filter(lambda a: re.match(r'test_%s(?:_\d+)?$' % ie_key, a),
dir(TestDownload)))


def gen_test_suite(ie_key):
def test_all(self):
print(self)
suite = unittest.TestSuite(
map(TestDownload, tests_for_ie(ie_key)))
result = self.defaultTestResult()
suite.run(result)
print('Errors: %d\t Failures: %d\tSkipped: %d' %
tuple(map(len, (result.errors, result.failures, result.skipped))))
print('Expected failures: %d\tUnexpected successes: %d' %
tuple(map(len, (result.expectedFailures, result.unexpectedSuccesses))))
return result

return test_all


for ie_key in set(
map(lambda a: a[5:],
filter(
lambda x: callable(getattr(TestDownload, x, None)),
filter(
lambda t: re.match(r"test_.+(?<!(?:_all|.._\d|._\d\d|_\d\d\d))$", t),
dir(TestDownload))))):
test_all = gen_test_suite(ie_key)
TestDownload.addTest(test_all, 'test_%s_all' % ie_key, 'Test all: %s' % ie_key)

if __name__ == '__main__':
unittest.main()
Loading