Skip to content

Commit

Permalink
Add low-level Python-C support for arbirary derived state
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher committed Nov 5, 2024
1 parent 57594f4 commit feabb65
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 8 deletions.
35 changes: 30 additions & 5 deletions _tsinfermodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,17 @@ uint64_PyArray_converter(PyObject *in, PyObject **out)
return NPY_SUCCEED;
}

static int
int8_PyArray_converter(PyObject *in, PyObject **out)
{
PyObject *ret = PyArray_FROMANY(in, NPY_INT8, 1, 1, NPY_ARRAY_IN_ARRAY);
if (ret == NULL) {
return NPY_FAIL;
}
*out = ret;
return NPY_SUCCEED;
}

/*===================================================================
* AncestorBuilder
*===================================================================
Expand Down Expand Up @@ -429,30 +440,43 @@ TreeSequenceBuilder_init(TreeSequenceBuilder *self, PyObject *args, PyObject *kw
{
int ret = -1;
int err;
static char *kwlist[] = {"num_alleles", "max_nodes", "max_edges", NULL};
static char *kwlist[] = {"num_alleles", "max_nodes", "max_edges", "derived_state",
NULL};
PyArrayObject *num_alleles = NULL;
PyArrayObject *derived_state = NULL;
int8_t *derived_state_data = NULL;
unsigned long max_nodes = 1024;
unsigned long max_edges = 1024;
unsigned long num_sites;
npy_intp *shape;
int flags = 0;

self->tree_sequence_builder = NULL;
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&|kk", kwlist,
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&|kkO&", kwlist,
uint64_PyArray_converter, &num_alleles,
&max_nodes, &max_edges)) {
&max_nodes, &max_edges,
int8_PyArray_converter, &derived_state)) {
goto out;
}
shape = PyArray_DIMS(num_alleles);
num_sites = shape[0];

if (derived_state != NULL) {
shape = PyArray_DIMS(derived_state);
if (shape[0] != (npy_intp) num_sites) {
PyErr_SetString(PyExc_ValueError, "derived state array wrong size");
goto out;
}
derived_state_data = PyArray_DATA(derived_state);
}
self->tree_sequence_builder = PyMem_Malloc(sizeof(tree_sequence_builder_t));
if (self->tree_sequence_builder == NULL) {
PyErr_NoMemory();
goto out;
}
err = tree_sequence_builder_alloc(self->tree_sequence_builder,
num_sites, PyArray_DATA(num_alleles),
num_sites,
PyArray_DATA(num_alleles),
derived_state_data,
max_nodes, max_edges, flags);
if (err != 0) {
handle_library_error(err);
Expand All @@ -461,6 +485,7 @@ TreeSequenceBuilder_init(TreeSequenceBuilder *self, PyObject *args, PyObject *kw
ret = 0;
out:
Py_XDECREF(num_alleles);
Py_XDECREF(derived_state);
return ret;
}

Expand Down
21 changes: 18 additions & 3 deletions tests/test_low_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,16 +88,31 @@ class TestTreeSequenceBuilder:
def test_init(self):
with pytest.raises(TypeError):
_tsinfer.TreeSequenceBuilder()
for bad_array in [None, "serf", [[], []], ["asdf"], {}]:
with pytest.raises(ValueError):
_tsinfer.TreeSequenceBuilder(bad_array)

for bad_type in [None, "sdf", {}]:
with pytest.raises(TypeError):
_tsinfer.TreeSequenceBuilder([2], max_nodes=bad_type)
with pytest.raises(TypeError):
_tsinfer.TreeSequenceBuilder([2], max_edges=bad_type)

def test_bad_num_alleles(self):
for bad_array in [None, "serf", [[], []], ["asdf"], {}]:
with pytest.raises(ValueError):
_tsinfer.TreeSequenceBuilder(bad_array)
with pytest.raises(_tsinfer.LibraryError, match="number of alleles"):
_tsinfer.TreeSequenceBuilder([1000])

def test_bad_derived_state(self):
for bad_array in [None, "serf", [[], []], ["asdf"], {}]:
with pytest.raises(ValueError):
_tsinfer.TreeSequenceBuilder([2], derived_state=bad_array)
with pytest.raises(_tsinfer.LibraryError, match="Bad derived state"):
for bad_derived_state in [-1, 2, 100]:
_tsinfer.TreeSequenceBuilder([2], derived_state=[bad_derived_state])

with pytest.raises(ValueError, match="derived state array wrong size"):
_tsinfer.TreeSequenceBuilder([2, 2, 2], derived_state=[])


class TestAncestorBuilder:
"""
Expand Down

0 comments on commit feabb65

Please sign in to comment.