diff --git a/tdc/model_server/tdc_hf.py b/tdc/model_server/tdc_hf.py index 223012d1..57e25473 100644 --- a/tdc/model_server/tdc_hf.py +++ b/tdc/model_server/tdc_hf.py @@ -57,8 +57,7 @@ def load(self): raise Exception("this model is not in the TDC model hub GH repo.") elif self.model_name == "Geneformer": from transformers import AutoModelForMaskedLM - model = AutoModelForMaskedLM.from_pretrained( - "ctheodoris/Geneformer") + model = AutoModelForMaskedLM.from_pretrained("tdc/Geneformer") return model elif self.model_name == "scGPT": from transformers import AutoModel