Skip to content

Commit

Permalink
Add typing to ArtifactLoader
Browse files Browse the repository at this point in the history
  • Loading branch information
Shrews committed Jul 31, 2023
1 parent 82a231d commit aece246
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 70 deletions.
4 changes: 2 additions & 2 deletions src/ansible_runner/config/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def prepare_env(self):

def prepare_command(self):
try:
cmdline_args = self.loader.load_file('args', str, encoding=None)
cmdline_args = self.loader.load_file('args', str)
self.command = shlex.split(cmdline_args)
self.execution_mode = ExecutionMode.RAW
except ConfigurationError:
Expand Down Expand Up @@ -237,7 +237,7 @@ def generate_ansible_command(self):
if self.cmdline_args:
cmdline_args = self.cmdline_args
else:
cmdline_args = self.loader.load_file('env/cmdline', str, encoding=None)
cmdline_args = self.loader.load_file('env/cmdline', str)

args = shlex.split(cmdline_args)
exec_list.extend(args)
Expand Down
104 changes: 46 additions & 58 deletions src/ansible_runner/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,14 @@
# specific language governing permissions and limitations
# under the License.
#

from __future__ import annotations

import os
import json
import codecs

from typing import Any, Dict
from yaml import safe_load, YAMLError

from ansible_runner.exceptions import ConfigurationError
Expand All @@ -39,118 +43,103 @@ class ArtifactLoader:
to load the same file.
'''

def __init__(self, base_path):
self._cache = {}
def __init__(self, base_path: str):
self._cache: Dict[str, Any] = {}
self.base_path = base_path

def _load_json(self, contents):
def _load_json(self, contents: str) -> dict | None:
'''
Attempts to deserialize the contents of a JSON object
Args:
contents (string): The contents to deserialize
:param str contents: The contents to deserialize.
Returns:
dict: If the contents are JSON serialized
None: If the contents are not JSON serialized
:return: A dict if the contents are JSON serialized,
otherwise returns None.
'''
try:
return json.loads(contents)
except ValueError:
return None

def _load_yaml(self, contents):
def _load_yaml(self, contents: str) -> dict | None:
'''
Attempts to deserialize the contents of a YAML object
Args:
contents (string): The contents to deserialize
Attempts to deserialize the contents of a YAML object.
Returns:
dict: If the contents are YAML serialized
:param str contents: The contents to deserialize.
None: If the contents are not YAML serialized
'''
:return: A dict if the contents are YAML serialized,
otherwise returns None.
'''
try:
return safe_load(contents)
except YAMLError:
return None

def get_contents(self, path):
def _get_contents(self, path: str) -> str:
'''
Loads the contents of the file specified by path
Args:
path (string): The relative or absolute path to the file to
be loaded. If the path is relative, then it is combined
with the base_path to generate a full path string
:param str path: The relative or absolute path to the file to
be loaded. If the path is relative, then it is combined
with the base_path to generate a full path string
Returns:
string: The contents of the file as a string
:return: The contents of the file as a string
Raises:
ConfigurationError: If the file cannot be loaded
:raises: ConfigurationError if the file cannot be loaded.
'''
try:
if not os.path.exists(path):
raise ConfigurationError(f"specified path does not exist {path}")
with codecs.open(path, encoding='utf-8') as f:
with codecs.open(path, encoding="utf-8") as f:
data = f.read()

return data

except (IOError, OSError) as exc:
raise ConfigurationError(f"error trying to load file contents: {exc}") from exc
except ValueError as exc:
raise ConfigurationError(f"error with encoding of file {path}: {exc}") from exc

def abspath(self, path):
def abspath(self, path: str) -> str:
'''
Transform the path to an absolute path
Args:
path (string): The path to transform to an absolute path
:param str path: The path to transform to an absolute path
Returns:
string: The absolute path to the file
:return: The absolute path to the file.
'''
if not path.startswith(os.path.sep) or path.startswith('~'):
path = os.path.expanduser(os.path.join(self.base_path, path))
return path

def isfile(self, path):
def isfile(self, path: str) -> bool:
'''
Check if the path is a file
:params path: The path to the file to check. If the path is relative
:param str path: The path to the file to check. If the path is relative
it will be exanded to an absolute path
:returns: boolean
:return: True if path is a file, False otherwise.
'''
return os.path.isfile(self.abspath(path))

def load_file(self, path, objtype=None, encoding='utf-8'):
def load_file(self, path: str, objtype: Any | None = None) -> str | dict | None:
'''
Load the file specified by path
This method will first try to load the file contents from cache and
if there is a cache miss, it will load the contents from disk
Args:
path (string): The full or relative path to the file to be loaded
encoding (string): The file contents text encoding
:param str path: The full or relative path to the file to be loaded.
objtype (object): The object type of the file contents. This
is used to type check the deserialized content against the
contents loaded from disk.
Ignore serializing if objtype is str.
:param Any objtype: The object type of the file contents. This
is used to type check the deserialized content against the
contents loaded from disk. Ignore serializing if objtype is str.
Returns:
object: The deserialized file contents which could be either a
string object or a dict object
:return: The deserialized file contents which could be either a
string object or a dict object
Raises:
ConfigurationError:
:raises: ConfigurationError on error during file load or deserialization.
'''
path = self.abspath(path)
debug(f"file path is {path}")
Expand All @@ -160,14 +149,10 @@ def load_file(self, path, objtype=None, encoding='utf-8'):

try:
debug(f"cache miss, attempting to load file from disk: {path}")
contents = parsed_data = self.get_contents(path)
if encoding:
parsed_data = contents.encode(encoding)
contents = self._get_contents(path)
except ConfigurationError as exc:
debug(exc)
debug(str(exc))
raise
except UnicodeEncodeError as exc:
raise ConfigurationError('unable to encode file contents') from exc

if objtype is not str:
for deserializer in (self._load_json, self._load_yaml):
Expand All @@ -179,5 +164,8 @@ def load_file(self, path, objtype=None, encoding='utf-8'):
debug(f"specified file {path} is not of type {objtype}")
raise ConfigurationError('invalid file serialization type for contents')

self._cache[path] = parsed_data
return parsed_data
self._cache[path] = parsed_data
else:
self._cache[path] = contents

return self._cache[path]
2 changes: 1 addition & 1 deletion test/unit/config/test_doc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# pylint: disable: R0401
# pylint: disable=R0401

import os
import pytest
Expand Down
2 changes: 1 addition & 1 deletion test/unit/config/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,7 @@ class MockArtifactLoader:
def __init__(self, base_path):
self.base_path = base_path

def load_file(self, path, objtype=None, encoding='utf-8'):
def load_file(self, path, objtype=None):
raise ConfigurationError

def isfile(self, _):
Expand Down
16 changes: 8 additions & 8 deletions test/unit/test_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def test_abspath(loader, tmp_path):


def test_load_file_text_cache_hit(loader, mocker, tmp_path):
mock_get_contents = mocker.patch.object(ansible_runner.loader.ArtifactLoader, 'get_contents')
mock_get_contents = mocker.patch.object(ansible_runner.loader.ArtifactLoader, '_get_contents')
mock_get_contents.return_value = 'test\nstring'

assert not loader._cache
Expand All @@ -66,20 +66,20 @@ def test_load_file_text_cache_hit(loader, mocker, tmp_path):
res = loader.load_file(testfile, str)
assert mock_get_contents.called
assert mock_get_contents.called_with_args(testfile)
assert res == b'test\nstring'
assert res == 'test\nstring'
assert testfile in loader._cache

mock_get_contents.reset_mock()

# cache hit
res = loader.load_file(testfile, str)
assert not mock_get_contents.called
assert res == b'test\nstring'
assert res == 'test\nstring'
assert testfile in loader._cache


def test_load_file_json(loader, mocker, tmp_path):
mock_get_contents = mocker.patch.object(ansible_runner.loader.ArtifactLoader, 'get_contents')
mock_get_contents = mocker.patch.object(ansible_runner.loader.ArtifactLoader, '_get_contents')
mock_get_contents.return_value = '---\ntest: string'

assert not loader._cache
Expand All @@ -94,7 +94,7 @@ def test_load_file_json(loader, mocker, tmp_path):


def test_load_file_type_check(loader, mocker, tmp_path):
mock_get_contents = mocker.patch.object(ansible_runner.loader.ArtifactLoader, 'get_contents')
mock_get_contents = mocker.patch.object(ansible_runner.loader.ArtifactLoader, '_get_contents')
mock_get_contents.return_value = '---\ntest: string'

assert not loader._cache
Expand Down Expand Up @@ -125,15 +125,15 @@ def test_get_contents_ok(loader, mocker):

mock_open.return_value.__enter__.return_value = handler

res = loader.get_contents('/tmp')
res = loader._get_contents('/tmp')
assert res == b'test string'


def test_get_contents_invalid_path(loader, tmp_path):
with raises(ConfigurationError):
loader.get_contents(tmp_path.joinpath('invalid').as_posix())
loader._get_contents(tmp_path.joinpath('invalid').as_posix())


def test_get_contents_exception(loader, tmp_path):
with raises(ConfigurationError):
loader.get_contents(tmp_path.as_posix())
loader._get_contents(tmp_path.as_posix())

0 comments on commit aece246

Please sign in to comment.