Skip to content

Commit

Permalink
notebooks: use ExperimentDataset
Browse files Browse the repository at this point in the history
Also execute via `juq papermill run`
  • Loading branch information
ryan-williams committed Jan 27, 2025
1 parent be141d5 commit df3d2e2
Show file tree
Hide file tree
Showing 4 changed files with 657 additions and 277 deletions.
160 changes: 97 additions & 63 deletions notebooks/tutorial_lightning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -28,43 +28,74 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Initialize SOMA Experiment query as training data"
"[Papermill] parameters:\n",
"\n",
"[Papermill]: https://papermill.readthedocs.io/"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"editable": true,
"slideshow": {
"slide_type": ""
},
"tags": [
"parameters"
]
},
"outputs": [],
"source": [
"import os\n",
"\n",
"tissue = \"tongue\"\n",
"n_epochs = 20\n",
"census_version = \"2024-07-01\"\n",
"batch_size = 128\n",
"learning_rate = 1e-5\n",
"progress_bar = not bool(os.environ['PAPERMILL']) # Defaults to True, unless env var $PAPERMILL is set"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Initialize SOMA Experiment query as training data"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import pytorch_lightning as pl\n",
"import tiledbsoma as soma\n",
"import torch\n",
"from tiledbsoma import AxisQuery, Experiment, SOMATileDBContext\n",
"from sklearn.preprocessing import LabelEncoder\n",
"\n",
"import tiledbsoma_ml as soma_ml\n",
"from tiledbsoma_ml import ExperimentDataset\n",
"\n",
"CZI_Census_Homo_Sapiens_URL = \"s3://cellxgene-census-public-us-west-2/cell-census/2024-07-01/soma/census_data/homo_sapiens/\"\n",
"CZI_Census_Homo_Sapiens_URL = f\"s3://cellxgene-census-public-us-west-2/cell-census/{census_version}/soma/census_data/homo_sapiens/\"\n",
"\n",
"experiment = soma.open(\n",
"experiment = Experiment.open(\n",
" CZI_Census_Homo_Sapiens_URL,\n",
" context=soma.SOMATileDBContext(tiledb_config={\"vfs.s3.region\": \"us-west-2\", \"vfs.s3.no_sign_request\": \"true\"}),\n",
" context=SOMATileDBContext(tiledb_config={\"vfs.s3.region\": \"us-west-2\", \"vfs.s3.no_sign_request\": \"true\"}),\n",
")\n",
"obs_value_filter = \"tissue_general == 'tongue' and is_primary_data == True\"\n",
"obs_value_filter = f\"tissue_general == '{tissue}' and is_primary_data == True\"\n",
"\n",
"with experiment.axis_query(\n",
" measurement_name=\"RNA\", obs_query=soma.AxisQuery(value_filter=obs_value_filter)\n",
" measurement_name=\"RNA\", obs_query=AxisQuery(value_filter=obs_value_filter)\n",
") as query:\n",
" obs_df = query.obs(column_names=[\"cell_type\"]).concat().to_pandas()\n",
" cell_type_encoder = LabelEncoder().fit(obs_df[\"cell_type\"].unique())\n",
"\n",
" experiment_dataset = soma_ml.ExperimentDataset(\n",
" query,\n",
" layer_name=\"raw\",\n",
" obs_column_names=[\"cell_type\"],\n",
" batch_size=128,\n",
" shuffle=True,\n",
" )"
"experiment_dataset = ExperimentDataset.create(\n",
" query,\n",
" layer_name=\"raw\",\n",
" obs_column_names=[\"cell_type\"],\n",
" batch_size=batch_size,\n",
" shuffle=True,\n",
")"
]
},
{
Expand All @@ -76,12 +107,15 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import pytorch_lightning as pl\n",
"\n",
"class LogisticRegressionLightning(pl.LightningModule):\n",
" def __init__(self, input_dim, output_dim, cell_type_encoder, learning_rate=1e-5):\n",
" def __init__(self, input_dim, output_dim, cell_type_encoder, learning_rate=learning_rate):\n",
" super(LogisticRegressionLightning, self).__init__()\n",
" self.linear = torch.nn.Linear(input_dim, output_dim)\n",
" self.cell_type_encoder = cell_type_encoder\n",
Expand Down Expand Up @@ -134,17 +168,50 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"GPU available: True (cuda), used: True\n",
"TPU available: False, using: 0 TPU cores\n",
"HPU available: False, using: 0 HPUs\n",
"/home/bruce/miniforge3/envs/toymodel/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default\n",
"TPU available: False, using: 0 TPU cores\n"
]
}
],
"source": [
"from tiledbsoma_ml import experiment_dataloader\n",
"\n",
"dataloader = experiment_dataloader(experiment_dataset)\n",
"\n",
"# The size of the input dimension is the number of genes\n",
"input_dim = experiment_dataset.shape[1]\n",
"\n",
"# The size of the output dimension is the number of distinct cell_type values\n",
"output_dim = len(cell_type_encoder.classes_)\n",
"\n",
"# Initialize the PyTorch Lightning model\n",
"model = LogisticRegressionLightning(\n",
" input_dim, output_dim, cell_type_encoder=cell_type_encoder\n",
")\n",
"\n",
"# Define the PyTorch Lightning Trainer\n",
"trainer = pl.Trainer(max_epochs=n_epochs, enable_progress_bar=progress_bar)\n",
"\n",
"# set precision\n",
"torch.set_float32_matmul_precision(\"high\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
"\n",
" | Name | Type | Params | Mode \n",
Expand All @@ -157,61 +224,28 @@
"726 K Total params\n",
"2.905 Total estimated model params size (MB)\n",
"2 Modules in train mode\n",
"0 Modules in eval mode\n",
"/home/bruce/miniforge3/envs/toymodel/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.\n",
"/home/bruce/miniforge3/envs/toymodel/lib/python3.11/site-packages/pytorch_lightning/utilities/data.py:122: Your `IterableDataset` has `__len__` defined. In combination with multi-process data loading (when num_workers > 1), `__len__` could be inaccurate if each worker is not configured independently to avoid having duplicate data.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 19: 100%|██████████| 118/118 [00:08<00:00, 14.31it/s, v_num=5, train_loss=1.670, train_accuracy=0.977]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"`Trainer.fit` stopped: `max_epochs=20` reached.\n"
"0 Modules in eval mode\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 19: 100%|██████████| 118/118 [00:08<00:00, 14.28it/s, v_num=5, train_loss=1.670, train_accuracy=0.977]\n"
"CPU times: user 3min 30s, sys: 1min 25s, total: 4min 55s\n",
"Wall time: 2min 14s\n"
]
}
],
"source": [
"dataloader = soma_ml.experiment_dataloader(experiment_dataset)\n",
"\n",
"# The size of the input dimension is the number of genes\n",
"input_dim = experiment_dataset.shape[1]\n",
"\n",
"# The size of the output dimension is the number of distinct cell_type values\n",
"output_dim = len(cell_type_encoder.classes_)\n",
"\n",
"# Initialize the PyTorch Lightning model\n",
"model = LogisticRegressionLightning(\n",
" input_dim, output_dim, cell_type_encoder=cell_type_encoder\n",
")\n",
"\n",
"# Define the PyTorch Lightning Trainer\n",
"trainer = pl.Trainer(max_epochs=20)\n",
"\n",
"# set precision\n",
"torch.set_float32_matmul_precision(\"high\")\n",
"\n",
"%%time\n",
"# Train the model\n",
"trainer.fit(model, train_dataloaders=dataloader)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "toymodel",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -225,9 +259,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
"version": "3.12.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
"nbformat_minor": 5
}
Loading

0 comments on commit df3d2e2

Please sign in to comment.