Skip to content

Commit

Permalink
Merge pull request #4 from openproblems-bio/add_jsd
Browse files Browse the repository at this point in the history
Add JSD metric
  • Loading branch information
sainirmayi authored Aug 26, 2024
2 parents 3f5442c + e1bb94a commit 6b25bf9
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 0 deletions.
33 changes: 33 additions & 0 deletions src/metrics/jsd/config.vsh.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
__merge__: ../../api/comp_metric.yaml

name: jsd
info:
metrics:
- name: jensen_shannon_distance
label: Jensen-Shannon Distance
summary: "Jensen-Shannon Distance measure the similarity between to probability distributions."
description: |
The Jensen-Shannon Distance, which is the square root of Jensen-Shannon Divergence is a symmetric method for measuring the similarity between two probability distributions. The similarity between the distributions is greater when the Jensen-Shannon distance is closer to zero.
reference: 10.1109/18.61115
documentation_url: https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.distance.jensenshannon.html
repository_url: https://github.com/scipy/scipy/
min: 0
max: 1
maximize: false

resources:
- type: python_script
path: script.py

engines:
- type: docker
image: ghcr.io/openproblems-bio/base_images/python:1.1.0
setup:
- type: python
packages: [numpy, scipy]

runners:
- type: executable
- type: nextflow
directives:
label: [midtime, midmem, midcpu]
34 changes: 34 additions & 0 deletions src/metrics/jsd/script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import anndata as ad
import numpy as np
from scipy.spatial.distance import jensenshannon

## VIASH START
par = {
'input_method': 'resources_test/spatial_decomposition/cxg_mouse_pancreas_atlas/output.h5ad',
'input_solution': 'resources_test/spatial_decomposition/cxg_mouse_pancreas_atlas/solution.h5ad',
'output': 'score.h5ad'
}
meta = {
'name': 'r2'
}
## VIASH END

print('Reading input files', flush=True)
input_method = ad.read_h5ad(par['input_method'])
input_solution = ad.read_h5ad(par['input_solution'])

print('Compute metrics', flush=True)
jsd = jensenshannon(input_solution.obsm['proportions_true'], input_method.obsm['proportions_pred'], axis=0)
uns_metric_ids = [ 'jsd' ]
uns_metric_values = [ np.mean(jsd) ]

print("Write output AnnData to file", flush=True)
output = ad.AnnData(
uns={
'dataset_id': input_method.uns['dataset_id'],
'method_id': input_method.uns['method_id'],
'metric_ids': uns_metric_ids,
'metric_values': uns_metric_values
}
)
output.write_h5ad(par['output'], compression='gzip')

0 comments on commit 6b25bf9

Please sign in to comment.