Skip to content

Commit

Permalink
Check whether types that are passed are iterables (#4)
Browse files Browse the repository at this point in the history
* Fixes #2 and generates empty uns if asked for

* Add extra tests

* Add clarifying documentation
  • Loading branch information
LouiseDck authored Nov 7, 2024
1 parent ee148fe commit 3c1f276
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 22 deletions.
62 changes: 58 additions & 4 deletions src/dummy_anndata/generate_dataset.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import anndata as ad

from collections.abc import Iterable

from .generate_dataframe import generate_dataframe
from .generate_dict import generate_dict, scalar_generators
from .generate_matrix import matrix_generators
from .generate_vector import vector_generators
from .generate_dataframe import generate_dataframe
from .generate_dict import scalar_generators, generate_dict


def generate_dataset(
Expand All @@ -19,8 +21,56 @@ def generate_dataset(
varp_types=None,
uns_types=None,
):
"""
Generate a synthetic AnnData dataset with specified dimensions and data types.
Parameters:
-----------
n_obs : int, optional (default=10)
Number of observations (cells).
n_vars : int, optional (default=20)
Number of variables (genes).
x_type : str, optional (default="generate_integer_matrix")
Type of matrix to generate for the main data matrix `X`. Must be a key in `matrix_generators`.
layer_types : list of str, optional
Types of matrices to generate for layers. Each type must be a key in `matrix_generators`.
obs_types : list of str, optional
Types of vectors to generate for `obs`. Each type must be a key in `vector_generators`.
var_types : list of str, optional
Types of vectors to generate for `var`. Each type must be a key in `vector_generators`.
obsm_types : list of str, optional
Types of matrices or vectors to generate for `obsm`. Each type must be a key in `matrix_generators` or `vector_generators`.
varm_types : list of str, optional
Types of matrices or vectors to generate for `varm`. Each type must be a key in `matrix_generators` or `vector_generators`.
obsp_types : list of str, optional
Types of matrices to generate for `obsp`. Each type must be a key in `matrix_generators`.
varp_types : list of str, optional
Types of matrices to generate for `varp`. Each type must be a key in `matrix_generators`.
uns_types : list of str, optional
Types of data to generate for `uns`. Each type must be a key in `vector_generators`, `matrix_generators`, or `scalar_generators`.
Returns:
--------
ad.AnnData
An AnnData object containing the generated dataset with the specified dimensions and data types.
Raises:
-------
AssertionError
If any of the specified types are not recognized by the corresponding generator dictionaries.
"""

assert x_type in matrix_generators, f"Unknown matrix type: {x_type}"

check_iterable_types(layer_types, "layer_types")
check_iterable_types(obs_types, "obs_types")
check_iterable_types(var_types, "var_types")
check_iterable_types(obsm_types, "obsm_types")
check_iterable_types(varm_types, "varm_types")
check_iterable_types(obsp_types, "obsp_types")
check_iterable_types(varp_types, "varp_types")
check_iterable_types(uns_types, "uns_types")

assert layer_types is None or all(
t in matrix_generators.keys() for t in layer_types
), "Unknown layer type"
Expand Down Expand Up @@ -55,11 +105,11 @@ def generate_dataset(
if obsm_types is None: # obsm_types are all matrices or vectors, except for categoricals and nullables
vector_not_allowed = set(["categorical", "categorical_ordered", "categorical_missing_values", "categorical_ordered_missing_values", \
"nullable_integer_array", "nullable_boolean_array"])
obsm_types = set(matrix_generators.keys()) - vector_not_allowed
obsm_types = set(matrix_generators.keys()) - vector_not_allowed
if varm_types is None: # varm_types are all matrices or vectors, except for categoricals and nullables
vector_not_allowed = set(["categorical", "categorical_ordered", "categorical_missing_values", "categorical_ordered_missing_values", \
"nullable_integer_array", "nullable_boolean_array"])
varm_types = set(matrix_generators.keys()) - vector_not_allowed
varm_types = set(matrix_generators.keys()) - vector_not_allowed
if obsp_types is None: # obsp_types are all matrices
obsp_types = list(matrix_generators.keys())
if varp_types is None: # varp_types are all matrices
Expand Down Expand Up @@ -112,3 +162,7 @@ def generate_dataset(
varp=varp,
uns=uns,
)


def check_iterable_types(iterable_types, name):
assert iterable_types is None or (isinstance(iterable_types, Iterable) and not isinstance(iterable_types, str)), f"{name} should be a non-string iterable type"
27 changes: 11 additions & 16 deletions src/dummy_anndata/generate_dict.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from .generate_vector import vector_generators
from .generate_matrix import matrix_generators

import pandas as pd
import numpy as np

from .generate_matrix import matrix_generators
from .generate_vector import vector_generators

scalar_generators = {
"string": "version",
"char": "a",
Expand Down Expand Up @@ -34,17 +33,13 @@ def generate_type(type, n_rows, n_cols):

def generate_dict(n_rows, n_cols, types=None, nested=True):
if types is None: # types are all vectors and all matrices
scalar_types = list(scalar_generators.keys()) + [
f"scalar_{t}" for t in vector_generators.keys()
]
types = (
scalar_types
+ list(vector_generators.keys())
+ list(matrix_generators.keys())
)

data = {t: generate_type(t, n_rows, n_cols) for t in types}
if nested:
data["nested"] = generate_dict(n_rows, n_cols, types, False)
scalar_types = list(scalar_generators.keys()) + [f"scalar_{t}" for t in vector_generators.keys()]
types = scalar_types + list(vector_generators.keys()) + list(matrix_generators.keys())

data = {}
if types: # types is not empty
data = {t: generate_type(t, n_rows, n_cols) for t in types}
if nested:
data["nested"] = generate_dict(n_rows, n_cols, types, False)

return data
8 changes: 6 additions & 2 deletions tests/test_basic.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import pytest

import dummy_anndata


Expand All @@ -13,3 +11,9 @@ def test_generating_dataset(tmp_path):
dummy = dummy_anndata.generate_dataset()
filename = tmp_path / "dummy.h5ad"
dummy.write_h5ad(filename)


def test_empty_uns():
dummy = dummy_anndata.generate_dataset(uns_types=[])

assert dummy.uns == {}

0 comments on commit 3c1f276

Please sign in to comment.