Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rxrx1 research running scripts #309

Merged
merged 46 commits into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
668e286
Add rxrx1 dataset and fedavg
sanaAyrml Jan 6, 2025
56193f8
Delete extra file
sanaAyrml Jan 6, 2025
c806b91
update running scripts
sanaAyrml Jan 6, 2025
33f8bd9
Update running scripts
sanaAyrml Jan 6, 2025
fccced0
Update client
sanaAyrml Jan 6, 2025
b43f982
update rxrx1 dataset
sanaAyrml Jan 6, 2025
6de30e1
Delete dataset extra print
sanaAyrml Jan 7, 2025
a7c3ff8
Add CUBLAS_WORKSPACE_CONFIG for experiment
sanaAyrml Jan 7, 2025
75eb831
Update number of clients
sanaAyrml Jan 7, 2025
3d03729
Add evaluation files
sanaAyrml Jan 7, 2025
d306951
Add rxrx1 ditto experiments
sanaAyrml Jan 8, 2025
f80ca11
[pre-commit.ci] Add auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 8, 2025
a2da4b1
Merge branch 'main' into sa_rxrx1_research
sanaAyrml Jan 9, 2025
3b56645
Update typing with new changes in main
sanaAyrml Jan 9, 2025
e0c6d49
Add caching option to dataloader
sanaAyrml Jan 9, 2025
2145be8
Update evaluation script
sanaAyrml Jan 9, 2025
624a529
Update evaluation script
sanaAyrml Jan 9, 2025
ba746ea
Ignoring a vulnerability without a fix yet
emersodb Jan 9, 2025
9cab150
address david's comments
sanaAyrml Jan 10, 2025
a2fd6ae
Add loading and unloading for dataset cache
sanaAyrml Jan 13, 2025
377372d
Increase memmory for experiments
sanaAyrml Jan 13, 2025
b322165
Add centeral training
sanaAyrml Jan 13, 2025
b0810aa
Update run script for centeral
sanaAyrml Jan 13, 2025
cce4ec2
Add time for centeral script
sanaAyrml Jan 13, 2025
ea0ee35
add preprocess file to rxrx1 dataset
sanaAyrml Jan 13, 2025
ba5d824
Add dataset download and preprocessing scripts
sanaAyrml Jan 13, 2025
9e8e570
Update dataset to Tensor dataset
sanaAyrml Jan 14, 2025
7e7434a
Update docstrings
sanaAyrml Jan 14, 2025
4347b20
[pre-commit.ci] Add auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 14, 2025
edd5c29
Merge branch 'main' into sa_rxrx1_research
sanaAyrml Jan 14, 2025
acaf941
Update load_data
sanaAyrml Jan 15, 2025
1c79841
make tensor save one by one
sanaAyrml Jan 15, 2025
31edd2a
[pre-commit.ci] Add auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 15, 2025
3f2ee19
Update preprocessing
sanaAyrml Jan 15, 2025
9d76d17
Merge branch 'sa_rxrx1_research' of https://github.com/VectorInstitut…
sanaAyrml Jan 15, 2025
f611765
Update image loading error
sanaAyrml Jan 15, 2025
186a930
Change targets type
sanaAyrml Jan 15, 2025
5f661aa
A few small fixes to the download script and additona to gitignore
emersodb Jan 16, 2025
ba8434f
reverting the directory path to previous one on the cluster
emersodb Jan 16, 2025
05afbed
Small vulnerability ignore
emersodb Jan 16, 2025
46daf5c
Merge pull request #315 from VectorInstitute/dbe/rxrx1_debugging
sanaAyrml Jan 16, 2025
2fd74ed
Address new set of David's comments
sanaAyrml Jan 16, 2025
a2d8f40
Merge branch 'main' into sa_rxrx1_research
sanaAyrml Jan 21, 2025
de54458
locking poetry files
sanaAyrml Jan 21, 2025
8a12718
Merge branch 'main' into sa_rxrx1_research
sanaAyrml Jan 22, 2025
afdcb24
Update Poetry
sanaAyrml Jan 22, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ settings.json
**/datasets/nnunet_raw/**
**/datasets/nnunet_preprocessed/**
**/datasets/cifar_partitioned_data/**
**/datasets/rxrx1/rxrx1_v1.0/**

# logs

Expand Down
28 changes: 28 additions & 0 deletions fl4health/datasets/rxrx1/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Fluorescent Microscopy Images Dataset Download and Preprocessing

This repository provides a set of scripts to download and preprocess RxRx1 datasets for use in federated learning experiments. This dataset include 6-channel fluorescent microscopy images of cells treated with different compounds. The dataset is provided by Recursion Pharmaceuticals and is available on the [RxRx1 Kaggle page](https://www.rxrx.ai/rxrx1).

## Getting Started

To start using these datasets, follow the steps below.


### Downloading the Datasets
To use the datasets for this project, run the provided shell script to download and unzip the required files.

```sh
sh fl4health/datasets/rxrx1/download.sh
```


### Preprocessing the Datasets

Once the datasets are downloaded, preprocess them to generate the required metadata file and prepare the training and testing tensors for each client participating in the federated learning experiments. The following command preprocesses the RxRx1 datasets:

```sh
python fl4health/datasets/rxrx1/preprocess.py --data_dir <path_to_rxrx1_data>
```

### Using the Datasets

After preprocessing, the datasets are ready to be used in the federated learning settings. For examples please refer to the [Rxrx1 experiments](research/rxrx1) directory.
45 changes: 45 additions & 0 deletions fl4health/datasets/rxrx1/download.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
echo "RxRx1 dataset download."
# Define the URL and the target directory and file name
URL="https://storage.googleapis.com/rxrx/rxrx1"
METADATA_URL="https://storage.googleapis.com/rxrx/rxrx1/rxrx1-metadata.zip"
DIRECTORY="/projects/fl4health/datasets/rxrx1_v1.0/"
IMAGE_FILE_NAME="rxrx1-images.zip"
METADATA_FILE="rxrx1-metadata.zip"
IMAGE_FILE_PATH=${DIRECTORY}${IMAGE_FILE_NAME}
METADATA_FILE_PATH=${DIRECTORY}${METADATA_FILE}

# Create the directory if it doesn't exist
mkdir -p "$DIRECTORY"

# Check if the file already exists
if [ -f "$IMAGE_FILE_PATH" ]; then
echo "File $IMAGE_FILE already exists. No download needed."
else
echo "Downloading $IMAGE_FILE_NAME"
wget -O "$IMAGE_FILE_PATH" "$URL/$IMAGE_FILE_NAME"
if [ $? -eq 0 ]; then
echo "Download completed successfully."
else
echo "Download failed."
fi
fi

mkdir -p ${DIRECTORY}images/
unzip ${IMAGE_FILE_PATH} -d ${DIRECTORY}images/

# Check if the file already exists
if [ -f "$METADATA_FILE_PATH" ]; then
echo "File $METADATA_FILE already exists. No download needed."
else
echo "Downloading $METADATA_FILE"
wget -O "$METADATA_FILE_PATH" "$URL/$METADATA_FILE"
if [ $? -eq 0 ]; then
echo "Download completed successfully."
else
echo "Download failed."
fi
fi

unzip ${METADATA_FILE_PATH} -d ${DIRECTORY}

echo "Download completed."
170 changes: 170 additions & 0 deletions fl4health/datasets/rxrx1/load_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
import copy
import os
import pickle
from collections import defaultdict
from collections.abc import Callable
from logging import INFO
from pathlib import Path

import numpy as np
import pandas as pd
import torch
from flwr.common.logger import log
from torch.utils.data import DataLoader, Subset

from fl4health.utils.dataset import TensorDataset


def construct_rxrx1_tensor_dataset(
metadata: pd.DataFrame,
data_path: Path,
client_num: int,
dataset_type: str,
transform: Callable | None = None,
) -> tuple[TensorDataset, dict[int, int]]:
"""
Construct a TensorDataset for rxrx1 data.

Args:
metadata (DataFrame): A DataFrame containing image metadata.
data_path (Path): Root directory which the image data should be loaded.
client_num (int): Client number to load data for.
dataset_type (str): 'train' or 'test' to specify dataset type.
transform (Callable | None): Transformation function to apply to the images. Defaults to None.

Returns:
tuple[TensorDataset, dict[int, int]]: A TensorDataset containing the processed images and label map.

"""

label_map = {label: idx for idx, label in enumerate(sorted(metadata["sirna_id"].unique()))}
original_label_map = {new_label: original_label for original_label, new_label in label_map.items()}
metadata = metadata[metadata["dataset"] == dataset_type]
targets_tensor = torch.Tensor(list(metadata["sirna_id"].map(label_map))).type(torch.long)
data_list = []
for index in range(len(targets_tensor)):
with open(
os.path.join(data_path, f"clients/{dataset_type}_data_{client_num+1}/image_{index}.pkl"), "rb"
) as file:
data_list.append(torch.Tensor(pickle.load(file)).unsqueeze(0))
data_tensor = torch.cat(data_list)
return TensorDataset(data_tensor, targets_tensor, transform), original_label_map


def label_frequency(dataset: TensorDataset | Subset, original_label_map: dict[int, int]) -> None:
"""
Prints the frequency of each label in the dataset.

Args:
dataset (TensorDataset | Subset): The dataset to analyze.
original_label_map (dict[int, int]): A mapping of the original labels to their new labels.

"""
# Extract metadata and label map
if isinstance(dataset, TensorDataset):
targets = dataset.targets
elif isinstance(dataset, Subset):
assert isinstance(dataset.dataset, TensorDataset), "Subset dataset must be an TensorDataset instance."
targets = dataset.dataset.targets
else:
raise TypeError("Dataset must be of type TensorDataset or Subset containing an TensorDataset.")

# Count label frequencies
label_to_indices = defaultdict(list)
assert isinstance(targets, torch.Tensor)
for idx, label in enumerate(targets): # Assumes dataset[idx] returns (data, label)
label_to_indices[label].append(idx)

# Print frequency of labels their names
for label, count in label_to_indices.items():
assert isinstance(label, int)
original_label = original_label_map.get(label)
log(INFO, f"Label {label} (original: {original_label}): {len(count)} samples")


def create_splits(
dataset: TensorDataset, seed: int | None = None, train_fraction: float = 0.8
) -> tuple[list[int], list[int]]:
"""
Splits the dataset into training and validation sets.

Args:
dataset (Dataset): The dataset to split.
train_fraction (float): Fraction of data to use for training.

Returns:
Tuple: (train_dataset, val_dataset)
"""

# Group indices by label
label_to_indices = defaultdict(list)
assert isinstance(dataset.targets, torch.Tensor)
for idx, label in enumerate(dataset.targets): # Assumes dataset[idx] returns (data, label)
label_to_indices[label.item()].append(idx)

# Stratified splitting
train_indices, val_indices = [], []
for label, indices in label_to_indices.items():
if seed is not None:
np_generator = np.random.default_rng(seed)
np_generator.shuffle(indices)
else:
np.random.shuffle(indices)
split_point = int(len(indices) * train_fraction)
train_indices.extend(indices[:split_point])
val_indices.extend(indices[split_point:])
if len(val_indices) == 0:
log(INFO, "Warning: Validation set is empty. Consider changing the train_fraction parameter.")

return train_indices, val_indices


def load_rxrx1_data(
data_path: Path,
client_num: int,
batch_size: int,
seed: int | None = None,
train_val_split: float = 0.8,
num_workers: int = 0,
) -> tuple[DataLoader, DataLoader, dict[str, int]]:

# Read the CSV file
data = pd.read_csv(f"{data_path}/clients/meta_data_{client_num+1}.csv")

dataset, _ = construct_rxrx1_tensor_dataset(data, data_path, client_num, "train")

train_indices, val_indices = create_splits(dataset, seed=seed, train_fraction=train_val_split)
train_set = copy.deepcopy(dataset)
sanaAyrml marked this conversation as resolved.
Show resolved Hide resolved
train_set.data = train_set.data[train_indices]
assert train_set.targets is not None
train_set.targets = train_set.targets[train_indices]

validation_set = copy.deepcopy(dataset)
sanaAyrml marked this conversation as resolved.
Show resolved Hide resolved
validation_set.data = validation_set.data[val_indices]
assert validation_set.targets is not None
validation_set.targets = validation_set.targets[val_indices]

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
validation_loader = DataLoader(validation_set, batch_size=batch_size)
num_examples = {
"train_set": len(train_set.data),
"validation_set": len(validation_set.data),
}

return train_loader, validation_loader, num_examples


def load_rxrx1_test_data(
data_path: Path, client_num: int, batch_size: int, num_workers: int = 0
) -> tuple[DataLoader, dict[str, int]]:

# Read the CSV file
data = pd.read_csv(f"{data_path}/clients/meta_data_{client_num+1}.csv")

dataset, _ = construct_rxrx1_tensor_dataset(data, data_path, client_num, "test")

evaluation_loader = DataLoader(
dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True
)
num_examples = {"eval_set": len(dataset.data)}
return evaluation_loader, num_examples
121 changes: 121 additions & 0 deletions fl4health/datasets/rxrx1/preprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import argparse
import os
import pickle
from pathlib import Path
from typing import Any

import pandas as pd
import torch
from PIL import Image
from torchvision.transforms import ToTensor


def filter_and_save_data(metadata: pd.DataFrame, top_sirna_ids: list[int], cell_type: str, output_path: Path) -> None:
"""
Filters data for the given cell type and frequency of their sirna_id and saves it to a CSV file.

Args:
metadata (pd.DataFrame): Metadata containing information about all images.
top_sirna_ids (list[int]): Top sirna_id values to filter by.
cell_type (str): Cell type to filter by.
output_path (Path): Path to save the filtered metadata.
"""
filtered_metadata = metadata[(metadata["sirna_id"].isin(top_sirna_ids)) & (metadata["cell_type"] == cell_type)]
filtered_metadata.to_csv(output_path, index=False)


def load_image(row: dict[str, Any], root: Path) -> torch.Tensor:
"""
Load an image tensor for a given row of metadata.

Args:
row (dict[str, Any]): A row of metadata containing experiment, plate, well, and site information.
root (Path): Root directory containing the image files.

Returns:
torch.Tensor: The loaded image tensor.
"""
experiment = row["experiment"]
plate = row["plate"]
well = row["well"]
site = row["site"]

images = []
# Rxrx1 originally consists of 6 channels, but to reduce the computational cost, we only use 3 channels
# following previous works such as https://github.com/p-lambda/wildYe.
for channel in range(1, 4):
sanaAyrml marked this conversation as resolved.
Show resolved Hide resolved
image_path = os.path.join(root, f"images/{experiment}/Plate{plate}/{well}_s{site}_w{channel}.png")
if not Path(image_path).exists():
raise FileNotFoundError(f"Image not found at {image_path}")
image = ToTensor()(Image.open(image_path).convert("L"))
images.append(image)

# Concatenate the three channels into one tensor
return torch.cat(images, dim=0)


def process_data(metadata: pd.DataFrame, input_dir: Path, output_dir: Path, client_num: int, type_data: str) -> None:
"""
Process the entire dataset, loading image tensors for each row.

Args:
metadata (pd.DataFrame): Metadata containing information about all images.
input_dir (Path): Input directory containing the image files.
output_dir (Path): Output directory containing the image files.
client_num (int): Client number to load data for.
type_data (str): 'train' or 'test' to specify dataset type.
"""
for i, row in metadata.iterrows():
image_tensor = load_image(row.to_dict(), Path(input_dir))
save_to_pkl(image_tensor, os.path.join(output_dir, f"{type_data}_data_{client_num+1}", f"image_{i}.pkl"))


def save_to_pkl(data: torch.Tensor, output_path: str) -> None:
"""
Save data to a pickle file.

Args:
data (torch.Tensor): Data to save.
output_path (str): Path to the output pickle file.
"""
with open(output_path, "wb") as f:
pickle.dump(data, f)


def main(dataset_dir: Path) -> None:
metadata_file = os.path.join(dataset_dir, "metadata.csv")
output_dir = os.path.join(dataset_dir, "clients")

os.makedirs(output_dir, exist_ok=True)

data = pd.read_csv(metadata_file)

# Get the top 50 `sirna_id`s by frequency
top_sirna_ids = data["sirna_id"].value_counts().head(50).index.tolist()

# Define cell types to distribute data based on them for each client
cell_types = ["RPE", "HUVEC", "HEPG2", "U2OS"]
output_files = [os.path.join(output_dir, f"meta_data_{i+1}.csv") for i in range(len(cell_types))]

# Filter and save data for each client
for cell_type, output_path in zip(cell_types, output_files):
filter_and_save_data(data, top_sirna_ids, cell_type, Path(output_path))

for i, metadata_path in enumerate(output_files):
metadata = pd.read_csv(metadata_path)

# Split the metadata into train and test datasets
train_metadata = metadata[metadata["dataset"] == "train"]
test_metadata = metadata[metadata["dataset"] == "test"]

process_data(train_metadata, dataset_dir, Path(output_dir), i, "train")
process_data(test_metadata, dataset_dir, Path(output_dir), i, "test")


if __name__ == "__main__":
# Argument parsing
parser = argparse.ArgumentParser(description="Filter dataset by the most frequent sirna_id and cell_type.")
parser.add_argument("dataset_dir", type=str, help="Path to the dataset directory containing metadata.csv")

args = parser.parse_args()
main(args.dataset_dir)
Loading
Loading