Skip to content

Commit

Permalink
Cleanup connector classes, print available Terraform keys
Browse files Browse the repository at this point in the history
  • Loading branch information
Fizzadar committed Mar 9, 2024
1 parent cf3830b commit b333116
Show file tree
Hide file tree
Showing 10 changed files with 52 additions and 48 deletions.
2 changes: 1 addition & 1 deletion pyinfra/connectors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def __init__(self, state: "State", host: "Host"):

@staticmethod
@abc.abstractmethod
def make_names_data(id: str) -> Iterator[tuple[str, dict, list[str]]]:
def make_names_data(name: str) -> Iterator[tuple[str, dict, list[str]]]:
"""
Generates hosts/data/groups information for inventory. This allows a
single connector reference to generate multiple target hosts.
Expand Down
11 changes: 5 additions & 6 deletions pyinfra/connectors/chroot.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,17 @@ def __init__(self, state: "State", host: "Host"):
self.local = LocalConnector(state, host)

@staticmethod
def make_names_data(directory: Optional[str] = None):
if not directory:
def make_names_data(name: Optional[str] = None):
if not name:
raise InventoryError("No directory provided!")

show_warning()

yield "@chroot/{0}".format(directory), {
"chroot_directory": "/{0}".format(directory.lstrip("/")),
yield "@chroot/{0}".format(name), {
"chroot_directory": "/{0}".format(name.lstrip("/")),
}, ["@chroot"]

def connect(self):
def connect(self) -> None:
self.local.connect()

chroot_directory = self.host.data.chroot_directory
Expand All @@ -65,7 +65,6 @@ def connect(self):
raise ConnectError(e.args[0])

self.host.connector_data["chroot_directory"] = chroot_directory
return True

def run_shell_command(
self,
Expand Down
15 changes: 7 additions & 8 deletions pyinfra/connectors/docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from pyinfra.api.state import State


class ConnectorData(TypedDict, total=False):
class ConnectorData(TypedDict):
docker_identifier: str


Expand All @@ -35,6 +35,7 @@ class ConnectorData(TypedDict, total=False):

def _find_start_docker_container(container_id) -> tuple[str, bool]:
docker_info = local.shell("docker container inspect {0}".format(container_id))
assert isinstance(docker_info, str)
docker_info = json.loads(docker_info)[0]
if docker_info["State"]["Running"] is False:
logger.info("Starting stopped container: {0}".format(container_id))
Expand Down Expand Up @@ -93,17 +94,17 @@ def __init__(self, state: "State", host: "Host"):
self.local = LocalConnector(state, host)

@staticmethod
def make_names_data(identifier=None):
if not identifier:
def make_names_data(name=None):
if not name:
raise InventoryError("No docker base ID provided!")

yield (
"@docker/{0}".format(identifier),
{"docker_identifier": identifier},
"@docker/{0}".format(name),
{"docker_identifier": name},
["@docker"],
)

def connect(self):
def connect(self) -> None:
self.local.connect()

docker_identifier = self.data["docker_identifier"]
Expand All @@ -115,8 +116,6 @@ def connect(self):
except PyinfraError:
self.container_id = _start_docker_image(docker_identifier)

return True

def disconnect(self):
container_id = self.container_id

Expand Down
6 changes: 3 additions & 3 deletions pyinfra/connectors/dockerssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ def __init__(self, state: "State", host: "Host"):
self.ssh = SSHConnector(state, host)

@staticmethod
def make_names_data(host_image_str):
def make_names_data(name):
try:
hostname, image = host_image_str.split(":", 1)
except (AttributeError, ValueError): # failure to parse the host_image_str
hostname, image = name.split(":", 1)
except (AttributeError, ValueError): # failure to parse the name
raise InventoryError("No ssh host or docker base image provided!")

if not image:
Expand Down
7 changes: 3 additions & 4 deletions pyinfra/connectors/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ class LocalConnector(BaseConnector):
handles_execution = True

@staticmethod
def make_names_data(_=None):
if _ is not None:
def make_names_data(name=None):
if name is not None:
raise InventoryError("Cannot have more than one @local")

yield "@local", {}, ["@local"]
Expand Down Expand Up @@ -206,8 +206,7 @@ def get_file(

return True

@staticmethod
def check_can_rsync(host):
def check_can_rsync(self):
if not find_executable("rsync"):
raise NotImplementedError("The `rsync` binary is not available on this system.")

Expand Down
17 changes: 6 additions & 11 deletions pyinfra/connectors/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,9 @@ class SSHConnector(BaseConnector):

client: Optional[SSHClient] = None

def make_names_data(hostname):
yield "@ssh/{0}".format(hostname), {"ssh_hostname": hostname}, []
@staticmethod
def make_names_data(name):
yield "@ssh/{0}".format(name), {"ssh_hostname": name}, []

def make_paramiko_kwargs(self) -> dict[str, Any]:
kwargs = {
Expand Down Expand Up @@ -476,16 +477,10 @@ def put_file(
self._put_file(filename_or_io, temp_file)

# Make sure our sudo/su user can access the file
if _su_user:
command = StringCommand("setfacl", "-m", "u:{0}:r".format(_su_user), temp_file)
elif _sudo_user:
command = StringCommand("setfacl", "-m", "u:{0}:r".format(_sudo_user), temp_file)
elif _doas_user:
command = StringCommand("setfacl", "-m", "u:{0}:r".format(_doas_user), temp_file)

if _su_user or _sudo_user or _doas_user:
other_user = _su_user or _sudo_user or _doas_user
if other_user:
status, output = self.run_shell_command(
command,
StringCommand("setfacl", "-m", f"u:{other_user}:r", temp_file),
print_output=print_output,
print_input=print_input,
**arguments,
Expand Down
12 changes: 7 additions & 5 deletions pyinfra/connectors/terraform.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,21 +78,23 @@ class TerraformInventoryConnector(BaseConnector):
"""

@staticmethod
def make_names_data(output_key=None):
def make_names_data(name=None):
show_warning()

if not output_key:
raise InventoryError("No Terraform output key!")
if not name:
name = ""

with progress_spinner({"fetch terraform output"}):
tf_output_raw = local.shell("terraform output -json")

assert isinstance(tf_output_raw, str)
tf_output = json.loads(tf_output_raw)
tf_output = _flatten_dict(tf_output)

tf_output_value = tf_output.get(output_key)
tf_output_value = tf_output.get(name)
if tf_output_value is None:
raise InventoryError(f"No Terraform output with key: `{output_key}`")
keys = "\n".join(f" - {k}" for k in tf_output.keys())
raise InventoryError(f"No Terraform output with key: `{name}`, valid keys:\n{keys}")

if not isinstance(tf_output_value, list):
raise InventoryError(
Expand Down
11 changes: 6 additions & 5 deletions pyinfra/connectors/vagrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ class VagrantInventoryConnector(BaseConnector):
"""

@staticmethod
def make_names_data(limit=None):
vagrant_ssh_info = get_vagrant_config(limit)
def make_names_data(name=None):
vagrant_ssh_info = get_vagrant_config(name)

logger.debug("Got Vagrant SSH info: \n%s", vagrant_ssh_info)

Expand Down Expand Up @@ -170,10 +170,11 @@ def make_names_data(limit=None):
hosts.append(_make_name_data(current_host))

if not hosts:
if limit:
if name:
raise InventoryError(
"No running Vagrant instances matching `{0}` found!".format(limit)
"No running Vagrant instances matching `{0}` found!".format(name)
)
raise InventoryError("No running Vagrant instances found!")

return hosts
for host in hosts:
yield host
13 changes: 11 additions & 2 deletions tests/test_connectors/test_terraform.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,21 @@ def test_make_names_data_no_output_key(self):

@patch("pyinfra.connectors.terraform.local.shell")
def test_make_names_data_no_output(self, fake_shell):
fake_shell.return_value = json.dumps({})
fake_shell.return_value = json.dumps(
{
"hello": {
"world": [],
},
},
)

with self.assertRaises(InventoryError) as context:
list(TerraformInventoryConnector.make_names_data("output_key"))

assert context.exception.args[0] == "No Terraform output with key: `output_key`"
assert (
context.exception.args[0]
== "No Terraform output with key: `output_key`, valid keys:\n - hello.world"
)

@patch("pyinfra.connectors.terraform.local.shell")
def test_make_names_data_invalid_output(self, fake_shell):
Expand Down
6 changes: 3 additions & 3 deletions tests/test_connectors/test_vagrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def tearDown(self):
)
@patch("pyinfra.connectors.vagrant.path.exists", lambda path: True)
def test_make_names_data_with_options(self):
data = VagrantInventoryConnector.make_names_data()
data = list(VagrantInventoryConnector.make_names_data())

assert data == [
(
Expand Down Expand Up @@ -103,7 +103,7 @@ def test_make_names_data_with_options(self):
]

def test_make_names_data_with_limit(self):
data = VagrantInventoryConnector.make_names_data(limit=("ubuntu16",))
data = list(VagrantInventoryConnector.make_names_data(name=("ubuntu16",)))

assert data == [
(
Expand All @@ -120,4 +120,4 @@ def test_make_names_data_with_limit(self):

def test_make_names_data_no_matches(self):
with self.assertRaises(InventoryError):
VagrantInventoryConnector.make_names_data(limit="nope")
list(VagrantInventoryConnector.make_names_data(name="nope"))

0 comments on commit b333116

Please sign in to comment.