diff --git a/proteinshake/adapter.py b/proteinshake/adapter.py index e69de29b..0a076592 100644 --- a/proteinshake/adapter.py +++ b/proteinshake/adapter.py @@ -0,0 +1,6 @@ +class Adapter: + """ + Downloads raw pdb files and/or meta data from a source and formats it to the shake database schema. + """ + + pass diff --git a/proteinshake/database.py b/proteinshake/database.py index 0021b8e9..4ff4ea5d 100644 --- a/proteinshake/database.py +++ b/proteinshake/database.py @@ -1,13 +1,16 @@ from pathlib import Path -from .collection import Collection class Database: + """ + Spins up a redis database + """ + def __init__(self, storage: Path) -> None: pass def update(self) -> None: pass - def query(self, query: str) -> Collection: + def query(self, query: str): pass diff --git a/proteinshake/framework.py b/proteinshake/framework.py index 75149a3a..4e583dc9 100644 --- a/proteinshake/framework.py +++ b/proteinshake/framework.py @@ -1,3 +1,10 @@ class Framework: + """ + Abstract class for a framework. Used as Mixin with a Transform. + """ + def create_loader(self, iterator): - pass + """ + Creates a framework-specific dataloader from an iterator. + """ + raise NotImplementedError diff --git a/proteinshake/metric.py b/proteinshake/metric.py index c5409bbe..c6898f9d 100644 --- a/proteinshake/metric.py +++ b/proteinshake/metric.py @@ -1,4 +1,6 @@ class Metric: - """For a collection of predictions and target values, return set of performance metrics.,""" + """ + Computes a set of relevant metrics for a task. + """ pass diff --git a/proteinshake/metrics/__init__.py b/proteinshake/metrics/__init__.py index a58168bf..2f3dc291 100644 --- a/proteinshake/metrics/__init__.py +++ b/proteinshake/metrics/__init__.py @@ -1 +1 @@ -from .evaluator import * +from .dummy_metric import * diff --git a/proteinshake/metrics/classification.py b/proteinshake/metrics/classification.py deleted file mode 100644 index e6bcb918..00000000 --- a/proteinshake/metrics/classification.py +++ /dev/null @@ -1,4 +0,0 @@ -class ClassificationEvaluator(Evaluator): - def __call__(self, pred : list, truth: list): - return {'accuracy': sklearn.accuracy(pred, truth)} - pass diff --git a/proteinshake/metrics/dummy_metric.py b/proteinshake/metrics/dummy_metric.py new file mode 100644 index 00000000..cab21a9a --- /dev/null +++ b/proteinshake/metrics/dummy_metric.py @@ -0,0 +1,7 @@ +from proteinshake.metric import Metric +import numpy as np + + +class DummyMetric(Metric): + def __call__(self, y_true, y_pred): + return {"Accuracy": np.random.random()} diff --git a/proteinshake/metrics/regression.py b/proteinshake/metrics/regression.py deleted file mode 100644 index e69de29b..00000000 diff --git a/proteinshake/metrics/retrieval.py b/proteinshake/metrics/retrieval.py deleted file mode 100644 index e69de29b..00000000 diff --git a/proteinshake/representation.py b/proteinshake/representation.py index fa00571c..88bebe69 100644 --- a/proteinshake/representation.py +++ b/proteinshake/representation.py @@ -1,2 +1,6 @@ class Representation: + """ + Abstract class for a representation. Used as Mixin with a Transform. + """ + pass diff --git a/proteinshake/split.py b/proteinshake/split.py index faf6151d..8716c5af 100644 --- a/proteinshake/split.py +++ b/proteinshake/split.py @@ -1,8 +1,17 @@ +from typing import Dict, Iterator + + class Split: """ - Abstract class for selecting train/val/test indices given a dataset. + Abstract class to create data splits from a dataset. """ + def __call__(self, dataset: Iterator) -> Dict[str, Iterator]: + """ + Takes an Xy iterator and returns a dictionary of Xy iterators, where each key denotes the split name (usually 'train', 'test', and 'val'). + """ + raise NotImplementedError + @property def hash(self): return self.__class__.__name__ diff --git a/proteinshake/splits/__init__.py b/proteinshake/splits/__init__.py index f711965d..aef80159 100644 --- a/proteinshake/splits/__init__.py +++ b/proteinshake/splits/__init__.py @@ -1 +1 @@ -from .splitter import * +from .dummy_split import * diff --git a/proteinshake/splits/attribute.py b/proteinshake/splits/attribute.py deleted file mode 100644 index 7e729b58..00000000 --- a/proteinshake/splits/attribute.py +++ /dev/null @@ -1,14 +0,0 @@ -class AttributeSplitter(Splitter): - """ - Compute splits based on an attribute that already exists in the dataset - """ - - def __init__( - self, train_attribute: str, val_attribute: str, test_attribute: str - ) -> None: - self.train_attribute = train_attribute - self.val_attribute = val_attribute - self.test_attribute = test_attribute - - def __call__(self, dataset) -> tuple[list, list, list]: - pass diff --git a/proteinshake/splits/dummy_split.py b/proteinshake/splits/dummy_split.py new file mode 100644 index 00000000..d3a5d5ab --- /dev/null +++ b/proteinshake/splits/dummy_split.py @@ -0,0 +1,13 @@ +from proteinshake.split import Split +import itertools + + +class DummySplit(Split): + def __call__(self, Xy): + train, testval = itertools.tee(Xy) + test, val = itertools.tee(testval) + return { + "train": filter(lambda Xy: Xy[0][0]["split"] == "train", train), + "test": filter(lambda Xy: Xy[0][0]["split"] == "test", test), + "val": filter(lambda Xy: Xy[0][0]["split"] == "val", val), + } diff --git a/proteinshake/splits/pairwise_attribute.py b/proteinshake/splits/pairwise_attribute.py deleted file mode 100644 index f90b9258..00000000 --- a/proteinshake/splits/pairwise_attribute.py +++ /dev/null @@ -1,17 +0,0 @@ -class PairwiseAttributeSplitter(Splitter): - """Compute pairwise splits based on an attribute that already exists in the dataset. - Takes all pairs of train/val/test in the single attribute splitting setting.""" - - def __init__( - self, train_attribute: str, val_attribute: str, test_attribute: str - ) -> None: - self.train_attribute = train_attribute - self.val_attribute = val_attribute - self.test_attribute = test_attribute - - def __call__(self, dataset) -> tuple[list, list, list]: - tmp_splitter = AttributeSplitter( - self.train_attribute, self.val_attribute, self.test_attribute - ) - # compute pairs of indices on the non-paired splits - pass diff --git a/proteinshake/splits/random.py b/proteinshake/splits/random.py deleted file mode 100644 index 6f72bdc2..00000000 --- a/proteinshake/splits/random.py +++ /dev/null @@ -1,20 +0,0 @@ -class MySplitter: - def __init__(self, seed=None) -> None: - self.rng = np.random.rng(seed) - - def fit(self, dataset): - n = len(dataset) - train, test_val = train_test_split( - np.arange(n), test_size=0.2, random_state=self.rng.random() - ) - test, val = train_test_split( - test_val, test_size=0.5, random_state=self.rng.random() - ) - self.lookup = { - **{index: "train" for index in train}, - **{index: "test" for index in test}, - **{index: "val" for index in val}, - } - - def assign(self, index, protein): - return self.lookup[index] diff --git a/proteinshake/target.py b/proteinshake/target.py index dcd5fe53..922cdb41 100644 --- a/proteinshake/target.py +++ b/proteinshake/target.py @@ -1,6 +1,13 @@ +from typing import Dict, Iterator + + class Target: - """Returns the attribute to predict for a single instance, given arbitrary inputs. - Different tasks will have target computations on different types and numbers of entitites. + """ + Abstract class for reshaping a dataset into the correct data-target structure for a task. """ - pass + def __call__(self, dataset: Iterator[dict]) -> Dict[str, Iterator]: + """ + Takes a dataset iterator and returns an Xy iterator, whose elements are ((X1,X2,...), y) pairs of data tuples and targets. + """ + raise NotImplementedError diff --git a/proteinshake/targets/__init__.py b/proteinshake/targets/__init__.py index ba202df3..d343d6c3 100644 --- a/proteinshake/targets/__init__.py +++ b/proteinshake/targets/__init__.py @@ -1 +1 @@ -from .target import Target +from .attribute_target import AttributeTarget diff --git a/proteinshake/targets/attribute_target.py b/proteinshake/targets/attribute_target.py new file mode 100644 index 00000000..2acaf842 --- /dev/null +++ b/proteinshake/targets/attribute_target.py @@ -0,0 +1,10 @@ +from proteinshake.target import Target + + +class AttributeTarget(Target): + def __init__(self, attribute) -> None: + super().__init__() + self.attribute = attribute + + def __call__(self, dataset): + return (((p,), p[self.attribute]) for p in dataset) diff --git a/proteinshake/targets/pairwise_property_target.py b/proteinshake/targets/pairwise_property_target.py deleted file mode 100644 index f9bea254..00000000 --- a/proteinshake/targets/pairwise_property_target.py +++ /dev/null @@ -1,9 +0,0 @@ -class PairwisePropertyTarget(Target): - def __init__(self, attribute: str, resolution: str, metadata: Any) -> None: - self.attribute = attribute - self.resolution = resolution - self.metadata = metadata - - def __call__(self, entity_1, entity_2) -> int: - return self.metadata[entity_1.ID][entity_2.ID] - pass diff --git a/proteinshake/targets/property_target.py b/proteinshake/targets/property_target.py deleted file mode 100644 index 1ab23afd..00000000 --- a/proteinshake/targets/property_target.py +++ /dev/null @@ -1,8 +0,0 @@ -class PropertyTarget(Target): - def __init__(self, attribute: str, resolution: str) -> None: - self.attribute = attribute - self.resolution = resolution - - def __call__(self, entity): - return entity[self.resolution][self.attribute] - pass diff --git a/proteinshake/task.py b/proteinshake/task.py index 6305d86b..78dd5786 100644 --- a/proteinshake/task.py +++ b/proteinshake/task.py @@ -6,7 +6,7 @@ from proteinshake.split import Split from proteinshake.target import Target from proteinshake.metric import Metric -from proteinshake.transform import Transform, Compose +from proteinshake.transform import Transform, Compose, IdentityTransform from proteinshake.util import amino_acid_alphabet, sharded, save_shards, load, warn @@ -15,7 +15,7 @@ class Task: split: Split = None target: Target = None metrics: Metric = None - augmentation: Transform = None + augmentation: Transform = IdentityTransform def __init__( self, @@ -46,6 +46,7 @@ def __init__( @property def proteins(self): # return dataset iterator + # this is a dummy for now. It will load a dataset from file in the future. rng = np.random.default_rng(42) return ( { @@ -62,7 +63,7 @@ def proteins(self): def transform(self, *transforms) -> None: Xy = self.target(self.proteins) - partitions = self.split(Xy) # returns dict of generators[(X,...),y] + partitions = self.split(Xy) self.transform = Compose(*[self.augmentation, *transforms]) # cache from here self.transform.fit(partitions["train"]) @@ -73,7 +74,7 @@ def transform(self, *transforms) -> None: ) save_shards( data_transformed, - self.root / self.split.hash / self.transform.hash / "shards", + self.root / self.split.hash / name / self.transform.hash / "shards", ) setattr(self, f"{name}_loader", partial(self.loader, split=name)) return self @@ -87,7 +88,7 @@ def loader( **kwargs, ): rng = np.random.default_rng(random_seed) - path = self.root / self.split.hash / self.transform.hash / "shards" + path = self.root / self.split.hash / split / self.transform.hash / "shards" shard_index = load(path / "index.npy") if self.shard_size % batch_size != 0 and batch_size % self.shard_size != 0: warn( diff --git a/proteinshake/tasks/__init__.py b/proteinshake/tasks/__init__.py new file mode 100644 index 00000000..99f6d4b4 --- /dev/null +++ b/proteinshake/tasks/__init__.py @@ -0,0 +1 @@ +from .dummy_task import DummyTask diff --git a/proteinshake/tasks/dummy_task.py b/proteinshake/tasks/dummy_task.py new file mode 100644 index 00000000..a3066791 --- /dev/null +++ b/proteinshake/tasks/dummy_task.py @@ -0,0 +1,11 @@ +from proteinshake.task import Task +from proteinshake.metrics import DummyMetric +from proteinshake.targets import AttributeTarget +from proteinshake.splits import DummySplit + + +class DummyTask(Task): + dataset = "test" + split = DummySplit + target = AttributeTarget + metrics = DummyMetric diff --git a/proteinshake/tasks/gene_ontology_classification.py b/proteinshake/tasks/gene_ontology_classification.py deleted file mode 100644 index 8b137891..00000000 --- a/proteinshake/tasks/gene_ontology_classification.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/proteinshake/transform.py b/proteinshake/transform.py index 5eb8cb55..8faa7865 100644 --- a/proteinshake/transform.py +++ b/proteinshake/transform.py @@ -3,6 +3,10 @@ class BaseTransform: + """ + Abstract class for transforms. A transform can be stochastic or deterministic, which decides whether the transformed result can be precomputed and saved to disk (deterministic), or if it needs to be computed when retrieving a data item (stochastic). Transforms generally take a batch of Xy tuples, some subclasses exist that facilitate reshaping (see below). Transforms can be fit beforehand (on the 'train' partition). + """ + stochastic = False def __call__(self, Xy): @@ -55,7 +59,16 @@ def inverse_transform(self, y): return y +class IdentityTransform(BaseTransform): + def __call__(self, Xy): + return Xy + + class Compose: + """ + Composes multiple transforms into one object. Takes care of splitting the deterministic and stochastic part, as well as storing the framework create_dataloader method. + """ + def __init__(self, *transforms): self.transforms = transforms self.deterministic_transforms = [] diff --git a/tests/task.py b/tests/task.py index 403a2363..e180f7f1 100644 --- a/tests/task.py +++ b/tests/task.py @@ -1,48 +1,11 @@ import unittest -import numpy as np -import itertools -from proteinshake.metric import Metric -from proteinshake.target import Target -from proteinshake.split import Split -from proteinshake.task import Task +from proteinshake.tasks import DummyTask from proteinshake.transform import * from proteinshake.transforms import * class TestTask(unittest.TestCase): def test_task(self): - # CONTRIBUTOR - class MyTarget(Target): - def __call__(self, dataset): - return (((p,), p["label"]) for p in dataset) - - class MyMetric(Metric): - def __call__(self, y_true, y_pred): - return {"Accuracy": np.random.random()} - - class MySplit(Split): - def __call__(self, Xy): - # this implementation looks a bit inefficient - train, testval = itertools.tee(Xy) - test, val = itertools.tee(testval) - return { - "train": filter(lambda Xy: Xy[0][0]["split"] == "train", train), - "test": filter(lambda Xy: Xy[0][0]["split"] == "test", test), - "val": filter(lambda Xy: Xy[0][0]["split"] == "val", val), - } - - class MyAugmentation(Transform): - def transform(self, X): - return X - - class MyTask(Task): - dataset = "test" - split = MySplit - target = MyTarget - metrics = MyMetric - augmentation = MyAugmentation - - # END USER class MyLabelTransform(LabelTransform): def transform(self, y): return -y @@ -50,14 +13,14 @@ def transform(self, y): def inverse_transform(self, y): return -y - task = MyTask(shard_size=8).transform( + task = DummyTask(target_kwargs={"attribute": "label"}, shard_size=8).transform( MyLabelTransform(), PointRepresentationTransform(), TorchFrameworkTransform(), ) for epoch in range(5): - for X, y in task.train_loader(batch_size=64, random_seed=0): + for X, y in task.train_loader(batch_size=16, shuffle=True, random_seed=0): print("X", X.shape) print("y", y.shape) break