From c06ef3c6dbafef9fffe7327d565bc5ed83e87752 Mon Sep 17 00:00:00 2001 From: ElliottKasoar Date: Wed, 13 Nov 2024 15:23:48 +0000 Subject: [PATCH] Apply ruff format and fixes --- abcd/__init__.py | 18 +-- abcd/backends/atoms_http.py | 11 +- abcd/backends/atoms_pymongo.py | 108 +++++-------- abcd/database.py | 2 +- abcd/frontends/commandline/commands.py | 52 +++--- abcd/frontends/commandline/config.py | 9 +- abcd/frontends/commandline/decorators.py | 1 - abcd/frontends/commandline/parser.py | 5 +- abcd/model.py | 19 +-- abcd/parsers/extras.py | 2 - abcd/parsers/queries.py | 6 +- abcd/parsers/queries_new.py | 4 +- abcd/queryset.py | 2 +- abcd/server/app/__init__.py | 4 +- abcd/server/app/db.py | 1 - abcd/server/app/nav.py | 5 +- abcd/server/app/views/api.py | 2 +- abcd/server/app/views/database.py | 3 +- docs_src/conf.py | 64 ++++---- tests/test_database.py | 8 +- tests/test_parsers.py | 10 +- tutorials/abcd_uploading.ipynb | 58 +++---- tutorials/abcd_usage.ipynb | 39 ++--- tutorials/gb_upload.py | 6 +- tutorials/grain_boundaries_tutorial.ipynb | 188 +++++++++++----------- tutorials/scripts/Preprocess.py | 13 +- tutorials/scripts/Reader.py | 14 +- tutorials/scripts/Visualise.py | 10 +- tutorials/scripts/Visualise_quip.py | 10 +- tutorials/test_db.py | 3 +- tutorials/test_upload.py | 3 - 31 files changed, 296 insertions(+), 384 deletions(-) diff --git a/abcd/__init__.py b/abcd/__init__.py index 04857244..e5de8376 100644 --- a/abcd/__init__.py +++ b/abcd/__init__.py @@ -1,6 +1,6 @@ +from enum import Enum import logging from urllib import parse -from enum import Enum logger = logging.getLogger(__name__) @@ -10,7 +10,7 @@ class ConnectionType(Enum): http = 2 -class ABCD(object): +class ABCD: @classmethod def from_config(cls, config): # Factory method @@ -24,7 +24,6 @@ def from_url(cls, url, **kwargs): logger.info(r) if r.scheme == "mongodb": - conn_settings = { "host": r.hostname, "port": r.port, @@ -39,20 +38,19 @@ def from_url(cls, url, **kwargs): from abcd.backends.atoms_pymongo import MongoDatabase return MongoDatabase(db_name=db, **conn_settings, **kwargs) - elif r.scheme == "mongodb+srv": + if r.scheme == "mongodb+srv": db = r.path.split("/")[1] if r.path else None db = db if db else "abcd" from abcd.backends.atoms_pymongo import MongoDatabase return MongoDatabase(db_name=db, host=r.geturl(), uri_mode=True, **kwargs) - elif r.scheme == "http" or r.scheme == "https": + if r.scheme == "http" or r.scheme == "https": raise NotImplementedError("http not yet supported! soon...") - elif r.scheme == "ssh": + if r.scheme == "ssh": raise NotImplementedError("ssh not yet supported! soon...") - else: - raise NotImplementedError( - "Unable to recognise the type of connection. (url: {})".format(url) - ) + raise NotImplementedError( + f"Unable to recognise the type of connection. (url: {url})" + ) if __name__ == "__main__": diff --git a/abcd/backends/atoms_http.py b/abcd/backends/atoms_http.py index ee62e61a..ba78c80a 100644 --- a/abcd/backends/atoms_http.py +++ b/abcd/backends/atoms_http.py @@ -1,10 +1,11 @@ import json import logging -import requests from os import linesep from typing import List import ase +import requests + from abcd.backends.abstract import Database logger = logging.getLogger(__name__) @@ -49,12 +50,12 @@ def search(self, query_string: str) -> List[str]: return results def get_atoms(self, id: str) -> Atoms: - data = requests.get(self.url + "/calculation/{}".format(id)).json() + data = requests.get(self.url + f"/calculation/{id}").json() atoms = Atoms.from_dict(data) return atoms def __repr__(self): - return "ABCD(type={}, url={}, ...)".format(self.__class__.__name__, self.url) + return f"ABCD(type={self.__class__.__name__}, url={self.url}, ...)" def _repr_html_(self): """jupyter notebook representation""" @@ -67,9 +68,7 @@ def print_info(self): [ "{:=^50}".format(" ABCD Database "), "{:>10}: {}".format("type", "remote (http/https)"), - linesep.join( - "{:>10}: {}".format(k, v) for k, v in self.db.info().items() - ), + linesep.join(f"{k:>10}: {v}" for k, v in self.db.info().items()), ] ) diff --git a/abcd/backends/atoms_pymongo.py b/abcd/backends/atoms_pymongo.py index c2baf7ad..1009efec 100644 --- a/abcd/backends/atoms_pymongo.py +++ b/abcd/backends/atoms_pymongo.py @@ -1,27 +1,25 @@ -import types -import logging -import numpy as np - -from typing import Union, Iterable -from os import linesep -from operator import itemgetter from collections import Counter +from collections.abc import Iterable from datetime import datetime +import logging +from operator import itemgetter +from os import linesep +from pathlib import Path +import types +from typing import Union from ase import Atoms from ase.io import iread +from bson import ObjectId +import numpy as np +from pymongo import MongoClient +import pymongo.errors +from abcd.database import AbstractABCD import abcd.errors from abcd.model import AbstractModel -from abcd.database import AbstractABCD -from abcd.queryset import AbstractQuerySet from abcd.parsers import extras - -import pymongo.errors -from pymongo import MongoClient -from bson import ObjectId - -from pathlib import Path +from abcd.queryset import AbstractQuerySet logger = logging.getLogger(__name__) @@ -125,11 +123,11 @@ def __exit__(self, exc_type, exc_val, exc_tb): pass def __call__(self, ast): - logger.info("parsed ast: {}".format(ast)) + logger.info(f"parsed ast: {ast}") if isinstance(ast, dict): return ast - elif isinstance(ast, str): + if isinstance(ast, str): from abcd.parsers.queries import parser p = parser(ast) @@ -165,7 +163,7 @@ def __init__( password=None, authSource="admin", uri_mode=False, - **kwargs + **kwargs, ): super().__init__() @@ -195,7 +193,7 @@ def __init__( try: info = self.client.server_info() # Forces a call. - logger.info("DB info: {}".format(info)) + logger.info(f"DB info: {info}") except pymongo.errors.OperationFailure: raise abcd.errors.AuthenticationError() @@ -226,7 +224,6 @@ def destroy(self): self.collection.drop() def push(self, atoms: Union[Atoms, Iterable], extra_info=None, store_calc=True): - if extra_info and isinstance(extra_info, str): extra_info = extras.parser.parse(extra_info) @@ -238,7 +235,6 @@ def push(self, atoms: Union[Atoms, Iterable], extra_info=None, store_calc=True): # self.collection.insert_one(data) elif isinstance(atoms, types.GeneratorType) or isinstance(atoms, list): - for item in atoms: data = AtomsModel.from_atoms( self.collection, item, extra_info=extra_info, store_calc=store_calc @@ -246,7 +242,6 @@ def push(self, atoms: Union[Atoms, Iterable], extra_info=None, store_calc=True): data.save() def upload(self, file: Path, extra_infos=None, store_calc=True): - if isinstance(file, str): file = Path(file) @@ -273,7 +268,7 @@ def get_atoms(self, query=None): def count(self, query=None): query = parser(query) - logger.info("query; {}".format(query)) + logger.info(f"query; {query}") if not query: query = {} @@ -285,8 +280,8 @@ def property(self, name, query=None): pipeline = [ {"$match": query}, - {"$match": {"{}".format(name): {"$exists": True}}}, - {"$project": {"_id": False, "data": "${}".format(name)}}, + {"$match": {f"{name}": {"$exists": True}}}, + {"$project": {"_id": False, "data": f"${name}"}}, ] return [val["data"] for val in self.db.atoms.aggregate(pipeline)] @@ -331,22 +326,16 @@ def get_type_of_property(self, prop, category): if category == "arrays": if type(data[0]) == list: - return "array({}, N x {})".format( - map_types[type(data[0][0])], len(data[0]) - ) - else: - return "vector({}, N)".format(map_types[type(data[0])]) + return f"array({map_types[type(data[0][0])]}, N x {len(data[0])})" + return f"vector({map_types[type(data[0])]}, N)" if type(data) == list: if type(data[0]) == list: if type(data[0][0]) == list: return "list(list(...)" - else: - return "array({})".format(map_types[type(data[0][0])]) - else: - return "vector({})".format(map_types[type(data[0])]) - else: - return "scalar({})".format(map_types[type(data)]) + return f"array({map_types[type(data[0][0])]})" + return f"vector({map_types[type(data[0])]})" + return f"scalar({map_types[type(data)]})" def count_properties(self, query=None): query = parser(query) @@ -396,7 +385,7 @@ def count_properties(self, query=None): return properties def add_property(self, data, query=None): - logger.info("add: data={}, query={}".format(data, query)) + logger.info(f"add: data={data}, query={query}") self.collection.update_many( parser(query), @@ -407,7 +396,7 @@ def add_property(self, data, query=None): ) def rename_property(self, name, new_name, query=None): - logger.info("rename: query={}, old={}, new={}".format(query, name, new_name)) + logger.info(f"rename: query={query}, old={name}, new={new_name}") # TODO name in derived.info_keys OR name in derived.arrays_keys OR name in derived.derived_keys self.collection.update_many( parser(query), {"$push": {"derived.info_keys": new_name}} @@ -428,7 +417,7 @@ def rename_property(self, name, new_name, query=None): # '$rename': {'arrays.{}'.format(name): 'arrays.{}'.format(new_name)}}) def delete_property(self, name, query=None): - logger.info("delete: query={}, porperty={}".format(name, query)) + logger.info(f"delete: query={name}, porperty={query}") self.collection.update_many( parser(query), @@ -439,7 +428,6 @@ def delete_property(self, name, query=None): ) def hist(self, name, query=None, **kwargs): - data = self.property(name, query) return histogram(name, data, **kwargs) @@ -454,10 +442,10 @@ def __repr__(self): host, port = self.client.address return ( - "{}(".format(self.__class__.__name__) - + "url={}:{}, ".format(host, port) - + "db={}, ".format(self.db.name) - + "collection={})".format(self.collection.name) + f"{self.__class__.__name__}(" + + f"url={host}:{port}, " + + f"db={self.db.name}, " + + f"collection={self.collection.name})" ) def _repr_html_(self): @@ -471,7 +459,7 @@ def print_info(self): [ "{:=^50}".format(" ABCD MongoDB "), "{:>10}: {}".format("type", "mongodb"), - linesep.join("{:>10}: {}".format(k, v) for k, v in self.info().items()), + linesep.join(f"{k:>10}: {v}" for k, v in self.info().items()), ] ) @@ -488,45 +476,35 @@ def histogram(name, data, **kwargs): if not data: return None - elif data and isinstance(data, list): - + if data and isinstance(data, list): ptype = type(data[0]) if not all(isinstance(x, ptype) for x in data): - print("Mixed type error of the {} property!".format(name)) + print(f"Mixed type error of the {name} property!") return None if ptype == float: bins = kwargs.get("bins", 10) return _hist_float(name, data, bins) - elif ptype == int: + if ptype == int: bins = kwargs.get("bins", 10) return _hist_int(name, data, bins) - elif ptype == str: + if ptype == str: return _hist_str(name, data, **kwargs) - elif ptype == datetime: + if ptype == datetime: bins = kwargs.get("bins", 10) return _hist_date(name, data, bins) - else: - print( - "{}: Histogram for list of {} types are not supported!".format( - name, type(data[0]) - ) - ) - logger.info( - "{}: Histogram for list of {} types are not supported!".format( - name, type(data[0]) - ) - ) - - else: + print(f"{name}: Histogram for list of {type(data[0])} types are not supported!") logger.info( - "{}: Histogram for {} types are not supported!".format(name, type(data)) + f"{name}: Histogram for list of {type(data[0])} types are not supported!" ) + + else: + logger.info(f"{name}: Histogram for {type(data)} types are not supported!") return None diff --git a/abcd/database.py b/abcd/database.py index 2553583b..ddfc1ee1 100644 --- a/abcd/database.py +++ b/abcd/database.py @@ -1,5 +1,5 @@ -import logging from abc import ABCMeta, abstractmethod +import logging logger = logging.getLogger(__name__) diff --git a/abcd/frontends/commandline/commands.py b/abcd/frontends/commandline/commands.py index de158a5c..6dabba5a 100644 --- a/abcd/frontends/commandline/commands.py +++ b/abcd/frontends/commandline/commands.py @@ -11,9 +11,7 @@ @init_config def login(*, config, name, url, **kwargs): logger.info( - "login args: \nconfig:{}, name:{}, url:{}, kwargs:{}".format( - config, name, url, kwargs - ) + f"login args: \nconfig:{config}, name:{name}, url:{url}, kwargs:{kwargs}" ) from abcd import ABCD @@ -36,7 +34,7 @@ def login(*, config, name, url, **kwargs): @init_config @init_db def download(*, db, query, fileformat, filename, **kwargs): - logger.info("download\n kwargs: {}".format(kwargs)) + logger.info(f"download\n kwargs: {kwargs}") from ase.io import write @@ -51,18 +49,14 @@ def download(*, db, query, fileformat, filename, **kwargs): @init_db @check_remote def delete(*, db, query, yes, **kwargs): - logger.info("delete\n kwargs: {}".format(kwargs)) + logger.info(f"delete\n kwargs: {kwargs}") if not yes: - print( - "Please use --yes for deleting {} configurations".format( - db.count(query=query) - ) - ) + print(f"Please use --yes for deleting {db.count(query=query)} configurations") exit(1) count = db.delete(query=query) - print("{} configuration has been deleted".format(count)) + print(f"{count} configuration has been deleted") @init_config @@ -79,10 +73,10 @@ def upload(*, db, path, extra_infos, ignore_calc_results, **kwargs): elif path.is_dir(): for file in path.glob(".xyz"): - logger.info("Uploaded file: {}".format(file)) + logger.info(f"Uploaded file: {file}") db.upload(file, extra_infos, store_calc=calculator) else: - logger.info("No file found: {}".format(path)) + logger.info(f"No file found: {path}") raise FileNotFoundError() else: @@ -92,8 +86,8 @@ def upload(*, db, path, extra_infos, ignore_calc_results, **kwargs): @init_config @init_db def summary(*, db, query, print_all, bins, truncate, props, **kwargs): - logger.info("summary\n kwargs: {}".format(kwargs)) - logger.info("query: {}".format(query)) + logger.info(f"summary\n kwargs: {kwargs}") + logger.info(f"query: {query}") if print_all: truncate = None @@ -111,17 +105,16 @@ def summary(*, db, query, print_all, bins, truncate, props, **kwargs): if "*" in props_list: props_list = "*" - logging.info("property list: {}".format(props_list)) + logging.info(f"property list: {props_list}") total = db.count(query) - print("Total number of configurations: {}".format(total)) + print(f"Total number of configurations: {total}") if total == 0: return f = Formater() if props_list is None: - props = db.count_properties(query=query) labels, categories, dtypes, counts = [], [], [], [] @@ -158,8 +151,8 @@ def summary(*, db, query, print_all, bins, truncate, props, **kwargs): @init_config @init_db def show(*, db, query, print_all, props, **kwargs): - logger.info("show\n kwargs: {}".format(kwargs)) - logger.info("query: {}".format(query)) + logger.info(f"show\n kwargs: {kwargs}") + logger.info(f"query: {query}") if not props: print("Please define at least on property by using the -p option!") @@ -171,7 +164,7 @@ def show(*, db, query, print_all, props, **kwargs): for dct in islice(db.get_items(query), 0, limit): print(" | ".join(str(dct.get(prop, None)) for prop in props)) - logging.info("property list: {}".format(props)) + logging.info(f"property list: {props}") @check_remote @@ -225,9 +218,7 @@ def key_delete(*, db, query, yes, keys, **kwargs): if not yes: print( - "Please use --yes for deleting keys from {} configurations".format( - db.count(query=query) - ) + f"Please use --yes for deleting keys from {db.count(query=query)} configurations" ) exit(1) @@ -241,9 +232,7 @@ def key_delete(*, db, query, yes, keys, **kwargs): def execute(*, db, query, yes, python_code, **kwargs): if not yes: print( - "Please use --yes for executing code on {} configurations".format( - db.count(query=query) - ) + f"Please use --yes for executing code on {db.count(query=query)} configurations" ) exit(1) @@ -253,11 +242,10 @@ def execute(*, db, query, yes, python_code, **kwargs): @check_remote def server(*, abcd_url, url, api_only, **kwargs): from urllib.parse import urlparse + from abcd.server.app import create_app - logger.info( - "SERVER - abcd: {}, url: {}, api_only:{}".format(abcd_url, url, api_only) - ) + logger.info(f"SERVER - abcd: {abcd_url}, url: {url}, api_only:{api_only}") if api_only: print("Not implemented yet!") @@ -269,7 +257,7 @@ def server(*, abcd_url, url, api_only, **kwargs): app.run(host=o.hostname, port=o.port) -class Formater(object): +class Formater: partialBlocks = ["▏", "▎", "▍", "▌", "▋", "▊", "▉", "█"] # char=pb def title(self, title): @@ -321,7 +309,6 @@ def hist_float(self, bin_edges, counts, width_hist=40): ) def hist_int(self, bin_edges, counts, width_hist=40): - ratio = width_hist / max(counts) width_count = len(str(max(counts))) @@ -373,7 +360,6 @@ def hist_str(self, total, counts, labels, width_hist=40): ) def hist_labels(self, counts, categories, dtypes, labels, width_hist=40): - width_count = len(str(max(counts))) ratio = width_hist / max(counts) for label, count, dtype in zip(labels, counts, dtypes): diff --git a/abcd/frontends/commandline/config.py b/abcd/frontends/commandline/config.py index 3aa21bea..7a13b7c1 100644 --- a/abcd/frontends/commandline/config.py +++ b/abcd/frontends/commandline/config.py @@ -1,6 +1,6 @@ -import os import json import logging +import os from pathlib import Path logger = logging.getLogger(__name__) @@ -20,7 +20,6 @@ def from_json(cls, filename): @classmethod def load(cls): - if ( os.environ.get("ABCD_CONFIG") and Path(os.environ.get("ABCD_CONFIG")).is_file() @@ -31,7 +30,7 @@ def load(cls): else: return cls() - logger.info("Using config file: {}".format(file)) + logger.info(f"Using config file: {file}") config = cls.from_json(file) @@ -44,10 +43,10 @@ def save(self): else Path.home() / ".abcd" ) - logger.info("The saved config's file: {}".format(file)) + logger.info(f"The saved config's file: {file}") with open(str(file), "w") as file: json.dump(self, file) def __repr__(self): - return "<{} {}>".format(self.__class__.__name__, dict.__repr__(self)) + return f"<{self.__class__.__name__} {dict.__repr__(self)}>" diff --git a/abcd/frontends/commandline/decorators.py b/abcd/frontends/commandline/decorators.py index c2439be7..8fb37499 100644 --- a/abcd/frontends/commandline/decorators.py +++ b/abcd/frontends/commandline/decorators.py @@ -1,7 +1,6 @@ import logging from abcd import ABCD - from abcd.frontends.commandline.config import Config from abcd.parsers.queries import parser diff --git a/abcd/frontends/commandline/parser.py b/abcd/frontends/commandline/parser.py index 9b2c1af2..ab9a9d54 100644 --- a/abcd/frontends/commandline/parser.py +++ b/abcd/frontends/commandline/parser.py @@ -1,7 +1,8 @@ -import logging from argparse import ArgumentParser +import logging + +from abcd.errors import AuthenticationError, TimeoutError, URLError from abcd.frontends.commandline import commands -from abcd.errors import URLError, AuthenticationError, TimeoutError logger = logging.getLogger(__name__) diff --git a/abcd/model.py b/abcd/model.py index f4c87b61..89ae6567 100644 --- a/abcd/model.py +++ b/abcd/model.py @@ -1,22 +1,21 @@ +from collections import Counter, UserDict import datetime import getpass -import logging from hashlib import md5 -from collections import Counter, UserDict -from ase.calculators.singlepoint import SinglePointCalculator +import logging -import numpy as np from ase import Atoms +from ase.calculators.singlepoint import SinglePointCalculator +import numpy as np logger = logging.getLogger(__name__) -class Hasher(object): +class Hasher: def __init__(self, method=md5()): self.method = method def update(self, value): - if isinstance(value, int): self.update(str(value).encode("ascii")) @@ -24,7 +23,7 @@ def update(self, value): self.update(value.encode("utf-8")) elif isinstance(value, float): - self.update("{:.8e}".format(value).encode("ascii")) + self.update(f"{value:.8e}".encode("ascii")) elif isinstance(value, (tuple, list)): for e in value: @@ -80,14 +79,12 @@ def derived(self): } def __getitem__(self, key): - if key == "derived": return self.derived return super().__getitem__(key) def __setitem__(self, key, value): - if key == "derived": # raise KeyError('Please do not use "derived" as key because it is protected!') # Silent return to avoid raising error in pymongo package @@ -107,7 +104,6 @@ def convert(self, value): return value def update_key_category(self, key, value): - if key == "_id": # raise KeyError('Please do not use "derived" as key because it is protected!') return @@ -199,7 +195,6 @@ def from_atoms(cls, atoms: Atoms, extra_info=None, store_calc=True): info_keys.update({"calculator_name", "calculator_parameters"}) for key, value in atoms.calc.results.items(): - if isinstance(value, np.ndarray): if value.shape[0] == n_atoms: arrays_keys.update(key) @@ -294,8 +289,8 @@ def pre_save(self): if __name__ == "__main__": - import io from pprint import pprint + from ase.io import read logging.basicConfig(level=logging.INFO) diff --git a/abcd/parsers/extras.py b/abcd/parsers/extras.py index 47e18774..0dc90f46 100644 --- a/abcd/parsers/extras.py +++ b/abcd/parsers/extras.py @@ -1,6 +1,4 @@ -import sys from lark import Lark, Transformer, v_args -from lark.exceptions import LarkError grammar = r""" start: ( key | key_value )* diff --git a/abcd/parsers/queries.py b/abcd/parsers/queries.py index 7071fe24..ebb18701 100644 --- a/abcd/parsers/queries.py +++ b/abcd/parsers/queries.py @@ -1,7 +1,7 @@ import logging + from lark import Lark, Transformer, v_args from lark.exceptions import LarkError -from abcd.queryset import Query logger = logging.getLogger(__name__) @@ -176,7 +176,7 @@ def __call__(self, string): # print(parser.parse(query).pretty()) try: tree = parser.parse(query) - logger.info("=> tree: {}".format(tree)) - logger.info("==> ast: {}".format(parser(query))) + logger.info(f"=> tree: {tree}") + logger.info(f"==> ast: {parser(query)}") except LarkError: raise NotImplementedError diff --git a/abcd/parsers/queries_new.py b/abcd/parsers/queries_new.py index 22004d11..9020f74e 100644 --- a/abcd/parsers/queries_new.py +++ b/abcd/parsers/queries_new.py @@ -1,6 +1,6 @@ import logging -from lark import Lark, Transformer, v_args -from lark.lexer import Token + +from lark import Lark, Transformer from lark.exceptions import LarkError logger = logging.getLogger(__name__) diff --git a/abcd/queryset.py b/abcd/queryset.py index 82b8ea03..40b0bb0b 100644 --- a/abcd/queryset.py +++ b/abcd/queryset.py @@ -1,5 +1,5 @@ -import logging from abc import ABCMeta +import logging logger = logging.getLogger(__name__) diff --git a/abcd/server/app/__init__.py b/abcd/server/app/__init__.py index d7176a42..2f91de0a 100644 --- a/abcd/server/app/__init__.py +++ b/abcd/server/app/__init__.py @@ -4,8 +4,8 @@ from flask_nav import register_renderer from abcd.server.app.db import db -from abcd.server.app.nav import nav, BootstrapRenderer, DatabaseNav -from abcd.server.app.views import index, database, api +from abcd.server.app.nav import BootstrapRenderer, DatabaseNav, nav +from abcd.server.app.views import api, database, index def create_app(abcd_url=None): diff --git a/abcd/server/app/db.py b/abcd/server/app/db.py index 65bc9c70..eadc1f45 100644 --- a/abcd/server/app/db.py +++ b/abcd/server/app/db.py @@ -1,6 +1,5 @@ from abcd import ABCD - # from flask_paginate import Pagination, get_page_args diff --git a/abcd/server/app/nav.py b/abcd/server/app/nav.py index ec8ea2d1..31c8edeb 100644 --- a/abcd/server/app/nav.py +++ b/abcd/server/app/nav.py @@ -1,8 +1,8 @@ -from flask_nav import Nav -from flask_nav.elements import Navbar, View, Separator, Subgroup, Link from hashlib import sha1 from dominate import tags +from flask_nav import Nav +from flask_nav.elements import Link, Navbar, View from flask_nav.renderers import Renderer nav = Nav() @@ -164,7 +164,6 @@ def visit_View(self, node): return item def visit_Subgroup(self, node): - if self._in_dropdown: raise RuntimeError("Cannot render nested Subgroups") diff --git a/abcd/server/app/views/api.py b/abcd/server/app/views/api.py index b3aa6d25..26f5b99b 100644 --- a/abcd/server/app/views/api.py +++ b/abcd/server/app/views/api.py @@ -1,4 +1,4 @@ -from flask import Blueprint, Response, make_response, jsonify, request +from flask import Blueprint, Response, jsonify, request bp = Blueprint("api", __name__) diff --git a/abcd/server/app/views/database.py b/abcd/server/app/views/database.py index 8569dcd2..607dc8b8 100644 --- a/abcd/server/app/views/database.py +++ b/abcd/server/app/views/database.py @@ -1,5 +1,4 @@ -from flask import Blueprint, render_template, request -from flask import Response +from flask import Blueprint, Response, render_template, request bp = Blueprint("database", __name__) diff --git a/docs_src/conf.py b/docs_src/conf.py index 97f1e35f..4eddd064 100644 --- a/docs_src/conf.py +++ b/docs_src/conf.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Configuration file for the Sphinx documentation builder. # @@ -15,18 +14,18 @@ import os import sys -sys.path.insert(0, os.path.abspath('.')) +sys.path.insert(0, os.path.abspath(".")) # -- Project information ----------------------------------------------------- -project = 'abcd' -copyright = '2019, adam' -author = 'adam' +project = "abcd" +copyright = "2019, adam" +author = "adam" # The short X.Y version -version = '' +version = "" # The full version, including alpha/beta/rc tags -release = '0.1' +release = "0.1" # -- General configuration --------------------------------------------------- @@ -39,11 +38,11 @@ # ones. extensions = [ - 'sphinx.ext.napoleon', - 'sphinx.ext.autodoc', - 'sphinx.ext.viewcode', + "sphinx.ext.napoleon", + "sphinx.ext.autodoc", + "sphinx.ext.viewcode", # 'sphinx.ext.mathjax', - 'sphinx.ext.githubpages', + "sphinx.ext.githubpages", ] # Napoleon settings @@ -60,16 +59,16 @@ napoleon_use_rtype = True # Add any paths that contain templates here, relative to this directory. -templates_path = ['templates'] +templates_path = ["templates"] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # # source_suffix = ['.rst', '.md'] -source_suffix = '.rst' +source_suffix = ".rst" # The master toctree document. -master_doc = 'index' +master_doc = "index" # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. @@ -81,7 +80,7 @@ # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] # The name of the Pygments (syntax highlighting) style to use. pygments_style = None @@ -93,7 +92,7 @@ # # html_theme = 'alabaster' -html_theme = 'sphinx_rtd_theme' +html_theme = "sphinx_rtd_theme" # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the @@ -110,8 +109,8 @@ # 'style_external_links': False, # 'vcs_pageview_mode': '', # Toc options - 'collapse_navigation': False, - 'sticky_navigation': False, + "collapse_navigation": False, + "sticky_navigation": False, # 'navigation_depth': 4, # 'includehidden': True, # 'titles_only': False @@ -120,7 +119,7 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['static'] +html_static_path = ["static"] # Custom sidebar templates, must be a dictionary that maps document names # to template names. @@ -136,7 +135,7 @@ # -- Options for HTMLHelp output --------------------------------------------- # Output file base name for HTML help builder. -htmlhelp_basename = 'abcddoc' +htmlhelp_basename = "abcddoc" # -- Options for LaTeX output ------------------------------------------------ @@ -144,15 +143,12 @@ # The paper size ('letterpaper' or 'a4paper'). # # 'papersize': 'letterpaper', - # The font size ('10pt', '11pt' or '12pt'). # # 'pointsize': '10pt', - # Additional stuff for the LaTeX preamble. # # 'preamble': '', - # Latex figure (float) alignment # # 'figure_align': 'htbp', @@ -162,18 +158,14 @@ # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - (master_doc, 'abcd.tex', 'abcd Documentation', - 'adam', 'manual'), + (master_doc, "abcd.tex", "abcd Documentation", "adam", "manual"), ] # -- Options for manual page output ------------------------------------------ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [ - (master_doc, 'abcd', 'abcd Documentation', - [author], 1) -] +man_pages = [(master_doc, "abcd", "abcd Documentation", [author], 1)] # -- Options for Texinfo output ---------------------------------------------- @@ -181,9 +173,15 @@ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - (master_doc, 'abcd', 'abcd Documentation', - author, 'abcd', 'One line description of project.', - 'Miscellaneous'), + ( + master_doc, + "abcd", + "abcd Documentation", + author, + "abcd", + "One line description of project.", + "Miscellaneous", + ), ] # -- Options for Epub output ------------------------------------------------- @@ -201,6 +199,6 @@ # epub_uid = '' # A list of files that should not be packed into the epub file. -epub_exclude_files = ['search.html'] +epub_exclude_files = ["search.html"] # -- Extension configuration ------------------------------------------------- diff --git a/tests/test_database.py b/tests/test_database.py index 82cddfab..864ca5b2 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -1,11 +1,11 @@ -import pytest +from io import StringIO + +from ase.io import read import mongomock +import pytest from abcd import ABCD -from io import StringIO -from ase.io import read - @pytest.fixture @mongomock.patch(servers=(("localhost", 27017),)) diff --git a/tests/test_parsers.py b/tests/test_parsers.py index 5a2211e0..df52fbc6 100644 --- a/tests/test_parsers.py +++ b/tests/test_parsers.py @@ -1,4 +1,5 @@ import pytest + from abcd.parsers.extras import parser as extras_parser from abcd.parsers.queries import parser as queries_parser @@ -143,8 +144,7 @@ def test_composite(self, parser): "not_bool_array=[T F S]", ], ) - def test_missing(self, string): - ... + def test_missing(self, string): ... class TestParsingQueries: @@ -188,8 +188,7 @@ def test_combination(self, parser, string, expected): ("any(aa) > 3", {}), ], ) - def test_expressions(self, case): - ... + def test_expressions(self, case): ... @pytest.mark.skip("known issues / future features") @pytest.mark.parametrize( @@ -202,5 +201,4 @@ def test_expressions(self, case): ("aa and (bb > 23.54 or (22 in cc and dd))", {}), ], ) - def test_expressions(self, case): - ... + def test_expressions(self, case): ... diff --git a/tutorials/abcd_uploading.ipynb b/tutorials/abcd_uploading.ipynb index 80a774c8..0b1640c4 100644 --- a/tutorials/abcd_uploading.ipynb +++ b/tutorials/abcd_uploading.ipynb @@ -43,13 +43,11 @@ }, "outputs": [], "source": [ - "import numpy as np\n", - "import matplotlib.pyplot as plt\n", - "\n", "from pathlib import Path\n", + "\n", "from ase.io import iread, read\n", - "# from utils.ext_xyz import XYZReader\n", "\n", + "# from utils.ext_xyz import XYZReader\n", "from abcd import ABCD" ] }, @@ -85,10 +83,10 @@ } ], "source": [ - "url = 'mongodb://localhost:27017'\n", - "url = 'mongodb://mongoadmin:secret@localhost:27017'\n", + "url = \"mongodb://localhost:27017\"\n", + "url = \"mongodb://mongoadmin:secret@localhost:27017\"\n", "# url = 'mongodb://2ef35d3635e9dc5a922a6a42:ac6ce72e259f5ddcc8dd5178@localhost:27017/?authSource=admin'\n", - "abcd = ABCD.from_url(url)\n", + "abcd = ABCD.from_url(url)\n", "\n", "print(abcd)" ] @@ -192,9 +190,9 @@ }, "outputs": [], "source": [ - "direcotry = Path('data/')\n", + "direcotry = Path(\"data/\")\n", "# file = direcotry / 'bcc_bulk_54_expanded_high.xyz'\n", - "file = direcotry / 'GAP_1.xyz'" + "file = direcotry / \"GAP_1.xyz\"" ] }, { @@ -228,17 +226,14 @@ "source": [ "%%time\n", "with abcd as db:\n", - "\n", " for atoms in iread(file.as_posix(), index=slice(None)):\n", - " \n", " # Hack to fix the representation of forces\n", - " \n", - "# atoms.calc.results['forces'] = atoms.arrays['force']\n", - "# del(atoms.info['energy'])\n", - " \n", + "\n", + " # atoms.calc.results['forces'] = atoms.arrays['force']\n", + " # del(atoms.info['energy'])\n", + "\n", " db.push(atoms, store_calc=False)\n", - "# break\n", - " " + "# break" ] }, { @@ -731,25 +726,23 @@ } ], "source": [ - "# wrong format or reader \n", + "# wrong format or reader\n", "# traj = read(\"libatoms/DataRepository/Graphene_GAP_Final/Graphene_GAP_Validation.xyz\", index=\":\")\n", "# InvalidDocument: key '18.36' must not contain '.'\n", "# traj = read(\"libatoms/DataRepository/bulk-methane-models-main/init-tiny.xyz\", index=\":\")\n", - "excludes =[\n", + "excludes = [\n", " Path(\"libatoms/DataRepository/Graphene_GAP_Final/Graphene_GAP_Validation.xyz\"),\n", - " Path(\"libatoms/DataRepository/bulk-methane-models-main/init-tiny.xyz\")\n", + " Path(\"libatoms/DataRepository/bulk-methane-models-main/init-tiny.xyz\"),\n", "]\n", "\n", - "for file in Path('libatoms/').glob('**/*.xyz'):\n", + "for file in Path(\"libatoms/\").glob(\"**/*.xyz\"):\n", " print(file)\n", "\n", - " if file in excludes :\n", + " if file in excludes:\n", " continue\n", "\n", " for atoms in iread(file.as_posix(), index=slice(None)):\n", - " abcd.push(atoms, store_calc=False)\n", - "\n", - " \n" + " abcd.push(atoms, store_calc=False)" ] }, { @@ -797,7 +790,7 @@ } ], "source": [ - "for file in Path('data/').glob('*.xyz'):\n", + "for file in Path(\"data/\").glob(\"*.xyz\"):\n", " print(file)\n", "\n", " for atoms in iread(file.as_posix(), index=slice(None)):\n", @@ -813,7 +806,7 @@ "metadata": {}, "outputs": [], "source": [ - "for file in Path('data/').glob('*.xyz'):\n", + "for file in Path(\"data/\").glob(\"*.xyz\"):\n", " for atoms in iread(file.as_posix(), index=slice(None)):\n", " abcd.push(atoms, store_calc=False)" ] @@ -896,17 +889,12 @@ } ], "source": [ - "for file in Path('GB_alphaFe_001/tilt/').glob('*.xyz'):\n", + "for file in Path(\"GB_alphaFe_001/tilt/\").glob(\"*.xyz\"):\n", " print(file)\n", - " gb_params = {\n", - " 'name': 'alphaFe',\n", - " 'type': 'tilt',\n", - " 'params': file.name[:-4]\n", - " \n", - " } \n", + " gb_params = {\"name\": \"alphaFe\", \"type\": \"tilt\", \"params\": file.name[:-4]}\n", "\n", " traj = read(file.as_posix(), index=slice(None))\n", - " db.push(traj, extra_info={'GB_params': gb_params}, store_calc=False)" + " db.push(traj, extra_info={\"GB_params\": gb_params}, store_calc=False)" ] }, { diff --git a/tutorials/abcd_usage.ipynb b/tutorials/abcd_usage.ipynb index 5bb44deb..f98880c5 100644 --- a/tutorials/abcd_usage.ipynb +++ b/tutorials/abcd_usage.ipynb @@ -25,10 +25,10 @@ }, "outputs": [], "source": [ - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", "from collections import Counter\n", "\n", + "import matplotlib.pyplot as plt\n", + "\n", "from abcd import ABCD" ] }, @@ -56,9 +56,9 @@ } ], "source": [ - "url = 'mongodb://localhost:27017'\n", - "url = 'mongodb://mongoadmin:secret@localhost:27017'\n", - "abcd = ABCD.from_url(url)\n", + "url = \"mongodb://localhost:27017\"\n", + "url = \"mongodb://mongoadmin:secret@localhost:27017\"\n", + "abcd = ABCD.from_url(url)\n", "\n", "print(abcd)" ] @@ -299,7 +299,7 @@ } ], "source": [ - "Counter(abcd.property('config_type'))" + "Counter(abcd.property(\"config_type\"))" ] }, { @@ -319,9 +319,7 @@ } ], "source": [ - "query = {\n", - " 'config_type': 'bcc_bulk_54_high'\n", - "}\n", + "query = {\"config_type\": \"bcc_bulk_54_high\"}\n", "# query = 'config_type=\"bcc_bulk_54_high\"'\n", "abcd.count(query)" ] @@ -358,7 +356,7 @@ } ], "source": [ - "Counter(abcd.property('config_name', query))" + "Counter(abcd.property(\"config_name\", query))" ] }, { @@ -385,10 +383,7 @@ } ], "source": [ - "query = {\n", - " 'config_type': 'bcc_bulk_54_high',\n", - " 'pbc': [True, True, True]\n", - "}\n", + "query = {\"config_type\": \"bcc_bulk_54_high\", \"pbc\": [True, True, True]}\n", "# query = 'config_type=\"bcc_bulk_54_high\" and pbc=[True, True, True]'\n", "abcd.count(query)" ] @@ -424,7 +419,7 @@ } ], "source": [ - "data = abcd.property('energy', query)\n", + "data = abcd.property(\"energy\", query)\n", "hist, bins, ax = plt.hist(data)\n", "plt.show()\n", "min(data), max(data)" @@ -469,9 +464,9 @@ ], "source": [ "query = {\n", - " 'config_type': 'bcc_bulk_54_high',\n", - " 'energy': {'$gt': -186885.0},\n", - " 'pbc': [True, True, True],\n", + " \"config_type\": \"bcc_bulk_54_high\",\n", + " \"energy\": {\"$gt\": -186885.0},\n", + " \"pbc\": [True, True, True],\n", "}\n", "abcd.count(query)" ] @@ -493,7 +488,7 @@ } ], "source": [ - "abcd.count_properties(query)['arrays']" + "abcd.count_properties(query)[\"arrays\"]" ] }, { @@ -521,9 +516,9 @@ ], "source": [ "query = {\n", - " 'config_type': 'bcc_bulk_54_high',\n", - " 'energy': {'$gt': -186885.0},\n", - " 'pbc': [True, True, True],\n", + " \"config_type\": \"bcc_bulk_54_high\",\n", + " \"energy\": {\"$gt\": -186885.0},\n", + " \"pbc\": [True, True, True],\n", "}\n", "abcd.count(query)" ] diff --git a/tutorials/gb_upload.py b/tutorials/gb_upload.py index 3d276de2..37eb8484 100644 --- a/tutorials/gb_upload.py +++ b/tutorials/gb_upload.py @@ -1,13 +1,13 @@ -import sys from pathlib import Path +import sys sys.path.append("..") -from abcd import ABCD from utils.ext_xyz import XYZReader -if __name__ == "__main__": +from abcd import ABCD +if __name__ == "__main__": url = "mongodb://localhost:27017" abcd = ABCD(url) diff --git a/tutorials/grain_boundaries_tutorial.ipynb b/tutorials/grain_boundaries_tutorial.ipynb index a2582c55..0d259c5c 100644 --- a/tutorials/grain_boundaries_tutorial.ipynb +++ b/tutorials/grain_boundaries_tutorial.ipynb @@ -68,7 +68,8 @@ "outputs": [], "source": [ "import sys\n", - "sys.path.append('..')" + "\n", + "sys.path.append(\"..\")" ] }, { @@ -83,24 +84,15 @@ "source": [ "%matplotlib inline\n", "\n", - "from os import listdir\n", - "from pathlib import Path\n", "\n", + "from asap3 import FullNeighborList\n", + "from asap3.analysis import PTM, CoordinationNumbers\n", + "from ase.io import read\n", + "import matplotlib.pyplot as plt\n", "import numpy as np\n", "from scipy.spatial import cKDTree\n", - "\n", - "from ase import Atoms\n", - "from ase.io import read\n", - "\n", "from scripts.Visualise import AtomViewer\n", - "\n", - "import matplotlib.pyplot as plt\n", - "\n", - "from asap3 import FullNeighborList\n", - "from asap3.analysis import CoordinationNumbers, FullCNA, PTM\n", - "\n", "from sklearn.cluster import KMeans\n", - "from sklearn.mixture import GaussianMixture\n", "\n", "from abcd import ABCD" ] @@ -119,7 +111,7 @@ "outputs": [], "source": [ "# potential energy of the perfect crystal according to a specific potential\n", - "Fe_BCC_energy_per_atom = -4.01298214176 # alpha-Fe PotBH\n", + "Fe_BCC_energy_per_atom = -4.01298214176 # alpha-Fe PotBH\n", "Fe_BCC_lattice_constant = 2.856" ] }, @@ -131,11 +123,11 @@ }, "outputs": [], "source": [ - "abcd = ABCD(url='mongodb://localhost:27017')\n", + "abcd = ABCD(url=\"mongodb://localhost:27017\")\n", "\n", "query = {\n", - " 'info.GB_params.name': 'alphaFe',\n", - " 'info.GB_params.type': 'tilt',\n", + " \"info.GB_params.name\": \"alphaFe\",\n", + " \"info.GB_params.type\": \"tilt\",\n", "}\n", "\n", "traj = list(abcd.get_atoms(query))" @@ -168,7 +160,7 @@ "outputs": [], "source": [ "for atoms in traj:\n", - " atoms.calc.results={'energy': atoms.info['energy']}" + " atoms.calc.results = {\"energy\": atoms.info[\"energy\"]}" ] }, { @@ -251,10 +243,12 @@ } ], "source": [ - "print('number of atoms: {:d}\\n'.format(atoms.get_number_of_atoms()),\n", - " 'total_energy: {:.4f} eV\\n'.format(atoms.get_total_energy()),\n", - " 'cell voluem: {:.4f} A^3\\n'.format(atoms.get_volume()),\n", - " 'periodic boundary: {}'.format(atoms.get_pbc()))" + "print(\n", + " f\"number of atoms: {atoms.get_number_of_atoms():d}\\n\",\n", + " f\"total_energy: {atoms.get_total_energy():.4f} eV\\n\",\n", + " f\"cell voluem: {atoms.get_volume():.4f} A^3\\n\",\n", + " f\"periodic boundary: {atoms.get_pbc()}\",\n", + ")" ] }, { @@ -282,12 +276,15 @@ ], "source": [ "def gb_energy(total_energy, n_atoms, area):\n", - " \n", " eV = 1.6021766208e-19\n", - " Angstrom = 1.e-10\n", + " Angstrom = 1.0e-10\n", "\n", - " return 1 / (2 * area * Angstrom**2) * \\\n", - " (total_energy - Fe_BCC_energy_per_atom * n_atoms) * eV\n", + " return (\n", + " 1\n", + " / (2 * area * Angstrom**2)\n", + " * (total_energy - Fe_BCC_energy_per_atom * n_atoms)\n", + " * eV\n", + " )\n", "\n", "\n", "cell = atoms.get_cell_lengths_and_angles()\n", @@ -295,8 +292,7 @@ "\n", "E_gb = gb_energy(atoms.get_total_energy(), len(atoms), area)\n", "\n", - "print('energy of grain boundary: {:.4f} J/m^2\\n'.format(E_gb),\n", - " 'area: {:.4f} A^2'.format(area))" + "print(f\"energy of grain boundary: {E_gb:.4f} J/m^2\\n\", f\"area: {area:.4f} A^2\")" ] }, { @@ -391,7 +387,7 @@ "metadata": {}, "outputs": [], "source": [ - "coord_num = CoordinationNumbers(atoms, rCut=0.93*Fe_BCC_lattice_constant) " + "coord_num = CoordinationNumbers(atoms, rCut=0.93 * Fe_BCC_lattice_constant)" ] }, { @@ -421,12 +417,11 @@ }, "outputs": [], "source": [ - "nblist = FullNeighborList(0.93*Fe_BCC_lattice_constant, atoms=atoms)\n", + "nblist = FullNeighborList(0.93 * Fe_BCC_lattice_constant, atoms=atoms)\n", "\n", "coord_num = np.zeros(len(atoms))\n", "for i, (atom, neighbor) in enumerate(zip(atoms, nblist)):\n", - " coord_num[i] = len(neighbor)\n", - " " + " coord_num[i] = len(neighbor)" ] }, { @@ -442,12 +437,12 @@ "metadata": {}, "outputs": [], "source": [ - "fig, ax = plt.subplots(figsize=(12,5))\n", + "fig, ax = plt.subplots(figsize=(12, 5))\n", "\n", "bins = np.arange(-0.5, max(coord_num) + 1)\n", "ax.hist(coord_num, bins)\n", "\n", - "ax.set_title('Coordination number')\n", + "ax.set_title(\"Coordination number\")\n", "ax.set_xlabel(\"Number of nearest neighbors\")\n", "ax.set_ylabel(\"Number of atoms\")\n", "plt.show()" @@ -491,8 +486,8 @@ "for r_scale in r_scale_list:\n", " coord_num = CoordinationNumbers(atoms=atoms, rCut=r_scale * Fe_BCC_lattice_constant)\n", " avg_neihbour.append(np.average(coord_num))\n", - " \n", - "fig, ax = plt.subplots(figsize=(12,5))\n", + "\n", + "fig, ax = plt.subplots(figsize=(12, 5))\n", "\n", "ax.plot(r_scale_list, avg_neihbour)\n", "ax.set_xlabel(\"r_scale\")\n", @@ -568,40 +563,39 @@ "outputs": [], "source": [ "def CentroSymmentryParameter(atoms, n):\n", - " \n", " atoms.wrap()\n", " coordinates = atoms.get_positions()\n", " box = np.diag(atoms.get_cell())\n", - " \n", + "\n", " # Building the nearest neighbor list\n", - " \n", + "\n", " nblist = cKDTree(coordinates, boxsize=box)\n", - " distances, nblist = nblist.query(coordinates, k=n+1)\n", + " distances, nblist = nblist.query(coordinates, k=n + 1)\n", "\n", - " \n", - " csp=np.zeros(len(atoms))\n", + " csp = np.zeros(len(atoms))\n", " for neighbors in nblist:\n", " atom_index = neighbors[0]\n", " n_indecies = neighbors[1:]\n", " N = len(n_indecies)\n", - " \n", + "\n", " r = atoms.positions[n_indecies] - atoms.positions[atom_index]\n", - " \n", - " # fixing periodic boundary \n", + "\n", + " # fixing periodic boundary\n", " r = np.where(abs(r) < abs(r - box), r, r - box)\n", " r = np.where(abs(r) < abs(r + box), r, r + box)\n", - " \n", + "\n", " pairs = []\n", " for i, r_i in enumerate(r):\n", - " pairs.append(np.linalg.norm(r_i + r[i+1:,:], axis=1))\n", + " pairs.append(np.linalg.norm(r_i + r[i + 1 :, :], axis=1))\n", "\n", " pairs = np.hstack(pairs)\n", "\n", " pairs.sort()\n", - " csp[atom_index] = np.sum(pairs[:N//2])\n", + " csp[atom_index] = np.sum(pairs[: N // 2])\n", "\n", " return csp\n", "\n", + "\n", "csp = CentroSymmentryParameter(atoms, n=8)" ] }, @@ -620,11 +614,11 @@ }, "outputs": [], "source": [ - "fig, ax = plt.subplots(figsize=(12,5))\n", + "fig, ax = plt.subplots(figsize=(12, 5))\n", "\n", "ax.hist(csp, bins=20)\n", "\n", - "ax.set_title('Distribution of Centro Symmentry Parameter')\n", + "ax.set_title(\"Distribution of Centro Symmentry Parameter\")\n", "ax.set_xlabel(\"Centro Symmentry Parameter\")\n", "ax.set_ylabel(\"Number of atoms\")\n", "# ax.set_yscale('symlog')\n", @@ -669,14 +663,14 @@ "outputs": [], "source": [ "# help(PTM)\n", - "# Imprtant key names of returned data: \n", - "# 'structure': The local crystal structure around atom i, if any. \n", + "# Imprtant key names of returned data:\n", + "# 'structure': The local crystal structure around atom i, if any.\n", "# 0 = none; 1 = FCC; 2 = HCP; 3 = BCC; 4 = Icosahedral; 5 = SC.\n", "# 'rmsd': The RMSD error in the fitting to the template, or INF if no structure was identified.\n", - "# 'scale': The average distance to the nearest neighbors for structures 1-4; \n", - "# or the average distance to nearest and next-nearest neighbors for structure 5; \n", + "# 'scale': The average distance to the nearest neighbors for structures 1-4;\n", + "# or the average distance to nearest and next-nearest neighbors for structure 5;\n", "# or INF if no structure was identified.\n", - "# 'orientation': The orientation of the crystal lattice, expressed as a unit quaternion. \n", + "# 'orientation': The orientation of the crystal lattice, expressed as a unit quaternion.\n", "# If no structure was found, the illegal value (0, 0, 0, 0) is returned.\n", "\n", "# ?PTM" @@ -690,7 +684,7 @@ }, "outputs": [], "source": [ - "ptm = PTM(atoms=atoms, cutoff=8.)\n", + "ptm = PTM(atoms=atoms, cutoff=8.0)\n", "ptm.keys()" ] }, @@ -705,24 +699,24 @@ }, "outputs": [], "source": [ - "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16,5))\n", + "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 5))\n", "\n", - "ax1.hist(ptm['structure'], range(0, 7))\n", + "ax1.hist(ptm[\"structure\"], range(0, 7))\n", "# ax1.set_yscale('symlog')\n", - "ax1.set_xticks([x + .5 for x in range(6)])\n", - "ax1.set_xticklabels(['None', 'FCC', 'HCP', 'BCC', 'Ic', 'SC'])\n", - "ax1.set_ylabel('# of atoms')\n", + "ax1.set_xticks([x + 0.5 for x in range(6)])\n", + "ax1.set_xticklabels([\"None\", \"FCC\", \"HCP\", \"BCC\", \"Ic\", \"SC\"])\n", + "ax1.set_ylabel(\"# of atoms\")\n", "\n", - "ax1.set_title('Structure')\n", + "ax1.set_title(\"Structure\")\n", "\n", - "ax2.hist(ptm['scale'])\n", + "ax2.hist(ptm[\"scale\"])\n", "# ax2.set_yscale('symlog')\n", - "ax1.set_xlabel('distance scale')\n", - "ax1.set_ylabel('# of atoms')\n", - "ax2.set_title('Distribution of distance scale')\n", + "ax1.set_xlabel(\"distance scale\")\n", + "ax1.set_ylabel(\"# of atoms\")\n", + "ax2.set_title(\"Distribution of distance scale\")\n", "\n", "\n", - "plt.show()\n" + "plt.show()" ] }, { @@ -733,7 +727,7 @@ }, "outputs": [], "source": [ - "view = AtomViewer(atoms, ptm['scale'])\n", + "view = AtomViewer(atoms, ptm[\"scale\"])\n", "\n", "view.view.center()\n", "view.gui" @@ -802,15 +796,15 @@ "source": [ "# Feature space\n", "# X = np.hstack([ptm['scale'][:, np.newaxis], csp[:, np.newaxis]])\n", - "X = np.hstack([ptm['orientation'], csp[:, np.newaxis]])\n", + "X = np.hstack([ptm[\"orientation\"], csp[:, np.newaxis]])\n", "# X = ptm['orientation']\n", "\n", "# Number of clusters\n", - "n_clusters=10\n", + "n_clusters = 10\n", "\n", "# Clustering method\n", "pred = KMeans(n_clusters=n_clusters).fit_predict(X)\n", - "# pred = GaussianMixture(n_components=n_clusters, covariance_type='full').fit(X).predict(X)\n" + "# pred = GaussianMixture(n_components=n_clusters, covariance_type='full').fit(X).predict(X)" ] }, { @@ -829,10 +823,10 @@ "outputs": [], "source": [ "count = np.bincount(pred)\n", - "fig, ax = plt.subplots(figsize=(12,5))\n", + "fig, ax = plt.subplots(figsize=(12, 5))\n", "\n", - "ax.bar(range(n_clusters),count)\n", - "ax.set_title('Histogram')\n", + "ax.bar(range(n_clusters), count)\n", + "ax.set_title(\"Histogram\")\n", "ax.set_xlabel(\"classes\")\n", "ax.set_ylabel(\"# of atoms\")\n", "\n", @@ -875,12 +869,14 @@ "source": [ "# select the 2 largest\n", "index = np.argsort(count)[-2:]\n", - "orientation0 = np.average(ptm['orientation'][pred==index[0], :], axis=0)\n", - "orientation1 = np.average(ptm['orientation'][pred==index[1], :], axis=0)\n", + "orientation0 = np.average(ptm[\"orientation\"][pred == index[0], :], axis=0)\n", + "orientation1 = np.average(ptm[\"orientation\"][pred == index[1], :], axis=0)\n", "\n", - "angle_difference = 2 * np.arccos(np.dot(orientation0, np.conj(orientation1)))*180/np.pi\n", + "angle_difference = (\n", + " 2 * np.arccos(np.dot(orientation0, np.conj(orientation1))) * 180 / np.pi\n", + ")\n", "\n", - "print('The angle difference between the grains is: {:.3f} degree'.format(angle_difference))" + "print(f\"The angle difference between the grains is: {angle_difference:.3f} degree\")" ] }, { @@ -900,28 +896,30 @@ "source": [ "def diff_angle(filepath):\n", " atoms = read(str(filepath))\n", - " \n", + "\n", " cell = atoms.get_cell_lengths_and_angles()\n", " area = cell[0] * cell[1]\n", "\n", " E_gb = gb_energy(atoms.get_total_energy(), len(atoms), area)\n", "\n", - " ptm = PTM(atoms=atoms, cutoff=8.)\n", + " ptm = PTM(atoms=atoms, cutoff=8.0)\n", "\n", " # X = ptm['orientation']\n", - " X = np.hstack([ptm['orientation'], ptm['scale'][:, np.newaxis]])\n", - " \n", - " n_clusters=20\n", + " X = np.hstack([ptm[\"orientation\"], ptm[\"scale\"][:, np.newaxis]])\n", + "\n", + " n_clusters = 20\n", "\n", " pred = KMeans(n_clusters=n_clusters).fit_predict(X)\n", - " \n", + "\n", " count = np.bincount(pred)\n", " index = np.argsort(count)[-2:]\n", - " \n", - " orientation0 = np.average(ptm['orientation'][pred==index[0], :], axis=0)\n", - " orientation1 = np.average(ptm['orientation'][pred==index[1], :], axis=0)\n", "\n", - " angle_difference = 2 * np.arccos(np.dot(orientation0, np.conj(orientation1)))*180/np.pi\n", + " orientation0 = np.average(ptm[\"orientation\"][pred == index[0], :], axis=0)\n", + " orientation1 = np.average(ptm[\"orientation\"][pred == index[1], :], axis=0)\n", + "\n", + " angle_difference = (\n", + " 2 * np.arccos(np.dot(orientation0, np.conj(orientation1))) * 180 / np.pi\n", + " )\n", "\n", " return angle_difference, E_gb" ] @@ -934,11 +932,11 @@ }, "outputs": [], "source": [ - "filelist = list(dir_path.glob('**/*.xyz'))\n", + "filelist = list(dir_path.glob(\"**/*.xyz\"))\n", "result = np.zeros((len(filelist), 2))\n", "for i, file in enumerate(filelist):\n", - " result[i,:] = np.array(diff_angle(file))\n", - " print('{:6.3f} deg {:7.4f} J/m^2 {}'.format(result[i,0], result[i,1], file.name))" + " result[i, :] = np.array(diff_angle(file))\n", + " print(f\"{result[i,0]:6.3f} deg {result[i,1]:7.4f} J/m^2 {file.name}\")" ] }, { @@ -956,9 +954,9 @@ }, "outputs": [], "source": [ - "result_sorted = result[result[:,0].argsort(),:]\n", - "fig, ax = plt.subplots(figsize=(10,6))\n", - "ax.plot(result_sorted[:,0], result_sorted[:,1], 'o-')\n", + "result_sorted = result[result[:, 0].argsort(), :]\n", + "fig, ax = plt.subplots(figsize=(10, 6))\n", + "ax.plot(result_sorted[:, 0], result_sorted[:, 1], \"o-\")\n", "plt.show()" ] }, diff --git a/tutorials/scripts/Preprocess.py b/tutorials/scripts/Preprocess.py index 49a315d8..3d702b95 100644 --- a/tutorials/scripts/Preprocess.py +++ b/tutorials/scripts/Preprocess.py @@ -1,17 +1,14 @@ -from pathlib import Path -from pprint import pprint import json +from pathlib import Path + from ase.io import read, write -from ase.geometry import crystal_structure_from_cell -import numpy as np # import numpy.linalg as la - import matplotlib.pyplot as plt -from scipy.interpolate import interp1d +import numpy as np -class Calculation(object): +class Calculation: def __init__(self, *args, **kwargs): self.filepath = kwargs.pop("filepath", None) self.parameters = kwargs @@ -64,7 +61,6 @@ def from_path(cls, path: Path): if __name__ == "__main__": - # Read grain boundary database dirpath = Path("../GB_alphaFe_001") @@ -88,7 +84,6 @@ def from_path(cls, path: Path): angles, energies = [], [] for calc in sorted(calculations["tilt"], key=lambda item: item.parameters["angle"]): - # E_gb = calc.parameters.get('E_gb', None) # # if E_gb is None: diff --git a/tutorials/scripts/Reader.py b/tutorials/scripts/Reader.py index 83d6be52..64b32e65 100644 --- a/tutorials/scripts/Reader.py +++ b/tutorials/scripts/Reader.py @@ -1,17 +1,14 @@ -from pathlib import Path -from pprint import pprint import json -from ase.io import read, write -from ase.geometry import crystal_structure_from_cell -import numpy as np +from pathlib import Path -# import numpy.linalg as la +from ase.io import read +# import numpy.linalg as la import matplotlib.pyplot as plt -from scipy.interpolate import interp1d +import numpy as np -class Calculation(object): +class Calculation: def __init__(self, *args, **kwargs): self.filepath = kwargs.pop("filepath", None) self.parameters = kwargs @@ -64,7 +61,6 @@ def from_path(cls, path: Path, index=-1): if __name__ == "__main__": - # Read grain boundary database dirpath = Path("../GB_alphaFe_001") diff --git a/tutorials/scripts/Visualise.py b/tutorials/scripts/Visualise.py index 308e76fb..6a4b3ac7 100644 --- a/tutorials/scripts/Visualise.py +++ b/tutorials/scripts/Visualise.py @@ -1,10 +1,8 @@ -import io import uuid -from nglview import register_backend, Structure -from ipywidgets import Dropdown, FloatSlider, IntSlider, HBox, VBox, Output - +from ipywidgets import Dropdown, FloatSlider, Output, VBox import matplotlib.pyplot as plt +from nglview import Structure, register_backend import numpy as np @@ -69,7 +67,7 @@ def ViewStructure(atoms): return view -class AtomViewer(object): +class AtomViewer: def __init__(self, atoms, data=[], xsize=1000, ysize=500): self.view = self._init_nglview(atoms, data, xsize, ysize) @@ -114,7 +112,7 @@ def _init_nglview(atoms, data, xsize, ysize): view._remote_call( "setSize", target="Widget", - args=["{:d}px".format(xsize), "{:d}px".format(ysize)], + args=[f"{xsize:d}px", f"{ysize:d}px"], ) data = np.max(data) - data diff --git a/tutorials/scripts/Visualise_quip.py b/tutorials/scripts/Visualise_quip.py index 308e76fb..6a4b3ac7 100644 --- a/tutorials/scripts/Visualise_quip.py +++ b/tutorials/scripts/Visualise_quip.py @@ -1,10 +1,8 @@ -import io import uuid -from nglview import register_backend, Structure -from ipywidgets import Dropdown, FloatSlider, IntSlider, HBox, VBox, Output - +from ipywidgets import Dropdown, FloatSlider, Output, VBox import matplotlib.pyplot as plt +from nglview import Structure, register_backend import numpy as np @@ -69,7 +67,7 @@ def ViewStructure(atoms): return view -class AtomViewer(object): +class AtomViewer: def __init__(self, atoms, data=[], xsize=1000, ysize=500): self.view = self._init_nglview(atoms, data, xsize, ysize) @@ -114,7 +112,7 @@ def _init_nglview(atoms, data, xsize, ysize): view._remote_call( "setSize", target="Widget", - args=["{:d}px".format(xsize), "{:d}px".format(ysize)], + args=[f"{xsize:d}px", f"{ysize:d}px"], ) data = np.max(data) - data diff --git a/tutorials/test_db.py b/tutorials/test_db.py index 8a8f34f8..87d6b27a 100644 --- a/tutorials/test_db.py +++ b/tutorials/test_db.py @@ -1,5 +1,6 @@ from pathlib import Path -from ase.io import iread, read + +from ase.io import read from abcd import ABCD diff --git a/tutorials/test_upload.py b/tutorials/test_upload.py index c5fe8e91..1717a591 100644 --- a/tutorials/test_upload.py +++ b/tutorials/test_upload.py @@ -1,6 +1,3 @@ -import numpy as np -from collections import Counter - from abcd import ABCD if __name__ == "__main__":