diff --git a/notebooks/tutorial_lightning.ipynb b/notebooks/tutorial_lightning.ipynb index ac8a55b..01551be 100644 --- a/notebooks/tutorial_lightning.ipynb +++ b/notebooks/tutorial_lightning.ipynb @@ -28,43 +28,74 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Initialize SOMA Experiment query as training data" + "[Papermill] parameters:\n", + "\n", + "[Papermill]: https://papermill.readthedocs.io/" ] }, { "cell_type": "code", "execution_count": 1, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [ + "parameters" + ] + }, + "outputs": [], + "source": [ + "import os\n", + "\n", + "tissue = \"tongue\"\n", + "n_epochs = 20\n", + "census_version = \"2024-07-01\"\n", + "batch_size = 128\n", + "learning_rate = 1e-5\n", + "progress_bar = not bool(os.environ['PAPERMILL']) # Defaults to True, unless env var $PAPERMILL is set" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Initialize SOMA Experiment query as training data" + ] + }, + { + "cell_type": "code", + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ - "import pytorch_lightning as pl\n", - "import tiledbsoma as soma\n", - "import torch\n", + "from tiledbsoma import AxisQuery, Experiment, SOMATileDBContext\n", "from sklearn.preprocessing import LabelEncoder\n", "\n", - "import tiledbsoma_ml as soma_ml\n", + "from tiledbsoma_ml import ExperimentDataset\n", "\n", - "CZI_Census_Homo_Sapiens_URL = \"s3://cellxgene-census-public-us-west-2/cell-census/2024-07-01/soma/census_data/homo_sapiens/\"\n", + "CZI_Census_Homo_Sapiens_URL = f\"s3://cellxgene-census-public-us-west-2/cell-census/{census_version}/soma/census_data/homo_sapiens/\"\n", "\n", - "experiment = soma.open(\n", + "experiment = Experiment.open(\n", " CZI_Census_Homo_Sapiens_URL,\n", - " context=soma.SOMATileDBContext(tiledb_config={\"vfs.s3.region\": \"us-west-2\", \"vfs.s3.no_sign_request\": \"true\"}),\n", + " context=SOMATileDBContext(tiledb_config={\"vfs.s3.region\": \"us-west-2\", \"vfs.s3.no_sign_request\": \"true\"}),\n", ")\n", - "obs_value_filter = \"tissue_general == 'tongue' and is_primary_data == True\"\n", + "obs_value_filter = f\"tissue_general == '{tissue}' and is_primary_data == True\"\n", "\n", "with experiment.axis_query(\n", - " measurement_name=\"RNA\", obs_query=soma.AxisQuery(value_filter=obs_value_filter)\n", + " measurement_name=\"RNA\", obs_query=AxisQuery(value_filter=obs_value_filter)\n", ") as query:\n", " obs_df = query.obs(column_names=[\"cell_type\"]).concat().to_pandas()\n", " cell_type_encoder = LabelEncoder().fit(obs_df[\"cell_type\"].unique())\n", "\n", - " experiment_dataset = soma_ml.ExperimentDataset(\n", - " query,\n", - " layer_name=\"raw\",\n", - " obs_column_names=[\"cell_type\"],\n", - " batch_size=128,\n", - " shuffle=True,\n", - " )" + "experiment_dataset = ExperimentDataset.create(\n", + " query,\n", + " layer_name=\"raw\",\n", + " obs_column_names=[\"cell_type\"],\n", + " batch_size=batch_size,\n", + " shuffle=True,\n", + ")" ] }, { @@ -76,12 +107,15 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ + "import torch\n", + "import pytorch_lightning as pl\n", + "\n", "class LogisticRegressionLightning(pl.LightningModule):\n", - " def __init__(self, input_dim, output_dim, cell_type_encoder, learning_rate=1e-5):\n", + " def __init__(self, input_dim, output_dim, cell_type_encoder, learning_rate=learning_rate):\n", " super(LogisticRegressionLightning, self).__init__()\n", " self.linear = torch.nn.Linear(input_dim, output_dim)\n", " self.cell_type_encoder = cell_type_encoder\n", @@ -134,7 +168,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -142,9 +176,42 @@ "output_type": "stream", "text": [ "GPU available: True (cuda), used: True\n", - "TPU available: False, using: 0 TPU cores\n", - "HPU available: False, using: 0 HPUs\n", - "/home/bruce/miniforge3/envs/toymodel/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default\n", + "TPU available: False, using: 0 TPU cores\n" + ] + } + ], + "source": [ + "from tiledbsoma_ml import experiment_dataloader\n", + "\n", + "dataloader = experiment_dataloader(experiment_dataset)\n", + "\n", + "# The size of the input dimension is the number of genes\n", + "input_dim = experiment_dataset.shape[1]\n", + "\n", + "# The size of the output dimension is the number of distinct cell_type values\n", + "output_dim = len(cell_type_encoder.classes_)\n", + "\n", + "# Initialize the PyTorch Lightning model\n", + "model = LogisticRegressionLightning(\n", + " input_dim, output_dim, cell_type_encoder=cell_type_encoder\n", + ")\n", + "\n", + "# Define the PyTorch Lightning Trainer\n", + "trainer = pl.Trainer(max_epochs=n_epochs, enable_progress_bar=progress_bar)\n", + "\n", + "# set precision\n", + "torch.set_float32_matmul_precision(\"high\")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", "\n", " | Name | Type | Params | Mode \n", @@ -157,53 +224,20 @@ "726 K Total params\n", "2.905 Total estimated model params size (MB)\n", "2 Modules in train mode\n", - "0 Modules in eval mode\n", - "/home/bruce/miniforge3/envs/toymodel/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.\n", - "/home/bruce/miniforge3/envs/toymodel/lib/python3.11/site-packages/pytorch_lightning/utilities/data.py:122: Your `IterableDataset` has `__len__` defined. In combination with multi-process data loading (when num_workers > 1), `__len__` could be inaccurate if each worker is not configured independently to avoid having duplicate data.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 19: 100%|██████████| 118/118 [00:08<00:00, 14.31it/s, v_num=5, train_loss=1.670, train_accuracy=0.977]" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "`Trainer.fit` stopped: `max_epochs=20` reached.\n" + "0 Modules in eval mode\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch 19: 100%|██████████| 118/118 [00:08<00:00, 14.28it/s, v_num=5, train_loss=1.670, train_accuracy=0.977]\n" + "CPU times: user 3min 30s, sys: 1min 25s, total: 4min 55s\n", + "Wall time: 2min 14s\n" ] } ], "source": [ - "dataloader = soma_ml.experiment_dataloader(experiment_dataset)\n", - "\n", - "# The size of the input dimension is the number of genes\n", - "input_dim = experiment_dataset.shape[1]\n", - "\n", - "# The size of the output dimension is the number of distinct cell_type values\n", - "output_dim = len(cell_type_encoder.classes_)\n", - "\n", - "# Initialize the PyTorch Lightning model\n", - "model = LogisticRegressionLightning(\n", - " input_dim, output_dim, cell_type_encoder=cell_type_encoder\n", - ")\n", - "\n", - "# Define the PyTorch Lightning Trainer\n", - "trainer = pl.Trainer(max_epochs=20)\n", - "\n", - "# set precision\n", - "torch.set_float32_matmul_precision(\"high\")\n", - "\n", + "%%time\n", "# Train the model\n", "trainer.fit(model, train_dataloaders=dataloader)" ] @@ -211,7 +245,7 @@ ], "metadata": { "kernelspec": { - "display_name": "toymodel", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -225,9 +259,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.9" + "version": "3.12.7" } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 5 } diff --git a/notebooks/tutorial_multiworker.ipynb b/notebooks/tutorial_multiworker.ipynb index 7493407..8010c12 100644 --- a/notebooks/tutorial_multiworker.ipynb +++ b/notebooks/tutorial_multiworker.ipynb @@ -2,15 +2,21 @@ "cells": [ { "cell_type": "markdown", - "metadata": {}, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, "source": [ "# Multi-process training\n", "\n", - "Multi-process usage of `tiledbsoma_ml.ExperimentAxisQueryIterDataset` includes both:\n", + "Multi-process usage of `tiledbsoma_ml.ExperimentDataset` includes both:\n", "* using the [`torch.utils.data.DataLoader`] with 1 or more workers (i.e., with an argument of `n_workers=1` or greater)\n", "* using a multi-process training configuration, such as [`DistributedDataParallel`]\n", "\n", - "In these configurations, `ExperimentAxisQueryIterDataset` will automatically partition data across workers. However, when using `shuffle=True`, there are several things to keep in mind:\n", + "In these configurations, `ExperimentDataset` will automatically partition data across workers. However, when using `shuffle=True`, there are several things to keep in mind:\n", "\n", "1. All worker processes must share the same random number generator `seed`, ensuring that all workers shuffle and partition the data in the same way.\n", "2. To ensure that each epoch returns a _different_ shuffle, the caller must set the epoch, using the `set_epoch` API. This is identical to the behavior of [`torch.utils.data.distributed.DistributedSampler`].\n", @@ -22,50 +28,83 @@ ] }, { + "cell_type": "markdown", "metadata": {}, + "source": [ + "[Papermill] parameters:\n", + "\n", + "[Papermill]: https://papermill.readthedocs.io/" + ] + }, + { "cell_type": "code", + "execution_count": 1, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [ + "parameters" + ] + }, "outputs": [], - "execution_count": null, "source": [ - "import tiledbsoma as soma\n", - "import torch\n", + "import os\n", + "\n", + "tissue = \"tongue\"\n", + "n_epochs = 20\n", + "census_version = \"2024-07-01\"\n", + "batch_size = 128\n", + "learning_rate = 1e-5\n", + "num_workers = 2" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from tiledbsoma import AxisQuery, Experiment, SOMATileDBContext\n", "from sklearn.preprocessing import LabelEncoder\n", "\n", - "import tiledbsoma_ml as soma_ml\n", + "from tiledbsoma_ml import ExperimentDataset\n", "\n", - "CZI_Census_Homo_Sapiens_URL = \"s3://cellxgene-census-public-us-west-2/cell-census/2024-07-01/soma/census_data/homo_sapiens/\"\n", + "CZI_Census_Homo_Sapiens_URL = f\"s3://cellxgene-census-public-us-west-2/cell-census/{census_version}/soma/census_data/homo_sapiens/\"\n", "\n", - "experiment = soma.open(\n", + "experiment = Experiment.open(\n", " CZI_Census_Homo_Sapiens_URL,\n", - " context=soma.SOMATileDBContext(tiledb_config={\"vfs.s3.region\": \"us-west-2\", \"vfs.s3.no_sign_request\": \"true\"}),\n", + " context=SOMATileDBContext(tiledb_config={\"vfs.s3.region\": \"us-west-2\", \"vfs.s3.no_sign_request\": \"true\"}),\n", ")\n", - "obs_value_filter = \"tissue_general == 'tongue' and is_primary_data == True\"\n", + "obs_value_filter = f\"tissue_general == '{tissue}' and is_primary_data == True\"\n", "\n", "with experiment.axis_query(\n", - " measurement_name=\"RNA\", obs_query=soma.AxisQuery(value_filter=obs_value_filter)\n", + " measurement_name=\"RNA\", obs_query=AxisQuery(value_filter=obs_value_filter)\n", ") as query:\n", " obs_df = query.obs(column_names=[\"cell_type\"]).concat().to_pandas()\n", " cell_type_encoder = LabelEncoder().fit(obs_df[\"cell_type\"].unique())\n", "\n", - " experiment_dataset = soma_ml.ExperimentDataset(\n", - " query,\n", - " layer_name=\"raw\",\n", - " obs_column_names=[\"cell_type\"],\n", - " batch_size=128,\n", - " shuffle=True,\n", - " )\n", - " " + "experiment_dataset = ExperimentDataset.create(\n", + " query,\n", + " layer_name=\"raw\",\n", + " obs_column_names=[\"cell_type\"],\n", + " batch_size=batch_size,\n", + " shuffle=True,\n", + ")" ] }, { - "metadata": {}, "cell_type": "code", + "execution_count": 3, + "metadata": {}, "outputs": [], - "execution_count": null, "source": [ + "import torch\n", + "\n", "class LogisticRegression(torch.nn.Module):\n", " def __init__(self, input_dim, output_dim):\n", - " super(LogisticRegression, self).__init__() # noqa: UP008\n", + " super(LogisticRegression, self).__init__()\n", " self.linear = torch.nn.Linear(input_dim, output_dim)\n", "\n", " def forward(self, x):\n", @@ -79,7 +118,7 @@ " train_correct = 0\n", " train_total = 0\n", "\n", - " for X_batch, y_batch in train_dataloader:\n", + " for X_batch, obs_batch in train_dataloader:\n", " optimizer.zero_grad()\n", "\n", " X_batch = torch.from_numpy(X_batch).float().to(device)\n", @@ -92,11 +131,11 @@ " predictions = torch.argmax(probabilities, axis=1)\n", "\n", " # Compute the loss and perform back propagation\n", - " y_batch = torch.from_numpy(cell_type_encoder.transform(y_batch['cell_type'])).to(device)\n", - " train_correct += (predictions == y_batch).sum().item()\n", + " obs_batch = torch.from_numpy(cell_type_encoder.transform(obs_batch['cell_type'])).to(device)\n", + " train_correct += (predictions == obs_batch).sum().item()\n", " train_total += len(predictions)\n", "\n", - " loss = loss_fn(outputs, y_batch.long())\n", + " loss = loss_fn(outputs, obs_batch.long())\n", " train_loss += loss.item()\n", " loss.backward()\n", " optimizer.step()\n", @@ -116,42 +155,25 @@ "\n", "The same approach should be taken for parallel training, e.g., when using DDP or DP.\n", "\n", - "*Tip*: when running with `num_workers=0`, i.e., using the data loader in-process, the `ExperimentAxisQueryIterDataset` will automatically increment the epoch count each time the iterator completes." + "*Tip*: when running with `num_workers=0`, i.e., using the data loader in-process, the `ExperimentDataset` will automatically increment the epoch count each time the iterator completes." ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": {}, "outputs": [ { - "name": "stdout", + "name": "stderr", "output_type": "stream", "text": [ - "Epoch 1: Train Loss: 0.0169229 Accuracy 0.3124\n", - "Epoch 2: Train Loss: 0.0148674 Accuracy 0.4272\n", - "Epoch 3: Train Loss: 0.0144468 Accuracy 0.4509\n", - "Epoch 4: Train Loss: 0.0141778 Accuracy 0.4999\n", - "Epoch 5: Train Loss: 0.0139660 Accuracy 0.5619\n", - "Epoch 6: Train Loss: 0.0137670 Accuracy 0.6971\n", - "Epoch 7: Train Loss: 0.0136089 Accuracy 0.8670\n", - "Epoch 8: Train Loss: 0.0135203 Accuracy 0.9099\n", - "Epoch 9: Train Loss: 0.0134427 Accuracy 0.9262\n", - "Epoch 10: Train Loss: 0.0133607 Accuracy 0.9300\n", - "Epoch 11: Train Loss: 0.0133110 Accuracy 0.9348\n", - "Epoch 12: Train Loss: 0.0132749 Accuracy 0.9378\n", - "Epoch 13: Train Loss: 0.0132431 Accuracy 0.9413\n", - "Epoch 14: Train Loss: 0.0132194 Accuracy 0.9444\n", - "Epoch 15: Train Loss: 0.0131942 Accuracy 0.9465\n", - "Epoch 16: Train Loss: 0.0131739 Accuracy 0.9499\n", - "Epoch 17: Train Loss: 0.0131527 Accuracy 0.9526\n", - "Epoch 18: Train Loss: 0.0131369 Accuracy 0.9551\n", - "Epoch 19: Train Loss: 0.0131214 Accuracy 0.9563\n", - "Epoch 20: Train Loss: 0.0131061 Accuracy 0.9578\n" + "switching torch multiprocessing start method from \"fork\" to \"spawn\"\n" ] } ], "source": [ + "from tiledbsoma_ml import experiment_dataloader\n", + "\n", "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", "\n", "# The size of the input dimension is the number of genes\n", @@ -162,19 +184,55 @@ "\n", "model = LogisticRegression(input_dim, output_dim).to(device)\n", "loss_fn = torch.nn.CrossEntropyLoss()\n", - "optimizer = torch.optim.Adam(model.parameters(), lr=1e-05)\n", - "\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n", "\n", - "# define a two-worker data loader. The dataset is shuffled, so call `set_epoch` to ensure\n", + "# Define a two-worker data loader. The dataset is shuffled, so call `set_epoch` to ensure\n", "# that a different shuffle is applied on each epoch.\n", - "experiment_dataloader = soma_ml.experiment_dataloader(\n", - " experiment_dataset, num_workers=2, persistent_workers=True\n", - ")\n", - "\n", - "for epoch in range(20):\n", + "dataloader = experiment_dataloader(\n", + " experiment_dataset, num_workers=num_workers, persistent_workers=True\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1: Train Loss: 0.0165012 Accuracy 0.3866\n", + "Epoch 2: Train Loss: 0.0148111 Accuracy 0.4217\n", + "Epoch 3: Train Loss: 0.0144168 Accuracy 0.6109\n", + "Epoch 4: Train Loss: 0.0141248 Accuracy 0.8374\n", + "Epoch 5: Train Loss: 0.0138151 Accuracy 0.9001\n", + "Epoch 6: Train Loss: 0.0136300 Accuracy 0.9123\n", + "Epoch 7: Train Loss: 0.0135218 Accuracy 0.9234\n", + "Epoch 8: Train Loss: 0.0134472 Accuracy 0.9324\n", + "Epoch 9: Train Loss: 0.0133907 Accuracy 0.9375\n", + "Epoch 10: Train Loss: 0.0133443 Accuracy 0.9419\n", + "Epoch 11: Train Loss: 0.0132998 Accuracy 0.9456\n", + "Epoch 12: Train Loss: 0.0132594 Accuracy 0.9489\n", + "Epoch 13: Train Loss: 0.0132298 Accuracy 0.9524\n", + "Epoch 14: Train Loss: 0.0132037 Accuracy 0.9549\n", + "Epoch 15: Train Loss: 0.0131809 Accuracy 0.9568\n", + "Epoch 16: Train Loss: 0.0131603 Accuracy 0.9585\n", + "Epoch 17: Train Loss: 0.0131425 Accuracy 0.9601\n", + "Epoch 18: Train Loss: 0.0131270 Accuracy 0.9613\n", + "Epoch 19: Train Loss: 0.0131112 Accuracy 0.9630\n", + "Epoch 20: Train Loss: 0.0130966 Accuracy 0.9639\n", + "CPU times: user 1min 6s, sys: 1min 58s, total: 3min 4s\n", + "Wall time: 4min 48s\n" + ] + } + ], + "source": [ + "%%time\n", + "for epoch in range(n_epochs):\n", " experiment_dataset.set_epoch(epoch)\n", " train_loss, train_accuracy = train_epoch(\n", - " model, experiment_dataloader, loss_fn, optimizer, device\n", + " model, dataloader, loss_fn, optimizer, device\n", " )\n", " print(\n", " f\"Epoch {epoch + 1}: Train Loss: {train_loss:.7f} Accuracy {train_accuracy:.4f}\"\n", @@ -184,7 +242,7 @@ ], "metadata": { "kernelspec": { - "display_name": "toymodel", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -198,9 +256,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.9" + "version": "3.12.7" } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 5 } diff --git a/notebooks/tutorial_pytorch.ipynb b/notebooks/tutorial_pytorch.ipynb index f19fafa..d84e8de 100644 --- a/notebooks/tutorial_pytorch.ipynb +++ b/notebooks/tutorial_pytorch.ipynb @@ -6,7 +6,7 @@ "source": [ "# Training a PyTorch Model\n", "\n", - "This tutorial shows how to train a Logistic Regression model in PyTorch using the `tiledbsoma.ml.ExperimentAxisQueryIterDataPipe` class, and the [CZI CELLxGENE Census](https://chanzuckerberg.github.io/cellxgene-census/) dataset. This is intended only to demonstrate the use of the `ExperimentAxisQueryIterDataPipe`, and not as an example of how to train a biologically useful model.\n", + "This tutorial trains a Logistic Regression model in PyTorch, using `tiledbsoma.ml.ExperimentDataset` and the [CZI CELLxGENE Census] dataset. This is intended only to demonstrate the use of `ExperimentDataset`, not as an example of how to train a biologically useful model.\n", "\n", "This tutorial assumes a basic familiarity with PyTorch and the Census API.\n", "\n", @@ -20,75 +20,118 @@ "\n", "**Contents**\n", "\n", - "* [Create an ExperimentAxisQueryIterDataPipe](#data-pipe)\n", + "* [Create an ExperimentDataset](#data-pipe)\n", "* [Split the dataset](#split)\n", "* [Create the DataLoader](#data-loader)\n", "* [Define the model](#model)\n", "* [Train the model](#train)\n", - "* [Make predictions with the model](#predict)" + "* [Make predictions with the model](#predict)\n", + "\n", + "[CZI CELLxGENE Census]: https://chanzuckerberg.github.io/cellxgene-census/" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Create an ExperimentAxisQueryIterDataPipe \n", + "[Papermill] parameters:\n", + "\n", + "[Papermill]: https://papermill.readthedocs.io/" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [ + "parameters" + ] + }, + "outputs": [], + "source": [ + "import os\n", + "\n", + "tissue = \"tongue\"\n", + "n_epochs = 20\n", + "census_version = \"2024-07-01\"\n", + "batch_size = 128\n", + "train_split = .8\n", + "seed = 111\n", + "learning_rate = 1e-5" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create an ExperimentDataset \n", "\n", "To train a PyTorch model on a SOMA [Experiment]:\n", "1. Open the Experiment.\n", "2. Select the desired `obs` rows and `var` columns with an [ExperimentAxisQuery].\n", - "3. Create an `ExperimentAxisQueryIterDataPipe`.\n", + "3. Create an `ExperimentDataset`.\n", "\n", "The example below utilizes a recent CZI Census release, accessed directly from S3. We also encode the `obs` `cell_type` labels, using a `scikit-learn` [LabelEncoder].\n", "\n", - "[Experiment]: https://tiledbsoma.readthedocs.io/en/stable/_autosummary/tiledbsoma.Experiment.html#tiledbsoma.Experiment\n", - "[ExperimentAxisQuery]: https://tiledbsoma.readthedocs.io/en/stable/_autosummary/tiledbsoma.ExperimentAxisQuery.html\n", + "[Experiment]: https://tiledbsoma.readthedocs.io/en/stable/python-tiledbsoma-experiment.html\n", + "[ExperimentAxisQuery]: https://tiledbsoma.readthedocs.io/en/stable/python-tiledbsoma-experimentaxisquery.html\n", "[LabelEncoder]: https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.LabelEncoder.html" ] }, { "cell_type": "code", - "execution_count": 1, - "metadata": {}, + "execution_count": 2, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, "outputs": [], "source": [ - "import tiledbsoma as soma\n", + "from tiledbsoma import AxisQuery, Experiment, SOMATileDBContext\n", "from sklearn.preprocessing import LabelEncoder\n", "\n", - "import tiledbsoma_ml as soma_ml\n", + "from tiledbsoma_ml import ExperimentDataset\n", "\n", - "CZI_Census_Homo_Sapiens_URL = \"s3://cellxgene-census-public-us-west-2/cell-census/2024-07-01/soma/census_data/homo_sapiens/\"\n", + "CZI_Census_Homo_Sapiens_URL = f\"s3://cellxgene-census-public-us-west-2/cell-census/{census_version}/soma/census_data/homo_sapiens/\"\n", "\n", - "experiment = soma.open(\n", + "experiment = Experiment.open(\n", " CZI_Census_Homo_Sapiens_URL,\n", - " context=soma.SOMATileDBContext(tiledb_config={\"vfs.s3.region\": \"us-west-2\", \"vfs.s3.no_sign_request\": \"true\"}),\n", + " context=SOMATileDBContext(tiledb_config={\"vfs.s3.region\": \"us-west-2\", \"vfs.s3.no_sign_request\": \"true\"}),\n", ")\n", - "obs_value_filter = \"tissue_general == 'tongue' and is_primary_data == True\"\n", + "obs_value_filter = f\"tissue_general == '{tissue}' and is_primary_data == True\"\n", "\n", "with experiment.axis_query(\n", - " measurement_name=\"RNA\", obs_query=soma.AxisQuery(value_filter=obs_value_filter)\n", + " measurement_name=\"RNA\", obs_query=AxisQuery(value_filter=obs_value_filter)\n", ") as query:\n", " obs_df = query.obs(column_names=[\"cell_type\"]).concat().to_pandas()\n", " cell_type_encoder = LabelEncoder().fit(obs_df[\"cell_type\"].unique())\n", "\n", - " experiment_dataset = soma_ml.ExperimentAxisQueryIterDataPipe(\n", - " query,\n", - " layer_name=\"raw\",\n", - " obs_column_names=[\"cell_type\"],\n", - " batch_size=128,\n", - " shuffle=True,\n", - " )" + "experiment_dataset = ExperimentDataset.create(\n", + " query,\n", + " layer_name=\"raw\",\n", + " obs_column_names=[\"cell_type\"],\n", + " batch_size=batch_size,\n", + " shuffle=True,\n", + " seed=111,\n", + ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### `ExperimentAxisQueryIterDataPipe` class explained\n", + "### `ExperimentDataset` class explained\n", "\n", - "This class provides an implementation of PyTorch's [`torchdata` IterDataPipe interface][IterDataPipe], which defines a common mechanism for wrapping and accessing training data from any underlying source. The `ExperimentAxisQueryIterDataPipe` class encapsulates the details of querying a SOMA `Experiment` and returning a series of \"batches,\" each consisting of a NumPy `ndarray` and a Pandas `DataFrame`. Most importantly, it retrieves data lazily, avoiding loading the entire training dataset into memory at once.\n", + "This class provides an implementation of PyTorch's [`torchdata` IterDataPipe interface][IterDataPipe], which defines a common mechanism for wrapping and accessing training data from any underlying source. The `ExperimentDataset` class encapsulates the details of querying a SOMA `Experiment` and returning a series of \"batches,\" each consisting of a NumPy `ndarray` and a Pandas `DataFrame`. Most importantly, it retrieves data lazily, avoiding loading the entire training dataset into memory at once.\n", "\n", - "### `ExperimentAxisQueryIterDataPipe` parameters explained\n", + "### `ExperimentDataset` parameters explained\n", "\n", "The constructor only requires a single parameter, `query`, which is an [`ExperimentAxisQuery`] containing the data to be used for training. This is obtained by querying an [`Experiment`], along the `obs` and/or `var` axes (see above, or [the TileDB-SOMA docs][tdbs docs], for examples).\n", "\n", @@ -119,7 +162,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -128,7 +171,7 @@ "(118, 60530)" ] }, - "execution_count": 2, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -148,11 +191,70 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "((94, 60530), (24, 60530))" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_dataset, test_dataset = experiment_dataset.split(\n", + " train_split, 1 - train_split,\n", + " seed=1,\n", + ")\n", + "train_dataset.shape, test_dataset.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "((12016,), (3004,))" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "t0 = train_dataset.query_ids.obs_joinids\n", + "t1 = test_dataset.query_ids.obs_joinids\n", + "t0.shape, t1.shape" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Quick check that train and test sets contain distinct `obs_joinids`:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ - "train_dataset, test_dataset = experiment_dataset.random_split(weights={\"train\": 0.8, \"test\": 0.2}, seed=1)" + "assert not set(train_dataset.obs_joinids) & set(test_dataset.obs_joinids)" ] }, { @@ -166,18 +268,20 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ - "train_dataloader = soma_ml.experiment_dataloader(train_dataset)" + "from tiledbsoma_ml import experiment_dataloader\n", + "\n", + "train_dataloader = experiment_dataloader(train_dataset)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Instantiating a `DataLoader` object directly is not recommended, as several of its parameters interfere with iterable-style DataPipes like `ExperimentAxisQueryIterDataPipe`. Using `experiment_dataloader` helps enforce correct usage." + "Instantiating a `DataLoader` object directly is not recommended, as several of its parameters interfere with iterable-style DataPipes like `ExperimentDataset`. Using `experiment_dataloader` helps enforce correct usage." ] }, { @@ -191,13 +295,26 @@ }, { "cell_type": "code", - "execution_count": 5, - "metadata": {}, + "execution_count": 8, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, "outputs": [], "source": [ "import torch\n", "\n", - "\n", + "# For demo purposes only, seed Torch's RNG, so the model weights (and training result) is deterministic.\n", + "# Along with ExperimentDataset.{create,split}, this allows running this notebook and getting the same exact result.\n", + "if seed is not None:\n", + " torch.manual_seed(seed)\n", + " if torch.cuda.is_available():\n", + " torch.cuda.manual_seed(seed)\n", + " torch.backends.cudnn.deterministic = True\n", + " \n", "class LogisticRegression(torch.nn.Module):\n", " def __init__(self, input_dim, output_dim):\n", " super(LogisticRegression, self).__init__()\n", @@ -217,7 +334,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -227,7 +344,7 @@ " train_correct = 0\n", " train_total = 0\n", "\n", - " for X_batch, y_batch in train_dataloader:\n", + " for X_batch, obs_batch in train_dataloader:\n", " optimizer.zero_grad()\n", "\n", " X_batch = torch.from_numpy(X_batch).float().to(device)\n", @@ -240,11 +357,11 @@ " predictions = torch.argmax(probabilities, axis=1)\n", "\n", " # Compute the loss and perform back propagation\n", - " y_batch = torch.from_numpy(cell_type_encoder.transform(y_batch['cell_type'])).to(device)\n", - " train_correct += (predictions == y_batch).sum().item()\n", + " obs_batch = torch.from_numpy(cell_type_encoder.transform(obs_batch['cell_type'])).to(device)\n", + " train_correct += (predictions == obs_batch).sum().item()\n", " train_total += len(predictions)\n", "\n", - " loss = loss_fn(outputs, y_batch.long())\n", + " loss = loss_fn(outputs, obs_batch.long())\n", " train_loss += loss.item()\n", " loss.backward()\n", " optimizer.step()\n", @@ -258,7 +375,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Note the line, `X_batch, y_batch = batch`. Since the `train_dataloader` was configured with `batch_size=16`, these variables will hold tensors of rank 2. The `X_batch` tensor will appear, for example, as:\n", + "Note the line, `X_batch, obs_batch = batch`. Since the `train_dataloader` was configured with `batch_size=16`, these variables will hold tensors of rank 2. The `X_batch` tensor will appear, for example, as:\n", "\n", "```\n", "tensor([[0., 0., 0., ..., 1., 0., 0.],\n", @@ -277,7 +394,7 @@ "tensor([0., 0., 0., ..., 1., 0., 0.])\n", "```\n", " \n", - "For `y_batch`, this will contain the user-specified `obs` `cell_type` training labels. By default, these are encoded using a LabelEncoder and it will be a matrix where each column represents the encoded values of each column specified in `obs_column_names` when creating the datapipe (in this case, only the cell type). It will look like this:\n", + "For `obs_batch`, this will contain the user-specified `obs` `cell_type` training labels. By default, these are encoded using a LabelEncoder and it will be a matrix where each column represents the encoded values of each column specified in `obs_column_names` when creating the datapipe (in this case, only the cell type). It will look like this:\n", "\n", "```\n", "tensor([1, 1, 3, ..., 2, 1, 4])\n", @@ -297,33 +414,39 @@ }, { "cell_type": "code", - "execution_count": 7, - "metadata": {}, + "execution_count": 10, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Epoch 1: Train Loss: 0.0171090 Accuracy 0.1798\n", - "Epoch 2: Train Loss: 0.0151506 Accuracy 0.3480\n", - "Epoch 3: Train Loss: 0.0146299 Accuracy 0.4174\n", - "Epoch 4: Train Loss: 0.0142093 Accuracy 0.4765\n", - "Epoch 5: Train Loss: 0.0140261 Accuracy 0.5111\n", - "Epoch 6: Train Loss: 0.0138939 Accuracy 0.5634\n", - "Epoch 7: Train Loss: 0.0137783 Accuracy 0.6182\n", - "Epoch 8: Train Loss: 0.0136766 Accuracy 0.7050\n", - "Epoch 9: Train Loss: 0.0135647 Accuracy 0.8293\n", - "Epoch 10: Train Loss: 0.0134729 Accuracy 0.8793\n", - "Epoch 11: Train Loss: 0.0133968 Accuracy 0.8938\n", - "Epoch 12: Train Loss: 0.0133453 Accuracy 0.9013\n", - "Epoch 13: Train Loss: 0.0133143 Accuracy 0.9047\n", - "Epoch 14: Train Loss: 0.0132873 Accuracy 0.9102\n", - "Epoch 15: Train Loss: 0.0132666 Accuracy 0.9176\n", - "Epoch 16: Train Loss: 0.0132246 Accuracy 0.9219\n", - "Epoch 17: Train Loss: 0.0132161 Accuracy 0.9230\n", - "Epoch 18: Train Loss: 0.0131877 Accuracy 0.9295\n", - "Epoch 19: Train Loss: 0.0131658 Accuracy 0.9344\n", - "Epoch 20: Train Loss: 0.0131338 Accuracy 0.9382\n" + "Epoch 1: Train Loss: 0.0176823 Accuracy 0.1494\n", + "Epoch 2: Train Loss: 0.0151293 Accuracy 0.2636\n", + "Epoch 3: Train Loss: 0.0147051 Accuracy 0.3770\n", + "Epoch 4: Train Loss: 0.0143555 Accuracy 0.4779\n", + "Epoch 5: Train Loss: 0.0140985 Accuracy 0.5173\n", + "Epoch 6: Train Loss: 0.0139185 Accuracy 0.5474\n", + "Epoch 7: Train Loss: 0.0137876 Accuracy 0.5905\n", + "Epoch 8: Train Loss: 0.0136877 Accuracy 0.6322\n", + "Epoch 9: Train Loss: 0.0136219 Accuracy 0.6462\n", + "Epoch 10: Train Loss: 0.0135693 Accuracy 0.6522\n", + "Epoch 11: Train Loss: 0.0135283 Accuracy 0.6532\n", + "Epoch 12: Train Loss: 0.0134948 Accuracy 0.6547\n", + "Epoch 13: Train Loss: 0.0134677 Accuracy 0.6563\n", + "Epoch 14: Train Loss: 0.0134442 Accuracy 0.6570\n", + "Epoch 15: Train Loss: 0.0134219 Accuracy 0.6614\n", + "Epoch 16: Train Loss: 0.0134028 Accuracy 0.6660\n", + "Epoch 17: Train Loss: 0.0133850 Accuracy 0.6734\n", + "Epoch 18: Train Loss: 0.0133693 Accuracy 0.6922\n", + "Epoch 19: Train Loss: 0.0133531 Accuracy 0.7233\n", + "Epoch 20: Train Loss: 0.0133380 Accuracy 0.7388\n" ] } ], @@ -338,9 +461,9 @@ "\n", "model = LogisticRegression(input_dim, output_dim).to(device)\n", "loss_fn = torch.nn.CrossEntropyLoss()\n", - "optimizer = torch.optim.Adam(model.parameters(), lr=1e-05)\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n", "\n", - "for epoch in range(20):\n", + "for epoch in range(n_epochs):\n", " train_loss, train_accuracy = train_epoch(model, train_dataloader, loss_fn, optimizer, device)\n", " print(f\"Epoch {epoch + 1}: Train Loss: {train_loss:.7f} Accuracy {train_accuracy:.4f}\")" ] @@ -356,14 +479,14 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ - "test_dataloader = soma_ml.experiment_dataloader(test_dataset)\n", - "X_batch, y_batch = next(iter(test_dataloader))\n", + "test_dataloader = experiment_dataloader(test_dataset)\n", + "X_batch, obs_batch = next(iter(test_dataloader))\n", "X_batch = torch.from_numpy(X_batch)\n", - "y_batch = torch.from_numpy(cell_type_encoder.transform(y_batch['cell_type']))" + "true_cell_types = torch.from_numpy(cell_type_encoder.transform(obs_batch['cell_type']))" ] }, { @@ -375,24 +498,25 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "tensor([ 8, 1, 1, 1, 1, 1, 1, 8, 8, 5, 1, 7, 8, 1, 1, 1, 1, 7,\n", - " 7, 8, 1, 1, 5, 5, 1, 8, 1, 1, 1, 7, 8, 7, 7, 7, 8, 7,\n", - " 5, 1, 1, 8, 1, 5, 8, 5, 1, 11, 1, 7, 1, 1, 5, 5, 1, 11,\n", - " 1, 6, 8, 5, 1, 8, 11, 8, 1, 8, 1, 8, 1, 5, 1, 1, 1, 8,\n", - " 8, 7, 5, 1, 1, 8, 1, 7, 2, 1, 7, 1, 5, 1, 1, 7, 1, 8,\n", - " 1, 1, 1, 7, 7, 1, 1, 1, 7, 1, 1, 7, 7, 5, 7, 8, 5, 1,\n", - " 5, 1, 5, 5, 5, 1, 1, 1, 8, 5, 1, 1, 7, 8, 1, 1, 1, 1,\n", - " 8, 1], device='cuda:0')" + "tensor([ 1, 1, 1, 1, 7, 1, 11, 1, 6, 7, 1, 1, 1, 8, 1, 1, 1, 11,\n", + " 1, 1, 8, 1, 1, 1, 7, 5, 1, 1, 1, 1, 8, 1, 8, 8, 1, 1,\n", + " 1, 8, 1, 1, 1, 1, 1, 11, 1, 1, 7, 1, 1, 1, 7, 5, 8, 5,\n", + " 1, 1, 1, 1, 1, 9, 1, 1, 1, 1, 8, 5, 1, 1, 9, 7, 1, 1,\n", + " 7, 8, 1, 1, 1, 1, 1, 7, 11, 1, 9, 1, 8, 8, 1, 7, 1, 5,\n", + " 1, 7, 7, 1, 1, 7, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", + " 1, 1, 11, 1, 1, 1, 1, 1, 11, 1, 8, 1, 1, 8, 1, 1, 7, 1,\n", + " 5, 1], device='cuda:0')" ] }, + "execution_count": 12, "metadata": {}, - "output_type": "display_data" + "output_type": "execute_result" } ], "source": [ @@ -403,8 +527,7 @@ "\n", "probabilities = torch.nn.functional.softmax(outputs, 1)\n", "predictions = torch.argmax(probabilities, axis=1)\n", - "\n", - "display(predictions)" + "predictions" ] }, { @@ -413,60 +536,60 @@ "source": [ "The predictions are returned as the encoded values of `cell_type` label. To recover the original cell type labels as strings, we decode using the same `LabelEncoder` used for training.\n", "\n", - "At inference time, if the model inputs are not obtained via an `ExperimentAxisQueryIterDataPipe`, one could pickle the encoder at training time and save it along with the model. Then, at inference time it can be unpickled and used as shown below." + "At inference time, if the model inputs are not obtained via an `ExperimentDataset`, one could pickle the encoder at training time and save it along with the model. Then, at inference time it can be unpickled and used as shown below." ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "array(['leukocyte', 'basal cell', 'basal cell', 'basal cell',\n", - " 'basal cell', 'basal cell', 'basal cell', 'leukocyte', 'leukocyte',\n", - " 'epithelial cell', 'basal cell', 'keratinocyte', 'leukocyte',\n", - " 'basal cell', 'basal cell', 'basal cell', 'basal cell',\n", - " 'keratinocyte', 'keratinocyte', 'leukocyte', 'basal cell',\n", - " 'basal cell', 'epithelial cell', 'epithelial cell', 'basal cell',\n", - " 'leukocyte', 'basal cell', 'basal cell', 'basal cell',\n", - " 'keratinocyte', 'leukocyte', 'keratinocyte', 'keratinocyte',\n", - " 'keratinocyte', 'leukocyte', 'keratinocyte', 'epithelial cell',\n", + "array(['basal cell', 'basal cell', 'basal cell', 'basal cell',\n", + " 'keratinocyte', 'basal cell', 'vein endothelial cell',\n", + " 'basal cell', 'fibroblast', 'keratinocyte', 'basal cell',\n", " 'basal cell', 'basal cell', 'leukocyte', 'basal cell',\n", - " 'epithelial cell', 'leukocyte', 'epithelial cell', 'basal cell',\n", - " 'vein endothelial cell', 'basal cell', 'keratinocyte',\n", - " 'basal cell', 'basal cell', 'epithelial cell', 'epithelial cell',\n", - " 'basal cell', 'vein endothelial cell', 'basal cell', 'fibroblast',\n", - " 'leukocyte', 'epithelial cell', 'basal cell', 'leukocyte',\n", - " 'vein endothelial cell', 'leukocyte', 'basal cell', 'leukocyte',\n", - " 'basal cell', 'leukocyte', 'basal cell', 'epithelial cell',\n", - " 'basal cell', 'basal cell', 'basal cell', 'leukocyte', 'leukocyte',\n", - " 'keratinocyte', 'epithelial cell', 'basal cell', 'basal cell',\n", - " 'leukocyte', 'basal cell', 'keratinocyte',\n", - " 'capillary endothelial cell', 'basal cell', 'keratinocyte',\n", - " 'basal cell', 'epithelial cell', 'basal cell', 'basal cell',\n", - " 'keratinocyte', 'basal cell', 'leukocyte', 'basal cell',\n", - " 'basal cell', 'basal cell', 'keratinocyte', 'keratinocyte',\n", - " 'basal cell', 'basal cell', 'basal cell', 'keratinocyte',\n", - " 'basal cell', 'basal cell', 'keratinocyte', 'keratinocyte',\n", - " 'epithelial cell', 'keratinocyte', 'leukocyte', 'epithelial cell',\n", - " 'basal cell', 'epithelial cell', 'basal cell', 'epithelial cell',\n", - " 'epithelial cell', 'epithelial cell', 'basal cell', 'basal cell',\n", - " 'basal cell', 'leukocyte', 'epithelial cell', 'basal cell',\n", - " 'basal cell', 'keratinocyte', 'leukocyte', 'basal cell',\n", + " 'basal cell', 'basal cell', 'vein endothelial cell', 'basal cell',\n", + " 'basal cell', 'leukocyte', 'basal cell', 'basal cell',\n", + " 'basal cell', 'keratinocyte', 'epithelial cell', 'basal cell',\n", " 'basal cell', 'basal cell', 'basal cell', 'leukocyte',\n", - " 'basal cell'], dtype=object)" + " 'basal cell', 'leukocyte', 'leukocyte', 'basal cell', 'basal cell',\n", + " 'basal cell', 'leukocyte', 'basal cell', 'basal cell',\n", + " 'basal cell', 'basal cell', 'basal cell', 'vein endothelial cell',\n", + " 'basal cell', 'basal cell', 'keratinocyte', 'basal cell',\n", + " 'basal cell', 'basal cell', 'keratinocyte', 'epithelial cell',\n", + " 'leukocyte', 'epithelial cell', 'basal cell', 'basal cell',\n", + " 'basal cell', 'basal cell', 'basal cell', 'pericyte', 'basal cell',\n", + " 'basal cell', 'basal cell', 'basal cell', 'leukocyte',\n", + " 'epithelial cell', 'basal cell', 'basal cell', 'pericyte',\n", + " 'keratinocyte', 'basal cell', 'basal cell', 'keratinocyte',\n", + " 'leukocyte', 'basal cell', 'basal cell', 'basal cell',\n", + " 'basal cell', 'basal cell', 'keratinocyte',\n", + " 'vein endothelial cell', 'basal cell', 'pericyte', 'basal cell',\n", + " 'leukocyte', 'leukocyte', 'basal cell', 'keratinocyte',\n", + " 'basal cell', 'epithelial cell', 'basal cell', 'keratinocyte',\n", + " 'keratinocyte', 'basal cell', 'basal cell', 'keratinocyte',\n", + " 'basal cell', 'basal cell', 'basal cell', 'basal cell',\n", + " 'basal cell', 'basal cell', 'basal cell', 'basal cell',\n", + " 'basal cell', 'basal cell', 'basal cell', 'basal cell',\n", + " 'basal cell', 'basal cell', 'vein endothelial cell', 'basal cell',\n", + " 'basal cell', 'basal cell', 'basal cell', 'basal cell',\n", + " 'vein endothelial cell', 'basal cell', 'leukocyte', 'basal cell',\n", + " 'basal cell', 'leukocyte', 'basal cell', 'basal cell',\n", + " 'keratinocyte', 'basal cell', 'epithelial cell', 'basal cell'],\n", + " dtype=object)" ] }, + "execution_count": 13, "metadata": {}, - "output_type": "display_data" + "output_type": "execute_result" } ], "source": [ "predicted_cell_types = cell_type_encoder.inverse_transform(predictions.cpu())\n", - "\n", - "display(predicted_cell_types)" + "predicted_cell_types" ] }, { @@ -478,7 +601,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 14, "metadata": {}, "outputs": [ { @@ -502,15 +625,15 @@ " \n", " \n", " \n", - " actual cell type\n", + " true cell type\n", " predicted cell type\n", " \n", " \n", " \n", " \n", " 0\n", - " leukocyte\n", - " leukocyte\n", + " basal cell\n", + " basal cell\n", " \n", " \n", " 1\n", @@ -529,8 +652,8 @@ " \n", " \n", " 4\n", - " basal cell\n", - " basal cell\n", + " keratinocyte\n", + " keratinocyte\n", " \n", " \n", " ...\n", @@ -539,27 +662,27 @@ " \n", " \n", " 123\n", - " fibroblast\n", + " keratinocyte\n", " basal cell\n", " \n", " \n", " 124\n", - " basal cell\n", - " basal cell\n", + " keratinocyte\n", + " keratinocyte\n", " \n", " \n", " 125\n", - " keratinocyte\n", + " epithelial cell\n", " basal cell\n", " \n", " \n", " 126\n", - " leukocyte\n", - " leukocyte\n", + " epithelial cell\n", + " epithelial cell\n", " \n", " \n", " 127\n", - " basal cell\n", + " keratinocyte\n", " basal cell\n", " \n", " \n", @@ -568,43 +691,207 @@ "" ], "text/plain": [ - " actual cell type predicted cell type\n", - "0 leukocyte leukocyte\n", + " true cell type predicted cell type\n", + "0 basal cell basal cell\n", "1 basal cell basal cell\n", "2 basal cell basal cell\n", "3 basal cell basal cell\n", - "4 basal cell basal cell\n", + "4 keratinocyte keratinocyte\n", ".. ... ...\n", - "123 fibroblast basal cell\n", - "124 basal cell basal cell\n", - "125 keratinocyte basal cell\n", - "126 leukocyte leukocyte\n", - "127 basal cell basal cell\n", + "123 keratinocyte basal cell\n", + "124 keratinocyte keratinocyte\n", + "125 epithelial cell basal cell\n", + "126 epithelial cell epithelial cell\n", + "127 keratinocyte basal cell\n", "\n", "[128 rows x 2 columns]" ] }, + "execution_count": 14, "metadata": {}, - "output_type": "display_data" + "output_type": "execute_result" } ], "source": [ "import pandas as pd\n", "\n", - "display(\n", - " pd.DataFrame(\n", - " {\n", - " \"actual cell type\": cell_type_encoder.inverse_transform(y_batch.ravel().numpy()),\n", - " \"predicted cell type\": predicted_cell_types,\n", - " }\n", - " )\n", - ")" + "batch_cmp_df = pd.DataFrame({\n", + " \"true cell type\": cell_type_encoder.inverse_transform(true_cell_types.ravel().numpy()),\n", + " \"predicted cell type\": predicted_cell_types,\n", + "})\n", + "batch_cmp_df" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
predicted cell typebasal cellepithelial cellfibroblastkeratinocyteleukocytepericytevein endothelial cell
true cell type
basal cell59
capillary endothelial cell1
epithelial cell116
fibroblast1
keratinocyte1513
leukocyte113
pericyte3
vein endothelial cell5
\n", + "
" + ], + "text/plain": [ + "predicted cell type basal cell epithelial cell fibroblast keratinocyte \\\n", + "true cell type \n", + "basal cell 59 \n", + "capillary endothelial cell \n", + "epithelial cell 11 6 \n", + "fibroblast 1 \n", + "keratinocyte 15 13 \n", + "leukocyte 1 \n", + "pericyte \n", + "vein endothelial cell \n", + "\n", + "predicted cell type leukocyte pericyte vein endothelial cell \n", + "true cell type \n", + "basal cell \n", + "capillary endothelial cell 1 \n", + "epithelial cell \n", + "fibroblast \n", + "keratinocyte \n", + "leukocyte 13 \n", + "pericyte 3 \n", + "vein endothelial cell 5 " + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pd.crosstab(\n", + " batch_cmp_df[\"true cell type\"],\n", + " batch_cmp_df[\"predicted cell type\"],\n", + ").replace(0, '')" ] } ], "metadata": { "kernelspec": { - "display_name": "tiledbsoma-dev", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -618,9 +905,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.9" + "version": "3.12.7" } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 5 } diff --git a/pyproject.toml b/pyproject.toml index 246af44..183195b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,3 +70,4 @@ lint.extend-select = ["I001"] # unsorted-imports fix = true target-version = "py311" line-length = 120 +exclude = ["*.ipynb"] # Changes cell IDs unnecessarily