Skip to content

Commit

Permalink
Refactor backend code
Browse files Browse the repository at this point in the history
  • Loading branch information
ElliottKasoar committed Jan 24, 2024
1 parent edb4ba8 commit 6d37e03
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 324 deletions.
212 changes: 35 additions & 177 deletions abcd/backends/atoms_opensearch.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,25 @@
from __future__ import annotations

import logging

from collections.abc import Generator
from datetime import datetime
from typing import Union, Iterable
import logging
from os import linesep
from datetime import datetime
from collections import Counter
from operator import itemgetter
from pathlib import Path

import numpy as np

from ase import Atoms
from ase.io import iread
from luqum.parser import parser
from luqum.elasticsearch import SchemaAnalyzer, ElasticsearchQueryBuilder
from opensearchpy import OpenSearch, helpers, AuthenticationException, ConnectionTimeout

from abcd.backends import utils
from abcd.database import AbstractABCD
import abcd.errors
from abcd.model import AbstractModel
from abcd.database import AbstractABCD
from abcd.queryset import AbstractQuerySet
from abcd.parsers import extras
from abcd.queryset import AbstractQuerySet

from opensearchpy import OpenSearch, helpers, AuthenticationException, ConnectionTimeout

from luqum.parser import parser
from luqum.elasticsearch import SchemaAnalyzer, ElasticsearchQueryBuilder

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -94,20 +89,19 @@ def __call__(self, query: Union[dict, str, list, None]) -> Union[dict, None]:
Union[dict, None]
The parsed query for OpenSearch.
"""
logger.info("parsed query: {}".format(query))

if not query:
query = self.get_default_query()
logger.info("parsed query: %s", query)

if isinstance(query, dict):
return query
elif isinstance(query, str):
if isinstance(query, str):
tree = parser.parse(query)
return self.query_builder(tree)
elif isinstance(query, list):
if isinstance(query, list):
if len(query) == 0:
return None
elif query[0] is None:
if query[0] is None:
return None
separator = " AND "
joined_query = separator.join(query)
Expand Down Expand Up @@ -244,7 +238,7 @@ class OpenSearchDatabase(AbstractABCD):
----------
client: OpenSearch
OpenSearch client.
db: str
db_name: str
Database name.
index_name: str
OpenSearch index name.
Expand All @@ -256,7 +250,7 @@ def __init__(
self,
host: str = "localhost",
port: int = 9200,
db: str = "abcd",
db_name: str = "abcd",
index_name: str = "atoms",
username: str = "admin",
password: str = "admin",
Expand All @@ -272,7 +266,7 @@ def __init__(
Name of OpenSearch host. Default is `localhost`.
port: int, optional
OpenSearch port. Default is `9200`.
db: str, optional
db_name: str, optional
Label for OpenSearch database. Used only when printing information.
Default is `abcd`.
index_name: str, optional
Expand Down Expand Up @@ -309,15 +303,15 @@ def __init__(

try:
info = self.client.info()
logger.info("DB info: {}".format(info))
logger.info("DB info: %s", info)

except AuthenticationException:
raise abcd.errors.AuthenticationError()

except ConnectionTimeout:
raise abcd.errors.TimeoutError()

self.db = db
self.db = db_name
self.index_name = index_name
self.create()
self.parser = OpenSearchQuery(self.client, self.index_name, analyse_schema)
Expand Down Expand Up @@ -558,8 +552,8 @@ def count(self, query: Union[dict, str, None] = None) -> int:
-------
Count of number of documents.
"""
logger.info("query; {}".format(query))
query = self.parser(query)
logger.info("parsed query: %s", query)
body = {
"query": query,
}
Expand Down Expand Up @@ -686,7 +680,7 @@ def properties(self, query: Union[dict, str, None] = None) -> dict:
for label in derived:
count = res["aggregations"][label]["doc_count"]
if count > 0:
key = label.split("_")[0]
key = label.split("_", maxsplit=1)[0]
if key in properties:
properties[key].append(prop)
else:
Expand All @@ -710,8 +704,6 @@ def get_type_of_property(self, prop: str, category: str) -> str:
-------
Type of the property.
"""
# TODO: Probably it would be nicer to store the type info in the database
# from the beginning.
atoms = self.client.search(
index=self.index_name,
body={"size": 1, "query": {"exists": {"field": prop}}},
Expand All @@ -720,23 +712,19 @@ def get_type_of_property(self, prop: str, category: str) -> str:
data = atoms["hits"]["hits"][0]["_source"][prop]

if category == "arrays":
if type(data[0]) == list:
if isinstance(data[0], list):
return "array({}, N x {})".format(
map_types[type(data[0][0])], len(data[0])
)
else:
return "vector({}, N)".format(map_types[type(data[0])])
return "vector({}, N)".format(map_types[type(data[0])])

if type(data) == list:
if type(data[0]) == list:
if type(data[0][0]) == list:
if isinstance(data, list):
if isinstance(data[0], list):
if isinstance(data[0][0], list):
return "list(list(...)"
else:
return "array({})".format(map_types[type(data[0][0])])
else:
return "vector({})".format(map_types[type(data[0])])
else:
return "scalar({})".format(map_types[type(data)])
return "array({})".format(map_types[type(data[0][0])])
return "vector({})".format(map_types[type(data[0])])
return "scalar({})".format(map_types[type(data)])

def count_properties(self, query: Union[dict, str, None] = None) -> dict:
"""
Expand Down Expand Up @@ -790,8 +778,10 @@ def count_properties(self, query: Union[dict, str, None] = None) -> dict:
if count > 0:
properties[key] = {
"count": count,
"category": label.split("_")[0],
"dtype": self.get_type_of_property(key, label.split("_")[0]),
"category": label.split("_", maxsplit=1)[0],
"dtype": self.get_type_of_property(
key, label.split("_", maxsplit=1)[0]
),
}

return properties
Expand All @@ -807,8 +797,8 @@ def add_property(self, data: dict, query: Union[dict, str, None] = None):
query: Union[dict, str, None]
Query to filter documents to add properties to. Default is `None`.
"""
logger.info("add: data={}, query={}".format(data, query))
query = self.parser(query)
logger.info("add: data=%s, query=%s", data, query)

script_txt = "ctx._source.derived.info_keys.addAll(params.keys);"
for key, val in data.items():
Expand Down Expand Up @@ -843,8 +833,8 @@ def rename_property(
query: Union[dict, str, None]
Query to filter documents to rename property. Default is `None`.
"""
logger.info("rename: query={}, old={}, new={}".format(query, name, new_name))
query = self.parser(query)
logger.info("rename: query=%s, old=%s, new=%s", query, name, new_name)

script_txt = f"if (!ctx._source.containsKey('{new_name}')) {{ "
script_txt += (
Expand Down Expand Up @@ -877,8 +867,8 @@ def delete_property(self, name: str, query: Union[dict, str, None] = None):
query: Union[dict, str, None]
Query to filter documents to have property deleted. Default is `None`.
"""
logger.info("delete: query={}, porperty={}".format(name, query))
query = self.parser(query)
logger.info("delete: query=%s, porperty=%s", name, query)

script_txt = f"if (ctx._source.containsKey('{name}')) {{ "
script_txt += "ctx._source.remove(params.name);"
Expand Down Expand Up @@ -920,7 +910,7 @@ def hist(
query = self.parser(query)

data = self.property(name, query)
return histogram(name, data, **kwargs)
return utils.histogram(name, data, **kwargs)

def __repr__(self):
"""
Expand Down Expand Up @@ -974,138 +964,6 @@ def __exit__(self, exc_type, exc_val, exc_tb):
pass


def histogram(name, data, **kwargs):
if not data:
return None

elif data and isinstance(data, list):
ptype = type(data[0])

if not all(isinstance(x, ptype) for x in data):
print("Mixed type error of the {} property!".format(name))
return None

if ptype == float:
bins = kwargs.get("bins", 10)
return _hist_float(name, data, bins)

elif ptype == int:
bins = kwargs.get("bins", 10)
return _hist_int(name, data, bins)

elif ptype == str:
return _hist_str(name, data, **kwargs)

elif ptype == datetime:
bins = kwargs.get("bins", 10)
return _hist_date(name, data, bins)

else:
print(
"{}: Histogram for list of {} types are not supported!".format(
name, type(data[0])
)
)
logger.info(
"{}: Histogram for list of {} types are not supported!".format(
name, type(data[0])
)
)

else:
logger.info(
"{}: Histogram for {} types are not supported!".format(name, type(data))
)
return None


def _hist_float(name, data, bins=10):
data = np.array(data)
hist, bin_edges = np.histogram(data, bins=bins)

return {
"type": "hist_float",
"name": name,
"bins": bins,
"edges": bin_edges,
"counts": hist,
"min": data.min(),
"max": data.max(),
"median": data.mean(),
"std": data.std(),
"var": data.var(),
}


def _hist_date(name, data, bins=10):
hist_data = np.array([t.timestamp() for t in data])
hist, bin_edges = np.histogram(hist_data, bins=bins)

fromtimestamp = datetime.fromtimestamp

return {
"type": "hist_date",
"name": name,
"bins": bins,
"edges": [fromtimestamp(d) for d in bin_edges],
"counts": hist,
"min": fromtimestamp(hist_data.min()),
"max": fromtimestamp(hist_data.max()),
"median": fromtimestamp(hist_data.mean()),
"std": fromtimestamp(hist_data.std()),
"var": fromtimestamp(hist_data.var()),
}


def _hist_int(name, data, bins=10):
data = np.array(data)
delta = max(data) - min(data) + 1

if bins > delta:
bins = delta

hist, bin_edges = np.histogram(data, bins=bins)

return {
"type": "hist_int",
"name": name,
"bins": bins,
"edges": bin_edges,
"counts": hist,
"min": data.min(),
"max": data.max(),
"median": data.mean(),
"std": data.std(),
"var": data.var(),
}


def _hist_str(name, data, bins=10, truncate=20):
n_unique = len(set(data))

if truncate:
# data = (item[:truncate] for item in data)
data = (
item[:truncate] + "..." if len(item) > truncate else item for item in data
)

data = Counter(data)

if bins:
labels, counts = zip(*sorted(data.items(), key=itemgetter(1, 0), reverse=True))
else:
labels, counts = zip(*data.items())

return {
"type": "hist_str",
"name": name,
"total": sum(data.values()),
"unique": n_unique,
"labels": labels[:bins],
"counts": counts[:bins],
}


if __name__ == "__main__":
db = OpenSearchDatabase(username="admin", password="admin")
print(db.info())
Loading

0 comments on commit 6d37e03

Please sign in to comment.