diff --git a/ceci/utils.py b/ceci/utils.py index e039a40..21a74ea 100644 --- a/ceci/utils.py +++ b/ceci/utils.py @@ -3,23 +3,36 @@ @contextmanager def extra_paths(paths, start=True): + # allow passing a single path or + # a list of them if isinstance(paths, str): paths = paths.split() + # On enter, add paths to sys.path, + # either the start or the end depending + # on the start argument for path in paths: if start: sys.path.insert(0, path) else: sys.path.append(path) + # Return control to caller try: yield + # On exit, remove the paths finally: for path in paths: - if start: - sys.path.remove(path) - else: - remove_last(sys.path, path) + try: + if start: + sys.path.remove(path) + else: + remove_last(sys.path, path) + # If e.g. user has already done this + # manually for some reason then just + # skip + except ValueError: + pass def remove_last(lst, item): """ diff --git a/tests/test_python_paths.py b/tests/test_python_paths.py index 0ca2e8f..d8f4a93 100644 --- a/tests/test_python_paths.py +++ b/tests/test_python_paths.py @@ -67,7 +67,8 @@ def test_extra_paths(): assert sys.path[0] == p[1] assert sys.path[1] == p[0] - assert p not in sys.path + for p1 in p: + assert p1 not in sys.path assert sys.path == orig_path try: @@ -78,7 +79,8 @@ def test_extra_paths(): except MyError: pass - assert p not in sys.path + for p1 in p: + assert p1 not in sys.path assert sys.path == orig_path @@ -89,7 +91,8 @@ def test_extra_paths(): assert sys.path[-1] == p[1] assert sys.path[-2] == p[0] - assert p not in sys.path + for p1 in p: + assert p1 not in sys.path assert sys.path == orig_path try: @@ -103,4 +106,22 @@ def test_extra_paths(): assert p not in sys.path assert sys.path == orig_path + # check that if the user removes the path + # themselves then it is okay + p = ['xxx111yyy222', 'aaa222333'] + with extra_paths(p, start=True): + sys.path.remove('xxx111yyy222') + + assert sys.path == orig_path + # check only one copy is removed + sys.path.append("aaa") + tmp_paths = sys.path[:] + p = "aaa" + with extra_paths(p, start=True): + pass + + assert sys.path == tmp_paths + + with extra_paths(p, start=False): + pass