Skip to content

Commit

Permalink
Apply ruff format and fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ElliottKasoar committed Nov 13, 2024
1 parent 1fc6b81 commit c06ef3c
Show file tree
Hide file tree
Showing 31 changed files with 296 additions and 384 deletions.
18 changes: 8 additions & 10 deletions abcd/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from enum import Enum
import logging
from urllib import parse
from enum import Enum

logger = logging.getLogger(__name__)

Expand All @@ -10,7 +10,7 @@ class ConnectionType(Enum):
http = 2


class ABCD(object):
class ABCD:
@classmethod
def from_config(cls, config):
# Factory method
Expand All @@ -24,7 +24,6 @@ def from_url(cls, url, **kwargs):
logger.info(r)

if r.scheme == "mongodb":

conn_settings = {
"host": r.hostname,
"port": r.port,
Expand All @@ -39,20 +38,19 @@ def from_url(cls, url, **kwargs):
from abcd.backends.atoms_pymongo import MongoDatabase

return MongoDatabase(db_name=db, **conn_settings, **kwargs)
elif r.scheme == "mongodb+srv":
if r.scheme == "mongodb+srv":
db = r.path.split("/")[1] if r.path else None
db = db if db else "abcd"
from abcd.backends.atoms_pymongo import MongoDatabase

return MongoDatabase(db_name=db, host=r.geturl(), uri_mode=True, **kwargs)
elif r.scheme == "http" or r.scheme == "https":
if r.scheme == "http" or r.scheme == "https":
raise NotImplementedError("http not yet supported! soon...")
elif r.scheme == "ssh":
if r.scheme == "ssh":
raise NotImplementedError("ssh not yet supported! soon...")
else:
raise NotImplementedError(
"Unable to recognise the type of connection. (url: {})".format(url)
)
raise NotImplementedError(
f"Unable to recognise the type of connection. (url: {url})"
)


if __name__ == "__main__":
Expand Down
11 changes: 5 additions & 6 deletions abcd/backends/atoms_http.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import json
import logging
import requests
from os import linesep
from typing import List

import ase
import requests

from abcd.backends.abstract import Database

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -49,12 +50,12 @@ def search(self, query_string: str) -> List[str]:
return results

def get_atoms(self, id: str) -> Atoms:
data = requests.get(self.url + "/calculation/{}".format(id)).json()
data = requests.get(self.url + f"/calculation/{id}").json()
atoms = Atoms.from_dict(data)
return atoms

def __repr__(self):
return "ABCD(type={}, url={}, ...)".format(self.__class__.__name__, self.url)
return f"ABCD(type={self.__class__.__name__}, url={self.url}, ...)"

def _repr_html_(self):
"""jupyter notebook representation"""
Expand All @@ -67,9 +68,7 @@ def print_info(self):
[
"{:=^50}".format(" ABCD Database "),
"{:>10}: {}".format("type", "remote (http/https)"),
linesep.join(
"{:>10}: {}".format(k, v) for k, v in self.db.info().items()
),
linesep.join(f"{k:>10}: {v}" for k, v in self.db.info().items()),
]
)

Expand Down
108 changes: 43 additions & 65 deletions abcd/backends/atoms_pymongo.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,25 @@
import types
import logging
import numpy as np

from typing import Union, Iterable
from os import linesep
from operator import itemgetter
from collections import Counter
from collections.abc import Iterable
from datetime import datetime
import logging
from operator import itemgetter
from os import linesep
from pathlib import Path
import types
from typing import Union

from ase import Atoms
from ase.io import iread
from bson import ObjectId
import numpy as np
from pymongo import MongoClient
import pymongo.errors

from abcd.database import AbstractABCD
import abcd.errors
from abcd.model import AbstractModel
from abcd.database import AbstractABCD
from abcd.queryset import AbstractQuerySet
from abcd.parsers import extras

import pymongo.errors
from pymongo import MongoClient
from bson import ObjectId

from pathlib import Path
from abcd.queryset import AbstractQuerySet

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -125,11 +123,11 @@ def __exit__(self, exc_type, exc_val, exc_tb):
pass

def __call__(self, ast):
logger.info("parsed ast: {}".format(ast))
logger.info(f"parsed ast: {ast}")

if isinstance(ast, dict):
return ast
elif isinstance(ast, str):
if isinstance(ast, str):
from abcd.parsers.queries import parser

p = parser(ast)
Expand Down Expand Up @@ -165,7 +163,7 @@ def __init__(
password=None,
authSource="admin",
uri_mode=False,
**kwargs
**kwargs,
):
super().__init__()

Expand Down Expand Up @@ -195,7 +193,7 @@ def __init__(

try:
info = self.client.server_info() # Forces a call.
logger.info("DB info: {}".format(info))
logger.info(f"DB info: {info}")

except pymongo.errors.OperationFailure:
raise abcd.errors.AuthenticationError()
Expand Down Expand Up @@ -226,7 +224,6 @@ def destroy(self):
self.collection.drop()

def push(self, atoms: Union[Atoms, Iterable], extra_info=None, store_calc=True):

if extra_info and isinstance(extra_info, str):
extra_info = extras.parser.parse(extra_info)

Expand All @@ -238,15 +235,13 @@ def push(self, atoms: Union[Atoms, Iterable], extra_info=None, store_calc=True):
# self.collection.insert_one(data)

elif isinstance(atoms, types.GeneratorType) or isinstance(atoms, list):

for item in atoms:
data = AtomsModel.from_atoms(
self.collection, item, extra_info=extra_info, store_calc=store_calc
)
data.save()

def upload(self, file: Path, extra_infos=None, store_calc=True):

if isinstance(file, str):
file = Path(file)

Expand All @@ -273,7 +268,7 @@ def get_atoms(self, query=None):

def count(self, query=None):
query = parser(query)
logger.info("query; {}".format(query))
logger.info(f"query; {query}")

if not query:
query = {}
Expand All @@ -285,8 +280,8 @@ def property(self, name, query=None):

pipeline = [
{"$match": query},
{"$match": {"{}".format(name): {"$exists": True}}},
{"$project": {"_id": False, "data": "${}".format(name)}},
{"$match": {f"{name}": {"$exists": True}}},
{"$project": {"_id": False, "data": f"${name}"}},
]

return [val["data"] for val in self.db.atoms.aggregate(pipeline)]
Expand Down Expand Up @@ -331,22 +326,16 @@ def get_type_of_property(self, prop, category):

if category == "arrays":
if type(data[0]) == list:
return "array({}, N x {})".format(
map_types[type(data[0][0])], len(data[0])
)
else:
return "vector({}, N)".format(map_types[type(data[0])])
return f"array({map_types[type(data[0][0])]}, N x {len(data[0])})"
return f"vector({map_types[type(data[0])]}, N)"

if type(data) == list:
if type(data[0]) == list:
if type(data[0][0]) == list:
return "list(list(...)"
else:
return "array({})".format(map_types[type(data[0][0])])
else:
return "vector({})".format(map_types[type(data[0])])
else:
return "scalar({})".format(map_types[type(data)])
return f"array({map_types[type(data[0][0])]})"
return f"vector({map_types[type(data[0])]})"
return f"scalar({map_types[type(data)]})"

def count_properties(self, query=None):
query = parser(query)
Expand Down Expand Up @@ -396,7 +385,7 @@ def count_properties(self, query=None):
return properties

def add_property(self, data, query=None):
logger.info("add: data={}, query={}".format(data, query))
logger.info(f"add: data={data}, query={query}")

self.collection.update_many(
parser(query),
Expand All @@ -407,7 +396,7 @@ def add_property(self, data, query=None):
)

def rename_property(self, name, new_name, query=None):
logger.info("rename: query={}, old={}, new={}".format(query, name, new_name))
logger.info(f"rename: query={query}, old={name}, new={new_name}")
# TODO name in derived.info_keys OR name in derived.arrays_keys OR name in derived.derived_keys
self.collection.update_many(
parser(query), {"$push": {"derived.info_keys": new_name}}
Expand All @@ -428,7 +417,7 @@ def rename_property(self, name, new_name, query=None):
# '$rename': {'arrays.{}'.format(name): 'arrays.{}'.format(new_name)}})

def delete_property(self, name, query=None):
logger.info("delete: query={}, porperty={}".format(name, query))
logger.info(f"delete: query={name}, porperty={query}")

self.collection.update_many(
parser(query),
Expand All @@ -439,7 +428,6 @@ def delete_property(self, name, query=None):
)

def hist(self, name, query=None, **kwargs):

data = self.property(name, query)
return histogram(name, data, **kwargs)

Expand All @@ -454,10 +442,10 @@ def __repr__(self):
host, port = self.client.address

return (
"{}(".format(self.__class__.__name__)
+ "url={}:{}, ".format(host, port)
+ "db={}, ".format(self.db.name)
+ "collection={})".format(self.collection.name)
f"{self.__class__.__name__}("
+ f"url={host}:{port}, "
+ f"db={self.db.name}, "
+ f"collection={self.collection.name})"
)

def _repr_html_(self):
Expand All @@ -471,7 +459,7 @@ def print_info(self):
[
"{:=^50}".format(" ABCD MongoDB "),
"{:>10}: {}".format("type", "mongodb"),
linesep.join("{:>10}: {}".format(k, v) for k, v in self.info().items()),
linesep.join(f"{k:>10}: {v}" for k, v in self.info().items()),
]
)

Expand All @@ -488,45 +476,35 @@ def histogram(name, data, **kwargs):
if not data:
return None

elif data and isinstance(data, list):

if data and isinstance(data, list):
ptype = type(data[0])

if not all(isinstance(x, ptype) for x in data):
print("Mixed type error of the {} property!".format(name))
print(f"Mixed type error of the {name} property!")
return None

if ptype == float:
bins = kwargs.get("bins", 10)
return _hist_float(name, data, bins)

elif ptype == int:
if ptype == int:
bins = kwargs.get("bins", 10)
return _hist_int(name, data, bins)

elif ptype == str:
if ptype == str:
return _hist_str(name, data, **kwargs)

elif ptype == datetime:
if ptype == datetime:
bins = kwargs.get("bins", 10)
return _hist_date(name, data, bins)

else:
print(
"{}: Histogram for list of {} types are not supported!".format(
name, type(data[0])
)
)
logger.info(
"{}: Histogram for list of {} types are not supported!".format(
name, type(data[0])
)
)

else:
print(f"{name}: Histogram for list of {type(data[0])} types are not supported!")
logger.info(
"{}: Histogram for {} types are not supported!".format(name, type(data))
f"{name}: Histogram for list of {type(data[0])} types are not supported!"
)

else:
logger.info(f"{name}: Histogram for {type(data)} types are not supported!")
return None


Expand Down
2 changes: 1 addition & 1 deletion abcd/database.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from abc import ABCMeta, abstractmethod
import logging

logger = logging.getLogger(__name__)

Expand Down
Loading

0 comments on commit c06ef3c

Please sign in to comment.