Skip to content

Commit

Permalink
add test defaults and attempt to use gpu
Browse files Browse the repository at this point in the history
  • Loading branch information
sainirmayi committed Nov 15, 2024
1 parent 7cee7a3 commit 3f171b3
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 7 deletions.
14 changes: 7 additions & 7 deletions src/methods/cell2location/config.vsh.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ arguments:
- name: "--max_epochs_st"
type: integer
default: 30000
info:
test_default: 1000
description: Maximum number of epochs to train cell2location model for spatial mapping.

resources:
Expand All @@ -71,18 +73,16 @@ resources:

engines:
- type: docker
image: openproblems/base_python:1.0.0
image: openproblems/base_pytorch_nvidia:1.0.0
setup:
- type: python
packages:
- scvi-tools==1.0.4
- scvi-tools
- cell2location
- jax==0.4.23
- jaxlib==0.4.23
- scipy<1.13 # The scipy.linalg functions tri, triu & tril are deprecated and will be removed in SciPy 1.13.

- theano
- pymc
runners:
- type: executable
- type: nextflow
directives:
label: [hightime, midmem, midcpu]
label: [hightime, midmem, midcpu, gpu]
5 changes: 5 additions & 0 deletions src/methods/cell2location/script.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import anndata as ad
import numpy as np
import os

os.environ["THEANO_FLAGS"] = 'device=cuda,floatX=float32,force_device=True'

from cell2location.cluster_averages.cluster_averages import compute_cluster_averages
from cell2location.models import Cell2location
from cell2location.models import RegressionModel
Expand Down Expand Up @@ -45,6 +49,7 @@
batch_key="batch_key",
# cell type, covariate used for constructing signatures
labels_key="cell_type",
# use_gpu=True
)
sc_model = RegressionModel(input_single_cell)
sc_model.train(max_epochs=par["max_epochs_sc"], batch_size=par["sc_batch_size"])
Expand Down

0 comments on commit 3f171b3

Please sign in to comment.