diff --git a/test.py b/test.py index 0a604c7..78bdada 100644 --- a/test.py +++ b/test.py @@ -104,6 +104,16 @@ def test_sys_pipes(): assert stderr.getvalue() == u"Hi, stdérr\n" +def test_sys_pipes_check(): + # pytest redirects stdout; un-redirect it for the test + with mock.patch('sys.stdout', sys.__stdout__), mock.patch( + 'sys.stderr', sys.__stderr__ + ): + with pytest.raises(ValueError): + with sys_pipes(): + pass + + def test_redirect_everything(): stdout = io.StringIO() stderr = io.StringIO() diff --git a/wurlitzer.py b/wurlitzer.py index 7e7db95..39089a7 100644 --- a/wurlitzer.py +++ b/wurlitzer.py @@ -532,10 +532,29 @@ def __exit__(self, *exc_info): def sys_pipes(encoding=_default_encoding, bufsize=None): """Redirect C-level stdout/stderr to sys.stdout/stderr - This is useful of sys.sdout/stderr are already being forwarded somewhere. + This is useful of sys.sdout/stderr are already being forwarded somewhere, + e.g. in a Jupyter kernel. DO NOT USE THIS if sys.stdout and sys.stderr are not already being forwarded. """ + # check that we aren't forwarding stdout to itself + for name in ("stdout", "stderr"): + stream = getattr(sys, name) + capture_stream = getattr(sys, "__{}__".format(name)) + try: + fd = stream.fileno() + capture_fd = capture_stream.fileno() + except Exception: + # ignore errors - if sys.stdout doesn't need a fileno, + # it's definitely not the original sys.__stdout__ + continue + else: + if fd == capture_fd: + raise ValueError( + "Cannot forward sys.__{0}__ to sys.{0}: they are the same! Maybe you want wurlitzer.pipes()?".format( + name + ) + ) return pipes(sys.stdout, sys.stderr, encoding=encoding, bufsize=bufsize)