From ba9ade405797995ae9c8aae8413eefe8bef9e1b0 Mon Sep 17 00:00:00 2001 From: Tim Millar Date: Mon, 1 May 2023 15:09:42 +1200 Subject: [PATCH 1/3] Add skipna option to genomic_relationship #1076 --- sgkit/stats/grm.py | 33 +++++++++++++++--- sgkit/tests/test_grm.py | 74 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 102 insertions(+), 5 deletions(-) diff --git a/sgkit/stats/grm.py b/sgkit/stats/grm.py index 43d4d935f..9aab6f4fd 100644 --- a/sgkit/stats/grm.py +++ b/sgkit/stats/grm.py @@ -10,6 +10,28 @@ from sgkit.utils import conditional_merge_datasets, create_dataset +def _grm_VanRaden( + call_dosage: ArrayLike, + ancestral_frequency: ArrayLike, + ploidy: int, + skipna: bool = False, +): + ancestral_dosage = ancestral_frequency * ploidy + M = call_dosage - ancestral_dosage[:, None] + if skipna: + nans = da.isnan(M) + M0 = da.where(nans, 0, M) + numerator = M0.T @ M0 + AD = ~nans * ancestral_dosage[:, None] + AFC = ~nans * (1 - ancestral_frequency[:, None]) + denominator = AD.T @ AFC + else: + numerator = M.T @ M + denominator = (ancestral_dosage * (1 - ancestral_frequency)).sum() + G = numerator / denominator + return G + + def genomic_relationship( ds: Dataset, *, @@ -17,6 +39,7 @@ def genomic_relationship( estimator: Optional[Literal["VanRaden"]] = None, ancestral_frequency: Optional[Hashable] = None, ploidy: Optional[int] = None, + skipna: bool = False, merge: bool = True, ) -> Dataset: """Compute a genomic relationship matrix (AKA the GRM or G-matrix). @@ -44,6 +67,10 @@ def genomic_relationship( Ploidy level of all samples within the dataset. By default this is inferred from the size of the "ploidy" dimension of the dataset. + skipna + If True, missing (nan) values of 'call_dosage' will be skipped so + that the relationship between each pair of individuals is estimated + using only variants where both samples have non-nan values. merge If True (the default), merge the input dataset and the computed output variables into a single dataset, otherwise return only @@ -134,11 +161,7 @@ def genomic_relationship( raise ValueError( "The ancestral_frequency variable must have one value per variant" ) - ad = af * ploidy - M = cd - ad[:, None] - num = M.T @ M - denom = (ad * (1 - af)).sum() - G = num / denom + G = _grm_VanRaden(cd, af, ploidy=ploidy, skipna=skipna) new_ds = create_dataset( { diff --git a/sgkit/tests/test_grm.py b/sgkit/tests/test_grm.py index 06626262b..a7dff13eb 100644 --- a/sgkit/tests/test_grm.py +++ b/sgkit/tests/test_grm.py @@ -118,6 +118,80 @@ def test_genomic_relationship__VanRaden_AGHmatrix_tetraploid(chunks): np.testing.assert_array_almost_equal(actual, expect) +def test_genomic_relationship__VanRaden_skipna(): + # Test that skipna option skips values in call_dosage + # such that the relationship between each pair of individuals + # is calculated using only the variants where neither sample + # has missing data. + # This should be equivalent to calculating the GRM using + # multiple subsets of the variants and using pairwise + # values from the larges subset of variants that doesn't + # result in a nan value. + nan = np.nan + dosage = np.array( + [ + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 2.0, 0.0], + [1.0, 1.0, 1.0, 2.0, nan, 1.0, 1.0, 0.0, 1.0, 2.0], + [2.0, 2.0, 0.0, 0.0, nan, 1.0, 1.0, 1.0, 0.0, 1.0], + [1.0, 0.0, 0.0, 0.0, nan, 1.0, 1.0, 1.0, 1.0, 0.0], + [1.0, 0.0, 1.0, 1.0, nan, 2.0, 0.0, 1.0, 0.0, 2.0], + [2.0, 1.0, 1.0, 1.0, nan, 1.0, 2.0, nan, 0.0, 1.0], + [2.0, 0.0, 1.0, 1.0, nan, 2.0, 1.0, nan, 1.0, 1.0], + [1.0, 1.0, 1.0, 2.0, nan, 1.0, 2.0, nan, 1.0, 0.0], + [1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, nan, 1.0, 1.0], + [2.0, 1.0, 1.0, 1.0, 1.0, 2.0, 1.0, nan, 2.0, 1.0], + [1.0, 2.0, 2.0, 1.0, 2.0, 0.0, 1.0, nan, 1.0, 2.0], + [0.0, 0.0, 1.0, 2.0, 0.0, 1.0, 0.0, nan, 1.0, 2.0], + [1.0, 2.0, 1.0, 2.0, 2.0, 0.0, 1.0, nan, 1.0, 0.0], + [0.0, 2.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0], + [1.0, 1.0, 2.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 2.0], + [2.0, 0.0, 2.0, 2.0, 1.0, 1.0, 1.0, 1.0, 0.0, 2.0], + [1.0, 0.0, 1.0, 1.0, 1.0, 2.0, 2.0, 1.0, 2.0, 1.0], + [2.0, 1.0, 2.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0, 1.0], + [1.0, 1.0, 2.0, 1.0, 1.0, 2.0, 0.0, 2.0, 1.0, 2.0], + [1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + ] + ) + ds = xr.Dataset() + ds["call_dosage"] = ["variants", "samples"], dosage + ds["ancestral_frequency"] = "variants", np.ones(len(dosage)) / 2 + # calculating without skipna will result in nans in the GRM + expect = sg.genomic_relationship( + ds, + call_dosage="call_dosage", + ancestral_frequency="ancestral_frequency", + estimator="VanRaden", + ploidy=2, + skipna=False, + ).stat_genomic_relationship.values + assert np.isnan(expect).sum() > 0 + # fill nan values using maximum subsets without missing data + idx_0 = ~np.isnan(dosage[:, 4]) + idx_1 = ~np.isnan(dosage[:, 7]) + idx_2 = np.logical_and(idx_0, idx_1) + for idx in [idx_0, idx_1, idx_2]: + sub = ds.sel(dict(variants=idx)) + sub_expect = sg.genomic_relationship( + sub, + call_dosage="call_dosage", + ancestral_frequency="ancestral_frequency", + estimator="VanRaden", + ploidy=2, + skipna=False, + ).stat_genomic_relationship.values + expect = np.where(np.isnan(expect), sub_expect, expect) + # calculate actual value using skipna=True + actual = sg.genomic_relationship( + ds, + call_dosage="call_dosage", + ancestral_frequency="ancestral_frequency", + estimator="VanRaden", + ploidy=2, + skipna=True, + ).stat_genomic_relationship.values + np.testing.assert_array_equal(actual, expect) + + @pytest.mark.parametrize("ploidy", [2, 4]) def test_genomic_relationship__detect_ploidy(ploidy): ds = xr.Dataset() From 449fb9ff82c7ce9de0c8932a0adb7466bcf88359 Mon Sep 17 00:00:00 2001 From: Tim Millar Date: Mon, 1 May 2023 15:13:56 +1200 Subject: [PATCH 2/3] Update changelog --- docs/changelog.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/changelog.rst b/docs/changelog.rst index 4f2cb9a70..c318af331 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -42,6 +42,9 @@ New Features - Add :func:`sgkit.io.vcf.zarr_array_sizes` for determining array sizes for storage in Zarr. (:user:`tomwhite`, :pr:`1073`, :issue:`734`) +- Add ``skipna`` option to :func:`genomic_relationship` function. + (:user:`timothymillar`, :pr:`1078`, :issue:`1076`) + Bug fixes ~~~~~~~~~ From 56dec882fe60d25bbac2ad72da3b7d1b0e3c609a Mon Sep 17 00:00:00 2001 From: Tim Millar Date: Mon, 1 May 2023 20:33:58 +1200 Subject: [PATCH 3/3] Add skipna and mean imputation examples for GRM #1025 --- sgkit/stats/grm.py | 94 ++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 87 insertions(+), 7 deletions(-) diff --git a/sgkit/stats/grm.py b/sgkit/stats/grm.py index 9aab6f4fd..342deaaee 100644 --- a/sgkit/stats/grm.py +++ b/sgkit/stats/grm.py @@ -101,6 +101,8 @@ def genomic_relationship( Examples -------- + Diploid dataset without missing data: + >>> import sgkit as sg >>> ds = sg.simulate_genotype_call_dataset(n_variant=6, n_sample=3, seed=0) >>> ds = sg.count_call_alleles(ds) @@ -108,18 +110,96 @@ def genomic_relationship( >>> ds["call_dosage"] = ds.call_allele_count[:,:,0] >>> ds.call_dosage.values # doctest: +NORMALIZE_WHITESPACE array([[2, 1, 1], - [1, 1, 1], - [2, 1, 0], - [2, 1, 1], - [1, 0, 0], - [1, 1, 2]], dtype=uint8) + [1, 1, 1], + [2, 1, 0], + [2, 1, 1], + [1, 0, 0], + [1, 1, 2]], dtype=uint8) >>> # use sample population frequency as ancestral frequency >>> ds["sample_frequency"] = ds.call_dosage.mean(dim="samples") / ds.dims["ploidy"] >>> ds = sg.genomic_relationship(ds, ancestral_frequency="sample_frequency") >>> ds.stat_genomic_relationship.values # doctest: +NORMALIZE_WHITESPACE array([[ 0.93617021, -0.21276596, -0.72340426], - [-0.21276596, 0.17021277, 0.04255319], - [-0.72340426, 0.04255319, 0.68085106]]) + [-0.21276596, 0.17021277, 0.04255319], + [-0.72340426, 0.04255319, 0.68085106]]) + + Skipping partial or missing genotype calls: + + >>> import sgkit as sg + >>> import xarray as xr + >>> ds = sg.simulate_genotype_call_dataset( + ... n_variant=6, + ... n_sample=4, + ... missing_pct=0.05, + ... seed=0, + ... ) + >>> ds = sg.count_call_alleles(ds) + >>> ds["call_dosage"] = xr.where( + ... ds.call_genotype_mask.any(dim="ploidy"), + ... np.nan, + ... ds.call_allele_count[:,:,1], # alternate allele + ... ) + >>> ds.call_dosage.values # doctest: +NORMALIZE_WHITESPACE + array([[ 0., 1., 1., 1.], + [ 1., nan, 0., 1.], + [ 2., 0., 1., 1.], + [ 1., 2., nan, 1.], + [ 1., 0., 1., 2.], + [ 2., 2., 0., 0.]]) + >>> ds["sample_frequency"] = ds.call_dosage.mean( + ... dim="samples", skipna=True + ... ) / ds.dims["ploidy"] + >>> ds = sg.genomic_relationship( + ... ds, ancestral_frequency="sample_frequency", skipna=True + ... ) + >>> ds.stat_genomic_relationship.values # doctest: +NORMALIZE_WHITESPACE + array([[ 0.9744836 , -0.16978417, -0.58417266, -0.33778858], + [-0.16978417, 1.45323741, -0.47619048, -0.89496403], + [-0.58417266, -0.47619048, 0.62446043, 0.34820144], + [-0.33778858, -0.89496403, 0.34820144, 0.79951397]]) + + Using mean imputation to replace missing genotype calls: + + >>> import sgkit as sg + >>> import xarray as xr + >>> ds = sg.simulate_genotype_call_dataset( + ... n_variant=6, + ... n_sample=4, + ... missing_pct=0.05, + ... seed=0, + ... ) + >>> ds = sg.count_call_alleles(ds) + >>> ds["call_dosage"] = xr.where( + ... ds.call_genotype_mask.any(dim="ploidy"), + ... np.nan, + ... ds.call_allele_count[:,:,1], # alternate allele + ... ) + >>> # use mean imputation to replace missing dosage + >>> ds["call_dosage_imputed"] = xr.where( + ... ds.call_genotype_mask.any(dim="ploidy"), + ... ds.call_dosage.mean(dim="samples", skipna=True), + ... ds.call_dosage, + ... ) + >>> ds.call_dosage_imputed.values # doctest: +NORMALIZE_WHITESPACE + array([[0. , 1. , 1. , 1. ], + [1. , 0.66666667, 0. , 1. ], + [2. , 0. , 1. , 1. ], + [1. , 2. , 1.33333333, 1. ], + [1. , 0. , 1. , 2. ], + [2. , 2. , 0. , 0. ]]) + >>> ds["sample_frequency"] = ds.call_dosage.mean( + ... dim="samples", skipna=True + ... ) / ds.dims["ploidy"] + >>> ds = sg.genomic_relationship( + ... ds, + ... call_dosage="call_dosage_imputed", + ... ancestral_frequency="sample_frequency", + ... ) + >>> ds.stat_genomic_relationship.values # doctest: +NORMALIZE_WHITESPACE + array([[ 0.9744836 , -0.14337789, -0.49331713, -0.33778858], + [-0.14337789, 1.2272175 , -0.32806804, -0.75577157], + [-0.49331713, -0.32806804, 0.527339 , 0.29404617], + [-0.33778858, -0.75577157, 0.29404617, 0.79951397]]) References ----------