This repository has been archived by the owner on Feb 26, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Provide divergence and convergence stats for EdgePopulation (#242)
* provide divergence and convergence stats for EdgePopulation
- Loading branch information
1 parent
270e597
commit d2d61c4
Showing
5 changed files
with
143 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
"""EdgePopulation stats helper.""" | ||
|
||
import numpy as np | ||
|
||
from bluepysnap.exceptions import BluepySnapError | ||
|
||
|
||
class StatsHelper: | ||
"""EdgePopulation stats helper.""" | ||
|
||
def __init__(self, edge_population): | ||
"""Initialize StatsHelper with an EdgePopulation instance.""" | ||
self._edge_population = edge_population | ||
|
||
def divergence(self, source, target, by, sample=None): | ||
"""`source` -> `target` divergence. | ||
Calculate the divergence based on number of `"connections"` or `"synapses"` each `source` | ||
cell shares with the cells specified in `target`. | ||
* `connections`: number of unique target cells each source cell shares a connection with | ||
* `synapses`: number of unique synapses between a source cell and its target cells | ||
Args: | ||
source (int/CircuitNodeId/CircuitNodeIds/sequence/str/mapping/None): source nodes | ||
target (int/CircuitNodeId/CircuitNodeIds/sequence/str/mapping/None): target nodes | ||
by (str): 'synapses' or 'connections' | ||
sample (int): if specified, sample size for source group | ||
Returns: | ||
Array with synapse / connection count per each cell from `source` sample | ||
(taking into account only connections to cells in `target`). | ||
""" | ||
by_alternatives = {"synapses", "connections"} | ||
if by not in by_alternatives: | ||
raise BluepySnapError(f"`by` should be one of {by_alternatives}; got: {by}") | ||
|
||
source_sample = self._edge_population.source.ids(source, sample=sample) | ||
|
||
result = {id_: 0 for id_ in source_sample} | ||
if by == "synapses": | ||
connections = self._edge_population.iter_connections( | ||
source_sample, target, return_synapse_count=True | ||
) | ||
for pre_gid, _, synapse_count in connections: | ||
result[pre_gid] += synapse_count | ||
else: | ||
connections = self._edge_population.iter_connections(source_sample, target) | ||
for pre_gid, _ in connections: | ||
result[pre_gid] += 1 | ||
|
||
return np.array(list(result.values())) | ||
|
||
def convergence(self, source, target, by=None, sample=None): | ||
"""`source` -> `target` convergence. | ||
Calculate the convergence based on number of `"connections"` or `"synapses"` each `target` | ||
cell shares with the cells specified in `source`. | ||
* `connections`: number of unique source cells each target cell shares a connection with | ||
* `synapses`: number of unique synapses between a target cell and its source cells | ||
Args: | ||
source (int/CircuitNodeId/CircuitNodeIds/sequence/str/mapping/None): source nodes | ||
target (int/CircuitNodeId/CircuitNodeIds/sequence/str/mapping/None): target nodes | ||
by (str): 'synapses' or 'connections' | ||
sample (int): if specified, sample size for target group | ||
Returns: | ||
Array with synapse / connection count per each cell from `target` sample | ||
(taking into account only connections from cells in `source`). | ||
""" | ||
by_alternatives = {"synapses", "connections"} | ||
if by not in by_alternatives: | ||
raise BluepySnapError(f"`by` should be one of {by_alternatives}; got: {by}") | ||
|
||
target_sample = self._edge_population.target.ids(target, sample=sample) | ||
|
||
result = {id_: 0 for id_ in target_sample} | ||
if by == "synapses": | ||
connections = self._edge_population.iter_connections( | ||
source, target_sample, return_synapse_count=True | ||
) | ||
for _, post_gid, synapse_count in connections: | ||
result[post_gid] += synapse_count | ||
else: | ||
connections = self._edge_population.iter_connections(source, target_sample) | ||
for _, post_gid in connections: | ||
result[post_gid] += 1 | ||
|
||
return np.array(list(result.values())) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
from unittest.mock import Mock | ||
|
||
import numpy.testing as npt | ||
import pytest | ||
|
||
import bluepysnap.edges.edge_population_stats as test_module | ||
from bluepysnap.exceptions import BluepySnapError | ||
|
||
|
||
class TestStatsHelper: | ||
def setup_method(self): | ||
self.edge_pop = Mock() | ||
self.stats = test_module.StatsHelper(self.edge_pop) | ||
|
||
def test_divergence_by_synapses(self): | ||
self.edge_pop.source.ids.return_value = [1, 2] | ||
self.edge_pop.iter_connections.return_value = [(1, None, 42), (1, None, 43)] | ||
actual = self.stats.divergence("pre", "post", by="synapses") | ||
npt.assert_equal(actual, [85, 0]) | ||
|
||
def test_divergence_by_connections(self): | ||
self.edge_pop.source.ids.return_value = [1, 2] | ||
self.edge_pop.iter_connections.return_value = [(1, None), (1, None)] | ||
actual = self.stats.divergence("pre", "post", by="connections") | ||
npt.assert_equal(actual, [2, 0]) | ||
|
||
def test_divergence_error(self): | ||
pytest.raises(BluepySnapError, self.stats.divergence, "pre", "post", by="err") | ||
|
||
def test_convergence_by_synapses(self): | ||
self.edge_pop.target.ids.return_value = [1, 2] | ||
self.edge_pop.iter_connections.return_value = [(None, 2, 42), (None, 2, 43)] | ||
actual = self.stats.convergence("pre", "post", by="synapses") | ||
npt.assert_equal(actual, [0, 85]) | ||
|
||
def test_convergence_by_connections(self): | ||
self.edge_pop.target.ids.return_value = [1, 2] | ||
self.edge_pop.iter_connections.return_value = [(None, 2), (None, 2)] | ||
actual = self.stats.convergence("pre", "post", by="connections") | ||
npt.assert_equal(actual, [0, 2]) | ||
|
||
def test_convergence_error(self): | ||
pytest.raises(BluepySnapError, self.stats.convergence, "pre", "post", by="err") |