From f8879b3ec171b17d16bed8a72b6fe80f4690a0cc Mon Sep 17 00:00:00 2001 From: Wen-Tse Chen Date: Wed, 20 Dec 2023 00:22:57 -0500 Subject: [PATCH] fix test w/o gpu bug --- openrl/envs/nlp/rewards/intent.py | 9 ++++++--- openrl/envs/nlp/rewards/kl_penalty.py | 9 +++++++-- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/openrl/envs/nlp/rewards/intent.py b/openrl/envs/nlp/rewards/intent.py index 0a0c4d3..2c82e96 100644 --- a/openrl/envs/nlp/rewards/intent.py +++ b/openrl/envs/nlp/rewards/intent.py @@ -41,6 +41,10 @@ def __init__( self.use_model_parallel = False if intent_model == "builtin_intent": + + self._device = "cpu" + self.use_data_parallel = False + from transformers import GPT2Config, GPT2LMHeadModel class TestTokenizer: @@ -66,6 +70,7 @@ def __init__(self, input_ids, attention_mask): self._model = GPT2LMHeadModel(config) else: + self._device = "cuda" model_path = data_abs_path(intent_model) self._tokenizer = AutoTokenizer.from_pretrained(intent_model) self._model = AutoModelForSequenceClassification.from_pretrained(model_path) @@ -81,12 +86,10 @@ def __init__(self, input_ids, attention_mask): with open(ds_config) as file: ds_config = json.load(file) - self._device = "cuda" - self._model = self._model.to("cuda") + self._model = self._model.to(self._device) self._model, *_ = deepspeed.initialize(model=self._model, config=ds_config) self.use_fp16 = ds_config["fp16"]["enabled"] else: - self._device = "cuda" if self.use_model_parallel: self._model.parallelize() elif self.use_data_parallel: diff --git a/openrl/envs/nlp/rewards/kl_penalty.py b/openrl/envs/nlp/rewards/kl_penalty.py index 9516b78..3cfafd4 100644 --- a/openrl/envs/nlp/rewards/kl_penalty.py +++ b/openrl/envs/nlp/rewards/kl_penalty.py @@ -47,6 +47,10 @@ def __init__( # reference model if ref_model == "builtin_ref": + + self.device = "cpu" + self.use_data_parallel = False + from transformers import GPT2Config, GPT2LMHeadModel config = GPT2Config() @@ -77,8 +81,9 @@ def __init__( elif self.use_data_parallel: # else defaults to data parallel if self.use_half: self._ref_net = self._ref_net.half() - self._ref_net = torch.nn.DataParallel(self._ref_net) - self._ref_net = self._ref_net.to(self.device) + else: + self._ref_net = torch.nn.DataParallel(self._ref_net) + self._ref_net = self._ref_net.to(self.device) # alpha adjustment self._alpha = 0.2