diff --git a/moto/rds/models.py b/moto/rds/models.py index ce14152bd3d5..af34de951297 100644 --- a/moto/rds/models.py +++ b/moto/rds/models.py @@ -1649,9 +1649,21 @@ def __init__(self, region_name: str, account_id: str): self.subnet_groups: Dict[str, DBSubnetGroup] = {} self._db_cluster_options: Optional[List[Dict[str, Any]]] = None self.db_proxies: Dict[str, DBProxy] = OrderedDict() - - def reset(self) -> None: - super().reset() + self.resource_map = { + DBCluster: self.clusters, + DBClusterParameterGroup: self.db_cluster_parameter_groups, + DBClusterSnapshot: self.cluster_snapshots, + DBInstance: self.databases, + DBParameterGroup: self.db_parameter_groups, + DBProxy: self.db_proxies, + DBSecurityGroup: self.security_groups, + DBSnapshot: self.database_snapshots, + DBSubnetGroup: self.subnet_groups, + EventSubscription: self.event_subscriptions, + ExportTask: self.export_tasks, + GlobalCluster: self.global_clusters, + OptionGroup: self.option_groups, + } @lru_cache() def db_cluster_options(self, engine) -> List[Dict[str, Any]]: # type: ignore @@ -2533,92 +2545,33 @@ def describe_event_subscriptions( raise SubscriptionNotFoundError(subscription_name) return self.event_subscriptions.values() + def _find_resource(self, resource_type: str, resource_name: str) -> Any: + for resource_class, resources in self.resource_map.items(): + if resource_type == getattr(resource_class, "resource_type", ""): + if resource_name in resources: # type: ignore + return resources[resource_name] # type: ignore + def list_tags_for_resource(self, arn: str) -> List[Dict[str, str]]: if self.arn_regex.match(arn): arn_breakdown = arn.split(":") resource_type = arn_breakdown[len(arn_breakdown) - 2] resource_name = arn_breakdown[len(arn_breakdown) - 1] - if resource_type == "db": # Database - if resource_name in self.databases: - return self.databases[resource_name].get_tags() - elif resource_type == "cluster": # Cluster - if resource_name in self.clusters: - return self.clusters[resource_name].get_tags() - elif resource_type == "es": # Event Subscription - if resource_name in self.event_subscriptions: - return self.event_subscriptions[resource_name].get_tags() - elif resource_type == "og": # Option Group - if resource_name in self.option_groups: - return self.option_groups[resource_name].get_tags() - elif resource_type == "pg": # Parameter Group - if resource_name in self.db_parameter_groups: - return self.db_parameter_groups[resource_name].get_tags() - elif resource_type == "ri": # Reserved DB instance - # TODO: Complete call to tags on resource type Reserved DB - # instance - return [] - elif resource_type == "secgrp": # DB security group - if resource_name in self.security_groups: - return self.security_groups[resource_name].get_tags() - elif resource_type == "snapshot": # DB Snapshot - if resource_name in self.database_snapshots: - return self.database_snapshots[resource_name].get_tags() - elif resource_type == "cluster-snapshot": # DB Cluster Snapshot - if resource_name in self.cluster_snapshots: - return self.cluster_snapshots[resource_name].get_tags() - elif resource_type == "subgrp": # DB subnet group - if resource_name in self.subnet_groups: - return self.subnet_groups[resource_name].get_tags() - elif resource_type == "db-proxy": # DB Proxy - if resource_name in self.db_proxies: - return self.db_proxies[resource_name].get_tags() - else: - raise RDSClientError( - "InvalidParameterValue", f"Invalid resource name: {arn}" - ) - return [] + resource = self._find_resource(resource_type, resource_name) + if resource: + return resource.get_tags() + return [] + raise RDSClientError("InvalidParameterValue", f"Invalid resource name: {arn}") def remove_tags_from_resource(self, arn: str, tag_keys: List[str]) -> None: if self.arn_regex.match(arn): arn_breakdown = arn.split(":") resource_type = arn_breakdown[len(arn_breakdown) - 2] resource_name = arn_breakdown[len(arn_breakdown) - 1] - if resource_type == "db": # Database - if resource_name in self.databases: - self.databases[resource_name].remove_tags(tag_keys) - elif resource_type == "es": # Event Subscription - if resource_name in self.event_subscriptions: - self.event_subscriptions[resource_name].remove_tags(tag_keys) - elif resource_type == "og": # Option Group - if resource_name in self.option_groups: - self.option_groups[resource_name].remove_tags(tag_keys) - elif resource_type == "pg": # Parameter Group - if resource_name in self.db_parameter_groups: - self.db_parameter_groups[resource_name].remove_tags(tag_keys) - elif resource_type == "ri": # Reserved DB instance - return None - elif resource_type == "secgrp": # DB security group - if resource_name in self.security_groups: - self.security_groups[resource_name].remove_tags(tag_keys) - elif resource_type == "snapshot": # DB Snapshot - if resource_name in self.database_snapshots: - self.database_snapshots[resource_name].remove_tags(tag_keys) - elif resource_type == "cluster": - if resource_name in self.clusters: - self.clusters[resource_name].remove_tags(tag_keys) - elif resource_type == "cluster-snapshot": # DB Cluster Snapshot - if resource_name in self.cluster_snapshots: - self.cluster_snapshots[resource_name].remove_tags(tag_keys) - elif resource_type == "subgrp": # DB subnet group - if resource_name in self.subnet_groups: - self.subnet_groups[resource_name].remove_tags(tag_keys) - elif resource_type == "db-proxy": # DB Proxy - if resource_name in self.db_proxies: - self.db_proxies[resource_name].remove_tags(tag_keys) - else: - raise RDSClientError( - "InvalidParameterValue", f"Invalid resource name: {arn}" - ) + resource = self._find_resource(resource_type, resource_name) + if resource: + resource.remove_tags(tag_keys) + return + raise RDSClientError("InvalidParameterValue", f"Invalid resource name: {arn}") def add_tags_to_resource( # type: ignore[return] self, arn: str, tags: List[Dict[str, str]] @@ -2627,42 +2580,11 @@ def add_tags_to_resource( # type: ignore[return] arn_breakdown = arn.split(":") resource_type = arn_breakdown[-2] resource_name = arn_breakdown[-1] - if resource_type == "db": # Database - if resource_name in self.databases: - return self.databases[resource_name].add_tags(tags) - elif resource_type == "es": # Event Subscription - if resource_name in self.event_subscriptions: - return self.event_subscriptions[resource_name].add_tags(tags) - elif resource_type == "og": # Option Group - if resource_name in self.option_groups: - return self.option_groups[resource_name].add_tags(tags) - elif resource_type == "pg": # Parameter Group - if resource_name in self.db_parameter_groups: - return self.db_parameter_groups[resource_name].add_tags(tags) - elif resource_type == "ri": # Reserved DB instance - return [] - elif resource_type == "secgrp": # DB security group - if resource_name in self.security_groups: - return self.security_groups[resource_name].add_tags(tags) - elif resource_type == "snapshot": # DB Snapshot - if resource_name in self.database_snapshots: - return self.database_snapshots[resource_name].add_tags(tags) - elif resource_type == "cluster": - if resource_name in self.clusters: - return self.clusters[resource_name].add_tags(tags) - elif resource_type == "cluster-snapshot": # DB Cluster Snapshot - if resource_name in self.cluster_snapshots: - return self.cluster_snapshots[resource_name].add_tags(tags) - elif resource_type == "subgrp": # DB subnet group - if resource_name in self.subnet_groups: - return self.subnet_groups[resource_name].add_tags(tags) - elif resource_type == "db-proxy": # DB Proxy - if resource_name in self.db_proxies: - return self.db_proxies[resource_name].add_tags(tags) - else: - raise RDSClientError( - "InvalidParameterValue", f"Invalid resource name: {arn}" - ) + resource = self._find_resource(resource_type, resource_name) + if resource: + return resource.add_tags(tags) + return [] + raise RDSClientError("InvalidParameterValue", f"Invalid resource name: {arn}") @staticmethod def _filter_resources(resources: Any, filters: Any, resource_class: Any) -> Any: # type: ignore[misc]