From 2239666542e431760903d45f2b799b954d813988 Mon Sep 17 00:00:00 2001 From: dirkf Date: Wed, 21 Feb 2024 15:44:04 +0000 Subject: [PATCH] Handle `expected_warnings` better * make fake `report_warning()` method signatures correct (per 640d39f) * support single warning to expect as well as sequence * don't colour text to be matched * use `expected_warnings()` function throughout --- test/helper.py | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/test/helper.py b/test/helper.py index 5b7e3dfe20b..527ce3a94e3 100644 --- a/test/helper.py +++ b/test/helper.py @@ -20,6 +20,7 @@ from youtube_dl.utils import ( IDENTITY, preferredencoding, + variadic, write_string, ) @@ -66,7 +67,7 @@ def report_warning(message): class FakeYDL(YoutubeDL): def __init__(self, override=None): # Different instances of the downloader can't share the same dictionary - # some test set the "sublang" parameter, which would break the md5 checks. + # some tests set the "sublang" parameter, which would break the md5 checks. params = get_params(override=override) super(FakeYDL, self).__init__(params, auto_init=False) self.result = [] @@ -83,14 +84,7 @@ def download(self, x): def expect_warning(self, regex): # Silence an expected warning matching a regex - old_report_warning = self.report_warning - - def report_warning(self, message): - if re.match(regex, message): - return - old_report_warning(message) - self.report_warning = types.MethodType(report_warning, self) - + expect_warnings(self, regex) class FakeLogger(object): def debug(self, msg): @@ -285,12 +279,14 @@ def assertEqual(self, got, expected, msg=None): def expect_warnings(ydl, warnings_re): real_warning = ydl.report_warning + # to facilitate matching, don't prettify messages + ydl.params['no_color'] = True - def _report_warning(w): - if not any(re.search(w_re, w) for w_re in warnings_re): - real_warning(w) + def _report_warning(self, w, *args, **kwargs): + if not any(re.search(w_re, w) for w_re in variadic(warnings_re)): + real_warning(w, *args, **kwargs) - ydl.report_warning = _report_warning + ydl.report_warning = types.MethodType(_report_warning, ydl) def http_server_port(httpd):