From 625646b020ea4bccf8a0c91d7f6875e7b9b67c7b Mon Sep 17 00:00:00 2001 From: Tomasz Zajac Date: Wed, 12 Feb 2025 15:28:48 +0100 Subject: [PATCH 1/4] fixed python compatibility issue --- infrahub_sync/adapters/infrahub.py | 8 +++++++- infrahub_sync/adapters/ipfabricsync.py | 8 +++++++- infrahub_sync/adapters/librenms.py | 8 +++++++- infrahub_sync/adapters/nautobot.py | 7 ++++++- infrahub_sync/adapters/netbox.py | 8 ++++++-- infrahub_sync/adapters/observium.py | 8 ++++++-- infrahub_sync/adapters/peeringmanager.py | 8 ++++++-- infrahub_sync/adapters/slurpitsync.py | 8 ++++++-- 8 files changed, 51 insertions(+), 12 deletions(-) diff --git a/infrahub_sync/adapters/infrahub.py b/infrahub_sync/adapters/infrahub.py index 7c62847..fe1a474 100644 --- a/infrahub_sync/adapters/infrahub.py +++ b/infrahub_sync/adapters/infrahub.py @@ -1,8 +1,14 @@ +import sys from __future__ import annotations import copy import os -from typing import TYPE_CHECKING, Any, Self +from typing import TYPE_CHECKING, Any +if sys.version_info.minor < 11: + from typing_extensions import Self +else: + from typing import Self + from diffsync import Adapter, DiffSyncModel from infrahub_sdk import ( diff --git a/infrahub_sync/adapters/ipfabricsync.py b/infrahub_sync/adapters/ipfabricsync.py index 0a12f66..da6c5f6 100644 --- a/infrahub_sync/adapters/ipfabricsync.py +++ b/infrahub_sync/adapters/ipfabricsync.py @@ -1,6 +1,12 @@ +import sys from __future__ import annotations -from typing import TYPE_CHECKING, Any, Self +from typing import TYPE_CHECKING, Any + +if sys.version_info.minor < 11: + from typing_extensions import Self +else: + from typing import Self try: from ipfabric import IPFClient diff --git a/infrahub_sync/adapters/librenms.py b/infrahub_sync/adapters/librenms.py index bc71f58..ae6e90e 100644 --- a/infrahub_sync/adapters/librenms.py +++ b/infrahub_sync/adapters/librenms.py @@ -1,7 +1,13 @@ +import sys from __future__ import annotations import os -from typing import TYPE_CHECKING, Any, Self +from typing import TYPE_CHECKING, Any + +if sys.version_info.minor < 11: + from typing_extensions import Self +else: + from typing import Self from diffsync import Adapter, DiffSyncModel diff --git a/infrahub_sync/adapters/nautobot.py b/infrahub_sync/adapters/nautobot.py index 89b0742..92cffd4 100644 --- a/infrahub_sync/adapters/nautobot.py +++ b/infrahub_sync/adapters/nautobot.py @@ -1,8 +1,13 @@ +import sys from __future__ import annotations # pylint: disable=R0801 import os -from typing import TYPE_CHECKING, Any, Self +from typing import TYPE_CHECKING, Any +if sys.version_info.minor < 11: + from typing_extensions import Self +else: + from typing import Self import pynautobot from diffsync import Adapter, DiffSyncModel diff --git a/infrahub_sync/adapters/netbox.py b/infrahub_sync/adapters/netbox.py index 4116bfd..f14b830 100644 --- a/infrahub_sync/adapters/netbox.py +++ b/infrahub_sync/adapters/netbox.py @@ -1,9 +1,13 @@ +import sys from __future__ import annotations # pylint: disable=R0801 import os -from typing import TYPE_CHECKING, Any, Self - +from typing import TYPE_CHECKING, Any +if sys.version_info.minor < 11: + from typing_extensions import Self +else: + from typing import Self import pynetbox from diffsync import Adapter, DiffSyncModel diff --git a/infrahub_sync/adapters/observium.py b/infrahub_sync/adapters/observium.py index 3a9b325..7b5a235 100644 --- a/infrahub_sync/adapters/observium.py +++ b/infrahub_sync/adapters/observium.py @@ -1,8 +1,12 @@ +import sys from __future__ import annotations import os -from typing import TYPE_CHECKING, Any, Self - +from typing import TYPE_CHECKING, Any +if sys.version_info.minor < 11: + from typing_extensions import Self +else: + from typing import Self from diffsync import Adapter, DiffSyncModel from infrahub_sync import ( diff --git a/infrahub_sync/adapters/peeringmanager.py b/infrahub_sync/adapters/peeringmanager.py index b6d6ef1..b758283 100644 --- a/infrahub_sync/adapters/peeringmanager.py +++ b/infrahub_sync/adapters/peeringmanager.py @@ -1,8 +1,12 @@ +import sys from __future__ import annotations import os -from typing import TYPE_CHECKING, Any, Self - +from typing import TYPE_CHECKING, Any +if sys.version_info.minor < 11: + from typing_extensions import Self +else: + from typing import Self import requests from diffsync import Adapter, DiffSyncModel diff --git a/infrahub_sync/adapters/slurpitsync.py b/infrahub_sync/adapters/slurpitsync.py index 0a04fe2..1d3d292 100644 --- a/infrahub_sync/adapters/slurpitsync.py +++ b/infrahub_sync/adapters/slurpitsync.py @@ -1,9 +1,13 @@ +import sys from __future__ import annotations import asyncio import ipaddress -from typing import TYPE_CHECKING, Any, Self - +from typing import TYPE_CHECKING, Any +if sys.version_info.minor < 11: + from typing_extensions import Self +else: + from typing import Self import slurpit from diffsync import Adapter, DiffSyncModel From 3ab28011213606c0d0c5ef8574945e64051ca67d Mon Sep 17 00:00:00 2001 From: Tomasz Zajac Date: Wed, 12 Feb 2025 16:27:01 +0100 Subject: [PATCH 2/4] fixed lint, changed import --- .../infrahub/sync_models.py | 42 ++++++- .../peeringmanager/sync_models.py | 42 ++++++- .../infrahub/sync_models.py | 19 ++- .../ipfabricsync/sync_models.py | 19 ++- .../infrahub/sync_models.py | 30 ++++- .../nautobot/sync_models.py | 30 ++++- .../infrahub/sync_models.py | 42 ++++++- .../nautobot/sync_models.py | 42 ++++++- .../infrahub/sync_models.py | 19 ++- .../netbox_to_infrahub/netbox/sync_models.py | 19 ++- .../infrahub/sync_models.py | 44 ++++++- .../peeringmanager/sync_models.py | 44 ++++++- .../infrahub/sync_models.py | 25 +++- .../slurpitsync/sync_models.py | 25 +++- infrahub_sync/__init__.py | 42 +++++-- infrahub_sync/adapters/infrahub.py | 115 +++++++++++++----- infrahub_sync/adapters/ipfabricsync.py | 50 +++++--- infrahub_sync/adapters/librenms.py | 54 +++++--- infrahub_sync/adapters/nautobot.py | 57 ++++++--- infrahub_sync/adapters/netbox.py | 56 ++++++--- infrahub_sync/adapters/observium.py | 48 +++++--- infrahub_sync/adapters/peeringmanager.py | 97 +++++++++++---- infrahub_sync/adapters/rest_api_client.py | 17 ++- infrahub_sync/adapters/slurpitsync.py | 84 +++++++++---- infrahub_sync/adapters/utils.py | 4 +- infrahub_sync/cli.py | 88 ++++++++++---- infrahub_sync/generator/__init__.py | 25 +++- infrahub_sync/potenda/__init__.py | 12 +- infrahub_sync/utils.py | 38 ++++-- tasks/docs.py | 4 +- 30 files changed, 981 insertions(+), 252 deletions(-) diff --git a/examples/infrahub_to_peering-manager/infrahub/sync_models.py b/examples/infrahub_to_peering-manager/infrahub/sync_models.py index f22cca2..a93816b 100644 --- a/examples/infrahub_to_peering-manager/infrahub/sync_models.py +++ b/examples/infrahub_to_peering-manager/infrahub/sync_models.py @@ -13,7 +13,14 @@ class InfraAutonomousSystem(InfrahubModel): _modelname = "InfraAutonomousSystem" _identifiers = ("asn",) - _attributes = ("name", "description", "irr_as_set", "ipv4_max_prefixes", "ipv6_max_prefixes", "affiliated") + _attributes = ( + "name", + "description", + "irr_as_set", + "ipv4_max_prefixes", + "ipv6_max_prefixes", + "affiliated", + ) name: str asn: int description: str | None = None @@ -29,7 +36,13 @@ class InfraAutonomousSystem(InfrahubModel): class InfraBGPPeerGroup(InfrahubModel): _modelname = "InfraBGPPeerGroup" _identifiers = ("name",) - _attributes = ("import_policies", "export_policies", "bgp_communities", "description", "status") + _attributes = ( + "import_policies", + "export_policies", + "bgp_communities", + "description", + "status", + ) name: str description: str | None = None status: str | None = None @@ -58,7 +71,14 @@ class InfraBGPCommunity(InfrahubModel): class InfraBGPRoutingPolicy(InfrahubModel): _modelname = "InfraBGPRoutingPolicy" _identifiers = ("name",) - _attributes = ("bgp_communities", "label", "description", "policy_type", "weight", "address_family") + _attributes = ( + "bgp_communities", + "label", + "description", + "policy_type", + "weight", + "address_family", + ) name: str label: str | None = None description: str | None = None @@ -74,7 +94,13 @@ class InfraBGPRoutingPolicy(InfrahubModel): class InfraIXP(InfrahubModel): _modelname = "InfraIXP" _identifiers = ("name",) - _attributes = ("import_policies", "export_policies", "bgp_communities", "description", "status") + _attributes = ( + "import_policies", + "export_policies", + "bgp_communities", + "description", + "status", + ) name: str description: str | None = None status: str | None = "enabled" @@ -89,7 +115,13 @@ class InfraIXP(InfrahubModel): class InfraIXPConnection(InfrahubModel): _modelname = "InfraIXPConnection" _identifiers = ("name",) - _attributes = ("internet_exchange_point", "description", "peeringdb_netixlan", "status", "vlan") + _attributes = ( + "internet_exchange_point", + "description", + "peeringdb_netixlan", + "status", + "vlan", + ) name: str description: str | None = None peeringdb_netixlan: int | None = None diff --git a/examples/infrahub_to_peering-manager/peeringmanager/sync_models.py b/examples/infrahub_to_peering-manager/peeringmanager/sync_models.py index 9eed820..35a218e 100644 --- a/examples/infrahub_to_peering-manager/peeringmanager/sync_models.py +++ b/examples/infrahub_to_peering-manager/peeringmanager/sync_models.py @@ -13,7 +13,14 @@ class InfraAutonomousSystem(PeeringmanagerModel): _modelname = "InfraAutonomousSystem" _identifiers = ("asn",) - _attributes = ("name", "description", "irr_as_set", "ipv4_max_prefixes", "ipv6_max_prefixes", "affiliated") + _attributes = ( + "name", + "description", + "irr_as_set", + "ipv4_max_prefixes", + "ipv6_max_prefixes", + "affiliated", + ) name: str asn: int description: str | None = None @@ -29,7 +36,13 @@ class InfraAutonomousSystem(PeeringmanagerModel): class InfraBGPPeerGroup(PeeringmanagerModel): _modelname = "InfraBGPPeerGroup" _identifiers = ("name",) - _attributes = ("import_policies", "export_policies", "bgp_communities", "description", "status") + _attributes = ( + "import_policies", + "export_policies", + "bgp_communities", + "description", + "status", + ) name: str description: str | None = None status: str | None = None @@ -58,7 +71,14 @@ class InfraBGPCommunity(PeeringmanagerModel): class InfraBGPRoutingPolicy(PeeringmanagerModel): _modelname = "InfraBGPRoutingPolicy" _identifiers = ("name",) - _attributes = ("bgp_communities", "label", "description", "policy_type", "weight", "address_family") + _attributes = ( + "bgp_communities", + "label", + "description", + "policy_type", + "weight", + "address_family", + ) name: str label: str | None = None description: str | None = None @@ -74,7 +94,13 @@ class InfraBGPRoutingPolicy(PeeringmanagerModel): class InfraIXP(PeeringmanagerModel): _modelname = "InfraIXP" _identifiers = ("name",) - _attributes = ("import_policies", "export_policies", "bgp_communities", "description", "status") + _attributes = ( + "import_policies", + "export_policies", + "bgp_communities", + "description", + "status", + ) name: str description: str | None = None status: str | None = "enabled" @@ -89,7 +115,13 @@ class InfraIXP(PeeringmanagerModel): class InfraIXPConnection(PeeringmanagerModel): _modelname = "InfraIXPConnection" _identifiers = ("name",) - _attributes = ("internet_exchange_point", "description", "peeringdb_netixlan", "status", "vlan") + _attributes = ( + "internet_exchange_point", + "description", + "peeringdb_netixlan", + "status", + "vlan", + ) name: str description: str | None = None peeringdb_netixlan: int | None = None diff --git a/examples/ipfabric_to_infrahub/infrahub/sync_models.py b/examples/ipfabric_to_infrahub/infrahub/sync_models.py index 430abb2..7cef38d 100644 --- a/examples/ipfabric_to_infrahub/infrahub/sync_models.py +++ b/examples/ipfabric_to_infrahub/infrahub/sync_models.py @@ -13,7 +13,15 @@ class InfraDevice(InfrahubModel): _modelname = "InfraDevice" _identifiers = ("hostname",) - _attributes = ("model", "location", "platform", "version", "fqdn", "serial_number", "hardware_serial_number") + _attributes = ( + "model", + "location", + "platform", + "version", + "fqdn", + "serial_number", + "hardware_serial_number", + ) fqdn: str | None = None hostname: str serial_number: str @@ -71,7 +79,14 @@ class InfraNOSVersion(InfrahubModel): class InfraPartNumber(InfrahubModel): _modelname = "InfraPartNumber" _identifiers = ("device", "name") - _attributes = ("model", "manufacturer", "part_vid", "part_id", "description", "part_sn") + _attributes = ( + "model", + "manufacturer", + "part_vid", + "part_id", + "description", + "part_sn", + ) name: str part_vid: str | None = None part_id: str | None = None diff --git a/examples/ipfabric_to_infrahub/ipfabricsync/sync_models.py b/examples/ipfabric_to_infrahub/ipfabricsync/sync_models.py index 61a5985..d0325b6 100644 --- a/examples/ipfabric_to_infrahub/ipfabricsync/sync_models.py +++ b/examples/ipfabric_to_infrahub/ipfabricsync/sync_models.py @@ -13,7 +13,15 @@ class InfraDevice(IpfabricsyncModel): _modelname = "InfraDevice" _identifiers = ("hostname",) - _attributes = ("model", "location", "platform", "version", "fqdn", "serial_number", "hardware_serial_number") + _attributes = ( + "model", + "location", + "platform", + "version", + "fqdn", + "serial_number", + "hardware_serial_number", + ) fqdn: str | None = None hostname: str serial_number: str @@ -71,7 +79,14 @@ class InfraNOSVersion(IpfabricsyncModel): class InfraPartNumber(IpfabricsyncModel): _modelname = "InfraPartNumber" _identifiers = ("device", "name") - _attributes = ("model", "manufacturer", "part_vid", "part_id", "description", "part_sn") + _attributes = ( + "model", + "manufacturer", + "part_vid", + "part_id", + "description", + "part_sn", + ) name: str part_vid: str | None = None part_id: str | None = None diff --git a/examples/nautobot-v1_to_infrahub/infrahub/sync_models.py b/examples/nautobot-v1_to_infrahub/infrahub/sync_models.py index 76e2a21..2a86c9d 100644 --- a/examples/nautobot-v1_to_infrahub/infrahub/sync_models.py +++ b/examples/nautobot-v1_to_infrahub/infrahub/sync_models.py @@ -60,7 +60,15 @@ class InfraCircuit(InfrahubModel): class InfraDevice(InfrahubModel): _modelname = "InfraDevice" _identifiers = ("location", "organization", "name") - _attributes = ("model", "rack", "role", "tags", "platform", "serial_number", "asset_tag") + _attributes = ( + "model", + "rack", + "role", + "tags", + "platform", + "serial_number", + "asset_tag", + ) name: str | None = None serial_number: str | None = None asset_tag: str | None = None @@ -104,7 +112,15 @@ class InfraIPAddress(InfrahubModel): class InfraInterfaceL2L3(InfrahubModel): _modelname = "InfraInterfaceL2L3" _identifiers = ("name", "device") - _attributes = ("tagged_vlan", "tags", "l2_mode", "description", "mgmt_only", "mac_address", "interface_type") + _attributes = ( + "tagged_vlan", + "tags", + "l2_mode", + "description", + "mgmt_only", + "mac_address", + "interface_type", + ) l2_mode: str | None = None name: str description: str | None = None @@ -161,7 +177,15 @@ class InfraProviderNetwork(InfrahubModel): class InfraRack(InfrahubModel): _modelname = "InfraRack" _identifiers = ("name",) - _attributes = ("location", "role", "tags", "height", "facility_id", "serial_number", "asset_tag") + _attributes = ( + "location", + "role", + "tags", + "height", + "facility_id", + "serial_number", + "asset_tag", + ) name: str height: int | None = None facility_id: str | None = None diff --git a/examples/nautobot-v1_to_infrahub/nautobot/sync_models.py b/examples/nautobot-v1_to_infrahub/nautobot/sync_models.py index 5c49ad1..19caebf 100644 --- a/examples/nautobot-v1_to_infrahub/nautobot/sync_models.py +++ b/examples/nautobot-v1_to_infrahub/nautobot/sync_models.py @@ -64,7 +64,15 @@ class InfraCircuit(NautobotModel): class InfraDevice(NautobotModel): _modelname = "InfraDevice" _identifiers = ("location", "organization", "name") - _attributes = ("model", "rack", "role", "tags", "platform", "serial_number", "asset_tag") + _attributes = ( + "model", + "rack", + "role", + "tags", + "platform", + "serial_number", + "asset_tag", + ) name: str | None = None serial_number: str | None = None asset_tag: str | None = None @@ -111,7 +119,15 @@ class InfraIPAddress(NautobotModel): class InfraInterfaceL2L3(NautobotModel): _modelname = "InfraInterfaceL2L3" _identifiers = ("name", "device") - _attributes = ("tagged_vlan", "tags", "l2_mode", "description", "mgmt_only", "mac_address", "interface_type") + _attributes = ( + "tagged_vlan", + "tags", + "l2_mode", + "description", + "mgmt_only", + "mac_address", + "interface_type", + ) l2_mode: str | None = None name: str description: str | None = None @@ -172,7 +188,15 @@ class InfraProviderNetwork(NautobotModel): class InfraRack(NautobotModel): _modelname = "InfraRack" _identifiers = ("name",) - _attributes = ("location", "role", "tags", "height", "facility_id", "serial_number", "asset_tag") + _attributes = ( + "location", + "role", + "tags", + "height", + "facility_id", + "serial_number", + "asset_tag", + ) name: str height: int | None = None facility_id: str | None = None diff --git a/examples/nautobot-v2_to_infrahub/infrahub/sync_models.py b/examples/nautobot-v2_to_infrahub/infrahub/sync_models.py index 20ea030..d775966 100644 --- a/examples/nautobot-v2_to_infrahub/infrahub/sync_models.py +++ b/examples/nautobot-v2_to_infrahub/infrahub/sync_models.py @@ -64,7 +64,16 @@ class InfraCircuit(InfrahubModel): class InfraDevice(InfrahubModel): _modelname = "InfraDevice" _identifiers = ("location", "organization", "name") - _attributes = ("model", "tags", "rack", "role", "status", "platform", "serial_number", "asset_tag") + _attributes = ( + "model", + "tags", + "rack", + "role", + "status", + "platform", + "serial_number", + "asset_tag", + ) serial_number: str | None = None asset_tag: str | None = None name: str | None = None @@ -113,7 +122,16 @@ class InfraIPAddress(InfrahubModel): class InfraInterfaceL2L3(InfrahubModel): _modelname = "InfraInterfaceL2L3" _identifiers = ("name", "device") - _attributes = ("status", "tags", "tagged_vlan", "l2_mode", "mac_address", "description", "mgmt_only", "interface_type") + _attributes = ( + "status", + "tags", + "tagged_vlan", + "l2_mode", + "mac_address", + "description", + "mgmt_only", + "interface_type", + ) l2_mode: str | None = None mac_address: str | None = None description: str | None = None @@ -178,7 +196,15 @@ class InfraProviderNetwork(InfrahubModel): class InfraRack(InfrahubModel): _modelname = "InfraRack" _identifiers = ("name",) - _attributes = ("role", "location", "tags", "facility_id", "asset_tag", "serial_number", "height") + _attributes = ( + "role", + "location", + "tags", + "facility_id", + "asset_tag", + "serial_number", + "height", + ) facility_id: str | None = None asset_tag: str | None = None name: str @@ -220,7 +246,15 @@ class InfraRouteTarget(InfrahubModel): class InfraVLAN(InfrahubModel): _modelname = "InfraVLAN" _identifiers = ("name",) - _attributes = ("role", "organization", "vlan_group", "status", "locations", "description", "vlan_id") + _attributes = ( + "role", + "organization", + "vlan_group", + "status", + "locations", + "description", + "vlan_id", + ) name: str description: str | None = None vlan_id: int diff --git a/examples/nautobot-v2_to_infrahub/nautobot/sync_models.py b/examples/nautobot-v2_to_infrahub/nautobot/sync_models.py index d8f4e34..92fd0aa 100644 --- a/examples/nautobot-v2_to_infrahub/nautobot/sync_models.py +++ b/examples/nautobot-v2_to_infrahub/nautobot/sync_models.py @@ -64,7 +64,16 @@ class InfraCircuit(NautobotModel): class InfraDevice(NautobotModel): _modelname = "InfraDevice" _identifiers = ("location", "organization", "name") - _attributes = ("model", "tags", "rack", "role", "status", "platform", "serial_number", "asset_tag") + _attributes = ( + "model", + "tags", + "rack", + "role", + "status", + "platform", + "serial_number", + "asset_tag", + ) serial_number: str | None = None asset_tag: str | None = None name: str | None = None @@ -113,7 +122,16 @@ class InfraIPAddress(NautobotModel): class InfraInterfaceL2L3(NautobotModel): _modelname = "InfraInterfaceL2L3" _identifiers = ("name", "device") - _attributes = ("status", "tags", "tagged_vlan", "l2_mode", "mac_address", "description", "mgmt_only", "interface_type") + _attributes = ( + "status", + "tags", + "tagged_vlan", + "l2_mode", + "mac_address", + "description", + "mgmt_only", + "interface_type", + ) l2_mode: str | None = None mac_address: str | None = None description: str | None = None @@ -178,7 +196,15 @@ class InfraProviderNetwork(NautobotModel): class InfraRack(NautobotModel): _modelname = "InfraRack" _identifiers = ("name",) - _attributes = ("role", "location", "tags", "facility_id", "asset_tag", "serial_number", "height") + _attributes = ( + "role", + "location", + "tags", + "facility_id", + "asset_tag", + "serial_number", + "height", + ) facility_id: str | None = None asset_tag: str | None = None name: str @@ -220,7 +246,15 @@ class InfraRouteTarget(NautobotModel): class InfraVLAN(NautobotModel): _modelname = "InfraVLAN" _identifiers = ("name",) - _attributes = ("role", "organization", "vlan_group", "status", "locations", "description", "vlan_id") + _attributes = ( + "role", + "organization", + "vlan_group", + "status", + "locations", + "description", + "vlan_id", + ) name: str description: str | None = None vlan_id: int diff --git a/examples/netbox_to_infrahub/infrahub/sync_models.py b/examples/netbox_to_infrahub/infrahub/sync_models.py index 2c8ebf1..b50fd9d 100644 --- a/examples/netbox_to_infrahub/infrahub/sync_models.py +++ b/examples/netbox_to_infrahub/infrahub/sync_models.py @@ -96,7 +96,15 @@ class TemplateDeviceType(InfrahubModel): class InfraInterfaceL2L3(InfrahubModel): _modelname = "InfraInterfaceL2L3" _identifiers = ("device", "name") - _attributes = ("tagged_vlan", "tags", "l2_mode", "description", "mgmt_only", "mac_address", "interface_type") + _attributes = ( + "tagged_vlan", + "tags", + "l2_mode", + "description", + "mgmt_only", + "mac_address", + "interface_type", + ) l2_mode: str | None = None name: str description: str | None = None @@ -157,7 +165,14 @@ class InfraPrefix(InfrahubModel): class InfraRack(InfrahubModel): _modelname = "InfraRack" _identifiers = ("name", "location") - _attributes = ("role", "tags", "asset_tag", "height", "serial_number", "facility_id") + _attributes = ( + "role", + "tags", + "asset_tag", + "height", + "serial_number", + "facility_id", + ) asset_tag: str | None = None height: int | None = None serial_number: str | None = None diff --git a/examples/netbox_to_infrahub/netbox/sync_models.py b/examples/netbox_to_infrahub/netbox/sync_models.py index 09ef1a6..c35768a 100644 --- a/examples/netbox_to_infrahub/netbox/sync_models.py +++ b/examples/netbox_to_infrahub/netbox/sync_models.py @@ -96,7 +96,15 @@ class TemplateDeviceType(NetboxModel): class InfraInterfaceL2L3(NetboxModel): _modelname = "InfraInterfaceL2L3" _identifiers = ("device", "name") - _attributes = ("tagged_vlan", "tags", "l2_mode", "description", "mgmt_only", "mac_address", "interface_type") + _attributes = ( + "tagged_vlan", + "tags", + "l2_mode", + "description", + "mgmt_only", + "mac_address", + "interface_type", + ) l2_mode: str | None = None name: str description: str | None = None @@ -157,7 +165,14 @@ class InfraPrefix(NetboxModel): class InfraRack(NetboxModel): _modelname = "InfraRack" _identifiers = ("name", "location") - _attributes = ("role", "tags", "asset_tag", "height", "serial_number", "facility_id") + _attributes = ( + "role", + "tags", + "asset_tag", + "height", + "serial_number", + "facility_id", + ) asset_tag: str | None = None height: int | None = None serial_number: str | None = None diff --git a/examples/peering-manager_to_infrahub/infrahub/sync_models.py b/examples/peering-manager_to_infrahub/infrahub/sync_models.py index 4c1468e..0bbfde5 100644 --- a/examples/peering-manager_to_infrahub/infrahub/sync_models.py +++ b/examples/peering-manager_to_infrahub/infrahub/sync_models.py @@ -13,7 +13,14 @@ class InfraAutonomousSystem(InfrahubModel): _modelname = "InfraAutonomousSystem" _identifiers = ("asn",) - _attributes = ("affiliated", "name", "irr_as_set", "description", "ipv4_max_prefixes", "ipv6_max_prefixes") + _attributes = ( + "affiliated", + "name", + "irr_as_set", + "description", + "ipv4_max_prefixes", + "ipv6_max_prefixes", + ) affiliated: bool | None = None name: str asn: int @@ -29,7 +36,13 @@ class InfraAutonomousSystem(InfrahubModel): class InfraBGPPeerGroup(InfrahubModel): _modelname = "InfraBGPPeerGroup" _identifiers = ("name",) - _attributes = ("bgp_communities", "import_policies", "export_policies", "status", "description") + _attributes = ( + "bgp_communities", + "import_policies", + "export_policies", + "status", + "description", + ) status: str | None = None name: str description: str | None = None @@ -69,7 +82,14 @@ class InfraBGPCommunity(InfrahubModel): class InfraBGPRoutingPolicy(InfrahubModel): _modelname = "InfraBGPRoutingPolicy" _identifiers = ("name",) - _attributes = ("bgp_communities", "label", "policy_type", "address_family", "weight", "description") + _attributes = ( + "bgp_communities", + "label", + "policy_type", + "address_family", + "weight", + "description", + ) name: str label: str | None = None policy_type: str @@ -85,7 +105,13 @@ class InfraBGPRoutingPolicy(InfrahubModel): class InfraIXP(InfrahubModel): _modelname = "InfraIXP" _identifiers = ("name",) - _attributes = ("export_policies", "bgp_communities", "import_policies", "status", "description") + _attributes = ( + "export_policies", + "bgp_communities", + "import_policies", + "status", + "description", + ) name: str status: str | None = "enabled" description: str | None = None @@ -100,7 +126,15 @@ class InfraIXP(InfrahubModel): class InfraIXPConnection(InfrahubModel): _modelname = "InfraIXPConnection" _identifiers = ("name",) - _attributes = ("ipv4_address", "ipv6_address", "internet_exchange_point", "status", "description", "peeringdb_netixlan", "vlan") + _attributes = ( + "ipv4_address", + "ipv6_address", + "internet_exchange_point", + "status", + "description", + "peeringdb_netixlan", + "vlan", + ) status: str | None = "enabled" description: str | None = None peeringdb_netixlan: int | None = None diff --git a/examples/peering-manager_to_infrahub/peeringmanager/sync_models.py b/examples/peering-manager_to_infrahub/peeringmanager/sync_models.py index 29912ff..63db5aa 100644 --- a/examples/peering-manager_to_infrahub/peeringmanager/sync_models.py +++ b/examples/peering-manager_to_infrahub/peeringmanager/sync_models.py @@ -13,7 +13,14 @@ class InfraAutonomousSystem(PeeringmanagerModel): _modelname = "InfraAutonomousSystem" _identifiers = ("asn",) - _attributes = ("affiliated", "name", "irr_as_set", "description", "ipv4_max_prefixes", "ipv6_max_prefixes") + _attributes = ( + "affiliated", + "name", + "irr_as_set", + "description", + "ipv4_max_prefixes", + "ipv6_max_prefixes", + ) affiliated: bool | None = None name: str asn: int @@ -29,7 +36,13 @@ class InfraAutonomousSystem(PeeringmanagerModel): class InfraBGPPeerGroup(PeeringmanagerModel): _modelname = "InfraBGPPeerGroup" _identifiers = ("name",) - _attributes = ("bgp_communities", "import_policies", "export_policies", "status", "description") + _attributes = ( + "bgp_communities", + "import_policies", + "export_policies", + "status", + "description", + ) status: str | None = None name: str description: str | None = None @@ -69,7 +82,14 @@ class InfraBGPCommunity(PeeringmanagerModel): class InfraBGPRoutingPolicy(PeeringmanagerModel): _modelname = "InfraBGPRoutingPolicy" _identifiers = ("name",) - _attributes = ("bgp_communities", "label", "policy_type", "address_family", "weight", "description") + _attributes = ( + "bgp_communities", + "label", + "policy_type", + "address_family", + "weight", + "description", + ) name: str label: str | None = None policy_type: str @@ -85,7 +105,13 @@ class InfraBGPRoutingPolicy(PeeringmanagerModel): class InfraIXP(PeeringmanagerModel): _modelname = "InfraIXP" _identifiers = ("name",) - _attributes = ("export_policies", "bgp_communities", "import_policies", "status", "description") + _attributes = ( + "export_policies", + "bgp_communities", + "import_policies", + "status", + "description", + ) name: str status: str | None = "enabled" description: str | None = None @@ -100,7 +126,15 @@ class InfraIXP(PeeringmanagerModel): class InfraIXPConnection(PeeringmanagerModel): _modelname = "InfraIXPConnection" _identifiers = ("name",) - _attributes = ("ipv4_address", "ipv6_address", "internet_exchange_point", "status", "description", "peeringdb_netixlan", "vlan") + _attributes = ( + "ipv4_address", + "ipv6_address", + "internet_exchange_point", + "status", + "description", + "peeringdb_netixlan", + "vlan", + ) status: str | None = "enabled" description: str | None = None peeringdb_netixlan: int | None = None diff --git a/examples/slurpit_to_infrahub/infrahub/sync_models.py b/examples/slurpit_to_infrahub/infrahub/sync_models.py index f32f88c..676ad26 100644 --- a/examples/slurpit_to_infrahub/infrahub/sync_models.py +++ b/examples/slurpit_to_infrahub/infrahub/sync_models.py @@ -1,10 +1,10 @@ - from __future__ import annotations from typing import Any from infrahub_sync.adapters.infrahub import InfrahubModel + # ------------------------------------------------------- # AUTO-GENERATED FILE, DO NOT MODIFY # This file has been generated with the command `infrahub-sync generate` @@ -24,6 +24,7 @@ class InfraDevice(InfrahubModel): local_id: str | None = None local_data: Any | None = None + class InfraHardwareInfo(InfrahubModel): _modelname = "InfraHardwareInfo" _identifiers = ("device", "serial") @@ -38,6 +39,7 @@ class InfraHardwareInfo(InfrahubModel): local_id: str | None = None local_data: Any | None = None + class InfraIPAddress(InfrahubModel): _modelname = "InfraIPAddress" _identifiers = ("address", "ip_prefix") @@ -49,6 +51,7 @@ class InfraIPAddress(InfrahubModel): local_id: str | None = None local_data: Any | None = None + class InfraInterface(InfrahubModel): _modelname = "InfraInterface" _identifiers = ("device", "name") @@ -61,6 +64,7 @@ class InfraInterface(InfrahubModel): local_id: str | None = None local_data: Any | None = None + class InfraPlatform(InfrahubModel): _modelname = "InfraPlatform" _identifiers = ("name",) @@ -70,6 +74,7 @@ class InfraPlatform(InfrahubModel): local_id: str | None = None local_data: Any | None = None + class InfraPrefix(InfrahubModel): _modelname = "InfraPrefix" _identifiers = ("vrf", "prefix") @@ -80,6 +85,7 @@ class InfraPrefix(InfrahubModel): local_id: str | None = None local_data: Any | None = None + class InfraVLAN(InfrahubModel): _modelname = "InfraVLAN" _identifiers = ("vlan_id", "name") @@ -90,6 +96,7 @@ class InfraVLAN(InfrahubModel): local_id: str | None = None local_data: Any | None = None + class InfraVRF(InfrahubModel): _modelname = "InfraVRF" _identifiers = ("name",) @@ -99,6 +106,7 @@ class InfraVRF(InfrahubModel): local_id: str | None = None local_data: Any | None = None + class InfraVersion(InfrahubModel): _modelname = "InfraVersion" _identifiers = ("version",) @@ -110,10 +118,21 @@ class InfraVersion(InfrahubModel): local_id: str | None = None local_data: Any | None = None + class LocationGeneric(InfrahubModel): _modelname = "LocationGeneric" _identifiers = ("name",) - _attributes = ("description", "number", "street", "city", "county", "state", "zipcode", "country", "phonenumber") + _attributes = ( + "description", + "number", + "street", + "city", + "county", + "state", + "zipcode", + "country", + "phonenumber", + ) name: str description: str | None = None number: str | None = None @@ -128,6 +147,7 @@ class LocationGeneric(InfrahubModel): local_id: str | None = None local_data: Any | None = None + class OrganizationGeneric(InfrahubModel): _modelname = "OrganizationGeneric" _identifiers = ("name",) @@ -138,6 +158,7 @@ class OrganizationGeneric(InfrahubModel): local_id: str | None = None local_data: Any | None = None + class TemplateDeviceType(InfrahubModel): _modelname = "TemplateDeviceType" _identifiers = ("name",) diff --git a/examples/slurpit_to_infrahub/slurpitsync/sync_models.py b/examples/slurpit_to_infrahub/slurpitsync/sync_models.py index 31b06f4..fafb464 100644 --- a/examples/slurpit_to_infrahub/slurpitsync/sync_models.py +++ b/examples/slurpit_to_infrahub/slurpitsync/sync_models.py @@ -1,10 +1,10 @@ - from __future__ import annotations from typing import Any from infrahub_sync.adapters.slurpitsync import SlurpitsyncModel + # ------------------------------------------------------- # AUTO-GENERATED FILE, DO NOT MODIFY # This file has been generated with the command `infrahub-sync generate` @@ -24,6 +24,7 @@ class InfraDevice(SlurpitsyncModel): local_id: str | None = None local_data: Any | None = None + class InfraHardwareInfo(SlurpitsyncModel): _modelname = "InfraHardwareInfo" _identifiers = ("device", "serial") @@ -38,6 +39,7 @@ class InfraHardwareInfo(SlurpitsyncModel): local_id: str | None = None local_data: Any | None = None + class InfraIPAddress(SlurpitsyncModel): _modelname = "InfraIPAddress" _identifiers = ("address", "ip_prefix") @@ -49,6 +51,7 @@ class InfraIPAddress(SlurpitsyncModel): local_id: str | None = None local_data: Any | None = None + class InfraInterface(SlurpitsyncModel): _modelname = "InfraInterface" _identifiers = ("device", "name") @@ -61,6 +64,7 @@ class InfraInterface(SlurpitsyncModel): local_id: str | None = None local_data: Any | None = None + class InfraPlatform(SlurpitsyncModel): _modelname = "InfraPlatform" _identifiers = ("name",) @@ -70,6 +74,7 @@ class InfraPlatform(SlurpitsyncModel): local_id: str | None = None local_data: Any | None = None + class InfraPrefix(SlurpitsyncModel): _modelname = "InfraPrefix" _identifiers = ("vrf", "prefix") @@ -80,6 +85,7 @@ class InfraPrefix(SlurpitsyncModel): local_id: str | None = None local_data: Any | None = None + class InfraVLAN(SlurpitsyncModel): _modelname = "InfraVLAN" _identifiers = ("vlan_id", "name") @@ -90,6 +96,7 @@ class InfraVLAN(SlurpitsyncModel): local_id: str | None = None local_data: Any | None = None + class InfraVRF(SlurpitsyncModel): _modelname = "InfraVRF" _identifiers = ("name",) @@ -99,6 +106,7 @@ class InfraVRF(SlurpitsyncModel): local_id: str | None = None local_data: Any | None = None + class InfraVersion(SlurpitsyncModel): _modelname = "InfraVersion" _identifiers = ("version",) @@ -110,10 +118,21 @@ class InfraVersion(SlurpitsyncModel): local_id: str | None = None local_data: Any | None = None + class LocationGeneric(SlurpitsyncModel): _modelname = "LocationGeneric" _identifiers = ("name",) - _attributes = ("description", "number", "street", "city", "county", "state", "zipcode", "country", "phonenumber") + _attributes = ( + "description", + "number", + "street", + "city", + "county", + "state", + "zipcode", + "country", + "phonenumber", + ) name: str description: str | None = None number: str | None = None @@ -128,6 +147,7 @@ class LocationGeneric(SlurpitsyncModel): local_id: str | None = None local_data: Any | None = None + class OrganizationGeneric(SlurpitsyncModel): _modelname = "OrganizationGeneric" _identifiers = ("name",) @@ -138,6 +158,7 @@ class OrganizationGeneric(SlurpitsyncModel): local_id: str | None = None local_data: Any | None = None + class TemplateDeviceType(SlurpitsyncModel): _modelname = "TemplateDeviceType" _identifiers = ("name",) diff --git a/infrahub_sync/__init__.py b/infrahub_sync/__init__.py index 501163e..8ca6f9b 100644 --- a/infrahub_sync/__init__.py +++ b/infrahub_sync/__init__.py @@ -79,8 +79,12 @@ def convert_to_int(value: Any) -> int: "!=": operator.ne, ">": lambda field, value: operator.gt(convert_to_int(field), convert_to_int(value)), "<": lambda field, value: operator.lt(convert_to_int(field), convert_to_int(value)), - ">=": lambda field, value: operator.ge(convert_to_int(field), convert_to_int(value)), - "<=": lambda field, value: operator.le(convert_to_int(field), convert_to_int(value)), + ">=": lambda field, value: operator.ge( + convert_to_int(field), convert_to_int(value) + ), + "<=": lambda field, value: operator.le( + convert_to_int(field), convert_to_int(value) + ), "in": lambda field, value: value and field in value, "not in": lambda field, value: field not in value, "contains": lambda field, value: field and value in field, @@ -89,7 +93,9 @@ def convert_to_int(value: Any) -> int: "is_not_empty": lambda field: field is not None and field, "regex": lambda field, pattern: re.match(pattern, field) is not None, # Netutils - "is_ip_within": lambda field, value: is_ip_within_filter(ip=field, ip_compare=value), + "is_ip_within": lambda field, value: is_ip_within_filter( + ip=field, ip_compare=value + ), } @@ -125,17 +131,25 @@ def apply_filter(cls, field_value: Any, operation: str, value: Any) -> bool: return operation_func(field_value, value) @classmethod - def apply_filters(cls, item: dict[str, Any], filters: list[SchemaMappingFilter]) -> bool: + def apply_filters( + cls, item: dict[str, Any], filters: list[SchemaMappingFilter] + ) -> bool: """Apply filters to an item and return True if it passes all filters.""" for filter_obj in filters: # Use dot notation to access attributes field_value = get_value(obj=item, name=filter_obj.field) - if not cls.apply_filter(field_value=field_value, operation=filter_obj.operation, value=filter_obj.value): + if not cls.apply_filter( + field_value=field_value, + operation=filter_obj.operation, + value=filter_obj.value, + ): return False return True @classmethod - def apply_transform(cls, item: dict[str, Any], transform_expr: str, field: str) -> None: + def apply_transform( + cls, item: dict[str, Any], transform_expr: str, field: str + ) -> None: """Apply a transformation expression using Jinja2 to a specified field in the item.""" try: # Create a Jinja2 template from the transformation expression @@ -152,7 +166,9 @@ def apply_transform(cls, item: dict[str, Any], transform_expr: str, field: str) raise ValueError(msg) from exc @classmethod - def apply_transforms(cls, item: dict[str, Any], transforms: list[SchemaMappingTransform]) -> dict[str, Any]: + def apply_transforms( + cls, item: dict[str, Any], transforms: list[SchemaMappingTransform] + ) -> dict[str, Any]: """Apply a list of structured transformations to an item.""" for transform_obj in transforms: field = transform_obj.field @@ -161,7 +177,9 @@ def apply_transforms(cls, item: dict[str, Any], transforms: list[SchemaMappingTr return item @classmethod - def filter_records(cls, records: list[dict], schema_mapping: SchemaMappingModel) -> list[dict]: + def filter_records( + cls, records: list[dict], schema_mapping: SchemaMappingModel + ) -> list[dict]: """ Apply filters to the records based on the schema mapping configuration. """ @@ -175,7 +193,9 @@ def filter_records(cls, records: list[dict], schema_mapping: SchemaMappingModel) return filtered_records @classmethod - def transform_records(cls, records: list[dict], schema_mapping: SchemaMappingModel) -> list[dict]: + def transform_records( + cls, records: list[dict], schema_mapping: SchemaMappingModel + ) -> list[dict]: """ Apply transformations to the records based on the schema mapping configuration. """ @@ -184,7 +204,9 @@ def transform_records(cls, records: list[dict], schema_mapping: SchemaMappingMod return records transformed_records = [] for record in records: - transformed_record = cls.apply_transforms(item=record, transforms=transforms) + transformed_record = cls.apply_transforms( + item=record, transforms=transforms + ) transformed_records.append(transformed_record) return transformed_records diff --git a/infrahub_sync/adapters/infrahub.py b/infrahub_sync/adapters/infrahub.py index fe1a474..a43d8f0 100644 --- a/infrahub_sync/adapters/infrahub.py +++ b/infrahub_sync/adapters/infrahub.py @@ -1,13 +1,13 @@ -import sys from __future__ import annotations import copy import os from typing import TYPE_CHECKING, Any -if sys.version_info.minor < 11: - from typing_extensions import Self -else: - from typing import Self + +try: + from typing import Self +except ImportError: + from typing_extensions import Self from diffsync import Adapter, DiffSyncModel @@ -46,12 +46,18 @@ def update_node(node: InfrahubNodeSync, attrs: dict) -> InfrahubNodeSync: if attr_value: if rel_schema.kind != "Generic": peer = node._client.store.get( - key=attr_value, kind=rel_schema.peer, raise_when_missing=False + key=attr_value, + kind=rel_schema.peer, + raise_when_missing=False, ) else: - peer = node._client.store.get(key=attr_value, raise_when_missing=False) + peer = node._client.store.get( + key=attr_value, raise_when_missing=False + ) if not peer: - print(f"Unable to find {rel_schema.peer} [{attr_value}] in the Store - Ignored") + print( + f"Unable to find {rel_schema.peer} [{attr_value}] in the Store - Ignored" + ) continue setattr(node, attr_name, peer) else: @@ -62,9 +68,12 @@ def update_node(node: InfrahubNodeSync, attrs: dict) -> InfrahubNodeSync: attr = getattr(node, attr_name) existing_peer_ids = attr.peer_ids new_peer_ids = [ - node._client.store.get(key=value, kind=rel_schema.peer).id for value in list(attr_value) + node._client.store.get(key=value, kind=rel_schema.peer).id + for value in list(attr_value) ] - _, existing_only, new_only = compare_lists(existing_peer_ids, new_peer_ids) + _, existing_only, new_only = compare_lists( + existing_peer_ids, new_peer_ids + ) for existing_id in existing_only: attr.remove(existing_id) @@ -77,7 +86,10 @@ def update_node(node: InfrahubNodeSync, attrs: dict) -> InfrahubNodeSync: def diffsync_to_infrahub( - ids: Mapping[Any, Any], attrs: Mapping[Any, Any], store: NodeStoreSync, schema: NodeSchema + ids: Mapping[Any, Any], + attrs: Mapping[Any, Any], + store: NodeStoreSync, + schema: NodeSchema, ) -> dict[Any, Any]: data = copy.deepcopy(dict(ids)) data.update(dict(attrs)) @@ -90,11 +102,17 @@ def diffsync_to_infrahub( del data[key] continue if rel_schema.kind != "Generic": - peer = store.get(key=data[key], kind=rel_schema.peer, raise_when_missing=False) + peer = store.get( + key=data[key], + kind=rel_schema.peer, + raise_when_missing=False, + ) else: peer = store.get(key=data[key], raise_when_missing=False) if not peer: - print(f"Unable to find {rel_schema.peer} [{data[key]}] in the Store - Ignored") + print( + f"Unable to find {rel_schema.peer} [{data[key]}] in the Store - Ignored" + ) continue data[key] = peer.id @@ -102,7 +120,10 @@ def diffsync_to_infrahub( if data[key] is None: del data[key] continue - new_values = [store.get(key=value, kind=rel_schema.peer).id for value in list(data[key])] + new_values = [ + store.get(key=value, kind=rel_schema.peer).id + for value in list(data[key]) + ] data[key] = new_values return data @@ -112,14 +133,24 @@ class InfrahubAdapter(DiffSyncMixin, Adapter): type = "Infrahub" def __init__( - self, target: str, adapter: SyncAdapter, config: SyncConfig, branch: str | None = None, *args, **kwargs + self, + target: str, + adapter: SyncAdapter, + config: SyncConfig, + branch: str | None = None, + *args, + **kwargs, ) -> None: super().__init__(*args, **kwargs) self.target = target self.config = config settings = adapter.settings or {} - infrahub_url = os.environ.get("INFRAHUB_ADDRESS") or os.environ.get("INFRAHUB_URL") or settings.get("url") + infrahub_url = ( + os.environ.get("INFRAHUB_ADDRESS") + or os.environ.get("INFRAHUB_URL") + or settings.get("url") + ) infrahub_token = os.environ.get("INFRAHUB_API_TOKEN") or settings.get("token") infrahub_branch = settings.get("branch") or branch @@ -128,7 +159,9 @@ def __init__( raise ValueError(msg) if infrahub_branch: - sdk_config = Config(timeout=60, default_branch=infrahub_branch, api_token=infrahub_token) + sdk_config = Config( + timeout=60, default_branch=infrahub_branch, api_token=infrahub_token + ) else: sdk_config = Config(timeout=60, api_token=infrahub_token) @@ -137,7 +170,9 @@ def __init__( # We need to identify with an account until we have some auth in place remote_account = config.source.name try: - self.account = self.client.get(kind="CoreAccount", name__value=remote_account) + self.account = self.client.get( + kind="CoreAccount", name__value=remote_account + ) except NodeNotFoundError: self.account = None @@ -148,13 +183,17 @@ def model_loader(self, model_name: str, model: InfrahubModel) -> None: This method retrieves data from Infrahub, applies filters and transformations as specified in the schema mapping, and loads the processed data into the adapter. """ - element = next((el for el in self.config.schema_mapping if el.name == model_name), None) + element = next( + (el for el in self.config.schema_mapping if el.name == model_name), None + ) if element: # Retrieve all nodes corresponding to model_name (list of InfrahubNodeSync) nodes = self.client.all(kind=model_name, populate_store=True) # Transform the list of InfrahubNodeSync into a list of (node, dict) tuples - node_dict_pairs = [(node, self.infrahub_node_to_diffsync(node=node)) for node in nodes] + node_dict_pairs = [ + (node, self.infrahub_node_to_diffsync(node=node)) for node in nodes + ] total = len(node_dict_pairs) # Extract the list of dicts for filtering and transforming @@ -162,17 +201,23 @@ def model_loader(self, model_name: str, model: InfrahubModel) -> None: if self.config.source.name.title() == self.type.title(): # Filter records - filtered_objs = model.filter_records(records=list_obj, schema_mapping=element) + filtered_objs = model.filter_records( + records=list_obj, schema_mapping=element + ) print(f"{self.type}: Loading {len(filtered_objs)}/{total} {model_name}") # Transform records - transformed_objs = model.transform_records(records=filtered_objs, schema_mapping=element) + transformed_objs = model.transform_records( + records=filtered_objs, schema_mapping=element + ) else: print(f"{self.type}: Loading all {total} {model_name}") transformed_objs = list_obj # Create model instances after filtering and transforming for transformed_obj in transformed_objs: - original_node = next(node for node, obj in node_dict_pairs if obj == transformed_obj) + original_node = next( + node for node, obj in node_dict_pairs if obj == transformed_obj + ) item = model(**transformed_obj) unique_id = item.get_unique_id() self.client.store.set(key=unique_id, node=original_node) @@ -194,20 +239,30 @@ def infrahub_node_to_diffsync(self, node: InfrahubNodeSync) -> dict: data[attr_name] = attr.value for rel_schema in node._schema.relationships: - if not has_field(config=self.config, name=node._schema.kind, field=rel_schema.name): + if not has_field( + config=self.config, name=node._schema.kind, field=rel_schema.name + ): continue if rel_schema.cardinality == "one": rel = getattr(node, rel_schema.name) if not rel.id: continue if rel_schema.kind != "Generic": - peer_node = self.client.store.get(key=rel.id, kind=rel_schema.peer, raise_when_missing=False) + peer_node = self.client.store.get( + key=rel.id, kind=rel_schema.peer, raise_when_missing=False + ) else: - peer_node = self.client.store.get(key=rel.id, raise_when_missing=False) + peer_node = self.client.store.get( + key=rel.id, raise_when_missing=False + ) if not peer_node: # I am not sure if we should end up here "normaly" - print(f"Debug Unable to find {rel_schema.peer} [{rel.id}] in the Store - Pulling from Infrahub") - peer_node = self.client.get(id=rel.id, kind=rel_schema.peer, populate_store=True) + print( + f"Debug Unable to find {rel_schema.peer} [{rel.id}] in the Store - Pulling from Infrahub" + ) + peer_node = self.client.get( + id=rel.id, kind=rel_schema.peer, populate_store=True + ) if not peer_node: print(f"Unable to find {rel_schema.peer} [{rel.id}]") continue @@ -245,7 +300,9 @@ def create( attrs: Mapping[Any, Any], ) -> Self | None: schema = adapter.client.schema.get(kind=cls.__name__) - data = diffsync_to_infrahub(ids=ids, attrs=attrs, schema=schema, store=adapter.client.store) + data = diffsync_to_infrahub( + ids=ids, attrs=attrs, schema=schema, store=adapter.client.store + ) unique_id = cls(**ids, **attrs).get_unique_id() source_id = None if adapter.account: diff --git a/infrahub_sync/adapters/ipfabricsync.py b/infrahub_sync/adapters/ipfabricsync.py index da6c5f6..2d18b22 100644 --- a/infrahub_sync/adapters/ipfabricsync.py +++ b/infrahub_sync/adapters/ipfabricsync.py @@ -1,12 +1,11 @@ -import sys from __future__ import annotations from typing import TYPE_CHECKING, Any -if sys.version_info.minor < 11: - from typing_extensions import Self -else: - from typing import Self +try: + from typing import Self +except ImportError: + from typing_extensions import Self try: from ipfabric import IPFClient @@ -36,7 +35,9 @@ class IpfabricsyncAdapter(DiffSyncMixin, Adapter): type = "IPFabricsync" - def __init__(self, target: str, adapter: SyncAdapter, config: SyncConfig, *args, **kwargs) -> None: + def __init__( + self, target: str, adapter: SyncAdapter, config: SyncConfig, *args, **kwargs + ) -> None: super().__init__(*args, **kwargs) self.target = target @@ -57,10 +58,19 @@ def _create_ipfabric_client(self, adapter: SyncAdapter) -> IPFClient: def build_mapping(self, reference, obj) -> str: # Get object class and model name from the store - object_class, modelname = self.store._get_object_class_and_model(model=reference) + object_class, modelname = self.store._get_object_class_and_model( + model=reference + ) # Find the schema element matching the model name - schema_element = next((element for element in self.config.schema_mapping if element.name == modelname), None) + schema_element = next( + ( + element + for element in self.config.schema_mapping + if element.name == modelname + ), + None, + ) # Collect all relevant field mappings for identifiers new_identifiers = [] @@ -87,27 +97,39 @@ def model_loader(self, model_name: str, model: IpfabricsyncModel) -> None: if element.name != model_name: continue - table = self.client.fetch_all(element.mapping, filters=ipf_filters.get(element.mapping)) + table = self.client.fetch_all( + element.mapping, filters=ipf_filters.get(element.mapping) + ) print(f"{self.type}: Loading {len(table)} from `{element.mapping}`") total = len(table) if self.config.source.name.title() == self.type.title(): # Filter records - filtered_objs = model.filter_records(records=table, schema_mapping=element) - print(f"{self.type}: Loading {len(filtered_objs)}/{total} {element.mapping}") + filtered_objs = model.filter_records( + records=table, schema_mapping=element + ) + print( + f"{self.type}: Loading {len(filtered_objs)}/{total} {element.mapping}" + ) # Transform records - transformed_objs = model.transform_records(records=filtered_objs, schema_mapping=element) + transformed_objs = model.transform_records( + records=filtered_objs, schema_mapping=element + ) else: print(f"{self.type}: Loading all {total} {element.mapping}") transformed_objs = table for obj in transformed_objs: - data = self.ipfabric_dict_to_diffsync(obj=obj, mapping=element, model=model) + data = self.ipfabric_dict_to_diffsync( + obj=obj, mapping=element, model=model + ) item = model(**data) self.update_or_add_model_instance(item) - def ipfabric_dict_to_diffsync(self, obj: dict, mapping: SchemaMappingModel, model: IpfabricsyncModel) -> dict: # pylint: disable=too-many-branches + def ipfabric_dict_to_diffsync( + self, obj: dict, mapping: SchemaMappingModel, model: IpfabricsyncModel + ) -> dict: # pylint: disable=too-many-branches data: dict[str, Any] = {"local_id": str(obj["id"])} for field in mapping.fields: # pylint: disable=too-many-nested-blocks diff --git a/infrahub_sync/adapters/librenms.py b/infrahub_sync/adapters/librenms.py index ae6e90e..4952234 100644 --- a/infrahub_sync/adapters/librenms.py +++ b/infrahub_sync/adapters/librenms.py @@ -1,13 +1,12 @@ -import sys from __future__ import annotations import os from typing import TYPE_CHECKING, Any -if sys.version_info.minor < 11: - from typing_extensions import Self -else: - from typing import Self +try: + from typing import Self +except ImportError: + from typing_extensions import Self from diffsync import Adapter, DiffSyncModel @@ -29,7 +28,9 @@ class LibrenmsAdapter(DiffSyncMixin, Adapter): type = "LibreNMS" - def __init__(self, target: str, adapter: SyncAdapter, config: SyncConfig, *args, **kwargs) -> None: + def __init__( + self, target: str, adapter: SyncAdapter, config: SyncConfig, *args, **kwargs + ) -> None: super().__init__(*args, **kwargs) self.target = target @@ -38,7 +39,11 @@ def __init__(self, target: str, adapter: SyncAdapter, config: SyncConfig, *args, def _create_rest_client(self, adapter: SyncAdapter) -> RestApiClient: settings = adapter.settings or {} - url = os.environ.get("LIBRENMS_ADDRESS") or os.environ.get("LIBRENMS_URL") or settings.get("url") + url = ( + os.environ.get("LIBRENMS_ADDRESS") + or os.environ.get("LIBRENMS_URL") + or settings.get("url") + ) api_endpoint = settings.get("api_endpoint", "/api/v0") auth_method = settings.get("auth_method", "x-auth-token") api_token = os.environ.get("LIBRENMS_TOKEN") or settings.get("token") @@ -53,7 +58,12 @@ def _create_rest_client(self, adapter: SyncAdapter) -> RestApiClient: raise ValueError(msg) full_base_url = f"{url.rstrip('/')}/{api_endpoint.strip('/')}" - return RestApiClient(base_url=full_base_url, auth_method=auth_method, api_token=api_token, timeout=timeout) + return RestApiClient( + base_url=full_base_url, + auth_method=auth_method, + api_token=api_token, + timeout=timeout, + ) def model_loader(self, model_name: str, model: LibrenmsModel) -> None: """ @@ -81,10 +91,16 @@ def model_loader(self, model_name: str, model: LibrenmsModel) -> None: total = len(objs) if self.config.source.name.title() == self.type.title(): # Filter records - filtered_objs = model.filter_records(records=objs, schema_mapping=element) - print(f"{self.type}: Loading {len(filtered_objs)}/{total} {resource_name}") + filtered_objs = model.filter_records( + records=objs, schema_mapping=element + ) + print( + f"{self.type}: Loading {len(filtered_objs)}/{total} {resource_name}" + ) # Transform records - transformed_objs = model.transform_records(records=filtered_objs, schema_mapping=element) + transformed_objs = model.transform_records( + records=filtered_objs, schema_mapping=element + ) else: print(f"{self.type}: Loading all {total} {resource_name}") transformed_objs = objs @@ -95,7 +111,9 @@ def model_loader(self, model_name: str, model: LibrenmsModel) -> None: item = model(**data) self.add(item) - def obj_to_diffsync(self, obj: dict[str, Any], mapping: SchemaMappingModel, model: LibrenmsModel) -> dict: + def obj_to_diffsync( + self, obj: dict[str, Any], mapping: SchemaMappingModel, model: LibrenmsModel + ) -> dict: obj_id = derive_identifier_key(obj=obj) data: dict[str, Any] = {"local_id": str(obj_id)} @@ -126,7 +144,9 @@ def obj_to_diffsync(self, obj: dict[str, Any], mapping: SchemaMappingModel, mode if isinstance(node, dict): matching_nodes = [] node_id = node.get("id", None) - matching_nodes = [item for item in nodes if item.local_id == str(node_id)] + matching_nodes = [ + item for item in nodes if item.local_id == str(node_id) + ] if len(matching_nodes) == 0: msg = f"Unable to locate the node {model} {node_id}" raise IndexError(msg) @@ -146,9 +166,13 @@ def obj_to_diffsync(self, obj: dict[str, Any], mapping: SchemaMappingModel, mode node_id = node[1] if node[0] == "id" else None if not node_id: continue - matching_nodes = [item for item in nodes if item.local_id == str(node_id)] + matching_nodes = [ + item for item in nodes if item.local_id == str(node_id) + ] if len(matching_nodes) == 0: - msg = f"Unable to locate the node {field.reference} {node_id}" + msg = ( + f"Unable to locate the node {field.reference} {node_id}" + ) raise IndexError(msg) data[field.name].append(matching_nodes[0].get_unique_id()) data[field.name] = sorted(data[field.name]) diff --git a/infrahub_sync/adapters/nautobot.py b/infrahub_sync/adapters/nautobot.py index 92cffd4..01ff1c6 100644 --- a/infrahub_sync/adapters/nautobot.py +++ b/infrahub_sync/adapters/nautobot.py @@ -1,13 +1,12 @@ -import sys from __future__ import annotations # pylint: disable=R0801 import os -from typing import TYPE_CHECKING, Any -if sys.version_info.minor < 11: - from typing_extensions import Self -else: - from typing import Self + +try: + from typing import Self +except ImportError: + from typing_extensions import Self import pynautobot from diffsync import Adapter, DiffSyncModel @@ -29,7 +28,9 @@ class NautobotAdapter(DiffSyncMixin, Adapter): type = "Nautobot" - def __init__(self, target: str, adapter: SyncAdapter, config: SyncConfig, *args, **kwargs) -> None: + def __init__( + self, target: str, adapter: SyncAdapter, config: SyncConfig, *args, **kwargs + ) -> None: super().__init__(*args, **kwargs) self.target = target @@ -38,14 +39,20 @@ def __init__(self, target: str, adapter: SyncAdapter, config: SyncConfig, *args, def _create_nautobot_client(self, adapter: SyncAdapter) -> pynautobot.api: settings = adapter.settings or {} - url = os.environ.get("NAUTOBOT_ADDRESS") or os.environ.get("NAUTOBOT_URL") or settings.get("url") + url = ( + os.environ.get("NAUTOBOT_ADDRESS") + or os.environ.get("NAUTOBOT_URL") + or settings.get("url") + ) token = os.environ.get("NAUTOBOT_TOKEN") or settings.get("token") if not url or not token: msg = "Both url and token must be specified!" raise ValueError(msg) - return pynautobot.api(url, token=token, threading=True, max_workers=5, retries=3) + return pynautobot.api( + url, token=token, threading=True, max_workers=5, retries=3 + ) def model_loader(self, model_name: str, model: NautobotModel) -> None: """ @@ -75,21 +82,31 @@ def model_loader(self, model_name: str, model: NautobotModel) -> None: total = len(list_obj) if self.config.source.name.title() == self.type.title(): # Filter records - filtered_objs = model.filter_records(records=list_obj, schema_mapping=element) - print(f"{self.type}: Loading {len(filtered_objs)}/{total} {resource_name}") + filtered_objs = model.filter_records( + records=list_obj, schema_mapping=element + ) + print( + f"{self.type}: Loading {len(filtered_objs)}/{total} {resource_name}" + ) # Transform records - transformed_objs = model.transform_records(records=filtered_objs, schema_mapping=element) + transformed_objs = model.transform_records( + records=filtered_objs, schema_mapping=element + ) else: print(f"{self.type}: Loading all {total} {resource_name}") transformed_objs = list_obj # Create model instances after filtering and transforming for obj in transformed_objs: - data = self.nautobot_obj_to_diffsync(obj=obj, mapping=element, model=model) + data = self.nautobot_obj_to_diffsync( + obj=obj, mapping=element, model=model + ) item = model(**data) self.add(item) - def nautobot_obj_to_diffsync(self, obj: dict[str, Any], mapping: SchemaMappingModel, model: NautobotModel) -> dict: + def nautobot_obj_to_diffsync( + self, obj: dict[str, Any], mapping: SchemaMappingModel, model: NautobotModel + ) -> dict: obj_id = obj.get("id") data: dict[str, Any] = {"local_id": str(obj_id)} @@ -120,10 +137,14 @@ def nautobot_obj_to_diffsync(self, obj: dict[str, Any], mapping: SchemaMappingMo matching_nodes = [] node_id = node.get("id", None) if node_id: - matching_nodes = [item for item in nodes if item.local_id == str(node_id)] + matching_nodes = [ + item for item in nodes if item.local_id == str(node_id) + ] if len(matching_nodes) == 0: # TODO: If the peer is a Node we are filtering, we could end up not finding it - print(f"Unable to locate the node {field.name} {node_id}") + print( + f"Unable to locate the node {field.name} {node_id}" + ) continue node = matching_nodes[0] data[field.name] = node.get_unique_id() @@ -138,7 +159,9 @@ def nautobot_obj_to_diffsync(self, obj: dict[str, Any], mapping: SchemaMappingMo node_id = node[1] if node[0] == "id" else None if not node_id: continue - matching_nodes = [item for item in nodes if item.local_id == str(node_id)] + matching_nodes = [ + item for item in nodes if item.local_id == str(node_id) + ] if len(matching_nodes) == 0: # TODO: If the peer is a Node we are filtering, we could end up not finding it print(f"Unable to locate the node {field.name} {node_id}") diff --git a/infrahub_sync/adapters/netbox.py b/infrahub_sync/adapters/netbox.py index f14b830..3c5880d 100644 --- a/infrahub_sync/adapters/netbox.py +++ b/infrahub_sync/adapters/netbox.py @@ -1,13 +1,13 @@ -import sys from __future__ import annotations # pylint: disable=R0801 import os from typing import TYPE_CHECKING, Any -if sys.version_info.minor < 11: - from typing_extensions import Self -else: - from typing import Self + +try: + from typing import Self +except ImportError: + from typing_extensions import Self import pynetbox from diffsync import Adapter, DiffSyncModel @@ -28,7 +28,9 @@ class NetboxAdapter(DiffSyncMixin, Adapter): type = "Netbox" - def __init__(self, target: str, adapter: SyncAdapter, config: SyncConfig, *args, **kwargs) -> None: + def __init__( + self, target: str, adapter: SyncAdapter, config: SyncConfig, *args, **kwargs + ) -> None: super().__init__(*args, **kwargs) self.target = target @@ -37,7 +39,11 @@ def __init__(self, target: str, adapter: SyncAdapter, config: SyncConfig, *args, def _create_netbox_client(self, adapter: SyncAdapter) -> pynetbox.api: settings = adapter.settings or {} - url = os.environ.get("NETBOX_ADDRESS") or os.environ.get("NETBOX_URL") or settings.get("url") + url = ( + os.environ.get("NETBOX_ADDRESS") + or os.environ.get("NETBOX_URL") + or settings.get("url") + ) token = os.environ.get("NETBOX_TOKEN") or settings.get("token") if not url or not token: @@ -73,21 +79,31 @@ def model_loader(self, model_name: str, model: NetboxModel) -> None: total = len(list_obj) if self.config.source.name.title() == self.type.title(): # Filter records - filtered_objs = model.filter_records(records=list_obj, schema_mapping=element) - print(f"{self.type}: Loading {len(filtered_objs)}/{total} {resource_name}") + filtered_objs = model.filter_records( + records=list_obj, schema_mapping=element + ) + print( + f"{self.type}: Loading {len(filtered_objs)}/{total} {resource_name}" + ) # Transform records - transformed_objs = model.transform_records(records=filtered_objs, schema_mapping=element) + transformed_objs = model.transform_records( + records=filtered_objs, schema_mapping=element + ) else: print(f"{self.type}: Loading all {total} {resource_name}") transformed_objs = list_obj # Create model instances after filtering and transforming for obj in transformed_objs: - data = self.netbox_obj_to_diffsync(obj=obj, mapping=element, model=model) + data = self.netbox_obj_to_diffsync( + obj=obj, mapping=element, model=model + ) item = model(**data) self.add(item) - def netbox_obj_to_diffsync(self, obj: dict[str, Any], mapping: SchemaMappingModel, model: NetboxModel) -> dict: + def netbox_obj_to_diffsync( + self, obj: dict[str, Any], mapping: SchemaMappingModel, model: NetboxModel + ) -> dict: obj_id = obj.get("id") data: dict[str, Any] = {"local_id": str(obj_id)} @@ -117,9 +133,13 @@ def netbox_obj_to_diffsync(self, obj: dict[str, Any], mapping: SchemaMappingMode if isinstance(node, dict): matching_nodes = [] node_id = node.get("id", None) - matching_nodes = [item for item in nodes if item.local_id == str(node_id)] + matching_nodes = [ + item for item in nodes if item.local_id == str(node_id) + ] if len(matching_nodes) == 0: - msg = f"Unable to locate the node {field.name} {node_id}" + msg = ( + f"Unable to locate the node {field.name} {node_id}" + ) raise IndexError(msg) node = matching_nodes[0] data[field.name] = node.get_unique_id() @@ -135,9 +155,13 @@ def netbox_obj_to_diffsync(self, obj: dict[str, Any], mapping: SchemaMappingMode node_id = node[1] if node[0] == "id" else None if not node_id: continue - matching_nodes = [item for item in nodes if item.local_id == str(node_id)] + matching_nodes = [ + item for item in nodes if item.local_id == str(node_id) + ] if len(matching_nodes) == 0: - msg = f"Unable to locate the node {field.reference} {node_id}" + msg = ( + f"Unable to locate the node {field.reference} {node_id}" + ) raise IndexError(msg) data[field.name].append(matching_nodes[0].get_unique_id()) data[field.name] = sorted(data[field.name]) diff --git a/infrahub_sync/adapters/observium.py b/infrahub_sync/adapters/observium.py index 7b5a235..3361fe4 100644 --- a/infrahub_sync/adapters/observium.py +++ b/infrahub_sync/adapters/observium.py @@ -1,12 +1,12 @@ -import sys from __future__ import annotations import os from typing import TYPE_CHECKING, Any -if sys.version_info.minor < 11: - from typing_extensions import Self -else: - from typing import Self + +try: + from typing import Self +except ImportError: + from typing_extensions import Self from diffsync import Adapter, DiffSyncModel from infrahub_sync import ( @@ -27,7 +27,9 @@ class ObserviumAdapter(DiffSyncMixin, Adapter): type = "Observium" - def __init__(self, target: str, adapter: SyncAdapter, config: SyncConfig, *args, **kwargs) -> None: + def __init__( + self, target: str, adapter: SyncAdapter, config: SyncConfig, *args, **kwargs + ) -> None: super().__init__(*args, **kwargs) self.target = target @@ -36,7 +38,11 @@ def __init__(self, target: str, adapter: SyncAdapter, config: SyncConfig, *args, def _create_rest_client(self, adapter: SyncAdapter) -> RestApiClient: settings = adapter.settings or {} - url = os.environ.get("OBSERVIUM_ADDRESS") or os.environ.get("OBSERVIUM_URL") or settings.get("url") + url = ( + os.environ.get("OBSERVIUM_ADDRESS") + or os.environ.get("OBSERVIUM_URL") + or settings.get("url") + ) api_endpoint = settings.get("api_endpoint", "/api/v0") auth_method = settings.get("auth_method", "basic") api_token = os.environ.get("OBSERVIUM_TOKEN") or settings.get("token") @@ -86,10 +92,16 @@ def model_loader(self, model_name: str, model: ObserviumModel) -> None: total = len(objs) if self.config.source.name.title() == self.type.title(): # Filter records - filtered_objs = model.filter_records(records=objs, schema_mapping=element) - print(f"{self.type}: Loading {len(filtered_objs)}/{total} {resource_name}") + filtered_objs = model.filter_records( + records=objs, schema_mapping=element + ) + print( + f"{self.type}: Loading {len(filtered_objs)}/{total} {resource_name}" + ) # Transform records - transformed_objs = model.transform_records(records=filtered_objs, schema_mapping=element) + transformed_objs = model.transform_records( + records=filtered_objs, schema_mapping=element + ) else: print(f"{self.type}: Loading all {total} {resource_name}") transformed_objs = objs @@ -100,7 +112,9 @@ def model_loader(self, model_name: str, model: ObserviumModel) -> None: item = model(**data) self.add(item) - def obj_to_diffsync(self, obj: dict[str, Any], mapping: SchemaMappingModel, model: ObserviumModel) -> dict: + def obj_to_diffsync( + self, obj: dict[str, Any], mapping: SchemaMappingModel, model: ObserviumModel + ) -> dict: obj_id = derive_identifier_key(obj=obj) data: dict[str, Any] = {"local_id": str(obj_id)} @@ -131,7 +145,9 @@ def obj_to_diffsync(self, obj: dict[str, Any], mapping: SchemaMappingModel, mode if isinstance(node, dict): matching_nodes = [] node_id = node.get("id", None) - matching_nodes = [item for item in nodes if item.local_id == str(node_id)] + matching_nodes = [ + item for item in nodes if item.local_id == str(node_id) + ] if len(matching_nodes) == 0: msg = f"Unable to locate the node {model} {node_id}" raise IndexError(msg) @@ -151,9 +167,13 @@ def obj_to_diffsync(self, obj: dict[str, Any], mapping: SchemaMappingModel, mode node_id = node[1] if node[0] == "id" else None if not node_id: continue - matching_nodes = [item for item in nodes if item.local_id == str(node_id)] + matching_nodes = [ + item for item in nodes if item.local_id == str(node_id) + ] if len(matching_nodes) == 0: - msg = f"Unable to locate the node {field.reference} {node_id}" + msg = ( + f"Unable to locate the node {field.reference} {node_id}" + ) raise IndexError(msg) data[field.name].append(matching_nodes[0].get_unique_id()) data[field.name] = sorted(data[field.name]) diff --git a/infrahub_sync/adapters/peeringmanager.py b/infrahub_sync/adapters/peeringmanager.py index b758283..bc7bdf6 100644 --- a/infrahub_sync/adapters/peeringmanager.py +++ b/infrahub_sync/adapters/peeringmanager.py @@ -1,12 +1,12 @@ -import sys from __future__ import annotations import os from typing import TYPE_CHECKING, Any -if sys.version_info.minor < 11: - from typing_extensions import Self -else: - from typing import Self + +try: + from typing import Self +except ImportError: + from typing_extensions import Self import requests from diffsync import Adapter, DiffSyncModel @@ -28,7 +28,9 @@ class PeeringmanagerAdapter(DiffSyncMixin, Adapter): type = "Peeringmanager" - def __init__(self, target: str, adapter: SyncAdapter, config: SyncConfig, *args, **kwargs) -> None: + def __init__( + self, target: str, adapter: SyncAdapter, config: SyncConfig, *args, **kwargs + ) -> None: super().__init__(*args, **kwargs) self.target = target @@ -37,8 +39,14 @@ def __init__(self, target: str, adapter: SyncAdapter, config: SyncConfig, *args, def _create_rest_client(self, adapter: SyncAdapter) -> RestApiClient: settings = adapter.settings or {} - url = os.environ.get("PEERING_MANAGER_ADDRESS") or os.environ.get("PEERING_MANAGER_URL") or settings.get("url") - api_endpoint = settings.get("api_endpoint", "/api") # Default endpoint, change if necessary + url = ( + os.environ.get("PEERING_MANAGER_ADDRESS") + or os.environ.get("PEERING_MANAGER_URL") + or settings.get("url") + ) + api_endpoint = settings.get( + "api_endpoint", "/api" + ) # Default endpoint, change if necessary auth_method = settings.get("auth_method", "token") api_token = os.environ.get("PEERING_MANAGER_TOKEN") or settings.get("token") timeout = settings.get("timeout", 30) @@ -52,7 +60,12 @@ def _create_rest_client(self, adapter: SyncAdapter) -> RestApiClient: raise ValueError(msg) full_base_url = f"{url.rstrip('/')}/{api_endpoint.strip('/')}" - return RestApiClient(base_url=full_base_url, auth_method=auth_method, api_token=api_token, timeout=timeout) + return RestApiClient( + base_url=full_base_url, + auth_method=auth_method, + api_token=api_token, + timeout=timeout, + ) def model_loader(self, model_name: str, model: PeeringmanagerModel) -> None: """ @@ -80,10 +93,16 @@ def model_loader(self, model_name: str, model: PeeringmanagerModel) -> None: total = len(objs) if self.config.source.name.title() == self.type.title(): # Filter records - filtered_objs = model.filter_records(records=objs, schema_mapping=element) - print(f"{self.type}: Loading {len(filtered_objs)}/{total} {resource_name}") + filtered_objs = model.filter_records( + records=objs, schema_mapping=element + ) + print( + f"{self.type}: Loading {len(filtered_objs)}/{total} {resource_name}" + ) # Transform records - transformed_objs = model.transform_records(records=filtered_objs, schema_mapping=element) + transformed_objs = model.transform_records( + records=filtered_objs, schema_mapping=element + ) else: print(f"{self.type}: Loading all {total} {resource_name}") transformed_objs = objs @@ -94,7 +113,12 @@ def model_loader(self, model_name: str, model: PeeringmanagerModel) -> None: item = model(**data) self.add(item) - def obj_to_diffsync(self, obj: dict[str, Any], mapping: SchemaMappingModel, model: PeeringmanagerModel) -> dict: + def obj_to_diffsync( + self, + obj: dict[str, Any], + mapping: SchemaMappingModel, + model: PeeringmanagerModel, + ) -> dict: obj_id = derive_identifier_key(obj=obj) data: dict[str, Any] = {"local_id": str(obj_id)} @@ -124,9 +148,13 @@ def obj_to_diffsync(self, obj: dict[str, Any], mapping: SchemaMappingModel, mode if isinstance(node, dict): matching_nodes = [] node_id = node.get("id", None) - matching_nodes = [item for item in nodes if item.local_id == str(node_id)] + matching_nodes = [ + item for item in nodes if item.local_id == str(node_id) + ] if len(matching_nodes) == 0: - msg = f"Unable to locate the node {field.name} {node_id}" + msg = ( + f"Unable to locate the node {field.name} {node_id}" + ) raise IndexError(msg) node = matching_nodes[0] data[field.name] = node.get_unique_id() @@ -142,9 +170,13 @@ def obj_to_diffsync(self, obj: dict[str, Any], mapping: SchemaMappingModel, mode node_id = node[1] if node[0] == "id" else None if not node_id: continue - matching_nodes = [item for item in nodes if item.local_id == str(node_id)] + matching_nodes = [ + item for item in nodes if item.local_id == str(node_id) + ] if len(matching_nodes) == 0: - msg = f"Unable to locate the node {field.reference} {node_id}" + msg = ( + f"Unable to locate the node {field.reference} {node_id}" + ) raise IndexError(msg) data[field.name].append(matching_nodes[0].get_unique_id()) data[field.name] = sorted(data[field.name]) @@ -172,10 +204,14 @@ def update(self, attrs: dict) -> Self | None: to the API endpoint of the object. """ # Determine the resource name using the schema mapping - resource_name = self.__class__.get_resource_name(schema_mapping=self.adapter.config.schema_mapping) + resource_name = self.__class__.get_resource_name( + schema_mapping=self.adapter.config.schema_mapping + ) # Determine the unique identifier for the API request - unique_identifier = self.local_id if hasattr(self, "local_id") else self.get_unique_id() + unique_identifier = ( + self.local_id if hasattr(self, "local_id") else self.get_unique_id() + ) endpoint = f"{resource_name}/{unique_identifier}/" # Map incoming attributes to the target attributes based on schema mapping @@ -190,21 +226,34 @@ def update(self, attrs: dict) -> Self | None: # Check if the field is a relationship if field_mapping.reference: - all_nodes_for_reference = self.adapter.store.get_all(model=field_mapping.reference) + all_nodes_for_reference = self.adapter.store.get_all( + model=field_mapping.reference + ) if isinstance(value, list): # For lists, filter nodes to match the unique IDs in the attribute value filtered_nodes = [ - node for node in all_nodes_for_reference if node.get_unique_id() in value + node + for node in all_nodes_for_reference + if node.get_unique_id() in value + ] + mapped_attrs[target_field_name] = [ + node.local_id for node in filtered_nodes ] - mapped_attrs[target_field_name] = [node.local_id for node in filtered_nodes] else: # For single references, find the matching node filtered_node = next( - (node for node in all_nodes_for_reference if node.get_unique_id() == value), None + ( + node + for node in all_nodes_for_reference + if node.get_unique_id() == value + ), + None, ) if filtered_node: - mapped_attrs[target_field_name] = filtered_node.local_id + mapped_attrs[target_field_name] = ( + filtered_node.local_id + ) else: mapped_attrs[target_field_name] = value diff --git a/infrahub_sync/adapters/rest_api_client.py b/infrahub_sync/adapters/rest_api_client.py index fe9af31..2957d7c 100644 --- a/infrahub_sync/adapters/rest_api_client.py +++ b/infrahub_sync/adapters/rest_api_client.py @@ -45,7 +45,11 @@ def __init__( self.timeout = timeout def request( - self, method: str, endpoint: str, params: dict[str, Any] | None = None, data: dict[str, Any] | None = None + self, + method: str, + endpoint: str, + params: dict[str, Any] | None = None, + data: dict[str, Any] | None = None, ) -> Any: """Make a request to the REST API.""" url = f"{self.base_url}/{endpoint.lstrip('/')}" @@ -63,7 +67,12 @@ def request( ) else: response = requests.request( - method=method, url=url, headers=self.headers, params=params, json=data, timeout=self.timeout + method=method, + url=url, + headers=self.headers, + params=params, + json=data, + timeout=self.timeout, ) response.raise_for_status() # Raise an HTTPError for bad responses @@ -71,7 +80,9 @@ def request( try: return response.json() except requests.exceptions.JSONDecodeError as exc: - print("Response content is not valid JSON:", response.text) # Print the response content + print( + "Response content is not valid JSON:", response.text + ) # Print the response content msg = "Response content is not valid JSON." raise ValueError(msg) from exc diff --git a/infrahub_sync/adapters/slurpitsync.py b/infrahub_sync/adapters/slurpitsync.py index 1d3d292..ce8b041 100644 --- a/infrahub_sync/adapters/slurpitsync.py +++ b/infrahub_sync/adapters/slurpitsync.py @@ -1,13 +1,13 @@ -import sys from __future__ import annotations import asyncio import ipaddress from typing import TYPE_CHECKING, Any -if sys.version_info.minor < 11: - from typing_extensions import Self -else: - from typing import Self + +try: + from typing import Self +except ImportError: + from typing_extensions import Self import slurpit from diffsync import Adapter, DiffSyncModel @@ -31,7 +31,9 @@ class SlurpitsyncAdapter(DiffSyncMixin, Adapter): type = "Slurpitsync" - def __init__(self, target: str, adapter: SyncAdapter, config: SyncConfig, *args, **kwargs) -> None: + def __init__( + self, target: str, adapter: SyncAdapter, config: SyncConfig, *args, **kwargs + ) -> None: super().__init__(*args, **kwargs) self.target = target self.client = self._create_slurpit_client(adapter=adapter) @@ -70,8 +72,13 @@ def unique_vendors(self) -> list[dict[str, Any]]: def unique_device_type(self) -> list[dict[str, Any]]: devices = self.run_async(self.client.device.get_devices()) - device_types = {(device.brand, device.device_type, device.device_os) for device in devices} - return [{"brand": item[0], "device_type": item[1], "device_os": item[2]} for item in device_types] + device_types = { + (device.brand, device.device_type, device.device_os) for device in devices + } + return [ + {"brand": item[0], "device_type": item[1], "device_os": item[2]} + for item in device_types + ] def filter_networks(self) -> list: """Filter out networks based on ignore prefixes and normalize network/mask fields.""" @@ -104,7 +111,10 @@ def should_ignore(network) -> bool: net = ipaddress.ip_network(network, strict=False) if net.prefixlen in {32, 128}: return True - return any(net == ipaddress.ip_network(ignore, strict=False) for ignore in ignore_prefixes) + return any( + net == ipaddress.ip_network(ignore, strict=False) + for ignore in ignore_prefixes + ) except ValueError: return False @@ -118,7 +128,12 @@ def should_ignore(network) -> bool: async def filter_interfaces(self, interfaces) -> list: precomputed_filtered_networks = [ - {"network": ipaddress.ip_network(prefix["normalized_prefix"], strict=False), "Vrf": prefix.get("Vrf", None)} + { + "network": ipaddress.ip_network( + prefix["normalized_prefix"], strict=False + ), + "Vrf": prefix.get("Vrf", None), + } for prefix in self.filtered_networks ] @@ -147,7 +162,9 @@ def normalize_and_find_prefix(entry): return entry # Concurrent execution of tasks - tasks = [normalize_and_find_prefix(entry) for entry in interfaces if entry.get("IP")] + tasks = [ + normalize_and_find_prefix(entry) for entry in interfaces if entry.get("IP") + ] # Run tasks concurrently filtered_interfaces = await asyncio.gather(*tasks) @@ -159,13 +176,17 @@ def normalize_and_find_prefix(entry): def planning_results(self, planning_name): plannings = self.run_async(self.client.planning.get_plannings()) - planning = next((plan.to_dict() for plan in plannings if plan.slug == planning_name), None) + planning = next( + (plan.to_dict() for plan in plannings if plan.slug == planning_name), None + ) if not planning: msg = f"No planning found for name: {planning_name}" raise IndexError(msg) search_data = {"planning_id": planning["id"], "unique_results": True} - results = self.run_async(self.client.planning.search_plannings(search_data, limit=30000)) + results = self.run_async( + self.client.planning.search_plannings(search_data, limit=30000) + ) return results or [] def model_loader(self, model_name: str, model: SlurpitsyncModel) -> None: @@ -198,16 +219,24 @@ def model_loader(self, model_name: str, model: SlurpitsyncModel) -> None: if self.config.source.name.title() == self.type.title(): # Filter records - filtered_objs = model.filter_records(records=list_obj, schema_mapping=element) - print(f"{self.type}: Loading {len(filtered_objs)}/{total} {element.mapping}") + filtered_objs = model.filter_records( + records=list_obj, schema_mapping=element + ) + print( + f"{self.type}: Loading {len(filtered_objs)}/{total} {element.mapping}" + ) # Transform records - transformed_objs = model.transform_records(records=filtered_objs, schema_mapping=element) + transformed_objs = model.transform_records( + records=filtered_objs, schema_mapping=element + ) else: print(f"{self.type}: Loading all {total} {resource_name}") transformed_objs = list_obj for obj in transformed_objs: - if data := self.slurpit_obj_to_diffsync(obj=obj, mapping=element, model=model): + if data := self.slurpit_obj_to_diffsync( + obj=obj, mapping=element, model=model + ): item = model(**data) try: # noqa: SIM105 self.add(item) @@ -220,10 +249,19 @@ def model_loader(self, model_name: str, model: SlurpitsyncModel) -> None: # Reuse mapping from another adapter def build_mapping(self, reference, obj): # Get object class and model name from the store - object_class, modelname = self.store._get_object_class_and_model(model=reference) + object_class, modelname = self.store._get_object_class_and_model( + model=reference + ) # Find the schema element matching the model name - schema_element = next((element for element in self.config.schema_mapping if element.name == modelname), None) + schema_element = next( + ( + element + for element in self.config.schema_mapping + if element.name == modelname + ), + None, + ) # Collect all relevant field mappings for identifiers new_identifiers = [] @@ -271,7 +309,9 @@ def slurpit_obj_to_diffsync( if node := obj.get(field.mapping): matching_nodes = [] node_id = self.build_mapping(reference=field.reference, obj=obj) - matching_nodes = [item for item in nodes if str(item) == node_id] + matching_nodes = [ + item for item in nodes if str(item) == node_id + ] if len(matching_nodes) == 0: self.skipped.append(node) return None @@ -283,7 +323,9 @@ def slurpit_obj_to_diffsync( data[field.name] = [] if node := obj.get(field.mapping): node_id = self.build_mapping(reference=field.reference, obj=obj) - matching_nodes = [item for item in nodes if str(item) == node_id] + matching_nodes = [ + item for item in nodes if str(item) == node_id + ] if len(matching_nodes) == 0: self.skipped.append(node) continue diff --git a/infrahub_sync/adapters/utils.py b/infrahub_sync/adapters/utils.py index e890be1..b09b671 100644 --- a/infrahub_sync/adapters/utils.py +++ b/infrahub_sync/adapters/utils.py @@ -14,7 +14,9 @@ def get_value(obj: Any, name: str) -> Any | None: first_name, remaining_part = name.split(".", maxsplit=1) # Check if the object is a dictionary and use appropriate method to access the attribute. - sub_obj = obj.get(first_name) if isinstance(obj, dict) else getattr(obj, first_name, None) + sub_obj = ( + obj.get(first_name) if isinstance(obj, dict) else getattr(obj, first_name, None) + ) if not sub_obj: return None diff --git a/infrahub_sync/cli.py b/infrahub_sync/cli.py index 55ca115..06cb0cd 100644 --- a/infrahub_sync/cli.py +++ b/infrahub_sync/cli.py @@ -32,31 +32,45 @@ def print_error_and_abort(message: str) -> typer.Abort: @app.command(name="list") def list_projects( - directory: str = typer.Option(default=None, help="Base directory to search for sync configurations"), + directory: str = typer.Option( + default=None, help="Base directory to search for sync configurations" + ), ) -> None: """List all available SYNC projects.""" for item in get_all_sync(directory=directory): - console.print(f"{item.name} | {item.source.name} >> {item.destination.name} | {item.directory}") + console.print( + f"{item.name} | {item.source.name} >> {item.destination.name} | {item.directory}" + ) @app.command(name="diff") def diff_cmd( name: str = typer.Option(default=None, help="Name of the sync to use"), - config_file: str = typer.Option(default=None, help="File path to the sync configuration YAML file"), - directory: str = typer.Option(default=None, help="Base directory to search for sync configurations"), + config_file: str = typer.Option( + default=None, help="File path to the sync configuration YAML file" + ), + directory: str = typer.Option( + default=None, help="Base directory to search for sync configurations" + ), branch: str = typer.Option(default=None, help="Branch to use for the diff."), - show_progress: bool = typer.Option(default=True, help="Show a progress bar during diff"), + show_progress: bool = typer.Option( + default=True, help="Show a progress bar during diff" + ), ) -> None: """Calculate and print the differences between the source and the destination systems for a given project.""" if sum([bool(name), bool(config_file)]) != 1: print_error_and_abort("Please specify exactly one of 'name' or 'config-file'.") - sync_instance = get_instance(name=name, config_file=config_file, directory=directory) + sync_instance = get_instance( + name=name, config_file=config_file, directory=directory + ) if not sync_instance: print_error_and_abort("Failed to load sync instance.") try: - ptd = get_potenda_from_instance(sync_instance=sync_instance, branch=branch, show_progress=show_progress) + ptd = get_potenda_from_instance( + sync_instance=sync_instance, branch=branch, show_progress=show_progress + ) except ValueError as exc: print_error_and_abort(f"Failed to initialize the Sync Instance: {exc}") try: @@ -73,24 +87,35 @@ def diff_cmd( @app.command(name="sync") def sync_cmd( name: str = typer.Option(default=None, help="Name of the sync to use"), - config_file: str = typer.Option(default=None, help="File path to the sync configuration YAML file"), - directory: str = typer.Option(default=None, help="Base directory to search for sync configurations"), + config_file: str = typer.Option( + default=None, help="File path to the sync configuration YAML file" + ), + directory: str = typer.Option( + default=None, help="Base directory to search for sync configurations" + ), branch: str = typer.Option(default=None, help="Branch to use for the sync."), diff: bool = typer.Option( - default=True, help="Print the differences between the source and the destination before syncing" + default=True, + help="Print the differences between the source and the destination before syncing", + ), + show_progress: bool = typer.Option( + default=True, help="Show a progress bar during syncing" ), - show_progress: bool = typer.Option(default=True, help="Show a progress bar during syncing"), ) -> None: """Synchronize the data between source and the destination systems for a given project or configuration file.""" if sum([bool(name), bool(config_file)]) != 1: print_error_and_abort("Please specify exactly one of 'name' or 'config-file'.") - sync_instance = get_instance(name=name, config_file=config_file, directory=directory) + sync_instance = get_instance( + name=name, config_file=config_file, directory=directory + ) if not sync_instance: print_error_and_abort("Failed to load sync instance.") try: - ptd = get_potenda_from_instance(sync_instance=sync_instance, branch=branch, show_progress=show_progress) + ptd = get_potenda_from_instance( + sync_instance=sync_instance, branch=branch, show_progress=show_progress + ) except ValueError as exc: print_error_and_abort(f"Failed to initialize the Sync Instance: {exc}") try: @@ -115,8 +140,12 @@ def sync_cmd( @app.command(name="generate") def generate( name: str = typer.Option(default=None, help="Name of the sync to use"), - config_file: str = typer.Option(default=None, help="File path to the sync configuration YAML file"), - directory: str = typer.Option(default=None, help="Base directory to search for sync configurations"), + config_file: str = typer.Option( + default=None, help="File path to the sync configuration YAML file" + ), + directory: str = typer.Option( + default=None, help="Base directory to search for sync configurations" + ), branch: str = typer.Option(default=None, help="Branch to use for the sync."), ) -> None: """Generate all the python files for a given sync based on the configuration.""" @@ -124,21 +153,32 @@ def generate( if sum([bool(name), bool(config_file)]) != 1: print_error_and_abort("Please specify exactly one of 'name' or 'config_file'.") - sync_instance: SyncInstance = get_instance(name=name, config_file=config_file, directory=directory) + sync_instance: SyncInstance = get_instance( + name=name, config_file=config_file, directory=directory + ) if not sync_instance: - print_error_and_abort(f"Unable to find the sync {name}. Use the list command to see the sync available") + print_error_and_abort( + f"Unable to find the sync {name}. Use the list command to see the sync available" + ) # Check if the destination is infrahub infrahub_address = "" # Determine if infrahub is in source or destination # We are using the destination as the "constraint", if there is 2 infrahubs instance sdk_config = None - if sync_instance.destination.name == "infrahub" and sync_instance.destination.settings: + if ( + sync_instance.destination.name == "infrahub" + and sync_instance.destination.settings + ): infrahub_address = sync_instance.destination.settings.get("url") or "" - sdk_config = get_infrahub_config(settings=sync_instance.destination.settings, branch=branch) + sdk_config = get_infrahub_config( + settings=sync_instance.destination.settings, branch=branch + ) elif sync_instance.source.name == "infrahub" and sync_instance.source.settings: infrahub_address = sync_instance.source.settings.get("url") or "" - sdk_config = get_infrahub_config(settings=sync_instance.source.settings, branch=branch) + sdk_config = get_infrahub_config( + settings=sync_instance.source.settings, branch=branch + ) # Initialize InfrahubClientSync if address and config are available client = InfrahubClientSync(address=infrahub_address, config=sdk_config) @@ -148,9 +188,13 @@ def generate( except ServerNotResponsiveError as exc: print_error_and_abort(str(exc)) - missing_schema_models = find_missing_schema_model(sync_instance=sync_instance, schema=schema) + missing_schema_models = find_missing_schema_model( + sync_instance=sync_instance, schema=schema + ) if missing_schema_models: - print_error_and_abort(f"One or more model model are not present in the Schema - {missing_schema_models}") + print_error_and_abort( + f"One or more model model are not present in the Schema - {missing_schema_models}" + ) rendered_files = render_adapter(sync_instance=sync_instance, schema=schema) for template, output_path in rendered_files: diff --git a/infrahub_sync/generator/__init__.py b/infrahub_sync/generator/__init__.py index e234d63..304616a 100644 --- a/infrahub_sync/generator/__init__.py +++ b/infrahub_sync/generator/__init__.py @@ -62,14 +62,18 @@ def get_identifiers(node: NodeSchema, config: SyncConfig) -> list[str] | None: """Return the identifiers that should be used by DiffSync.""" config_identifiers = [ - item.identifiers for item in config.schema_mapping if item.name == node.kind and item.identifiers + item.identifiers + for item in config.schema_mapping + if item.name == node.kind and item.identifiers ] if config_identifiers: return config_identifiers[0] identifiers = [ - attr.name for attr in node.attributes if attr.unique and has_field(config, name=node.kind, field=attr.name) + attr.name + for attr in node.attributes + if attr.unique and has_field(config, name=node.kind, field=attr.name) ] if not identifiers: @@ -80,18 +84,25 @@ def get_identifiers(node: NodeSchema, config: SyncConfig) -> list[str] | None: def get_attributes(node: NodeSchema, config: SyncConfig) -> list[str] | None: """Return the attributes that should be used by DiffSync.""" - attrs_attributes = [attr.name for attr in node.attributes if has_field(config, name=node.kind, field=attr.name)] + attrs_attributes = [ + attr.name + for attr in node.attributes + if has_field(config, name=node.kind, field=attr.name) + ] rels_identifiers = [ rel.name for rel in node.relationships - if rel.kind != RelationshipKind.COMPONENT and has_field(config, name=node.kind, field=rel.name) + if rel.kind != RelationshipKind.COMPONENT + and has_field(config, name=node.kind, field=rel.name) ] identifiers = get_identifiers(node=node, config=config) if not identifiers: return None - attributes = [item for item in rels_identifiers + attrs_attributes if item not in identifiers] + attributes = [ + item for item in rels_identifiers + attrs_attributes if item not in identifiers + ] if not attributes: return None @@ -150,7 +161,9 @@ def has_children(node: NodeSchema, config: SyncConfig) -> bool: return bool(get_children(config=config, node=node)) -def render_template(template_file: Path, output_dir: Path, output_file: Path, context: dict[str, Any]) -> None: +def render_template( + template_file: Path, output_dir: Path, output_file: Path, context: dict[str, Any] +) -> None: template_loader = jinja2.PackageLoader("infrahub_sync", "generator/templates") template_env = jinja2.Environment( loader=template_loader, diff --git a/infrahub_sync/potenda/__init__.py b/infrahub_sync/potenda/__init__.py index 8d7452c..a6a3903 100644 --- a/infrahub_sync/potenda/__init__.py +++ b/infrahub_sync/potenda/__init__.py @@ -79,9 +79,15 @@ def load(self): def diff(self) -> Diff: print(f"Diff: Comparing data from {self.source} to {self.destination}") self.progress_bar = None - return self.destination.diff_from(self.source, flags=self.flags, callback=self._print_callback) + return self.destination.diff_from( + self.source, flags=self.flags, callback=self._print_callback + ) def sync(self, diff: Diff | None = None): - print(f"Sync: Importing data from {self.source} to {self.destination} based on Diff") + print( + f"Sync: Importing data from {self.source} to {self.destination} based on Diff" + ) self.progress_bar = None - return self.destination.sync_from(self.source, diff=diff, flags=self.flags, callback=self._print_callback) + return self.destination.sync_from( + self.source, diff=diff, flags=self.flags, callback=self._print_callback + ) diff --git a/infrahub_sync/utils.py b/infrahub_sync/utils.py index 216027b..30ef912 100644 --- a/infrahub_sync/utils.py +++ b/infrahub_sync/utils.py @@ -21,7 +21,8 @@ def find_missing_schema_model( - sync_instance: SyncInstance, schema: MutableMapping[str, Union[NodeSchema, GenericSchema]] + sync_instance: SyncInstance, + schema: MutableMapping[str, Union[NodeSchema, GenericSchema]], ) -> list[str]: missing_schema_models = [] for item in sync_instance.schema_mapping: @@ -34,7 +35,8 @@ def find_missing_schema_model( def render_adapter( - sync_instance: SyncInstance, schema: MutableMapping[str, Union[NodeSchema, GenericSchema]] + sync_instance: SyncInstance, + schema: MutableMapping[str, Union[NodeSchema, GenericSchema]], ) -> list[tuple[str, str]]: files_to_render = ( ("diffsync_models.j2", "sync_models.py"), @@ -70,7 +72,9 @@ def import_adapter(sync_instance: SyncInstance, adapter: SyncAdapter): try: adapter_name = f"{adapter.name.title()}Sync" - spec = importlib.util.spec_from_file_location(f"{adapter.name}.adapter", str(adapter_file_path)) + spec = importlib.util.spec_from_file_location( + f"{adapter.name}.adapter", str(adapter_file_path) + ) adapter_module = importlib.util.module_from_spec(spec) sys.modules[f"{adapter.name}.adapter"] = adapter_module spec.loader.exec_module(adapter_module) @@ -103,7 +107,9 @@ def get_all_sync(directory: str | None = None) -> list[SyncInstance]: def get_instance( - name: str | None = None, config_file: str | None = "config.yml", directory: str | None = None + name: str | None = None, + config_file: str | None = "config.yml", + directory: str | None = None, ) -> SyncInstance | None: if name: all_sync_instances = get_all_sync(directory=directory) @@ -133,25 +139,35 @@ def get_instance( def get_potenda_from_instance( - sync_instance: SyncInstance, branch: str | None = None, show_progress: bool | None = True + sync_instance: SyncInstance, + branch: str | None = None, + show_progress: bool | None = True, ) -> Potenda: source = import_adapter(sync_instance=sync_instance, adapter=sync_instance.source) - destination = import_adapter(sync_instance=sync_instance, adapter=sync_instance.destination) + destination = import_adapter( + sync_instance=sync_instance, adapter=sync_instance.destination + ) source_store = LocalStore() destination_store = LocalStore() if sync_instance.store and sync_instance.store.type == "redis": - if sync_instance.store.settings and isinstance(sync_instance.store.settings, dict): + if sync_instance.store.settings and isinstance( + sync_instance.store.settings, dict + ): redis_settings = sync_instance.store.settings source_store = RedisStore(**redis_settings, name=sync_instance.source.name) - destination_store = RedisStore(**redis_settings, name=sync_instance.destination.name) + destination_store = RedisStore( + **redis_settings, name=sync_instance.destination.name + ) else: source_store = RedisStore(name=sync_instance.source.name) destination_store = RedisStore(name=sync_instance.destination.name) try: if sync_instance.source.name == "infrahub": - settings_branch = sync_instance.source.settings.get("branch") or branch or "main" + settings_branch = ( + sync_instance.source.settings.get("branch") or branch or "main" + ) src: SyncInstance = source( config=sync_instance, target="source", @@ -171,7 +187,9 @@ def get_potenda_from_instance( raise ValueError(msg) from exc try: if sync_instance.destination.name == "infrahub": - settings_branch = sync_instance.source.settings.get("branch") or branch or "main" + settings_branch = ( + sync_instance.source.settings.get("branch") or branch or "main" + ) dst: SyncInstance = destination( config=sync_instance, target="destination", diff --git a/tasks/docs.py b/tasks/docs.py index 6e7f89e..447017a 100644 --- a/tasks/docs.py +++ b/tasks/docs.py @@ -28,7 +28,9 @@ def _generate_infrahubsync_documentation(context: Context) -> None: @task def markdownlint(context: Context) -> None: - has_markdownlint = check_if_command_available(context=context, command_name="markdownlint-cli2") + has_markdownlint = check_if_command_available( + context=context, command_name="markdownlint-cli2" + ) if not has_markdownlint: print("Warning, markdownlint-cli2 is not installed") From 50f3aea714d5c5f016a11f6463e14e031e49e4ca Mon Sep 17 00:00:00 2001 From: Tomasz Zajac Date: Wed, 12 Feb 2025 16:40:51 +0100 Subject: [PATCH 3/4] fixed ruff --- infrahub_sync/adapters/nautobot.py | 1 + infrahub_sync/adapters/slurpitsync.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/infrahub_sync/adapters/nautobot.py b/infrahub_sync/adapters/nautobot.py index 01ff1c6..533e1b5 100644 --- a/infrahub_sync/adapters/nautobot.py +++ b/infrahub_sync/adapters/nautobot.py @@ -2,6 +2,7 @@ # pylint: disable=R0801 import os +from typing import TYPE_CHECKING, Any try: from typing import Self diff --git a/infrahub_sync/adapters/slurpitsync.py b/infrahub_sync/adapters/slurpitsync.py index ce8b041..e32a404 100644 --- a/infrahub_sync/adapters/slurpitsync.py +++ b/infrahub_sync/adapters/slurpitsync.py @@ -5,9 +5,9 @@ from typing import TYPE_CHECKING, Any try: - from typing import Self + from typing import Any, Self except ImportError: - from typing_extensions import Self + from typing_extensions import Any, Self import slurpit from diffsync import Adapter, DiffSyncModel From 01d920c8e571c64c5905b53fbf82ce8375e6d432 Mon Sep 17 00:00:00 2001 From: Tomasz Zajac Date: Wed, 12 Feb 2025 16:45:02 +0100 Subject: [PATCH 4/4] fixed ruff format --- infrahub_sync/__init__.py | 36 +++------- infrahub_sync/adapters/infrahub.py | 82 ++++++---------------- infrahub_sync/adapters/ipfabricsync.py | 38 +++------- infrahub_sync/adapters/librenms.py | 38 +++------- infrahub_sync/adapters/nautobot.py | 46 +++--------- infrahub_sync/adapters/netbox.py | 46 +++--------- infrahub_sync/adapters/observium.py | 38 +++------- infrahub_sync/adapters/peeringmanager.py | 72 +++++-------------- infrahub_sync/adapters/rest_api_client.py | 4 +- infrahub_sync/adapters/slurpitsync.py | 68 +++++------------- infrahub_sync/adapters/utils.py | 4 +- infrahub_sync/cli.py | 85 ++++++----------------- infrahub_sync/generator/__init__.py | 25 ++----- infrahub_sync/potenda/__init__.py | 12 +--- infrahub_sync/utils.py | 24 ++----- tasks/docs.py | 4 +- 16 files changed, 150 insertions(+), 472 deletions(-) diff --git a/infrahub_sync/__init__.py b/infrahub_sync/__init__.py index 8ca6f9b..05b9876 100644 --- a/infrahub_sync/__init__.py +++ b/infrahub_sync/__init__.py @@ -79,12 +79,8 @@ def convert_to_int(value: Any) -> int: "!=": operator.ne, ">": lambda field, value: operator.gt(convert_to_int(field), convert_to_int(value)), "<": lambda field, value: operator.lt(convert_to_int(field), convert_to_int(value)), - ">=": lambda field, value: operator.ge( - convert_to_int(field), convert_to_int(value) - ), - "<=": lambda field, value: operator.le( - convert_to_int(field), convert_to_int(value) - ), + ">=": lambda field, value: operator.ge(convert_to_int(field), convert_to_int(value)), + "<=": lambda field, value: operator.le(convert_to_int(field), convert_to_int(value)), "in": lambda field, value: value and field in value, "not in": lambda field, value: field not in value, "contains": lambda field, value: field and value in field, @@ -93,9 +89,7 @@ def convert_to_int(value: Any) -> int: "is_not_empty": lambda field: field is not None and field, "regex": lambda field, pattern: re.match(pattern, field) is not None, # Netutils - "is_ip_within": lambda field, value: is_ip_within_filter( - ip=field, ip_compare=value - ), + "is_ip_within": lambda field, value: is_ip_within_filter(ip=field, ip_compare=value), } @@ -131,9 +125,7 @@ def apply_filter(cls, field_value: Any, operation: str, value: Any) -> bool: return operation_func(field_value, value) @classmethod - def apply_filters( - cls, item: dict[str, Any], filters: list[SchemaMappingFilter] - ) -> bool: + def apply_filters(cls, item: dict[str, Any], filters: list[SchemaMappingFilter]) -> bool: """Apply filters to an item and return True if it passes all filters.""" for filter_obj in filters: # Use dot notation to access attributes @@ -147,9 +139,7 @@ def apply_filters( return True @classmethod - def apply_transform( - cls, item: dict[str, Any], transform_expr: str, field: str - ) -> None: + def apply_transform(cls, item: dict[str, Any], transform_expr: str, field: str) -> None: """Apply a transformation expression using Jinja2 to a specified field in the item.""" try: # Create a Jinja2 template from the transformation expression @@ -166,9 +156,7 @@ def apply_transform( raise ValueError(msg) from exc @classmethod - def apply_transforms( - cls, item: dict[str, Any], transforms: list[SchemaMappingTransform] - ) -> dict[str, Any]: + def apply_transforms(cls, item: dict[str, Any], transforms: list[SchemaMappingTransform]) -> dict[str, Any]: """Apply a list of structured transformations to an item.""" for transform_obj in transforms: field = transform_obj.field @@ -177,9 +165,7 @@ def apply_transforms( return item @classmethod - def filter_records( - cls, records: list[dict], schema_mapping: SchemaMappingModel - ) -> list[dict]: + def filter_records(cls, records: list[dict], schema_mapping: SchemaMappingModel) -> list[dict]: """ Apply filters to the records based on the schema mapping configuration. """ @@ -193,9 +179,7 @@ def filter_records( return filtered_records @classmethod - def transform_records( - cls, records: list[dict], schema_mapping: SchemaMappingModel - ) -> list[dict]: + def transform_records(cls, records: list[dict], schema_mapping: SchemaMappingModel) -> list[dict]: """ Apply transformations to the records based on the schema mapping configuration. """ @@ -204,9 +188,7 @@ def transform_records( return records transformed_records = [] for record in records: - transformed_record = cls.apply_transforms( - item=record, transforms=transforms - ) + transformed_record = cls.apply_transforms(item=record, transforms=transforms) transformed_records.append(transformed_record) return transformed_records diff --git a/infrahub_sync/adapters/infrahub.py b/infrahub_sync/adapters/infrahub.py index a43d8f0..735df4c 100644 --- a/infrahub_sync/adapters/infrahub.py +++ b/infrahub_sync/adapters/infrahub.py @@ -51,13 +51,9 @@ def update_node(node: InfrahubNodeSync, attrs: dict) -> InfrahubNodeSync: raise_when_missing=False, ) else: - peer = node._client.store.get( - key=attr_value, raise_when_missing=False - ) + peer = node._client.store.get(key=attr_value, raise_when_missing=False) if not peer: - print( - f"Unable to find {rel_schema.peer} [{attr_value}] in the Store - Ignored" - ) + print(f"Unable to find {rel_schema.peer} [{attr_value}] in the Store - Ignored") continue setattr(node, attr_name, peer) else: @@ -68,12 +64,9 @@ def update_node(node: InfrahubNodeSync, attrs: dict) -> InfrahubNodeSync: attr = getattr(node, attr_name) existing_peer_ids = attr.peer_ids new_peer_ids = [ - node._client.store.get(key=value, kind=rel_schema.peer).id - for value in list(attr_value) + node._client.store.get(key=value, kind=rel_schema.peer).id for value in list(attr_value) ] - _, existing_only, new_only = compare_lists( - existing_peer_ids, new_peer_ids - ) + _, existing_only, new_only = compare_lists(existing_peer_ids, new_peer_ids) for existing_id in existing_only: attr.remove(existing_id) @@ -110,9 +103,7 @@ def diffsync_to_infrahub( else: peer = store.get(key=data[key], raise_when_missing=False) if not peer: - print( - f"Unable to find {rel_schema.peer} [{data[key]}] in the Store - Ignored" - ) + print(f"Unable to find {rel_schema.peer} [{data[key]}] in the Store - Ignored") continue data[key] = peer.id @@ -120,10 +111,7 @@ def diffsync_to_infrahub( if data[key] is None: del data[key] continue - new_values = [ - store.get(key=value, kind=rel_schema.peer).id - for value in list(data[key]) - ] + new_values = [store.get(key=value, kind=rel_schema.peer).id for value in list(data[key])] data[key] = new_values return data @@ -146,11 +134,7 @@ def __init__( self.config = config settings = adapter.settings or {} - infrahub_url = ( - os.environ.get("INFRAHUB_ADDRESS") - or os.environ.get("INFRAHUB_URL") - or settings.get("url") - ) + infrahub_url = os.environ.get("INFRAHUB_ADDRESS") or os.environ.get("INFRAHUB_URL") or settings.get("url") infrahub_token = os.environ.get("INFRAHUB_API_TOKEN") or settings.get("token") infrahub_branch = settings.get("branch") or branch @@ -159,9 +143,7 @@ def __init__( raise ValueError(msg) if infrahub_branch: - sdk_config = Config( - timeout=60, default_branch=infrahub_branch, api_token=infrahub_token - ) + sdk_config = Config(timeout=60, default_branch=infrahub_branch, api_token=infrahub_token) else: sdk_config = Config(timeout=60, api_token=infrahub_token) @@ -170,9 +152,7 @@ def __init__( # We need to identify with an account until we have some auth in place remote_account = config.source.name try: - self.account = self.client.get( - kind="CoreAccount", name__value=remote_account - ) + self.account = self.client.get(kind="CoreAccount", name__value=remote_account) except NodeNotFoundError: self.account = None @@ -183,17 +163,13 @@ def model_loader(self, model_name: str, model: InfrahubModel) -> None: This method retrieves data from Infrahub, applies filters and transformations as specified in the schema mapping, and loads the processed data into the adapter. """ - element = next( - (el for el in self.config.schema_mapping if el.name == model_name), None - ) + element = next((el for el in self.config.schema_mapping if el.name == model_name), None) if element: # Retrieve all nodes corresponding to model_name (list of InfrahubNodeSync) nodes = self.client.all(kind=model_name, populate_store=True) # Transform the list of InfrahubNodeSync into a list of (node, dict) tuples - node_dict_pairs = [ - (node, self.infrahub_node_to_diffsync(node=node)) for node in nodes - ] + node_dict_pairs = [(node, self.infrahub_node_to_diffsync(node=node)) for node in nodes] total = len(node_dict_pairs) # Extract the list of dicts for filtering and transforming @@ -201,23 +177,17 @@ def model_loader(self, model_name: str, model: InfrahubModel) -> None: if self.config.source.name.title() == self.type.title(): # Filter records - filtered_objs = model.filter_records( - records=list_obj, schema_mapping=element - ) + filtered_objs = model.filter_records(records=list_obj, schema_mapping=element) print(f"{self.type}: Loading {len(filtered_objs)}/{total} {model_name}") # Transform records - transformed_objs = model.transform_records( - records=filtered_objs, schema_mapping=element - ) + transformed_objs = model.transform_records(records=filtered_objs, schema_mapping=element) else: print(f"{self.type}: Loading all {total} {model_name}") transformed_objs = list_obj # Create model instances after filtering and transforming for transformed_obj in transformed_objs: - original_node = next( - node for node, obj in node_dict_pairs if obj == transformed_obj - ) + original_node = next(node for node, obj in node_dict_pairs if obj == transformed_obj) item = model(**transformed_obj) unique_id = item.get_unique_id() self.client.store.set(key=unique_id, node=original_node) @@ -239,30 +209,20 @@ def infrahub_node_to_diffsync(self, node: InfrahubNodeSync) -> dict: data[attr_name] = attr.value for rel_schema in node._schema.relationships: - if not has_field( - config=self.config, name=node._schema.kind, field=rel_schema.name - ): + if not has_field(config=self.config, name=node._schema.kind, field=rel_schema.name): continue if rel_schema.cardinality == "one": rel = getattr(node, rel_schema.name) if not rel.id: continue if rel_schema.kind != "Generic": - peer_node = self.client.store.get( - key=rel.id, kind=rel_schema.peer, raise_when_missing=False - ) + peer_node = self.client.store.get(key=rel.id, kind=rel_schema.peer, raise_when_missing=False) else: - peer_node = self.client.store.get( - key=rel.id, raise_when_missing=False - ) + peer_node = self.client.store.get(key=rel.id, raise_when_missing=False) if not peer_node: # I am not sure if we should end up here "normaly" - print( - f"Debug Unable to find {rel_schema.peer} [{rel.id}] in the Store - Pulling from Infrahub" - ) - peer_node = self.client.get( - id=rel.id, kind=rel_schema.peer, populate_store=True - ) + print(f"Debug Unable to find {rel_schema.peer} [{rel.id}] in the Store - Pulling from Infrahub") + peer_node = self.client.get(id=rel.id, kind=rel_schema.peer, populate_store=True) if not peer_node: print(f"Unable to find {rel_schema.peer} [{rel.id}]") continue @@ -300,9 +260,7 @@ def create( attrs: Mapping[Any, Any], ) -> Self | None: schema = adapter.client.schema.get(kind=cls.__name__) - data = diffsync_to_infrahub( - ids=ids, attrs=attrs, schema=schema, store=adapter.client.store - ) + data = diffsync_to_infrahub(ids=ids, attrs=attrs, schema=schema, store=adapter.client.store) unique_id = cls(**ids, **attrs).get_unique_id() source_id = None if adapter.account: diff --git a/infrahub_sync/adapters/ipfabricsync.py b/infrahub_sync/adapters/ipfabricsync.py index 2d18b22..01c1bb1 100644 --- a/infrahub_sync/adapters/ipfabricsync.py +++ b/infrahub_sync/adapters/ipfabricsync.py @@ -35,9 +35,7 @@ class IpfabricsyncAdapter(DiffSyncMixin, Adapter): type = "IPFabricsync" - def __init__( - self, target: str, adapter: SyncAdapter, config: SyncConfig, *args, **kwargs - ) -> None: + def __init__(self, target: str, adapter: SyncAdapter, config: SyncConfig, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.target = target @@ -58,17 +56,11 @@ def _create_ipfabric_client(self, adapter: SyncAdapter) -> IPFClient: def build_mapping(self, reference, obj) -> str: # Get object class and model name from the store - object_class, modelname = self.store._get_object_class_and_model( - model=reference - ) + object_class, modelname = self.store._get_object_class_and_model(model=reference) # Find the schema element matching the model name schema_element = next( - ( - element - for element in self.config.schema_mapping - if element.name == modelname - ), + (element for element in self.config.schema_mapping if element.name == modelname), None, ) @@ -97,39 +89,27 @@ def model_loader(self, model_name: str, model: IpfabricsyncModel) -> None: if element.name != model_name: continue - table = self.client.fetch_all( - element.mapping, filters=ipf_filters.get(element.mapping) - ) + table = self.client.fetch_all(element.mapping, filters=ipf_filters.get(element.mapping)) print(f"{self.type}: Loading {len(table)} from `{element.mapping}`") total = len(table) if self.config.source.name.title() == self.type.title(): # Filter records - filtered_objs = model.filter_records( - records=table, schema_mapping=element - ) - print( - f"{self.type}: Loading {len(filtered_objs)}/{total} {element.mapping}" - ) + filtered_objs = model.filter_records(records=table, schema_mapping=element) + print(f"{self.type}: Loading {len(filtered_objs)}/{total} {element.mapping}") # Transform records - transformed_objs = model.transform_records( - records=filtered_objs, schema_mapping=element - ) + transformed_objs = model.transform_records(records=filtered_objs, schema_mapping=element) else: print(f"{self.type}: Loading all {total} {element.mapping}") transformed_objs = table for obj in transformed_objs: - data = self.ipfabric_dict_to_diffsync( - obj=obj, mapping=element, model=model - ) + data = self.ipfabric_dict_to_diffsync(obj=obj, mapping=element, model=model) item = model(**data) self.update_or_add_model_instance(item) - def ipfabric_dict_to_diffsync( - self, obj: dict, mapping: SchemaMappingModel, model: IpfabricsyncModel - ) -> dict: # pylint: disable=too-many-branches + def ipfabric_dict_to_diffsync(self, obj: dict, mapping: SchemaMappingModel, model: IpfabricsyncModel) -> dict: # pylint: disable=too-many-branches data: dict[str, Any] = {"local_id": str(obj["id"])} for field in mapping.fields: # pylint: disable=too-many-nested-blocks diff --git a/infrahub_sync/adapters/librenms.py b/infrahub_sync/adapters/librenms.py index 4952234..81efed1 100644 --- a/infrahub_sync/adapters/librenms.py +++ b/infrahub_sync/adapters/librenms.py @@ -28,9 +28,7 @@ class LibrenmsAdapter(DiffSyncMixin, Adapter): type = "LibreNMS" - def __init__( - self, target: str, adapter: SyncAdapter, config: SyncConfig, *args, **kwargs - ) -> None: + def __init__(self, target: str, adapter: SyncAdapter, config: SyncConfig, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.target = target @@ -39,11 +37,7 @@ def __init__( def _create_rest_client(self, adapter: SyncAdapter) -> RestApiClient: settings = adapter.settings or {} - url = ( - os.environ.get("LIBRENMS_ADDRESS") - or os.environ.get("LIBRENMS_URL") - or settings.get("url") - ) + url = os.environ.get("LIBRENMS_ADDRESS") or os.environ.get("LIBRENMS_URL") or settings.get("url") api_endpoint = settings.get("api_endpoint", "/api/v0") auth_method = settings.get("auth_method", "x-auth-token") api_token = os.environ.get("LIBRENMS_TOKEN") or settings.get("token") @@ -91,16 +85,10 @@ def model_loader(self, model_name: str, model: LibrenmsModel) -> None: total = len(objs) if self.config.source.name.title() == self.type.title(): # Filter records - filtered_objs = model.filter_records( - records=objs, schema_mapping=element - ) - print( - f"{self.type}: Loading {len(filtered_objs)}/{total} {resource_name}" - ) + filtered_objs = model.filter_records(records=objs, schema_mapping=element) + print(f"{self.type}: Loading {len(filtered_objs)}/{total} {resource_name}") # Transform records - transformed_objs = model.transform_records( - records=filtered_objs, schema_mapping=element - ) + transformed_objs = model.transform_records(records=filtered_objs, schema_mapping=element) else: print(f"{self.type}: Loading all {total} {resource_name}") transformed_objs = objs @@ -111,9 +99,7 @@ def model_loader(self, model_name: str, model: LibrenmsModel) -> None: item = model(**data) self.add(item) - def obj_to_diffsync( - self, obj: dict[str, Any], mapping: SchemaMappingModel, model: LibrenmsModel - ) -> dict: + def obj_to_diffsync(self, obj: dict[str, Any], mapping: SchemaMappingModel, model: LibrenmsModel) -> dict: obj_id = derive_identifier_key(obj=obj) data: dict[str, Any] = {"local_id": str(obj_id)} @@ -144,9 +130,7 @@ def obj_to_diffsync( if isinstance(node, dict): matching_nodes = [] node_id = node.get("id", None) - matching_nodes = [ - item for item in nodes if item.local_id == str(node_id) - ] + matching_nodes = [item for item in nodes if item.local_id == str(node_id)] if len(matching_nodes) == 0: msg = f"Unable to locate the node {model} {node_id}" raise IndexError(msg) @@ -166,13 +150,9 @@ def obj_to_diffsync( node_id = node[1] if node[0] == "id" else None if not node_id: continue - matching_nodes = [ - item for item in nodes if item.local_id == str(node_id) - ] + matching_nodes = [item for item in nodes if item.local_id == str(node_id)] if len(matching_nodes) == 0: - msg = ( - f"Unable to locate the node {field.reference} {node_id}" - ) + msg = f"Unable to locate the node {field.reference} {node_id}" raise IndexError(msg) data[field.name].append(matching_nodes[0].get_unique_id()) data[field.name] = sorted(data[field.name]) diff --git a/infrahub_sync/adapters/nautobot.py b/infrahub_sync/adapters/nautobot.py index 533e1b5..3049c59 100644 --- a/infrahub_sync/adapters/nautobot.py +++ b/infrahub_sync/adapters/nautobot.py @@ -29,9 +29,7 @@ class NautobotAdapter(DiffSyncMixin, Adapter): type = "Nautobot" - def __init__( - self, target: str, adapter: SyncAdapter, config: SyncConfig, *args, **kwargs - ) -> None: + def __init__(self, target: str, adapter: SyncAdapter, config: SyncConfig, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.target = target @@ -40,20 +38,14 @@ def __init__( def _create_nautobot_client(self, adapter: SyncAdapter) -> pynautobot.api: settings = adapter.settings or {} - url = ( - os.environ.get("NAUTOBOT_ADDRESS") - or os.environ.get("NAUTOBOT_URL") - or settings.get("url") - ) + url = os.environ.get("NAUTOBOT_ADDRESS") or os.environ.get("NAUTOBOT_URL") or settings.get("url") token = os.environ.get("NAUTOBOT_TOKEN") or settings.get("token") if not url or not token: msg = "Both url and token must be specified!" raise ValueError(msg) - return pynautobot.api( - url, token=token, threading=True, max_workers=5, retries=3 - ) + return pynautobot.api(url, token=token, threading=True, max_workers=5, retries=3) def model_loader(self, model_name: str, model: NautobotModel) -> None: """ @@ -83,31 +75,21 @@ def model_loader(self, model_name: str, model: NautobotModel) -> None: total = len(list_obj) if self.config.source.name.title() == self.type.title(): # Filter records - filtered_objs = model.filter_records( - records=list_obj, schema_mapping=element - ) - print( - f"{self.type}: Loading {len(filtered_objs)}/{total} {resource_name}" - ) + filtered_objs = model.filter_records(records=list_obj, schema_mapping=element) + print(f"{self.type}: Loading {len(filtered_objs)}/{total} {resource_name}") # Transform records - transformed_objs = model.transform_records( - records=filtered_objs, schema_mapping=element - ) + transformed_objs = model.transform_records(records=filtered_objs, schema_mapping=element) else: print(f"{self.type}: Loading all {total} {resource_name}") transformed_objs = list_obj # Create model instances after filtering and transforming for obj in transformed_objs: - data = self.nautobot_obj_to_diffsync( - obj=obj, mapping=element, model=model - ) + data = self.nautobot_obj_to_diffsync(obj=obj, mapping=element, model=model) item = model(**data) self.add(item) - def nautobot_obj_to_diffsync( - self, obj: dict[str, Any], mapping: SchemaMappingModel, model: NautobotModel - ) -> dict: + def nautobot_obj_to_diffsync(self, obj: dict[str, Any], mapping: SchemaMappingModel, model: NautobotModel) -> dict: obj_id = obj.get("id") data: dict[str, Any] = {"local_id": str(obj_id)} @@ -138,14 +120,10 @@ def nautobot_obj_to_diffsync( matching_nodes = [] node_id = node.get("id", None) if node_id: - matching_nodes = [ - item for item in nodes if item.local_id == str(node_id) - ] + matching_nodes = [item for item in nodes if item.local_id == str(node_id)] if len(matching_nodes) == 0: # TODO: If the peer is a Node we are filtering, we could end up not finding it - print( - f"Unable to locate the node {field.name} {node_id}" - ) + print(f"Unable to locate the node {field.name} {node_id}") continue node = matching_nodes[0] data[field.name] = node.get_unique_id() @@ -160,9 +138,7 @@ def nautobot_obj_to_diffsync( node_id = node[1] if node[0] == "id" else None if not node_id: continue - matching_nodes = [ - item for item in nodes if item.local_id == str(node_id) - ] + matching_nodes = [item for item in nodes if item.local_id == str(node_id)] if len(matching_nodes) == 0: # TODO: If the peer is a Node we are filtering, we could end up not finding it print(f"Unable to locate the node {field.name} {node_id}") diff --git a/infrahub_sync/adapters/netbox.py b/infrahub_sync/adapters/netbox.py index 3c5880d..12e4182 100644 --- a/infrahub_sync/adapters/netbox.py +++ b/infrahub_sync/adapters/netbox.py @@ -28,9 +28,7 @@ class NetboxAdapter(DiffSyncMixin, Adapter): type = "Netbox" - def __init__( - self, target: str, adapter: SyncAdapter, config: SyncConfig, *args, **kwargs - ) -> None: + def __init__(self, target: str, adapter: SyncAdapter, config: SyncConfig, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.target = target @@ -39,11 +37,7 @@ def __init__( def _create_netbox_client(self, adapter: SyncAdapter) -> pynetbox.api: settings = adapter.settings or {} - url = ( - os.environ.get("NETBOX_ADDRESS") - or os.environ.get("NETBOX_URL") - or settings.get("url") - ) + url = os.environ.get("NETBOX_ADDRESS") or os.environ.get("NETBOX_URL") or settings.get("url") token = os.environ.get("NETBOX_TOKEN") or settings.get("token") if not url or not token: @@ -79,31 +73,21 @@ def model_loader(self, model_name: str, model: NetboxModel) -> None: total = len(list_obj) if self.config.source.name.title() == self.type.title(): # Filter records - filtered_objs = model.filter_records( - records=list_obj, schema_mapping=element - ) - print( - f"{self.type}: Loading {len(filtered_objs)}/{total} {resource_name}" - ) + filtered_objs = model.filter_records(records=list_obj, schema_mapping=element) + print(f"{self.type}: Loading {len(filtered_objs)}/{total} {resource_name}") # Transform records - transformed_objs = model.transform_records( - records=filtered_objs, schema_mapping=element - ) + transformed_objs = model.transform_records(records=filtered_objs, schema_mapping=element) else: print(f"{self.type}: Loading all {total} {resource_name}") transformed_objs = list_obj # Create model instances after filtering and transforming for obj in transformed_objs: - data = self.netbox_obj_to_diffsync( - obj=obj, mapping=element, model=model - ) + data = self.netbox_obj_to_diffsync(obj=obj, mapping=element, model=model) item = model(**data) self.add(item) - def netbox_obj_to_diffsync( - self, obj: dict[str, Any], mapping: SchemaMappingModel, model: NetboxModel - ) -> dict: + def netbox_obj_to_diffsync(self, obj: dict[str, Any], mapping: SchemaMappingModel, model: NetboxModel) -> dict: obj_id = obj.get("id") data: dict[str, Any] = {"local_id": str(obj_id)} @@ -133,13 +117,9 @@ def netbox_obj_to_diffsync( if isinstance(node, dict): matching_nodes = [] node_id = node.get("id", None) - matching_nodes = [ - item for item in nodes if item.local_id == str(node_id) - ] + matching_nodes = [item for item in nodes if item.local_id == str(node_id)] if len(matching_nodes) == 0: - msg = ( - f"Unable to locate the node {field.name} {node_id}" - ) + msg = f"Unable to locate the node {field.name} {node_id}" raise IndexError(msg) node = matching_nodes[0] data[field.name] = node.get_unique_id() @@ -155,13 +135,9 @@ def netbox_obj_to_diffsync( node_id = node[1] if node[0] == "id" else None if not node_id: continue - matching_nodes = [ - item for item in nodes if item.local_id == str(node_id) - ] + matching_nodes = [item for item in nodes if item.local_id == str(node_id)] if len(matching_nodes) == 0: - msg = ( - f"Unable to locate the node {field.reference} {node_id}" - ) + msg = f"Unable to locate the node {field.reference} {node_id}" raise IndexError(msg) data[field.name].append(matching_nodes[0].get_unique_id()) data[field.name] = sorted(data[field.name]) diff --git a/infrahub_sync/adapters/observium.py b/infrahub_sync/adapters/observium.py index 3361fe4..ec00e21 100644 --- a/infrahub_sync/adapters/observium.py +++ b/infrahub_sync/adapters/observium.py @@ -27,9 +27,7 @@ class ObserviumAdapter(DiffSyncMixin, Adapter): type = "Observium" - def __init__( - self, target: str, adapter: SyncAdapter, config: SyncConfig, *args, **kwargs - ) -> None: + def __init__(self, target: str, adapter: SyncAdapter, config: SyncConfig, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.target = target @@ -38,11 +36,7 @@ def __init__( def _create_rest_client(self, adapter: SyncAdapter) -> RestApiClient: settings = adapter.settings or {} - url = ( - os.environ.get("OBSERVIUM_ADDRESS") - or os.environ.get("OBSERVIUM_URL") - or settings.get("url") - ) + url = os.environ.get("OBSERVIUM_ADDRESS") or os.environ.get("OBSERVIUM_URL") or settings.get("url") api_endpoint = settings.get("api_endpoint", "/api/v0") auth_method = settings.get("auth_method", "basic") api_token = os.environ.get("OBSERVIUM_TOKEN") or settings.get("token") @@ -92,16 +86,10 @@ def model_loader(self, model_name: str, model: ObserviumModel) -> None: total = len(objs) if self.config.source.name.title() == self.type.title(): # Filter records - filtered_objs = model.filter_records( - records=objs, schema_mapping=element - ) - print( - f"{self.type}: Loading {len(filtered_objs)}/{total} {resource_name}" - ) + filtered_objs = model.filter_records(records=objs, schema_mapping=element) + print(f"{self.type}: Loading {len(filtered_objs)}/{total} {resource_name}") # Transform records - transformed_objs = model.transform_records( - records=filtered_objs, schema_mapping=element - ) + transformed_objs = model.transform_records(records=filtered_objs, schema_mapping=element) else: print(f"{self.type}: Loading all {total} {resource_name}") transformed_objs = objs @@ -112,9 +100,7 @@ def model_loader(self, model_name: str, model: ObserviumModel) -> None: item = model(**data) self.add(item) - def obj_to_diffsync( - self, obj: dict[str, Any], mapping: SchemaMappingModel, model: ObserviumModel - ) -> dict: + def obj_to_diffsync(self, obj: dict[str, Any], mapping: SchemaMappingModel, model: ObserviumModel) -> dict: obj_id = derive_identifier_key(obj=obj) data: dict[str, Any] = {"local_id": str(obj_id)} @@ -145,9 +131,7 @@ def obj_to_diffsync( if isinstance(node, dict): matching_nodes = [] node_id = node.get("id", None) - matching_nodes = [ - item for item in nodes if item.local_id == str(node_id) - ] + matching_nodes = [item for item in nodes if item.local_id == str(node_id)] if len(matching_nodes) == 0: msg = f"Unable to locate the node {model} {node_id}" raise IndexError(msg) @@ -167,13 +151,9 @@ def obj_to_diffsync( node_id = node[1] if node[0] == "id" else None if not node_id: continue - matching_nodes = [ - item for item in nodes if item.local_id == str(node_id) - ] + matching_nodes = [item for item in nodes if item.local_id == str(node_id)] if len(matching_nodes) == 0: - msg = ( - f"Unable to locate the node {field.reference} {node_id}" - ) + msg = f"Unable to locate the node {field.reference} {node_id}" raise IndexError(msg) data[field.name].append(matching_nodes[0].get_unique_id()) data[field.name] = sorted(data[field.name]) diff --git a/infrahub_sync/adapters/peeringmanager.py b/infrahub_sync/adapters/peeringmanager.py index bc7bdf6..5050a00 100644 --- a/infrahub_sync/adapters/peeringmanager.py +++ b/infrahub_sync/adapters/peeringmanager.py @@ -28,9 +28,7 @@ class PeeringmanagerAdapter(DiffSyncMixin, Adapter): type = "Peeringmanager" - def __init__( - self, target: str, adapter: SyncAdapter, config: SyncConfig, *args, **kwargs - ) -> None: + def __init__(self, target: str, adapter: SyncAdapter, config: SyncConfig, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.target = target @@ -39,14 +37,8 @@ def __init__( def _create_rest_client(self, adapter: SyncAdapter) -> RestApiClient: settings = adapter.settings or {} - url = ( - os.environ.get("PEERING_MANAGER_ADDRESS") - or os.environ.get("PEERING_MANAGER_URL") - or settings.get("url") - ) - api_endpoint = settings.get( - "api_endpoint", "/api" - ) # Default endpoint, change if necessary + url = os.environ.get("PEERING_MANAGER_ADDRESS") or os.environ.get("PEERING_MANAGER_URL") or settings.get("url") + api_endpoint = settings.get("api_endpoint", "/api") # Default endpoint, change if necessary auth_method = settings.get("auth_method", "token") api_token = os.environ.get("PEERING_MANAGER_TOKEN") or settings.get("token") timeout = settings.get("timeout", 30) @@ -93,16 +85,10 @@ def model_loader(self, model_name: str, model: PeeringmanagerModel) -> None: total = len(objs) if self.config.source.name.title() == self.type.title(): # Filter records - filtered_objs = model.filter_records( - records=objs, schema_mapping=element - ) - print( - f"{self.type}: Loading {len(filtered_objs)}/{total} {resource_name}" - ) + filtered_objs = model.filter_records(records=objs, schema_mapping=element) + print(f"{self.type}: Loading {len(filtered_objs)}/{total} {resource_name}") # Transform records - transformed_objs = model.transform_records( - records=filtered_objs, schema_mapping=element - ) + transformed_objs = model.transform_records(records=filtered_objs, schema_mapping=element) else: print(f"{self.type}: Loading all {total} {resource_name}") transformed_objs = objs @@ -148,13 +134,9 @@ def obj_to_diffsync( if isinstance(node, dict): matching_nodes = [] node_id = node.get("id", None) - matching_nodes = [ - item for item in nodes if item.local_id == str(node_id) - ] + matching_nodes = [item for item in nodes if item.local_id == str(node_id)] if len(matching_nodes) == 0: - msg = ( - f"Unable to locate the node {field.name} {node_id}" - ) + msg = f"Unable to locate the node {field.name} {node_id}" raise IndexError(msg) node = matching_nodes[0] data[field.name] = node.get_unique_id() @@ -170,13 +152,9 @@ def obj_to_diffsync( node_id = node[1] if node[0] == "id" else None if not node_id: continue - matching_nodes = [ - item for item in nodes if item.local_id == str(node_id) - ] + matching_nodes = [item for item in nodes if item.local_id == str(node_id)] if len(matching_nodes) == 0: - msg = ( - f"Unable to locate the node {field.reference} {node_id}" - ) + msg = f"Unable to locate the node {field.reference} {node_id}" raise IndexError(msg) data[field.name].append(matching_nodes[0].get_unique_id()) data[field.name] = sorted(data[field.name]) @@ -204,14 +182,10 @@ def update(self, attrs: dict) -> Self | None: to the API endpoint of the object. """ # Determine the resource name using the schema mapping - resource_name = self.__class__.get_resource_name( - schema_mapping=self.adapter.config.schema_mapping - ) + resource_name = self.__class__.get_resource_name(schema_mapping=self.adapter.config.schema_mapping) # Determine the unique identifier for the API request - unique_identifier = ( - self.local_id if hasattr(self, "local_id") else self.get_unique_id() - ) + unique_identifier = self.local_id if hasattr(self, "local_id") else self.get_unique_id() endpoint = f"{resource_name}/{unique_identifier}/" # Map incoming attributes to the target attributes based on schema mapping @@ -226,34 +200,22 @@ def update(self, attrs: dict) -> Self | None: # Check if the field is a relationship if field_mapping.reference: - all_nodes_for_reference = self.adapter.store.get_all( - model=field_mapping.reference - ) + all_nodes_for_reference = self.adapter.store.get_all(model=field_mapping.reference) if isinstance(value, list): # For lists, filter nodes to match the unique IDs in the attribute value filtered_nodes = [ - node - for node in all_nodes_for_reference - if node.get_unique_id() in value - ] - mapped_attrs[target_field_name] = [ - node.local_id for node in filtered_nodes + node for node in all_nodes_for_reference if node.get_unique_id() in value ] + mapped_attrs[target_field_name] = [node.local_id for node in filtered_nodes] else: # For single references, find the matching node filtered_node = next( - ( - node - for node in all_nodes_for_reference - if node.get_unique_id() == value - ), + (node for node in all_nodes_for_reference if node.get_unique_id() == value), None, ) if filtered_node: - mapped_attrs[target_field_name] = ( - filtered_node.local_id - ) + mapped_attrs[target_field_name] = filtered_node.local_id else: mapped_attrs[target_field_name] = value diff --git a/infrahub_sync/adapters/rest_api_client.py b/infrahub_sync/adapters/rest_api_client.py index 2957d7c..de315b0 100644 --- a/infrahub_sync/adapters/rest_api_client.py +++ b/infrahub_sync/adapters/rest_api_client.py @@ -80,9 +80,7 @@ def request( try: return response.json() except requests.exceptions.JSONDecodeError as exc: - print( - "Response content is not valid JSON:", response.text - ) # Print the response content + print("Response content is not valid JSON:", response.text) # Print the response content msg = "Response content is not valid JSON." raise ValueError(msg) from exc diff --git a/infrahub_sync/adapters/slurpitsync.py b/infrahub_sync/adapters/slurpitsync.py index e32a404..c4aed05 100644 --- a/infrahub_sync/adapters/slurpitsync.py +++ b/infrahub_sync/adapters/slurpitsync.py @@ -31,9 +31,7 @@ class SlurpitsyncAdapter(DiffSyncMixin, Adapter): type = "Slurpitsync" - def __init__( - self, target: str, adapter: SyncAdapter, config: SyncConfig, *args, **kwargs - ) -> None: + def __init__(self, target: str, adapter: SyncAdapter, config: SyncConfig, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.target = target self.client = self._create_slurpit_client(adapter=adapter) @@ -72,13 +70,8 @@ def unique_vendors(self) -> list[dict[str, Any]]: def unique_device_type(self) -> list[dict[str, Any]]: devices = self.run_async(self.client.device.get_devices()) - device_types = { - (device.brand, device.device_type, device.device_os) for device in devices - } - return [ - {"brand": item[0], "device_type": item[1], "device_os": item[2]} - for item in device_types - ] + device_types = {(device.brand, device.device_type, device.device_os) for device in devices} + return [{"brand": item[0], "device_type": item[1], "device_os": item[2]} for item in device_types] def filter_networks(self) -> list: """Filter out networks based on ignore prefixes and normalize network/mask fields.""" @@ -111,10 +104,7 @@ def should_ignore(network) -> bool: net = ipaddress.ip_network(network, strict=False) if net.prefixlen in {32, 128}: return True - return any( - net == ipaddress.ip_network(ignore, strict=False) - for ignore in ignore_prefixes - ) + return any(net == ipaddress.ip_network(ignore, strict=False) for ignore in ignore_prefixes) except ValueError: return False @@ -129,9 +119,7 @@ def should_ignore(network) -> bool: async def filter_interfaces(self, interfaces) -> list: precomputed_filtered_networks = [ { - "network": ipaddress.ip_network( - prefix["normalized_prefix"], strict=False - ), + "network": ipaddress.ip_network(prefix["normalized_prefix"], strict=False), "Vrf": prefix.get("Vrf", None), } for prefix in self.filtered_networks @@ -162,9 +150,7 @@ def normalize_and_find_prefix(entry): return entry # Concurrent execution of tasks - tasks = [ - normalize_and_find_prefix(entry) for entry in interfaces if entry.get("IP") - ] + tasks = [normalize_and_find_prefix(entry) for entry in interfaces if entry.get("IP")] # Run tasks concurrently filtered_interfaces = await asyncio.gather(*tasks) @@ -176,17 +162,13 @@ def normalize_and_find_prefix(entry): def planning_results(self, planning_name): plannings = self.run_async(self.client.planning.get_plannings()) - planning = next( - (plan.to_dict() for plan in plannings if plan.slug == planning_name), None - ) + planning = next((plan.to_dict() for plan in plannings if plan.slug == planning_name), None) if not planning: msg = f"No planning found for name: {planning_name}" raise IndexError(msg) search_data = {"planning_id": planning["id"], "unique_results": True} - results = self.run_async( - self.client.planning.search_plannings(search_data, limit=30000) - ) + results = self.run_async(self.client.planning.search_plannings(search_data, limit=30000)) return results or [] def model_loader(self, model_name: str, model: SlurpitsyncModel) -> None: @@ -219,24 +201,16 @@ def model_loader(self, model_name: str, model: SlurpitsyncModel) -> None: if self.config.source.name.title() == self.type.title(): # Filter records - filtered_objs = model.filter_records( - records=list_obj, schema_mapping=element - ) - print( - f"{self.type}: Loading {len(filtered_objs)}/{total} {element.mapping}" - ) + filtered_objs = model.filter_records(records=list_obj, schema_mapping=element) + print(f"{self.type}: Loading {len(filtered_objs)}/{total} {element.mapping}") # Transform records - transformed_objs = model.transform_records( - records=filtered_objs, schema_mapping=element - ) + transformed_objs = model.transform_records(records=filtered_objs, schema_mapping=element) else: print(f"{self.type}: Loading all {total} {resource_name}") transformed_objs = list_obj for obj in transformed_objs: - if data := self.slurpit_obj_to_diffsync( - obj=obj, mapping=element, model=model - ): + if data := self.slurpit_obj_to_diffsync(obj=obj, mapping=element, model=model): item = model(**data) try: # noqa: SIM105 self.add(item) @@ -249,17 +223,11 @@ def model_loader(self, model_name: str, model: SlurpitsyncModel) -> None: # Reuse mapping from another adapter def build_mapping(self, reference, obj): # Get object class and model name from the store - object_class, modelname = self.store._get_object_class_and_model( - model=reference - ) + object_class, modelname = self.store._get_object_class_and_model(model=reference) # Find the schema element matching the model name schema_element = next( - ( - element - for element in self.config.schema_mapping - if element.name == modelname - ), + (element for element in self.config.schema_mapping if element.name == modelname), None, ) @@ -309,9 +277,7 @@ def slurpit_obj_to_diffsync( if node := obj.get(field.mapping): matching_nodes = [] node_id = self.build_mapping(reference=field.reference, obj=obj) - matching_nodes = [ - item for item in nodes if str(item) == node_id - ] + matching_nodes = [item for item in nodes if str(item) == node_id] if len(matching_nodes) == 0: self.skipped.append(node) return None @@ -323,9 +289,7 @@ def slurpit_obj_to_diffsync( data[field.name] = [] if node := obj.get(field.mapping): node_id = self.build_mapping(reference=field.reference, obj=obj) - matching_nodes = [ - item for item in nodes if str(item) == node_id - ] + matching_nodes = [item for item in nodes if str(item) == node_id] if len(matching_nodes) == 0: self.skipped.append(node) continue diff --git a/infrahub_sync/adapters/utils.py b/infrahub_sync/adapters/utils.py index b09b671..e890be1 100644 --- a/infrahub_sync/adapters/utils.py +++ b/infrahub_sync/adapters/utils.py @@ -14,9 +14,7 @@ def get_value(obj: Any, name: str) -> Any | None: first_name, remaining_part = name.split(".", maxsplit=1) # Check if the object is a dictionary and use appropriate method to access the attribute. - sub_obj = ( - obj.get(first_name) if isinstance(obj, dict) else getattr(obj, first_name, None) - ) + sub_obj = obj.get(first_name) if isinstance(obj, dict) else getattr(obj, first_name, None) if not sub_obj: return None diff --git a/infrahub_sync/cli.py b/infrahub_sync/cli.py index 06cb0cd..8bd4549 100644 --- a/infrahub_sync/cli.py +++ b/infrahub_sync/cli.py @@ -32,45 +32,31 @@ def print_error_and_abort(message: str) -> typer.Abort: @app.command(name="list") def list_projects( - directory: str = typer.Option( - default=None, help="Base directory to search for sync configurations" - ), + directory: str = typer.Option(default=None, help="Base directory to search for sync configurations"), ) -> None: """List all available SYNC projects.""" for item in get_all_sync(directory=directory): - console.print( - f"{item.name} | {item.source.name} >> {item.destination.name} | {item.directory}" - ) + console.print(f"{item.name} | {item.source.name} >> {item.destination.name} | {item.directory}") @app.command(name="diff") def diff_cmd( name: str = typer.Option(default=None, help="Name of the sync to use"), - config_file: str = typer.Option( - default=None, help="File path to the sync configuration YAML file" - ), - directory: str = typer.Option( - default=None, help="Base directory to search for sync configurations" - ), + config_file: str = typer.Option(default=None, help="File path to the sync configuration YAML file"), + directory: str = typer.Option(default=None, help="Base directory to search for sync configurations"), branch: str = typer.Option(default=None, help="Branch to use for the diff."), - show_progress: bool = typer.Option( - default=True, help="Show a progress bar during diff" - ), + show_progress: bool = typer.Option(default=True, help="Show a progress bar during diff"), ) -> None: """Calculate and print the differences between the source and the destination systems for a given project.""" if sum([bool(name), bool(config_file)]) != 1: print_error_and_abort("Please specify exactly one of 'name' or 'config-file'.") - sync_instance = get_instance( - name=name, config_file=config_file, directory=directory - ) + sync_instance = get_instance(name=name, config_file=config_file, directory=directory) if not sync_instance: print_error_and_abort("Failed to load sync instance.") try: - ptd = get_potenda_from_instance( - sync_instance=sync_instance, branch=branch, show_progress=show_progress - ) + ptd = get_potenda_from_instance(sync_instance=sync_instance, branch=branch, show_progress=show_progress) except ValueError as exc: print_error_and_abort(f"Failed to initialize the Sync Instance: {exc}") try: @@ -87,35 +73,25 @@ def diff_cmd( @app.command(name="sync") def sync_cmd( name: str = typer.Option(default=None, help="Name of the sync to use"), - config_file: str = typer.Option( - default=None, help="File path to the sync configuration YAML file" - ), - directory: str = typer.Option( - default=None, help="Base directory to search for sync configurations" - ), + config_file: str = typer.Option(default=None, help="File path to the sync configuration YAML file"), + directory: str = typer.Option(default=None, help="Base directory to search for sync configurations"), branch: str = typer.Option(default=None, help="Branch to use for the sync."), diff: bool = typer.Option( default=True, help="Print the differences between the source and the destination before syncing", ), - show_progress: bool = typer.Option( - default=True, help="Show a progress bar during syncing" - ), + show_progress: bool = typer.Option(default=True, help="Show a progress bar during syncing"), ) -> None: """Synchronize the data between source and the destination systems for a given project or configuration file.""" if sum([bool(name), bool(config_file)]) != 1: print_error_and_abort("Please specify exactly one of 'name' or 'config-file'.") - sync_instance = get_instance( - name=name, config_file=config_file, directory=directory - ) + sync_instance = get_instance(name=name, config_file=config_file, directory=directory) if not sync_instance: print_error_and_abort("Failed to load sync instance.") try: - ptd = get_potenda_from_instance( - sync_instance=sync_instance, branch=branch, show_progress=show_progress - ) + ptd = get_potenda_from_instance(sync_instance=sync_instance, branch=branch, show_progress=show_progress) except ValueError as exc: print_error_and_abort(f"Failed to initialize the Sync Instance: {exc}") try: @@ -140,12 +116,8 @@ def sync_cmd( @app.command(name="generate") def generate( name: str = typer.Option(default=None, help="Name of the sync to use"), - config_file: str = typer.Option( - default=None, help="File path to the sync configuration YAML file" - ), - directory: str = typer.Option( - default=None, help="Base directory to search for sync configurations" - ), + config_file: str = typer.Option(default=None, help="File path to the sync configuration YAML file"), + directory: str = typer.Option(default=None, help="Base directory to search for sync configurations"), branch: str = typer.Option(default=None, help="Branch to use for the sync."), ) -> None: """Generate all the python files for a given sync based on the configuration.""" @@ -153,32 +125,21 @@ def generate( if sum([bool(name), bool(config_file)]) != 1: print_error_and_abort("Please specify exactly one of 'name' or 'config_file'.") - sync_instance: SyncInstance = get_instance( - name=name, config_file=config_file, directory=directory - ) + sync_instance: SyncInstance = get_instance(name=name, config_file=config_file, directory=directory) if not sync_instance: - print_error_and_abort( - f"Unable to find the sync {name}. Use the list command to see the sync available" - ) + print_error_and_abort(f"Unable to find the sync {name}. Use the list command to see the sync available") # Check if the destination is infrahub infrahub_address = "" # Determine if infrahub is in source or destination # We are using the destination as the "constraint", if there is 2 infrahubs instance sdk_config = None - if ( - sync_instance.destination.name == "infrahub" - and sync_instance.destination.settings - ): + if sync_instance.destination.name == "infrahub" and sync_instance.destination.settings: infrahub_address = sync_instance.destination.settings.get("url") or "" - sdk_config = get_infrahub_config( - settings=sync_instance.destination.settings, branch=branch - ) + sdk_config = get_infrahub_config(settings=sync_instance.destination.settings, branch=branch) elif sync_instance.source.name == "infrahub" and sync_instance.source.settings: infrahub_address = sync_instance.source.settings.get("url") or "" - sdk_config = get_infrahub_config( - settings=sync_instance.source.settings, branch=branch - ) + sdk_config = get_infrahub_config(settings=sync_instance.source.settings, branch=branch) # Initialize InfrahubClientSync if address and config are available client = InfrahubClientSync(address=infrahub_address, config=sdk_config) @@ -188,13 +149,9 @@ def generate( except ServerNotResponsiveError as exc: print_error_and_abort(str(exc)) - missing_schema_models = find_missing_schema_model( - sync_instance=sync_instance, schema=schema - ) + missing_schema_models = find_missing_schema_model(sync_instance=sync_instance, schema=schema) if missing_schema_models: - print_error_and_abort( - f"One or more model model are not present in the Schema - {missing_schema_models}" - ) + print_error_and_abort(f"One or more model model are not present in the Schema - {missing_schema_models}") rendered_files = render_adapter(sync_instance=sync_instance, schema=schema) for template, output_path in rendered_files: diff --git a/infrahub_sync/generator/__init__.py b/infrahub_sync/generator/__init__.py index 304616a..e234d63 100644 --- a/infrahub_sync/generator/__init__.py +++ b/infrahub_sync/generator/__init__.py @@ -62,18 +62,14 @@ def get_identifiers(node: NodeSchema, config: SyncConfig) -> list[str] | None: """Return the identifiers that should be used by DiffSync.""" config_identifiers = [ - item.identifiers - for item in config.schema_mapping - if item.name == node.kind and item.identifiers + item.identifiers for item in config.schema_mapping if item.name == node.kind and item.identifiers ] if config_identifiers: return config_identifiers[0] identifiers = [ - attr.name - for attr in node.attributes - if attr.unique and has_field(config, name=node.kind, field=attr.name) + attr.name for attr in node.attributes if attr.unique and has_field(config, name=node.kind, field=attr.name) ] if not identifiers: @@ -84,25 +80,18 @@ def get_identifiers(node: NodeSchema, config: SyncConfig) -> list[str] | None: def get_attributes(node: NodeSchema, config: SyncConfig) -> list[str] | None: """Return the attributes that should be used by DiffSync.""" - attrs_attributes = [ - attr.name - for attr in node.attributes - if has_field(config, name=node.kind, field=attr.name) - ] + attrs_attributes = [attr.name for attr in node.attributes if has_field(config, name=node.kind, field=attr.name)] rels_identifiers = [ rel.name for rel in node.relationships - if rel.kind != RelationshipKind.COMPONENT - and has_field(config, name=node.kind, field=rel.name) + if rel.kind != RelationshipKind.COMPONENT and has_field(config, name=node.kind, field=rel.name) ] identifiers = get_identifiers(node=node, config=config) if not identifiers: return None - attributes = [ - item for item in rels_identifiers + attrs_attributes if item not in identifiers - ] + attributes = [item for item in rels_identifiers + attrs_attributes if item not in identifiers] if not attributes: return None @@ -161,9 +150,7 @@ def has_children(node: NodeSchema, config: SyncConfig) -> bool: return bool(get_children(config=config, node=node)) -def render_template( - template_file: Path, output_dir: Path, output_file: Path, context: dict[str, Any] -) -> None: +def render_template(template_file: Path, output_dir: Path, output_file: Path, context: dict[str, Any]) -> None: template_loader = jinja2.PackageLoader("infrahub_sync", "generator/templates") template_env = jinja2.Environment( loader=template_loader, diff --git a/infrahub_sync/potenda/__init__.py b/infrahub_sync/potenda/__init__.py index a6a3903..8d7452c 100644 --- a/infrahub_sync/potenda/__init__.py +++ b/infrahub_sync/potenda/__init__.py @@ -79,15 +79,9 @@ def load(self): def diff(self) -> Diff: print(f"Diff: Comparing data from {self.source} to {self.destination}") self.progress_bar = None - return self.destination.diff_from( - self.source, flags=self.flags, callback=self._print_callback - ) + return self.destination.diff_from(self.source, flags=self.flags, callback=self._print_callback) def sync(self, diff: Diff | None = None): - print( - f"Sync: Importing data from {self.source} to {self.destination} based on Diff" - ) + print(f"Sync: Importing data from {self.source} to {self.destination} based on Diff") self.progress_bar = None - return self.destination.sync_from( - self.source, diff=diff, flags=self.flags, callback=self._print_callback - ) + return self.destination.sync_from(self.source, diff=diff, flags=self.flags, callback=self._print_callback) diff --git a/infrahub_sync/utils.py b/infrahub_sync/utils.py index 30ef912..9eaf1d0 100644 --- a/infrahub_sync/utils.py +++ b/infrahub_sync/utils.py @@ -72,9 +72,7 @@ def import_adapter(sync_instance: SyncInstance, adapter: SyncAdapter): try: adapter_name = f"{adapter.name.title()}Sync" - spec = importlib.util.spec_from_file_location( - f"{adapter.name}.adapter", str(adapter_file_path) - ) + spec = importlib.util.spec_from_file_location(f"{adapter.name}.adapter", str(adapter_file_path)) adapter_module = importlib.util.module_from_spec(spec) sys.modules[f"{adapter.name}.adapter"] = adapter_module spec.loader.exec_module(adapter_module) @@ -144,30 +142,22 @@ def get_potenda_from_instance( show_progress: bool | None = True, ) -> Potenda: source = import_adapter(sync_instance=sync_instance, adapter=sync_instance.source) - destination = import_adapter( - sync_instance=sync_instance, adapter=sync_instance.destination - ) + destination = import_adapter(sync_instance=sync_instance, adapter=sync_instance.destination) source_store = LocalStore() destination_store = LocalStore() if sync_instance.store and sync_instance.store.type == "redis": - if sync_instance.store.settings and isinstance( - sync_instance.store.settings, dict - ): + if sync_instance.store.settings and isinstance(sync_instance.store.settings, dict): redis_settings = sync_instance.store.settings source_store = RedisStore(**redis_settings, name=sync_instance.source.name) - destination_store = RedisStore( - **redis_settings, name=sync_instance.destination.name - ) + destination_store = RedisStore(**redis_settings, name=sync_instance.destination.name) else: source_store = RedisStore(name=sync_instance.source.name) destination_store = RedisStore(name=sync_instance.destination.name) try: if sync_instance.source.name == "infrahub": - settings_branch = ( - sync_instance.source.settings.get("branch") or branch or "main" - ) + settings_branch = sync_instance.source.settings.get("branch") or branch or "main" src: SyncInstance = source( config=sync_instance, target="source", @@ -187,9 +177,7 @@ def get_potenda_from_instance( raise ValueError(msg) from exc try: if sync_instance.destination.name == "infrahub": - settings_branch = ( - sync_instance.source.settings.get("branch") or branch or "main" - ) + settings_branch = sync_instance.source.settings.get("branch") or branch or "main" dst: SyncInstance = destination( config=sync_instance, target="destination", diff --git a/tasks/docs.py b/tasks/docs.py index 447017a..6e7f89e 100644 --- a/tasks/docs.py +++ b/tasks/docs.py @@ -28,9 +28,7 @@ def _generate_infrahubsync_documentation(context: Context) -> None: @task def markdownlint(context: Context) -> None: - has_markdownlint = check_if_command_available( - context=context, command_name="markdownlint-cli2" - ) + has_markdownlint = check_if_command_available(context=context, command_name="markdownlint-cli2") if not has_markdownlint: print("Warning, markdownlint-cli2 is not installed")