Skip to content
This repository has been archived by the owner on Feb 26, 2025. It is now read-only.

Commit

Permalink
Provide divergence and convergence stats for EdgePopulation (#242)
Browse files Browse the repository at this point in the history
* provide divergence and convergence stats for EdgePopulation
  • Loading branch information
joni-herttuainen authored Feb 28, 2024
1 parent 270e597 commit d2d61c4
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Improvements

- Both now return ``self.ids(query)`` if ``properties=None``
- ``properties`` is now a keyword argument in ``EdgePopulation.get``
- Added ``EdgePopulation.stats`` with two methods: ``divergence``, ``convergence``


Version v3.0.1
Expand Down
6 changes: 6 additions & 0 deletions bluepysnap/edges/edge_population.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from bluepysnap import query, utils
from bluepysnap.circuit_ids import CircuitEdgeIds, CircuitNodeId
from bluepysnap.circuit_ids_types import IDS_DTYPE, CircuitEdgeId
from bluepysnap.edges.edge_population_stats import StatsHelper
from bluepysnap.exceptions import BluepySnapError
from bluepysnap.sonata_constants import DYNAMICS_PREFIX, ConstContainer, Edge

Expand Down Expand Up @@ -147,6 +148,11 @@ def property_dtypes(self):
"""
return self.get([0], list(self.property_names)).dtypes.sort_index()

@cached_property
def stats(self):
"""Access edge population stats methods."""
return StatsHelper(self)

def container_property_names(self, container):
"""Lists the ConstContainer properties shared with the EdgePopulation.
Expand Down
91 changes: 91 additions & 0 deletions bluepysnap/edges/edge_population_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""EdgePopulation stats helper."""

import numpy as np

from bluepysnap.exceptions import BluepySnapError


class StatsHelper:
"""EdgePopulation stats helper."""

def __init__(self, edge_population):
"""Initialize StatsHelper with an EdgePopulation instance."""
self._edge_population = edge_population

def divergence(self, source, target, by, sample=None):
"""`source` -> `target` divergence.
Calculate the divergence based on number of `"connections"` or `"synapses"` each `source`
cell shares with the cells specified in `target`.
* `connections`: number of unique target cells each source cell shares a connection with
* `synapses`: number of unique synapses between a source cell and its target cells
Args:
source (int/CircuitNodeId/CircuitNodeIds/sequence/str/mapping/None): source nodes
target (int/CircuitNodeId/CircuitNodeIds/sequence/str/mapping/None): target nodes
by (str): 'synapses' or 'connections'
sample (int): if specified, sample size for source group
Returns:
Array with synapse / connection count per each cell from `source` sample
(taking into account only connections to cells in `target`).
"""
by_alternatives = {"synapses", "connections"}
if by not in by_alternatives:
raise BluepySnapError(f"`by` should be one of {by_alternatives}; got: {by}")

source_sample = self._edge_population.source.ids(source, sample=sample)

result = {id_: 0 for id_ in source_sample}
if by == "synapses":
connections = self._edge_population.iter_connections(
source_sample, target, return_synapse_count=True
)
for pre_gid, _, synapse_count in connections:
result[pre_gid] += synapse_count
else:
connections = self._edge_population.iter_connections(source_sample, target)
for pre_gid, _ in connections:
result[pre_gid] += 1

return np.array(list(result.values()))

def convergence(self, source, target, by=None, sample=None):
"""`source` -> `target` convergence.
Calculate the convergence based on number of `"connections"` or `"synapses"` each `target`
cell shares with the cells specified in `source`.
* `connections`: number of unique source cells each target cell shares a connection with
* `synapses`: number of unique synapses between a target cell and its source cells
Args:
source (int/CircuitNodeId/CircuitNodeIds/sequence/str/mapping/None): source nodes
target (int/CircuitNodeId/CircuitNodeIds/sequence/str/mapping/None): target nodes
by (str): 'synapses' or 'connections'
sample (int): if specified, sample size for target group
Returns:
Array with synapse / connection count per each cell from `target` sample
(taking into account only connections from cells in `source`).
"""
by_alternatives = {"synapses", "connections"}
if by not in by_alternatives:
raise BluepySnapError(f"`by` should be one of {by_alternatives}; got: {by}")

target_sample = self._edge_population.target.ids(target, sample=sample)

result = {id_: 0 for id_ in target_sample}
if by == "synapses":
connections = self._edge_population.iter_connections(
source, target_sample, return_synapse_count=True
)
for _, post_gid, synapse_count in connections:
result[post_gid] += synapse_count
else:
connections = self._edge_population.iter_connections(source, target_sample)
for _, post_gid in connections:
result[post_gid] += 1

return np.array(list(result.values()))
2 changes: 2 additions & 0 deletions tests/test_edge_population.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from bluepysnap.circuit import Circuit
from bluepysnap.circuit_ids import CircuitEdgeIds, CircuitNodeIds
from bluepysnap.circuit_ids_types import IDS_DTYPE, CircuitEdgeId, CircuitNodeId
from bluepysnap.edges.edge_population_stats import StatsHelper
from bluepysnap.exceptions import BluepySnapError
from bluepysnap.sonata_constants import DEFAULT_EDGE_TYPE, Edge

Expand Down Expand Up @@ -41,6 +42,7 @@ def test_basic(self):
assert self.test_obj.source.name == "default"
assert self.test_obj.target.name == "default"
assert self.test_obj.size, 4
assert isinstance(self.test_obj.stats, StatsHelper)
assert sorted(self.test_obj.property_names) == sorted(
[
Synapse.SOURCE_NODE_ID,
Expand Down
43 changes: 43 additions & 0 deletions tests/test_edge_population_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from unittest.mock import Mock

import numpy.testing as npt
import pytest

import bluepysnap.edges.edge_population_stats as test_module
from bluepysnap.exceptions import BluepySnapError


class TestStatsHelper:
def setup_method(self):
self.edge_pop = Mock()
self.stats = test_module.StatsHelper(self.edge_pop)

def test_divergence_by_synapses(self):
self.edge_pop.source.ids.return_value = [1, 2]
self.edge_pop.iter_connections.return_value = [(1, None, 42), (1, None, 43)]
actual = self.stats.divergence("pre", "post", by="synapses")
npt.assert_equal(actual, [85, 0])

def test_divergence_by_connections(self):
self.edge_pop.source.ids.return_value = [1, 2]
self.edge_pop.iter_connections.return_value = [(1, None), (1, None)]
actual = self.stats.divergence("pre", "post", by="connections")
npt.assert_equal(actual, [2, 0])

def test_divergence_error(self):
pytest.raises(BluepySnapError, self.stats.divergence, "pre", "post", by="err")

def test_convergence_by_synapses(self):
self.edge_pop.target.ids.return_value = [1, 2]
self.edge_pop.iter_connections.return_value = [(None, 2, 42), (None, 2, 43)]
actual = self.stats.convergence("pre", "post", by="synapses")
npt.assert_equal(actual, [0, 85])

def test_convergence_by_connections(self):
self.edge_pop.target.ids.return_value = [1, 2]
self.edge_pop.iter_connections.return_value = [(None, 2), (None, 2)]
actual = self.stats.convergence("pre", "post", by="connections")
npt.assert_equal(actual, [0, 2])

def test_convergence_error(self):
pytest.raises(BluepySnapError, self.stats.convergence, "pre", "post", by="err")

0 comments on commit d2d61c4

Please sign in to comment.