Skip to content

Commit

Permalink
Merge pull request #96 from gibsramen/new-docs
Browse files Browse the repository at this point in the history
New documentation
  • Loading branch information
gibsramen authored Oct 29, 2023
2 parents 58f763d + 1af6ba5 commit 7f0aa64
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 21 deletions.
55 changes: 51 additions & 4 deletions birdman/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,36 @@ def fit_to_inference(
dims: dict,
posterior_predictive: str = None,
log_likelihood: str = None,
):
) -> az.InferenceData:
"""Convert a fitted model to an arviz InferenceData object.
:param fit: Fitted CmdStan model
:type fit: Either CmdStanMCMC or CmdStanVB
:param chains: Number of chains
:type chains: int
:param draws: Number of draws
:type draws: int
:param params: Parameters to include in inference
:type params: Sequence[str]
:param coords: Coordinates for InferenceData
:type coords: dict
:param dims: Dimensions for InferenceData
:type dims: dict
:param posterior_predictive: Name of posterior predictive var in model
:type posterior_predictive: str
:param log_likelihood: Name of log likelihood var in model
:type log_likelihood: str
:returns: Model converted to InferenceData
:rtype: az.InferenceData
"""
if log_likelihood is not None and log_likelihood not in dims:
raise KeyError("Must include dimensions for log-likelihood!")
if posterior_predictive is not None and posterior_predictive not in dims:
Expand Down Expand Up @@ -120,15 +149,33 @@ def concatenate_inferences(
return az.concat(*all_group_inferences)


# TODO: Fix docstring
def stan_var_to_da(
data: np.ndarray,
coords: dict,
dims: dict,
chains: int,
draws: int
):
"""Convert Stan variable draws to xr.DataArray."""
) -> xr.DataArray:
"""Convert results of stan_var to DataArray.
:params data: Result of stan_var
:type data: np.ndarray
:params coords: Coordinates of variables
:type coords: dict
:params dims: Dimensions of variables
:type dims: dict
:params chains: Number of chains
:type chains: int
:params draws: Number of draws
:type draws: int
:returns: DataArray representation of stan variables
:rtype: xr.DataArray
"""
data = np.stack(np.split(data, chains))

coords["draw"] = np.arange(draws)
Expand Down
27 changes: 18 additions & 9 deletions docs/custom_model.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,19 @@ We then import the data into Python so we can use BIRDMAn.
import biom
import pandas as pd
import glob
fpath = glob.glob("templates/*.txt")[0]
table = biom.load_table("BIOM/94270/reference-hit.biom")
metadata = pd.read_csv(
"templates/11913_20191016-112545.txt",
fpath,
sep="\t",
index_col=0
)
metadata.head()
Processing metadata
-------------------

Expand Down Expand Up @@ -224,9 +229,6 @@ We will now pass this file along with our table, metadata, and formula into BIRD
nb_lme = birdman.TableModel(
table=filt_tbl,
model_path="negative_binomial_re.stan",
num_iter=500,
chains=4,
seed=42
)
nb_lme.create_regression(
metadata=metadata_model.loc[samps_to_keep],
Expand Down Expand Up @@ -272,6 +274,7 @@ Now we can add all the necessary parameters to BIRDMAn with the ``add_parameters
"depth": np.log(filt_tbl.sum(axis="sample")),
"B_p": 3.0,
"inv_disp_sd": 3.0,
"A": np.log(1 / filt_tbl.shape[0]),
"u_p": 1.0
}
nb_lme.add_parameters(param_dict)
Expand All @@ -294,8 +297,8 @@ We pass all these arguments into the ``specify_model`` method of the ``Model`` o
dims={
"beta_var": ["covariate", "feature_alr"],
"inv_disp": ["feature"],
"subj_int": ["subject"],
"log_lik": ["tbl_sample", "feature"],
"subj_int": ["subject", "feature_alr"],
"log_lhood": ["tbl_sample", "feature"],
"y_predict": ["tbl_sample", "feature"]
},
coords={
Expand All @@ -306,7 +309,7 @@ We pass all these arguments into the ``specify_model`` method of the ``Model`` o
"tbl_sample": nb_lme.sample_names
},
posterior_predictive="y_predict",
log_likelihood="log_lik",
log_likelihood="log_lhood",
include_observed_data=True
)
Expand All @@ -319,7 +322,7 @@ Finally, we compile and fit the model.
.. code-block:: python
nb_lme.compile_model()
nb_lme.fit_model()
nb_lme.fit_model(method="vi", num_draws=500)
Converting to ``InferenceData``
-------------------------------
Expand All @@ -329,7 +332,13 @@ When the model has finished fitting, you can convert to an inference data assumi
.. code-block:: python
from birdman.transform import posterior_alr_to_clr
inference = nb_lme.to_inference()
inference.posterior = posterior_alr_to_clr(inference.posterior)
inference.posterior = posterior_alr_to_clr(
inference.posterior,
alr_params=["subj_int", "beta_var"],
dim_replacement={"feature_alr": "feature"},
new_labels=filt_tbl.ids("observation")
)
With this you can use the rest of the BIRDMAn suite as usual or directly interact with the ``arviz`` library!
13 changes: 9 additions & 4 deletions docs/default_model_example.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,16 @@ Next, we want to import the data into Python so we can run BIRDMAn.
from birdman import NegativeBinomial
table = biom.load_table("BIOM/44773/otu_table.biom")
fpath = glob.glob("templates/*.txt")[0]
table = biom.load_table("BIOM/44773/reference-hit.biom")
metadata = pd.read_csv(
"templates/107_20180101-113755.txt",
fpath,
sep="\t",
index_col=0
)
metadata.head()
This table has nearly 2000 features, many of which are likely lowly prevalent. We are going to filter to only features that are present in at least 5 samples.

.. code-block:: python
Expand All @@ -44,11 +47,12 @@ For this example we're going to use a simple formula that only takes ``diet`` in

.. code-block:: python
from birdman import NegativeBinomial
nb = NegativeBinomial(
table=table_filt,
formula="diet",
metadata=metadata,
num_iter=1000,
)
We then have to compile and fit our model. This is very straightforward in BIRDMAn.
Expand All @@ -68,6 +72,7 @@ Now we have our parameter estimates which we can use in downstream analyses. Man
.. code-block:: python
from birdman.transform import posterior_alr_to_clr
inference = nb.to_inference()
inference.posterior = posterior_alr_to_clr(
inference.posterior,
Expand All @@ -85,7 +90,7 @@ Finally, we'll plot the feature differentials and their standard deviations. We
ax = viz.plot_parameter_estimates(
inference,
parameter="beta_var",
coord={"covariate": "diet[T.DIO]"},
coords={"covariate": "diet[T.DIO]"},
)
.. image:: imgs/example_differentials.png
5 changes: 3 additions & 2 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ We provide several default models but also allow users to create their own stati
Installation
------------

.. note:: BIRDMAn requires Python >= 3.7
.. note:: BIRDMAn requires Python >= 3.8

There are several dependencies you must install to use BIRDMAn.
The easiest way to install the required dependencies is through ``conda`` or ``mamba``.

.. code:: bash
conda install -c conda-forge biom-format patsy xarray arviz cmdstanpy
mamba install -c conda-forge biom-format patsy xarray arviz cmdstanpy
pip install birdman
If you are planning on contributing to BIRDMAn you must also install the following packages:
Expand Down Expand Up @@ -53,6 +53,7 @@ If you are planning on contributing to BIRDMAn you must also install the followi
:caption: API

models
inference
diagnostics
summary
transform
Expand Down
6 changes: 6 additions & 0 deletions docs/inference.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Inference Functions
===================

.. automodule:: birdman.inference
:members:

2 changes: 2 additions & 0 deletions docs/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ These are the default models that are included in BIRDMAn. They should be usable
:members:
.. autoclass:: birdman.default_models.NegativeBinomialLME
:members:
.. autoclass:: birdman.default_models.NegativeBinomialLMESingle
:members:

Table Model
-----------
Expand Down
3 changes: 1 addition & 2 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ channels:
dependencies:
- pandas
- numpy
- python=3.7
- python=3.8
- python-language-server
- xarray
- patsy
Expand All @@ -16,4 +16,3 @@ dependencies:
- docutils==0.16
- pip:
- sphinx-rtd-theme==0.5.1
prefix: /Users/gibs/miniconda3/envs/birdman

0 comments on commit 7f0aa64

Please sign in to comment.