diff --git a/abcd/__init__.py b/abcd/__init__.py index 994afcd8..081e868c 100644 --- a/abcd/__init__.py +++ b/abcd/__init__.py @@ -40,7 +40,8 @@ def from_url(cls, url, **kwargs): return MongoDatabase(db_name=db, **conn_settings, **kwargs) - if r.scheme == "opensearch": + r.scheme = ConnectionType[r.scheme] + if r.scheme is ConnectionType.opensearch: conn_settings = { "host": r.hostname, "port": r.port, diff --git a/abcd/backends/atoms_opensearch.py b/abcd/backends/atoms_opensearch.py index bb1eb1af..78af1b08 100644 --- a/abcd/backends/atoms_opensearch.py +++ b/abcd/backends/atoms_opensearch.py @@ -196,7 +196,7 @@ def save(self): if not self._id: self._client.index(index=self._index_name, body=body) else: - body.pop("_id", None) + del body["_id"] body = {"doc": body} self._client.update(index=self._index_name, id=self._id, body=body) @@ -281,11 +281,11 @@ def __init__( info = self.client.info() logger.info("DB info: %s", info) - except AuthenticationException: - raise abcd.errors.AuthenticationError() + except AuthenticationException as err: + raise abcd.errors.AuthenticationError() from err - except ConnectionTimeout: - raise abcd.errors.TimeoutError() + except ConnectionTimeout as err: + raise abcd.errors.TimeoutError() from err self.db = db_name self.index_name = index_name @@ -677,9 +677,9 @@ def count_property(self, name, query: Union[dict, str, None] = None) -> dict: prop = {} - for val in self.client.search(index=self.index_name, body=body,)[ - "aggregations" - ][format(name)]["buckets"]: + for val in self.client.search( + index=self.index_name, body=body + )["aggregations"][format(name)]["buckets"]: prop[val["key"]] = val["doc_count"] return prop @@ -728,8 +728,7 @@ def properties(self, query: Union[dict, str, None] = None) -> dict: body=body, ) - derived = ["info_keys", "derived_keys", "arrays_keys"] - for label in derived: + for label in ("info_keys", "derived_keys", "arrays_keys"): count = res["aggregations"][label]["doc_count"] if count > 0: key = label.split("_", maxsplit=1)[0] @@ -824,8 +823,7 @@ def count_properties(self, query: Union[dict, str, None] = None) -> dict: body=body, ) - derived = ["info_keys", "derived_keys", "arrays_keys"] - for label in derived: + for label in ("info_keys", "derived_keys", "arrays_keys"): count = res["aggregations"][label]["doc_count"] if count > 0: properties[key] = { @@ -980,9 +978,9 @@ def __repr__(self): host, port = None, None return ( - "{}(".format(self.__class__.__name__) - + "url={}:{}, ".format(host, port) - + "index={}) ".format(self.index_name) + f"{self.__class__.__name__}(" + f"url={host}:{port}, " + f"index={self.index_name}) " ) def _repr_html_(self): diff --git a/abcd/backends/atoms_properties.py b/abcd/backends/atoms_properties.py index 5503c398..5ce1ccb2 100644 --- a/abcd/backends/atoms_properties.py +++ b/abcd/backends/atoms_properties.py @@ -95,7 +95,7 @@ def __init__( self.df.replace({np.nan: None}, inplace=True) if units is not None: - for key in units.keys(): + for key in units: if key not in self.df.columns.values: raise ValueError( f"Invalid field name: {key}. Keys in `units` must " @@ -109,20 +109,19 @@ def __init__( self.store_struct_file = store_struct_file if self.store_struct_file: - if struct_file_template is not None: - self.struct_file_template = struct_file_template - else: + if struct_file_template is None: raise ValueError( "`struct_file_template` must be specified if " "store_struct_file is True." ) - if struct_name_label is not None: - self.struct_name_label = struct_name_label - else: + self.struct_file_template = struct_file_template + + if struct_name_label is None: raise ValueError( "`struct_name_label` must be specified if store_struct_file is" " True." ) + self.struct_name_label = struct_name_label self.set_struct_files() def _separate_units(self):