diff --git a/tdc/model_server/model_loaders/scvi.py b/tdc/model_server/model_loaders/scvi.py new file mode 100644 index 00000000..e43563fb --- /dev/null +++ b/tdc/model_server/model_loaders/scvi.py @@ -0,0 +1,10 @@ + +class scVILoader: + + def __init__(self): + pass + + def load(self): + """load scVI model + calling download_scvi() and any other helper functions + """ \ No newline at end of file diff --git a/tdc/model_server/models/scvi.py b/tdc/model_server/models/scvi.py new file mode 100644 index 00000000..4bf2b0e9 --- /dev/null +++ b/tdc/model_server/models/scvi.py @@ -0,0 +1,20 @@ + +class scVI: + """class to load and perform inference w/ scvi + + adding any additional utils from scvi that facilitate inference / data processing + """ + + def __init__(self): + import scvi + pass + + def forward(self, **kwargs): + """ + loads self.model if needed + calls inference on these arguments + """ + + def load(self): + """import the model loader + -> then, LOAD the MODEL CLASS and return it and also save it into self.model""" \ No newline at end of file diff --git a/tdc/model_server/tdc_hf.py b/tdc/model_server/tdc_hf.py index 6571fbcd..16a18a83 100644 --- a/tdc/model_server/tdc_hf.py +++ b/tdc/model_server/tdc_hf.py @@ -14,7 +14,7 @@ 'CYP3A4_Veith-AttentiveFP', ] -model_hub = ["Geneformer", "scGPT"] +model_hub = ["Geneformer", "scGPT", "scVI"] class tdc_hf_interface: @@ -66,6 +66,9 @@ def load(self): AutoModel.register(ScGPTConfig, ScGPTModel) model = AutoModel.from_pretrained("tdc/scGPT") return model + elif self.model_name == "scVI": + # import scVI model and return the model class + pass raise Exception("Not implemented yet!") def load_deeppurpose(self, save_path):