-
Notifications
You must be signed in to change notification settings - Fork 620
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
added the decorator for patching the environment variable allowing mp…
…s to fallback to CPU
- Loading branch information
Showing
16 changed files
with
140 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .envy_patching import monkeypatch_env | ||
|
||
__all__ = ["monkeypatch_env"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
from functools import wraps | ||
|
||
import pytest | ||
|
||
|
||
def monkeypatch_env(key, value): | ||
"""Decorator to monkeypatch environment variables in tests. | ||
Parameters | ||
---------- | ||
key : str | ||
The environment variable key. | ||
value : str | ||
The environment variable value. | ||
Returns | ||
------- | ||
decorator | ||
The decorator function that will monkeypatch the environment variable. | ||
Examples | ||
-------- | ||
import os | ||
import pytest | ||
from helpers import monkeypatch_env | ||
@monkeypatch_env("PYTORCH_ENABLE_MPS_FALLBACK", "1") | ||
def test_get_lstm_cell(): | ||
assert os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") == "1" | ||
""" | ||
|
||
def decorator(test_func): | ||
@wraps(test_func) | ||
def wrapper(*args, **kwargs): | ||
@pytest.fixture | ||
def _monkeypatch_env(monkeypatch): | ||
monkeypatch.setenv(key, value) | ||
|
||
@pytest.mark.usefixtures("_monkeypatch_env") | ||
def run_test(): | ||
return test_func(*args, **kwargs) | ||
|
||
return run_test() | ||
|
||
return wrapper | ||
|
||
return decorator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.