diff --git a/notebooks/tutorial_lightning.ipynb b/notebooks/tutorial_lightning.ipynb
index ac8a55b..01551be 100644
--- a/notebooks/tutorial_lightning.ipynb
+++ b/notebooks/tutorial_lightning.ipynb
@@ -28,43 +28,74 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "## Initialize SOMA Experiment query as training data"
+ "[Papermill] parameters:\n",
+ "\n",
+ "[Papermill]: https://papermill.readthedocs.io/"
]
},
{
"cell_type": "code",
"execution_count": 1,
+ "metadata": {
+ "editable": true,
+ "slideshow": {
+ "slide_type": ""
+ },
+ "tags": [
+ "parameters"
+ ]
+ },
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "\n",
+ "tissue = \"tongue\"\n",
+ "n_epochs = 20\n",
+ "census_version = \"2024-07-01\"\n",
+ "batch_size = 128\n",
+ "learning_rate = 1e-5\n",
+ "progress_bar = not bool(os.environ['PAPERMILL']) # Defaults to True, unless env var $PAPERMILL is set"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Initialize SOMA Experiment query as training data"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
- "import pytorch_lightning as pl\n",
- "import tiledbsoma as soma\n",
- "import torch\n",
+ "from tiledbsoma import AxisQuery, Experiment, SOMATileDBContext\n",
"from sklearn.preprocessing import LabelEncoder\n",
"\n",
- "import tiledbsoma_ml as soma_ml\n",
+ "from tiledbsoma_ml import ExperimentDataset\n",
"\n",
- "CZI_Census_Homo_Sapiens_URL = \"s3://cellxgene-census-public-us-west-2/cell-census/2024-07-01/soma/census_data/homo_sapiens/\"\n",
+ "CZI_Census_Homo_Sapiens_URL = f\"s3://cellxgene-census-public-us-west-2/cell-census/{census_version}/soma/census_data/homo_sapiens/\"\n",
"\n",
- "experiment = soma.open(\n",
+ "experiment = Experiment.open(\n",
" CZI_Census_Homo_Sapiens_URL,\n",
- " context=soma.SOMATileDBContext(tiledb_config={\"vfs.s3.region\": \"us-west-2\", \"vfs.s3.no_sign_request\": \"true\"}),\n",
+ " context=SOMATileDBContext(tiledb_config={\"vfs.s3.region\": \"us-west-2\", \"vfs.s3.no_sign_request\": \"true\"}),\n",
")\n",
- "obs_value_filter = \"tissue_general == 'tongue' and is_primary_data == True\"\n",
+ "obs_value_filter = f\"tissue_general == '{tissue}' and is_primary_data == True\"\n",
"\n",
"with experiment.axis_query(\n",
- " measurement_name=\"RNA\", obs_query=soma.AxisQuery(value_filter=obs_value_filter)\n",
+ " measurement_name=\"RNA\", obs_query=AxisQuery(value_filter=obs_value_filter)\n",
") as query:\n",
" obs_df = query.obs(column_names=[\"cell_type\"]).concat().to_pandas()\n",
" cell_type_encoder = LabelEncoder().fit(obs_df[\"cell_type\"].unique())\n",
"\n",
- " experiment_dataset = soma_ml.ExperimentDataset(\n",
- " query,\n",
- " layer_name=\"raw\",\n",
- " obs_column_names=[\"cell_type\"],\n",
- " batch_size=128,\n",
- " shuffle=True,\n",
- " )"
+ "experiment_dataset = ExperimentDataset.create(\n",
+ " query,\n",
+ " layer_name=\"raw\",\n",
+ " obs_column_names=[\"cell_type\"],\n",
+ " batch_size=batch_size,\n",
+ " shuffle=True,\n",
+ ")"
]
},
{
@@ -76,12 +107,15 @@
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
+ "import torch\n",
+ "import pytorch_lightning as pl\n",
+ "\n",
"class LogisticRegressionLightning(pl.LightningModule):\n",
- " def __init__(self, input_dim, output_dim, cell_type_encoder, learning_rate=1e-5):\n",
+ " def __init__(self, input_dim, output_dim, cell_type_encoder, learning_rate=learning_rate):\n",
" super(LogisticRegressionLightning, self).__init__()\n",
" self.linear = torch.nn.Linear(input_dim, output_dim)\n",
" self.cell_type_encoder = cell_type_encoder\n",
@@ -134,7 +168,7 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 4,
"metadata": {},
"outputs": [
{
@@ -142,9 +176,42 @@
"output_type": "stream",
"text": [
"GPU available: True (cuda), used: True\n",
- "TPU available: False, using: 0 TPU cores\n",
- "HPU available: False, using: 0 HPUs\n",
- "/home/bruce/miniforge3/envs/toymodel/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default\n",
+ "TPU available: False, using: 0 TPU cores\n"
+ ]
+ }
+ ],
+ "source": [
+ "from tiledbsoma_ml import experiment_dataloader\n",
+ "\n",
+ "dataloader = experiment_dataloader(experiment_dataset)\n",
+ "\n",
+ "# The size of the input dimension is the number of genes\n",
+ "input_dim = experiment_dataset.shape[1]\n",
+ "\n",
+ "# The size of the output dimension is the number of distinct cell_type values\n",
+ "output_dim = len(cell_type_encoder.classes_)\n",
+ "\n",
+ "# Initialize the PyTorch Lightning model\n",
+ "model = LogisticRegressionLightning(\n",
+ " input_dim, output_dim, cell_type_encoder=cell_type_encoder\n",
+ ")\n",
+ "\n",
+ "# Define the PyTorch Lightning Trainer\n",
+ "trainer = pl.Trainer(max_epochs=n_epochs, enable_progress_bar=progress_bar)\n",
+ "\n",
+ "# set precision\n",
+ "torch.set_float32_matmul_precision(\"high\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
"LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
"\n",
" | Name | Type | Params | Mode \n",
@@ -157,53 +224,20 @@
"726 K Total params\n",
"2.905 Total estimated model params size (MB)\n",
"2 Modules in train mode\n",
- "0 Modules in eval mode\n",
- "/home/bruce/miniforge3/envs/toymodel/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.\n",
- "/home/bruce/miniforge3/envs/toymodel/lib/python3.11/site-packages/pytorch_lightning/utilities/data.py:122: Your `IterableDataset` has `__len__` defined. In combination with multi-process data loading (when num_workers > 1), `__len__` could be inaccurate if each worker is not configured independently to avoid having duplicate data.\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 19: 100%|██████████| 118/118 [00:08<00:00, 14.31it/s, v_num=5, train_loss=1.670, train_accuracy=0.977]"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "`Trainer.fit` stopped: `max_epochs=20` reached.\n"
+ "0 Modules in eval mode\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "Epoch 19: 100%|██████████| 118/118 [00:08<00:00, 14.28it/s, v_num=5, train_loss=1.670, train_accuracy=0.977]\n"
+ "CPU times: user 3min 30s, sys: 1min 25s, total: 4min 55s\n",
+ "Wall time: 2min 14s\n"
]
}
],
"source": [
- "dataloader = soma_ml.experiment_dataloader(experiment_dataset)\n",
- "\n",
- "# The size of the input dimension is the number of genes\n",
- "input_dim = experiment_dataset.shape[1]\n",
- "\n",
- "# The size of the output dimension is the number of distinct cell_type values\n",
- "output_dim = len(cell_type_encoder.classes_)\n",
- "\n",
- "# Initialize the PyTorch Lightning model\n",
- "model = LogisticRegressionLightning(\n",
- " input_dim, output_dim, cell_type_encoder=cell_type_encoder\n",
- ")\n",
- "\n",
- "# Define the PyTorch Lightning Trainer\n",
- "trainer = pl.Trainer(max_epochs=20)\n",
- "\n",
- "# set precision\n",
- "torch.set_float32_matmul_precision(\"high\")\n",
- "\n",
+ "%%time\n",
"# Train the model\n",
"trainer.fit(model, train_dataloaders=dataloader)"
]
@@ -211,7 +245,7 @@
],
"metadata": {
"kernelspec": {
- "display_name": "toymodel",
+ "display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
@@ -225,9 +259,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.11.9"
+ "version": "3.12.7"
}
},
"nbformat": 4,
- "nbformat_minor": 2
+ "nbformat_minor": 5
}
diff --git a/notebooks/tutorial_multiworker.ipynb b/notebooks/tutorial_multiworker.ipynb
index 7493407..8010c12 100644
--- a/notebooks/tutorial_multiworker.ipynb
+++ b/notebooks/tutorial_multiworker.ipynb
@@ -2,15 +2,21 @@
"cells": [
{
"cell_type": "markdown",
- "metadata": {},
+ "metadata": {
+ "editable": true,
+ "slideshow": {
+ "slide_type": ""
+ },
+ "tags": []
+ },
"source": [
"# Multi-process training\n",
"\n",
- "Multi-process usage of `tiledbsoma_ml.ExperimentAxisQueryIterDataset` includes both:\n",
+ "Multi-process usage of `tiledbsoma_ml.ExperimentDataset` includes both:\n",
"* using the [`torch.utils.data.DataLoader`] with 1 or more workers (i.e., with an argument of `n_workers=1` or greater)\n",
"* using a multi-process training configuration, such as [`DistributedDataParallel`]\n",
"\n",
- "In these configurations, `ExperimentAxisQueryIterDataset` will automatically partition data across workers. However, when using `shuffle=True`, there are several things to keep in mind:\n",
+ "In these configurations, `ExperimentDataset` will automatically partition data across workers. However, when using `shuffle=True`, there are several things to keep in mind:\n",
"\n",
"1. All worker processes must share the same random number generator `seed`, ensuring that all workers shuffle and partition the data in the same way.\n",
"2. To ensure that each epoch returns a _different_ shuffle, the caller must set the epoch, using the `set_epoch` API. This is identical to the behavior of [`torch.utils.data.distributed.DistributedSampler`].\n",
@@ -22,50 +28,83 @@
]
},
{
+ "cell_type": "markdown",
"metadata": {},
+ "source": [
+ "[Papermill] parameters:\n",
+ "\n",
+ "[Papermill]: https://papermill.readthedocs.io/"
+ ]
+ },
+ {
"cell_type": "code",
+ "execution_count": 1,
+ "metadata": {
+ "editable": true,
+ "slideshow": {
+ "slide_type": ""
+ },
+ "tags": [
+ "parameters"
+ ]
+ },
"outputs": [],
- "execution_count": null,
"source": [
- "import tiledbsoma as soma\n",
- "import torch\n",
+ "import os\n",
+ "\n",
+ "tissue = \"tongue\"\n",
+ "n_epochs = 20\n",
+ "census_version = \"2024-07-01\"\n",
+ "batch_size = 128\n",
+ "learning_rate = 1e-5\n",
+ "num_workers = 2"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from tiledbsoma import AxisQuery, Experiment, SOMATileDBContext\n",
"from sklearn.preprocessing import LabelEncoder\n",
"\n",
- "import tiledbsoma_ml as soma_ml\n",
+ "from tiledbsoma_ml import ExperimentDataset\n",
"\n",
- "CZI_Census_Homo_Sapiens_URL = \"s3://cellxgene-census-public-us-west-2/cell-census/2024-07-01/soma/census_data/homo_sapiens/\"\n",
+ "CZI_Census_Homo_Sapiens_URL = f\"s3://cellxgene-census-public-us-west-2/cell-census/{census_version}/soma/census_data/homo_sapiens/\"\n",
"\n",
- "experiment = soma.open(\n",
+ "experiment = Experiment.open(\n",
" CZI_Census_Homo_Sapiens_URL,\n",
- " context=soma.SOMATileDBContext(tiledb_config={\"vfs.s3.region\": \"us-west-2\", \"vfs.s3.no_sign_request\": \"true\"}),\n",
+ " context=SOMATileDBContext(tiledb_config={\"vfs.s3.region\": \"us-west-2\", \"vfs.s3.no_sign_request\": \"true\"}),\n",
")\n",
- "obs_value_filter = \"tissue_general == 'tongue' and is_primary_data == True\"\n",
+ "obs_value_filter = f\"tissue_general == '{tissue}' and is_primary_data == True\"\n",
"\n",
"with experiment.axis_query(\n",
- " measurement_name=\"RNA\", obs_query=soma.AxisQuery(value_filter=obs_value_filter)\n",
+ " measurement_name=\"RNA\", obs_query=AxisQuery(value_filter=obs_value_filter)\n",
") as query:\n",
" obs_df = query.obs(column_names=[\"cell_type\"]).concat().to_pandas()\n",
" cell_type_encoder = LabelEncoder().fit(obs_df[\"cell_type\"].unique())\n",
"\n",
- " experiment_dataset = soma_ml.ExperimentDataset(\n",
- " query,\n",
- " layer_name=\"raw\",\n",
- " obs_column_names=[\"cell_type\"],\n",
- " batch_size=128,\n",
- " shuffle=True,\n",
- " )\n",
- " "
+ "experiment_dataset = ExperimentDataset.create(\n",
+ " query,\n",
+ " layer_name=\"raw\",\n",
+ " obs_column_names=[\"cell_type\"],\n",
+ " batch_size=batch_size,\n",
+ " shuffle=True,\n",
+ ")"
]
},
{
- "metadata": {},
"cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
"outputs": [],
- "execution_count": null,
"source": [
+ "import torch\n",
+ "\n",
"class LogisticRegression(torch.nn.Module):\n",
" def __init__(self, input_dim, output_dim):\n",
- " super(LogisticRegression, self).__init__() # noqa: UP008\n",
+ " super(LogisticRegression, self).__init__()\n",
" self.linear = torch.nn.Linear(input_dim, output_dim)\n",
"\n",
" def forward(self, x):\n",
@@ -79,7 +118,7 @@
" train_correct = 0\n",
" train_total = 0\n",
"\n",
- " for X_batch, y_batch in train_dataloader:\n",
+ " for X_batch, obs_batch in train_dataloader:\n",
" optimizer.zero_grad()\n",
"\n",
" X_batch = torch.from_numpy(X_batch).float().to(device)\n",
@@ -92,11 +131,11 @@
" predictions = torch.argmax(probabilities, axis=1)\n",
"\n",
" # Compute the loss and perform back propagation\n",
- " y_batch = torch.from_numpy(cell_type_encoder.transform(y_batch['cell_type'])).to(device)\n",
- " train_correct += (predictions == y_batch).sum().item()\n",
+ " obs_batch = torch.from_numpy(cell_type_encoder.transform(obs_batch['cell_type'])).to(device)\n",
+ " train_correct += (predictions == obs_batch).sum().item()\n",
" train_total += len(predictions)\n",
"\n",
- " loss = loss_fn(outputs, y_batch.long())\n",
+ " loss = loss_fn(outputs, obs_batch.long())\n",
" train_loss += loss.item()\n",
" loss.backward()\n",
" optimizer.step()\n",
@@ -116,42 +155,25 @@
"\n",
"The same approach should be taken for parallel training, e.g., when using DDP or DP.\n",
"\n",
- "*Tip*: when running with `num_workers=0`, i.e., using the data loader in-process, the `ExperimentAxisQueryIterDataset` will automatically increment the epoch count each time the iterator completes."
+ "*Tip*: when running with `num_workers=0`, i.e., using the data loader in-process, the `ExperimentDataset` will automatically increment the epoch count each time the iterator completes."
]
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 4,
"metadata": {},
"outputs": [
{
- "name": "stdout",
+ "name": "stderr",
"output_type": "stream",
"text": [
- "Epoch 1: Train Loss: 0.0169229 Accuracy 0.3124\n",
- "Epoch 2: Train Loss: 0.0148674 Accuracy 0.4272\n",
- "Epoch 3: Train Loss: 0.0144468 Accuracy 0.4509\n",
- "Epoch 4: Train Loss: 0.0141778 Accuracy 0.4999\n",
- "Epoch 5: Train Loss: 0.0139660 Accuracy 0.5619\n",
- "Epoch 6: Train Loss: 0.0137670 Accuracy 0.6971\n",
- "Epoch 7: Train Loss: 0.0136089 Accuracy 0.8670\n",
- "Epoch 8: Train Loss: 0.0135203 Accuracy 0.9099\n",
- "Epoch 9: Train Loss: 0.0134427 Accuracy 0.9262\n",
- "Epoch 10: Train Loss: 0.0133607 Accuracy 0.9300\n",
- "Epoch 11: Train Loss: 0.0133110 Accuracy 0.9348\n",
- "Epoch 12: Train Loss: 0.0132749 Accuracy 0.9378\n",
- "Epoch 13: Train Loss: 0.0132431 Accuracy 0.9413\n",
- "Epoch 14: Train Loss: 0.0132194 Accuracy 0.9444\n",
- "Epoch 15: Train Loss: 0.0131942 Accuracy 0.9465\n",
- "Epoch 16: Train Loss: 0.0131739 Accuracy 0.9499\n",
- "Epoch 17: Train Loss: 0.0131527 Accuracy 0.9526\n",
- "Epoch 18: Train Loss: 0.0131369 Accuracy 0.9551\n",
- "Epoch 19: Train Loss: 0.0131214 Accuracy 0.9563\n",
- "Epoch 20: Train Loss: 0.0131061 Accuracy 0.9578\n"
+ "switching torch multiprocessing start method from \"fork\" to \"spawn\"\n"
]
}
],
"source": [
+ "from tiledbsoma_ml import experiment_dataloader\n",
+ "\n",
"device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
"\n",
"# The size of the input dimension is the number of genes\n",
@@ -162,19 +184,55 @@
"\n",
"model = LogisticRegression(input_dim, output_dim).to(device)\n",
"loss_fn = torch.nn.CrossEntropyLoss()\n",
- "optimizer = torch.optim.Adam(model.parameters(), lr=1e-05)\n",
- "\n",
+ "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n",
"\n",
- "# define a two-worker data loader. The dataset is shuffled, so call `set_epoch` to ensure\n",
+ "# Define a two-worker data loader. The dataset is shuffled, so call `set_epoch` to ensure\n",
"# that a different shuffle is applied on each epoch.\n",
- "experiment_dataloader = soma_ml.experiment_dataloader(\n",
- " experiment_dataset, num_workers=2, persistent_workers=True\n",
- ")\n",
- "\n",
- "for epoch in range(20):\n",
+ "dataloader = experiment_dataloader(\n",
+ " experiment_dataset, num_workers=num_workers, persistent_workers=True\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 1: Train Loss: 0.0165012 Accuracy 0.3866\n",
+ "Epoch 2: Train Loss: 0.0148111 Accuracy 0.4217\n",
+ "Epoch 3: Train Loss: 0.0144168 Accuracy 0.6109\n",
+ "Epoch 4: Train Loss: 0.0141248 Accuracy 0.8374\n",
+ "Epoch 5: Train Loss: 0.0138151 Accuracy 0.9001\n",
+ "Epoch 6: Train Loss: 0.0136300 Accuracy 0.9123\n",
+ "Epoch 7: Train Loss: 0.0135218 Accuracy 0.9234\n",
+ "Epoch 8: Train Loss: 0.0134472 Accuracy 0.9324\n",
+ "Epoch 9: Train Loss: 0.0133907 Accuracy 0.9375\n",
+ "Epoch 10: Train Loss: 0.0133443 Accuracy 0.9419\n",
+ "Epoch 11: Train Loss: 0.0132998 Accuracy 0.9456\n",
+ "Epoch 12: Train Loss: 0.0132594 Accuracy 0.9489\n",
+ "Epoch 13: Train Loss: 0.0132298 Accuracy 0.9524\n",
+ "Epoch 14: Train Loss: 0.0132037 Accuracy 0.9549\n",
+ "Epoch 15: Train Loss: 0.0131809 Accuracy 0.9568\n",
+ "Epoch 16: Train Loss: 0.0131603 Accuracy 0.9585\n",
+ "Epoch 17: Train Loss: 0.0131425 Accuracy 0.9601\n",
+ "Epoch 18: Train Loss: 0.0131270 Accuracy 0.9613\n",
+ "Epoch 19: Train Loss: 0.0131112 Accuracy 0.9630\n",
+ "Epoch 20: Train Loss: 0.0130966 Accuracy 0.9639\n",
+ "CPU times: user 1min 6s, sys: 1min 58s, total: 3min 4s\n",
+ "Wall time: 4min 48s\n"
+ ]
+ }
+ ],
+ "source": [
+ "%%time\n",
+ "for epoch in range(n_epochs):\n",
" experiment_dataset.set_epoch(epoch)\n",
" train_loss, train_accuracy = train_epoch(\n",
- " model, experiment_dataloader, loss_fn, optimizer, device\n",
+ " model, dataloader, loss_fn, optimizer, device\n",
" )\n",
" print(\n",
" f\"Epoch {epoch + 1}: Train Loss: {train_loss:.7f} Accuracy {train_accuracy:.4f}\"\n",
@@ -184,7 +242,7 @@
],
"metadata": {
"kernelspec": {
- "display_name": "toymodel",
+ "display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
@@ -198,9 +256,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.11.9"
+ "version": "3.12.7"
}
},
"nbformat": 4,
- "nbformat_minor": 2
+ "nbformat_minor": 5
}
diff --git a/notebooks/tutorial_pytorch.ipynb b/notebooks/tutorial_pytorch.ipynb
index f19fafa..d84e8de 100644
--- a/notebooks/tutorial_pytorch.ipynb
+++ b/notebooks/tutorial_pytorch.ipynb
@@ -6,7 +6,7 @@
"source": [
"# Training a PyTorch Model\n",
"\n",
- "This tutorial shows how to train a Logistic Regression model in PyTorch using the `tiledbsoma.ml.ExperimentAxisQueryIterDataPipe` class, and the [CZI CELLxGENE Census](https://chanzuckerberg.github.io/cellxgene-census/) dataset. This is intended only to demonstrate the use of the `ExperimentAxisQueryIterDataPipe`, and not as an example of how to train a biologically useful model.\n",
+ "This tutorial trains a Logistic Regression model in PyTorch, using `tiledbsoma.ml.ExperimentDataset` and the [CZI CELLxGENE Census] dataset. This is intended only to demonstrate the use of `ExperimentDataset`, not as an example of how to train a biologically useful model.\n",
"\n",
"This tutorial assumes a basic familiarity with PyTorch and the Census API.\n",
"\n",
@@ -20,75 +20,118 @@
"\n",
"**Contents**\n",
"\n",
- "* [Create an ExperimentAxisQueryIterDataPipe](#data-pipe)\n",
+ "* [Create an ExperimentDataset](#data-pipe)\n",
"* [Split the dataset](#split)\n",
"* [Create the DataLoader](#data-loader)\n",
"* [Define the model](#model)\n",
"* [Train the model](#train)\n",
- "* [Make predictions with the model](#predict)"
+ "* [Make predictions with the model](#predict)\n",
+ "\n",
+ "[CZI CELLxGENE Census]: https://chanzuckerberg.github.io/cellxgene-census/"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "## Create an ExperimentAxisQueryIterDataPipe \n",
+ "[Papermill] parameters:\n",
+ "\n",
+ "[Papermill]: https://papermill.readthedocs.io/"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {
+ "editable": true,
+ "slideshow": {
+ "slide_type": ""
+ },
+ "tags": [
+ "parameters"
+ ]
+ },
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "\n",
+ "tissue = \"tongue\"\n",
+ "n_epochs = 20\n",
+ "census_version = \"2024-07-01\"\n",
+ "batch_size = 128\n",
+ "train_split = .8\n",
+ "seed = 111\n",
+ "learning_rate = 1e-5"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Create an ExperimentDataset \n",
"\n",
"To train a PyTorch model on a SOMA [Experiment]:\n",
"1. Open the Experiment.\n",
"2. Select the desired `obs` rows and `var` columns with an [ExperimentAxisQuery].\n",
- "3. Create an `ExperimentAxisQueryIterDataPipe`.\n",
+ "3. Create an `ExperimentDataset`.\n",
"\n",
"The example below utilizes a recent CZI Census release, accessed directly from S3. We also encode the `obs` `cell_type` labels, using a `scikit-learn` [LabelEncoder].\n",
"\n",
- "[Experiment]: https://tiledbsoma.readthedocs.io/en/stable/_autosummary/tiledbsoma.Experiment.html#tiledbsoma.Experiment\n",
- "[ExperimentAxisQuery]: https://tiledbsoma.readthedocs.io/en/stable/_autosummary/tiledbsoma.ExperimentAxisQuery.html\n",
+ "[Experiment]: https://tiledbsoma.readthedocs.io/en/stable/python-tiledbsoma-experiment.html\n",
+ "[ExperimentAxisQuery]: https://tiledbsoma.readthedocs.io/en/stable/python-tiledbsoma-experimentaxisquery.html\n",
"[LabelEncoder]: https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.LabelEncoder.html"
]
},
{
"cell_type": "code",
- "execution_count": 1,
- "metadata": {},
+ "execution_count": 2,
+ "metadata": {
+ "editable": true,
+ "slideshow": {
+ "slide_type": ""
+ },
+ "tags": []
+ },
"outputs": [],
"source": [
- "import tiledbsoma as soma\n",
+ "from tiledbsoma import AxisQuery, Experiment, SOMATileDBContext\n",
"from sklearn.preprocessing import LabelEncoder\n",
"\n",
- "import tiledbsoma_ml as soma_ml\n",
+ "from tiledbsoma_ml import ExperimentDataset\n",
"\n",
- "CZI_Census_Homo_Sapiens_URL = \"s3://cellxgene-census-public-us-west-2/cell-census/2024-07-01/soma/census_data/homo_sapiens/\"\n",
+ "CZI_Census_Homo_Sapiens_URL = f\"s3://cellxgene-census-public-us-west-2/cell-census/{census_version}/soma/census_data/homo_sapiens/\"\n",
"\n",
- "experiment = soma.open(\n",
+ "experiment = Experiment.open(\n",
" CZI_Census_Homo_Sapiens_URL,\n",
- " context=soma.SOMATileDBContext(tiledb_config={\"vfs.s3.region\": \"us-west-2\", \"vfs.s3.no_sign_request\": \"true\"}),\n",
+ " context=SOMATileDBContext(tiledb_config={\"vfs.s3.region\": \"us-west-2\", \"vfs.s3.no_sign_request\": \"true\"}),\n",
")\n",
- "obs_value_filter = \"tissue_general == 'tongue' and is_primary_data == True\"\n",
+ "obs_value_filter = f\"tissue_general == '{tissue}' and is_primary_data == True\"\n",
"\n",
"with experiment.axis_query(\n",
- " measurement_name=\"RNA\", obs_query=soma.AxisQuery(value_filter=obs_value_filter)\n",
+ " measurement_name=\"RNA\", obs_query=AxisQuery(value_filter=obs_value_filter)\n",
") as query:\n",
" obs_df = query.obs(column_names=[\"cell_type\"]).concat().to_pandas()\n",
" cell_type_encoder = LabelEncoder().fit(obs_df[\"cell_type\"].unique())\n",
"\n",
- " experiment_dataset = soma_ml.ExperimentAxisQueryIterDataPipe(\n",
- " query,\n",
- " layer_name=\"raw\",\n",
- " obs_column_names=[\"cell_type\"],\n",
- " batch_size=128,\n",
- " shuffle=True,\n",
- " )"
+ "experiment_dataset = ExperimentDataset.create(\n",
+ " query,\n",
+ " layer_name=\"raw\",\n",
+ " obs_column_names=[\"cell_type\"],\n",
+ " batch_size=batch_size,\n",
+ " shuffle=True,\n",
+ " seed=111,\n",
+ ")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "### `ExperimentAxisQueryIterDataPipe` class explained\n",
+ "### `ExperimentDataset` class explained\n",
"\n",
- "This class provides an implementation of PyTorch's [`torchdata` IterDataPipe interface][IterDataPipe], which defines a common mechanism for wrapping and accessing training data from any underlying source. The `ExperimentAxisQueryIterDataPipe` class encapsulates the details of querying a SOMA `Experiment` and returning a series of \"batches,\" each consisting of a NumPy `ndarray` and a Pandas `DataFrame`. Most importantly, it retrieves data lazily, avoiding loading the entire training dataset into memory at once.\n",
+ "This class provides an implementation of PyTorch's [`torchdata` IterDataPipe interface][IterDataPipe], which defines a common mechanism for wrapping and accessing training data from any underlying source. The `ExperimentDataset` class encapsulates the details of querying a SOMA `Experiment` and returning a series of \"batches,\" each consisting of a NumPy `ndarray` and a Pandas `DataFrame`. Most importantly, it retrieves data lazily, avoiding loading the entire training dataset into memory at once.\n",
"\n",
- "### `ExperimentAxisQueryIterDataPipe` parameters explained\n",
+ "### `ExperimentDataset` parameters explained\n",
"\n",
"The constructor only requires a single parameter, `query`, which is an [`ExperimentAxisQuery`] containing the data to be used for training. This is obtained by querying an [`Experiment`], along the `obs` and/or `var` axes (see above, or [the TileDB-SOMA docs][tdbs docs], for examples).\n",
"\n",
@@ -119,7 +162,7 @@
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": 3,
"metadata": {},
"outputs": [
{
@@ -128,7 +171,7 @@
"(118, 60530)"
]
},
- "execution_count": 2,
+ "execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
@@ -148,11 +191,70 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 4,
+ "metadata": {
+ "editable": true,
+ "slideshow": {
+ "slide_type": ""
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "((94, 60530), (24, 60530))"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "train_dataset, test_dataset = experiment_dataset.split(\n",
+ " train_split, 1 - train_split,\n",
+ " seed=1,\n",
+ ")\n",
+ "train_dataset.shape, test_dataset.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "((12016,), (3004,))"
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "t0 = train_dataset.query_ids.obs_joinids\n",
+ "t1 = test_dataset.query_ids.obs_joinids\n",
+ "t0.shape, t1.shape"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Quick check that train and test sets contain distinct `obs_joinids`:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
- "train_dataset, test_dataset = experiment_dataset.random_split(weights={\"train\": 0.8, \"test\": 0.2}, seed=1)"
+ "assert not set(train_dataset.obs_joinids) & set(test_dataset.obs_joinids)"
]
},
{
@@ -166,18 +268,20 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
- "train_dataloader = soma_ml.experiment_dataloader(train_dataset)"
+ "from tiledbsoma_ml import experiment_dataloader\n",
+ "\n",
+ "train_dataloader = experiment_dataloader(train_dataset)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "Instantiating a `DataLoader` object directly is not recommended, as several of its parameters interfere with iterable-style DataPipes like `ExperimentAxisQueryIterDataPipe`. Using `experiment_dataloader` helps enforce correct usage."
+ "Instantiating a `DataLoader` object directly is not recommended, as several of its parameters interfere with iterable-style DataPipes like `ExperimentDataset`. Using `experiment_dataloader` helps enforce correct usage."
]
},
{
@@ -191,13 +295,26 @@
},
{
"cell_type": "code",
- "execution_count": 5,
- "metadata": {},
+ "execution_count": 8,
+ "metadata": {
+ "editable": true,
+ "slideshow": {
+ "slide_type": ""
+ },
+ "tags": []
+ },
"outputs": [],
"source": [
"import torch\n",
"\n",
- "\n",
+ "# For demo purposes only, seed Torch's RNG, so the model weights (and training result) is deterministic.\n",
+ "# Along with ExperimentDataset.{create,split}, this allows running this notebook and getting the same exact result.\n",
+ "if seed is not None:\n",
+ " torch.manual_seed(seed)\n",
+ " if torch.cuda.is_available():\n",
+ " torch.cuda.manual_seed(seed)\n",
+ " torch.backends.cudnn.deterministic = True\n",
+ " \n",
"class LogisticRegression(torch.nn.Module):\n",
" def __init__(self, input_dim, output_dim):\n",
" super(LogisticRegression, self).__init__()\n",
@@ -217,7 +334,7 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
@@ -227,7 +344,7 @@
" train_correct = 0\n",
" train_total = 0\n",
"\n",
- " for X_batch, y_batch in train_dataloader:\n",
+ " for X_batch, obs_batch in train_dataloader:\n",
" optimizer.zero_grad()\n",
"\n",
" X_batch = torch.from_numpy(X_batch).float().to(device)\n",
@@ -240,11 +357,11 @@
" predictions = torch.argmax(probabilities, axis=1)\n",
"\n",
" # Compute the loss and perform back propagation\n",
- " y_batch = torch.from_numpy(cell_type_encoder.transform(y_batch['cell_type'])).to(device)\n",
- " train_correct += (predictions == y_batch).sum().item()\n",
+ " obs_batch = torch.from_numpy(cell_type_encoder.transform(obs_batch['cell_type'])).to(device)\n",
+ " train_correct += (predictions == obs_batch).sum().item()\n",
" train_total += len(predictions)\n",
"\n",
- " loss = loss_fn(outputs, y_batch.long())\n",
+ " loss = loss_fn(outputs, obs_batch.long())\n",
" train_loss += loss.item()\n",
" loss.backward()\n",
" optimizer.step()\n",
@@ -258,7 +375,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "Note the line, `X_batch, y_batch = batch`. Since the `train_dataloader` was configured with `batch_size=16`, these variables will hold tensors of rank 2. The `X_batch` tensor will appear, for example, as:\n",
+ "Note the line, `X_batch, obs_batch = batch`. Since the `train_dataloader` was configured with `batch_size=16`, these variables will hold tensors of rank 2. The `X_batch` tensor will appear, for example, as:\n",
"\n",
"```\n",
"tensor([[0., 0., 0., ..., 1., 0., 0.],\n",
@@ -277,7 +394,7 @@
"tensor([0., 0., 0., ..., 1., 0., 0.])\n",
"```\n",
" \n",
- "For `y_batch`, this will contain the user-specified `obs` `cell_type` training labels. By default, these are encoded using a LabelEncoder and it will be a matrix where each column represents the encoded values of each column specified in `obs_column_names` when creating the datapipe (in this case, only the cell type). It will look like this:\n",
+ "For `obs_batch`, this will contain the user-specified `obs` `cell_type` training labels. By default, these are encoded using a LabelEncoder and it will be a matrix where each column represents the encoded values of each column specified in `obs_column_names` when creating the datapipe (in this case, only the cell type). It will look like this:\n",
"\n",
"```\n",
"tensor([1, 1, 3, ..., 2, 1, 4])\n",
@@ -297,33 +414,39 @@
},
{
"cell_type": "code",
- "execution_count": 7,
- "metadata": {},
+ "execution_count": 10,
+ "metadata": {
+ "editable": true,
+ "slideshow": {
+ "slide_type": ""
+ },
+ "tags": []
+ },
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "Epoch 1: Train Loss: 0.0171090 Accuracy 0.1798\n",
- "Epoch 2: Train Loss: 0.0151506 Accuracy 0.3480\n",
- "Epoch 3: Train Loss: 0.0146299 Accuracy 0.4174\n",
- "Epoch 4: Train Loss: 0.0142093 Accuracy 0.4765\n",
- "Epoch 5: Train Loss: 0.0140261 Accuracy 0.5111\n",
- "Epoch 6: Train Loss: 0.0138939 Accuracy 0.5634\n",
- "Epoch 7: Train Loss: 0.0137783 Accuracy 0.6182\n",
- "Epoch 8: Train Loss: 0.0136766 Accuracy 0.7050\n",
- "Epoch 9: Train Loss: 0.0135647 Accuracy 0.8293\n",
- "Epoch 10: Train Loss: 0.0134729 Accuracy 0.8793\n",
- "Epoch 11: Train Loss: 0.0133968 Accuracy 0.8938\n",
- "Epoch 12: Train Loss: 0.0133453 Accuracy 0.9013\n",
- "Epoch 13: Train Loss: 0.0133143 Accuracy 0.9047\n",
- "Epoch 14: Train Loss: 0.0132873 Accuracy 0.9102\n",
- "Epoch 15: Train Loss: 0.0132666 Accuracy 0.9176\n",
- "Epoch 16: Train Loss: 0.0132246 Accuracy 0.9219\n",
- "Epoch 17: Train Loss: 0.0132161 Accuracy 0.9230\n",
- "Epoch 18: Train Loss: 0.0131877 Accuracy 0.9295\n",
- "Epoch 19: Train Loss: 0.0131658 Accuracy 0.9344\n",
- "Epoch 20: Train Loss: 0.0131338 Accuracy 0.9382\n"
+ "Epoch 1: Train Loss: 0.0176823 Accuracy 0.1494\n",
+ "Epoch 2: Train Loss: 0.0151293 Accuracy 0.2636\n",
+ "Epoch 3: Train Loss: 0.0147051 Accuracy 0.3770\n",
+ "Epoch 4: Train Loss: 0.0143555 Accuracy 0.4779\n",
+ "Epoch 5: Train Loss: 0.0140985 Accuracy 0.5173\n",
+ "Epoch 6: Train Loss: 0.0139185 Accuracy 0.5474\n",
+ "Epoch 7: Train Loss: 0.0137876 Accuracy 0.5905\n",
+ "Epoch 8: Train Loss: 0.0136877 Accuracy 0.6322\n",
+ "Epoch 9: Train Loss: 0.0136219 Accuracy 0.6462\n",
+ "Epoch 10: Train Loss: 0.0135693 Accuracy 0.6522\n",
+ "Epoch 11: Train Loss: 0.0135283 Accuracy 0.6532\n",
+ "Epoch 12: Train Loss: 0.0134948 Accuracy 0.6547\n",
+ "Epoch 13: Train Loss: 0.0134677 Accuracy 0.6563\n",
+ "Epoch 14: Train Loss: 0.0134442 Accuracy 0.6570\n",
+ "Epoch 15: Train Loss: 0.0134219 Accuracy 0.6614\n",
+ "Epoch 16: Train Loss: 0.0134028 Accuracy 0.6660\n",
+ "Epoch 17: Train Loss: 0.0133850 Accuracy 0.6734\n",
+ "Epoch 18: Train Loss: 0.0133693 Accuracy 0.6922\n",
+ "Epoch 19: Train Loss: 0.0133531 Accuracy 0.7233\n",
+ "Epoch 20: Train Loss: 0.0133380 Accuracy 0.7388\n"
]
}
],
@@ -338,9 +461,9 @@
"\n",
"model = LogisticRegression(input_dim, output_dim).to(device)\n",
"loss_fn = torch.nn.CrossEntropyLoss()\n",
- "optimizer = torch.optim.Adam(model.parameters(), lr=1e-05)\n",
+ "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n",
"\n",
- "for epoch in range(20):\n",
+ "for epoch in range(n_epochs):\n",
" train_loss, train_accuracy = train_epoch(model, train_dataloader, loss_fn, optimizer, device)\n",
" print(f\"Epoch {epoch + 1}: Train Loss: {train_loss:.7f} Accuracy {train_accuracy:.4f}\")"
]
@@ -356,14 +479,14 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
- "test_dataloader = soma_ml.experiment_dataloader(test_dataset)\n",
- "X_batch, y_batch = next(iter(test_dataloader))\n",
+ "test_dataloader = experiment_dataloader(test_dataset)\n",
+ "X_batch, obs_batch = next(iter(test_dataloader))\n",
"X_batch = torch.from_numpy(X_batch)\n",
- "y_batch = torch.from_numpy(cell_type_encoder.transform(y_batch['cell_type']))"
+ "true_cell_types = torch.from_numpy(cell_type_encoder.transform(obs_batch['cell_type']))"
]
},
{
@@ -375,24 +498,25 @@
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
- "tensor([ 8, 1, 1, 1, 1, 1, 1, 8, 8, 5, 1, 7, 8, 1, 1, 1, 1, 7,\n",
- " 7, 8, 1, 1, 5, 5, 1, 8, 1, 1, 1, 7, 8, 7, 7, 7, 8, 7,\n",
- " 5, 1, 1, 8, 1, 5, 8, 5, 1, 11, 1, 7, 1, 1, 5, 5, 1, 11,\n",
- " 1, 6, 8, 5, 1, 8, 11, 8, 1, 8, 1, 8, 1, 5, 1, 1, 1, 8,\n",
- " 8, 7, 5, 1, 1, 8, 1, 7, 2, 1, 7, 1, 5, 1, 1, 7, 1, 8,\n",
- " 1, 1, 1, 7, 7, 1, 1, 1, 7, 1, 1, 7, 7, 5, 7, 8, 5, 1,\n",
- " 5, 1, 5, 5, 5, 1, 1, 1, 8, 5, 1, 1, 7, 8, 1, 1, 1, 1,\n",
- " 8, 1], device='cuda:0')"
+ "tensor([ 1, 1, 1, 1, 7, 1, 11, 1, 6, 7, 1, 1, 1, 8, 1, 1, 1, 11,\n",
+ " 1, 1, 8, 1, 1, 1, 7, 5, 1, 1, 1, 1, 8, 1, 8, 8, 1, 1,\n",
+ " 1, 8, 1, 1, 1, 1, 1, 11, 1, 1, 7, 1, 1, 1, 7, 5, 8, 5,\n",
+ " 1, 1, 1, 1, 1, 9, 1, 1, 1, 1, 8, 5, 1, 1, 9, 7, 1, 1,\n",
+ " 7, 8, 1, 1, 1, 1, 1, 7, 11, 1, 9, 1, 8, 8, 1, 7, 1, 5,\n",
+ " 1, 7, 7, 1, 1, 7, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
+ " 1, 1, 11, 1, 1, 1, 1, 1, 11, 1, 8, 1, 1, 8, 1, 1, 7, 1,\n",
+ " 5, 1], device='cuda:0')"
]
},
+ "execution_count": 12,
"metadata": {},
- "output_type": "display_data"
+ "output_type": "execute_result"
}
],
"source": [
@@ -403,8 +527,7 @@
"\n",
"probabilities = torch.nn.functional.softmax(outputs, 1)\n",
"predictions = torch.argmax(probabilities, axis=1)\n",
- "\n",
- "display(predictions)"
+ "predictions"
]
},
{
@@ -413,60 +536,60 @@
"source": [
"The predictions are returned as the encoded values of `cell_type` label. To recover the original cell type labels as strings, we decode using the same `LabelEncoder` used for training.\n",
"\n",
- "At inference time, if the model inputs are not obtained via an `ExperimentAxisQueryIterDataPipe`, one could pickle the encoder at training time and save it along with the model. Then, at inference time it can be unpickled and used as shown below."
+ "At inference time, if the model inputs are not obtained via an `ExperimentDataset`, one could pickle the encoder at training time and save it along with the model. Then, at inference time it can be unpickled and used as shown below."
]
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
- "array(['leukocyte', 'basal cell', 'basal cell', 'basal cell',\n",
- " 'basal cell', 'basal cell', 'basal cell', 'leukocyte', 'leukocyte',\n",
- " 'epithelial cell', 'basal cell', 'keratinocyte', 'leukocyte',\n",
- " 'basal cell', 'basal cell', 'basal cell', 'basal cell',\n",
- " 'keratinocyte', 'keratinocyte', 'leukocyte', 'basal cell',\n",
- " 'basal cell', 'epithelial cell', 'epithelial cell', 'basal cell',\n",
- " 'leukocyte', 'basal cell', 'basal cell', 'basal cell',\n",
- " 'keratinocyte', 'leukocyte', 'keratinocyte', 'keratinocyte',\n",
- " 'keratinocyte', 'leukocyte', 'keratinocyte', 'epithelial cell',\n",
+ "array(['basal cell', 'basal cell', 'basal cell', 'basal cell',\n",
+ " 'keratinocyte', 'basal cell', 'vein endothelial cell',\n",
+ " 'basal cell', 'fibroblast', 'keratinocyte', 'basal cell',\n",
" 'basal cell', 'basal cell', 'leukocyte', 'basal cell',\n",
- " 'epithelial cell', 'leukocyte', 'epithelial cell', 'basal cell',\n",
- " 'vein endothelial cell', 'basal cell', 'keratinocyte',\n",
- " 'basal cell', 'basal cell', 'epithelial cell', 'epithelial cell',\n",
- " 'basal cell', 'vein endothelial cell', 'basal cell', 'fibroblast',\n",
- " 'leukocyte', 'epithelial cell', 'basal cell', 'leukocyte',\n",
- " 'vein endothelial cell', 'leukocyte', 'basal cell', 'leukocyte',\n",
- " 'basal cell', 'leukocyte', 'basal cell', 'epithelial cell',\n",
- " 'basal cell', 'basal cell', 'basal cell', 'leukocyte', 'leukocyte',\n",
- " 'keratinocyte', 'epithelial cell', 'basal cell', 'basal cell',\n",
- " 'leukocyte', 'basal cell', 'keratinocyte',\n",
- " 'capillary endothelial cell', 'basal cell', 'keratinocyte',\n",
- " 'basal cell', 'epithelial cell', 'basal cell', 'basal cell',\n",
- " 'keratinocyte', 'basal cell', 'leukocyte', 'basal cell',\n",
- " 'basal cell', 'basal cell', 'keratinocyte', 'keratinocyte',\n",
- " 'basal cell', 'basal cell', 'basal cell', 'keratinocyte',\n",
- " 'basal cell', 'basal cell', 'keratinocyte', 'keratinocyte',\n",
- " 'epithelial cell', 'keratinocyte', 'leukocyte', 'epithelial cell',\n",
- " 'basal cell', 'epithelial cell', 'basal cell', 'epithelial cell',\n",
- " 'epithelial cell', 'epithelial cell', 'basal cell', 'basal cell',\n",
- " 'basal cell', 'leukocyte', 'epithelial cell', 'basal cell',\n",
- " 'basal cell', 'keratinocyte', 'leukocyte', 'basal cell',\n",
+ " 'basal cell', 'basal cell', 'vein endothelial cell', 'basal cell',\n",
+ " 'basal cell', 'leukocyte', 'basal cell', 'basal cell',\n",
+ " 'basal cell', 'keratinocyte', 'epithelial cell', 'basal cell',\n",
" 'basal cell', 'basal cell', 'basal cell', 'leukocyte',\n",
- " 'basal cell'], dtype=object)"
+ " 'basal cell', 'leukocyte', 'leukocyte', 'basal cell', 'basal cell',\n",
+ " 'basal cell', 'leukocyte', 'basal cell', 'basal cell',\n",
+ " 'basal cell', 'basal cell', 'basal cell', 'vein endothelial cell',\n",
+ " 'basal cell', 'basal cell', 'keratinocyte', 'basal cell',\n",
+ " 'basal cell', 'basal cell', 'keratinocyte', 'epithelial cell',\n",
+ " 'leukocyte', 'epithelial cell', 'basal cell', 'basal cell',\n",
+ " 'basal cell', 'basal cell', 'basal cell', 'pericyte', 'basal cell',\n",
+ " 'basal cell', 'basal cell', 'basal cell', 'leukocyte',\n",
+ " 'epithelial cell', 'basal cell', 'basal cell', 'pericyte',\n",
+ " 'keratinocyte', 'basal cell', 'basal cell', 'keratinocyte',\n",
+ " 'leukocyte', 'basal cell', 'basal cell', 'basal cell',\n",
+ " 'basal cell', 'basal cell', 'keratinocyte',\n",
+ " 'vein endothelial cell', 'basal cell', 'pericyte', 'basal cell',\n",
+ " 'leukocyte', 'leukocyte', 'basal cell', 'keratinocyte',\n",
+ " 'basal cell', 'epithelial cell', 'basal cell', 'keratinocyte',\n",
+ " 'keratinocyte', 'basal cell', 'basal cell', 'keratinocyte',\n",
+ " 'basal cell', 'basal cell', 'basal cell', 'basal cell',\n",
+ " 'basal cell', 'basal cell', 'basal cell', 'basal cell',\n",
+ " 'basal cell', 'basal cell', 'basal cell', 'basal cell',\n",
+ " 'basal cell', 'basal cell', 'vein endothelial cell', 'basal cell',\n",
+ " 'basal cell', 'basal cell', 'basal cell', 'basal cell',\n",
+ " 'vein endothelial cell', 'basal cell', 'leukocyte', 'basal cell',\n",
+ " 'basal cell', 'leukocyte', 'basal cell', 'basal cell',\n",
+ " 'keratinocyte', 'basal cell', 'epithelial cell', 'basal cell'],\n",
+ " dtype=object)"
]
},
+ "execution_count": 13,
"metadata": {},
- "output_type": "display_data"
+ "output_type": "execute_result"
}
],
"source": [
"predicted_cell_types = cell_type_encoder.inverse_transform(predictions.cpu())\n",
- "\n",
- "display(predicted_cell_types)"
+ "predicted_cell_types"
]
},
{
@@ -478,7 +601,7 @@
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": 14,
"metadata": {},
"outputs": [
{
@@ -502,15 +625,15 @@
" \n",
" \n",
" \n",
" \n",
" \n",
- " actual cell type \n",
+ " true cell type \n",
" predicted cell type \n",
"
predicted cell type | \n", + "basal cell | \n", + "epithelial cell | \n", + "fibroblast | \n", + "keratinocyte | \n", + "leukocyte | \n", + "pericyte | \n", + "vein endothelial cell | \n", + "
---|---|---|---|---|---|---|---|
true cell type | \n", + "\n", + " | \n", + " | \n", + " | \n", + " | \n", + " | \n", + " | \n", + " |
basal cell | \n", + "59 | \n", + "\n", + " | \n", + " | \n", + " | \n", + " | \n", + " | \n", + " |
capillary endothelial cell | \n", + "\n", + " | \n", + " | \n", + " | \n", + " | \n", + " | \n", + " | 1 | \n", + "
epithelial cell | \n", + "11 | \n", + "6 | \n", + "\n", + " | \n", + " | \n", + " | \n", + " | \n", + " |
fibroblast | \n", + "\n", + " | \n", + " | 1 | \n", + "\n", + " | \n", + " | \n", + " | \n", + " |
keratinocyte | \n", + "15 | \n", + "\n", + " | \n", + " | 13 | \n", + "\n", + " | \n", + " | \n", + " |
leukocyte | \n", + "1 | \n", + "\n", + " | \n", + " | \n", + " | 13 | \n", + "\n", + " | \n", + " |
pericyte | \n", + "\n", + " | \n", + " | \n", + " | \n", + " | \n", + " | 3 | \n", + "\n", + " |
vein endothelial cell | \n", + "\n", + " | \n", + " | \n", + " | \n", + " | \n", + " | \n", + " | 5 | \n", + "