From 065398348427920c8bfb96473686dc7cda21ad68 Mon Sep 17 00:00:00 2001 From: ElliottKasoar <45317199+ElliottKasoar@users.noreply.github.com> Date: Tue, 11 Jun 2024 14:25:31 +0100 Subject: [PATCH] Tidy code --- abcd/backends/atoms_opensearch.py | 60 ++++++++++++++------------ abcd/backends/atoms_properties.py | 7 +-- abcd/backends/utils.py | 2 +- abcd/frontends/commandline/commands.py | 4 +- 4 files changed, 36 insertions(+), 37 deletions(-) diff --git a/abcd/backends/atoms_opensearch.py b/abcd/backends/atoms_opensearch.py index eb51b91e..c535bd37 100644 --- a/abcd/backends/atoms_opensearch.py +++ b/abcd/backends/atoms_opensearch.py @@ -1,6 +1,6 @@ from __future__ import annotations -from collections.abc import Generator +from collections.abc import Iterator from datetime import datetime from typing import Iterable, Optional, Union import logging @@ -40,20 +40,20 @@ class OpenSearchQuery(AbstractQuerySet): """Class to parse and build queries for OpenSearch.""" - def __call__(self, query: Union[dict, str, list, None]) -> Union[dict, None]: + def __call__(self, query: Optional[Union[dict, str, list]]) -> Optional[dict]: """ Parses and builds queries for OpenSearch. Parameters ---------- - query: Union[dict, str, list, None] + query: Optional[Union[dict, str, list]] Query to be parsed for OpenSearch. If passed as a dictionary, the query is left unchanged. If passed a string or list, the query is treated as a query string, based on Lucene query syntax. Returns ------- - Union[dict, None] + Optional[dict] The parsed query for OpenSearch. """ if not query: @@ -327,6 +327,7 @@ def delete(self, query: Optional[Union[dict, str]] = None): Query to filter documents to be deleted. Default is `None`. """ query = self.parser(query) + logger.info("parsed query: %s", query) body = {"query": query} self.client.delete_by_query( @@ -354,7 +355,7 @@ def refresh(self): """ self.client.indices.refresh(index=self.index_name) - def save_bulk(self, actions: Iterable, **kwargs): + def save_bulk(self, actions: Iterable[dict], **kwargs): """ Save a collection of documents in bulk. @@ -410,7 +411,7 @@ def push( ) data.save() - elif isinstance(atoms, Generator) or isinstance(atoms, list): + elif isinstance(atoms, Iterator) or isinstance(atoms, list): actions = [] for i, item in enumerate(atoms): if isinstance(extra_info, list): @@ -431,7 +432,7 @@ def push( def upload( self, file: Path, - extra_infos: Optional[Union[Iterable, dict]] = None, + extra_infos: Union[Iterable, dict] = (), store_calc: bool = True, ): """ @@ -441,9 +442,9 @@ def upload( ---------- file: Path Path to file to be uploaded - extra_infos: Optional[Union[Iterable, dict]] + extra_infos: Union[Iterable, dict] Extra information to store in the document with the atoms data. - Default is `None`. + Default is `()`. store_calc: bool, optional Whether to store data from the calculator attached to atoms. Default is `True`. @@ -452,19 +453,14 @@ def upload( if isinstance(file, str): file = Path(file) - extra_info = {} - if extra_infos: - for info in extra_infos: - extra_info.update(extras.parser.parse(info)) # type: ignore + extra_info = dict(map(extras.parser.parse, extra_infos)) extra_info["filename"] = str(file) data = iread(str(file)) self.push(data, extra_info, store_calc=store_calc) - def get_items( - self, query: Optional[Union[dict, str]] = None - ) -> Generator[dict, None, None]: + def get_items(self, query: Optional[Union[dict, str]] = None) -> Iterator[dict]: """ Get data as a dictionary from documents in the database. @@ -475,9 +471,11 @@ def get_items( Returns ------- - Generator for dictionary of data. + Iterator[dict] + Iterator for dictionary of data. """ query = self.parser(query) + logger.info("parsed query: %s", query) query = { "query": query, } @@ -489,9 +487,7 @@ def get_items( ): yield {"_id": hit["_id"], **hit["_source"]} - def get_atoms( - self, query: Optional[Union[dict, str]] = None - ) -> Generator[Atoms, None, None]: + def get_atoms(self, query: Optional[Union[dict, str]] = None) -> Iterator[Atoms]: """ Get data as Atoms object from documents in the database. @@ -502,9 +498,11 @@ def get_atoms( Returns ------- - Generator for AtomsModel object of data. + Iterator[Atoms] + Generator for AtomsModel object of data. """ query = self.parser(query) + logger.info("parsed query: %s", query) query = { "query": query, } @@ -514,7 +512,7 @@ def get_atoms( index=self.index_name, query=query, ): - yield AtomsModel(None, None, hit["_source"]).to_ase() + yield AtomsModel(dict=hit["_source"]).to_ase() def count(self, query: Optional[Union[dict, str]] = None, timeout=30.0) -> int: """ @@ -603,6 +601,7 @@ def property( if only one property is given. """ query = self.parser(query) + logger.info("parsed query: %s", query) query = { "query": query, } @@ -662,6 +661,7 @@ def count_property(self, name, query: Optional[Union[dict, str]] = None) -> dict matching documents. """ query = self.parser(query) + logger.info("parsed query: %s", query) body = { "size": 0, @@ -678,9 +678,9 @@ def count_property(self, name, query: Optional[Union[dict, str]] = None) -> dict prop = {} - for val in self.client.search( - index=self.index_name, body=body - )["aggregations"][format(name)]["buckets"]: + for val in self.client.search(index=self.index_name, body=body)["aggregations"][ + format(name) + ]["buckets"]: prop[val["key"]] = val["doc_count"] return prop @@ -702,6 +702,7 @@ def properties(self, query: Optional[Union[dict, str]] = None) -> dict: the properties of that type. """ query = self.parser(query) + logger.info("parsed query: %s", query) properties = {} @@ -793,6 +794,7 @@ def count_properties(self, query: Optional[Union[dict, str]] = None) -> dict: corresponding to their counts, categories and data types. """ query = self.parser(query) + logger.info("parsed query: %s", query) properties = {} try: @@ -942,7 +944,7 @@ def delete_property(self, name: str, query: Optional[Union[dict, str]] = None): def hist( self, name: str, query: Optional[Union[dict, str]] = None, **kwargs - ) -> Union[dict, None]: + ) -> Optional[dict]: """ Calculate histogram statistics for a property from all matching documents. @@ -955,10 +957,12 @@ def hist( Returns ------- - Dictionary containing histogram statistics, including the number of - bins, edges, counts, min, max, and standard deviation. + Optional[dict] + Dictionary containing histogram statistics, including the number of + bins, edges, counts, min, max, and standard deviation. """ query = self.parser(query) + logger.info("parsed query: %s", query) data = self.property(name, query) return utils.histogram(name, data, **kwargs) diff --git a/abcd/backends/atoms_properties.py b/abcd/backends/atoms_properties.py index 5ce1ccb2..7372e797 100644 --- a/abcd/backends/atoms_properties.py +++ b/abcd/backends/atoms_properties.py @@ -121,7 +121,7 @@ def __init__( "`struct_name_label` must be specified if store_struct_file is" " True." ) - self.struct_name_label = struct_name_label + self.struct_name_label = struct_name_label self.set_struct_files() def _separate_units(self): @@ -175,15 +175,12 @@ def get_struct_file(self, struct_name: str) -> str: ------- Filename for the current structure. """ - if struct_name is None: - raise ValueError("`struct_name` must be specified") if "{struct_name}" not in self.struct_file_template: raise ValueError( "'struct_name' must be a variable in the template file: " f"{self.struct_file_template}" ) - else: - return eval(f"f'{self.struct_file_template}'") + return eval(f"f'{self.struct_file_template}'") def to_list(self) -> list[dict]: """ diff --git a/abcd/backends/utils.py b/abcd/backends/utils.py index dd6f18aa..e55471eb 100644 --- a/abcd/backends/utils.py +++ b/abcd/backends/utils.py @@ -12,7 +12,7 @@ def histogram(name, data, **kwargs): if not data: return None - if data and isinstance(data, list): + if isinstance(data, list): ptype = type(data[0]) if not all(isinstance(x, ptype) for x in data): diff --git a/abcd/frontends/commandline/commands.py b/abcd/frontends/commandline/commands.py index 36dc8415..6c3e0edf 100644 --- a/abcd/frontends/commandline/commands.py +++ b/abcd/frontends/commandline/commands.py @@ -267,9 +267,7 @@ def server(*, abcd_url, url, api_only, **kwargs): from urllib.parse import urlparse from abcd.server.app import create_app - logger.info( - "SERVER - abcd: %s, url: %s, api_only: %s", abcd_url, url, api_only - ) + logger.info("SERVER - abcd: %s, url: %s, api_only: %s", abcd_url, url, api_only) if api_only: print("Not implemented yet!")