From 3f171b3d2f3f5c60458eb72111a69e7585c25978 Mon Sep 17 00:00:00 2001 From: Nirmayi Date: Fri, 15 Nov 2024 09:44:47 +0100 Subject: [PATCH] add test defaults and attempt to use gpu --- src/methods/cell2location/config.vsh.yaml | 14 +++++++------- src/methods/cell2location/script.py | 5 +++++ 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/src/methods/cell2location/config.vsh.yaml b/src/methods/cell2location/config.vsh.yaml index 68d6b9e..c888876 100644 --- a/src/methods/cell2location/config.vsh.yaml +++ b/src/methods/cell2location/config.vsh.yaml @@ -63,6 +63,8 @@ arguments: - name: "--max_epochs_st" type: integer default: 30000 + info: + test_default: 1000 description: Maximum number of epochs to train cell2location model for spatial mapping. resources: @@ -71,18 +73,16 @@ resources: engines: - type: docker - image: openproblems/base_python:1.0.0 + image: openproblems/base_pytorch_nvidia:1.0.0 setup: - type: python packages: - - scvi-tools==1.0.4 + - scvi-tools - cell2location - - jax==0.4.23 - - jaxlib==0.4.23 - - scipy<1.13 # The scipy.linalg functions tri, triu & tril are deprecated and will be removed in SciPy 1.13. - + - theano + - pymc runners: - type: executable - type: nextflow directives: - label: [hightime, midmem, midcpu] + label: [hightime, midmem, midcpu, gpu] diff --git a/src/methods/cell2location/script.py b/src/methods/cell2location/script.py index 7245bbd..86c8097 100644 --- a/src/methods/cell2location/script.py +++ b/src/methods/cell2location/script.py @@ -1,5 +1,9 @@ import anndata as ad import numpy as np +import os + +os.environ["THEANO_FLAGS"] = 'device=cuda,floatX=float32,force_device=True' + from cell2location.cluster_averages.cluster_averages import compute_cluster_averages from cell2location.models import Cell2location from cell2location.models import RegressionModel @@ -45,6 +49,7 @@ batch_key="batch_key", # cell type, covariate used for constructing signatures labels_key="cell_type", + # use_gpu=True ) sc_model = RegressionModel(input_single_cell) sc_model.train(max_epochs=par["max_epochs_sc"], batch_size=par["sc_batch_size"])