Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
rghosh8 committed Feb 13, 2025
1 parent abeb00f commit 717c01e
Showing 1 changed file with 105 additions and 80 deletions.
185 changes: 105 additions & 80 deletions multi_turn_reward_for_RLHF/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,44 +2,62 @@
import torch.nn as nn
import torch.optim as optim
import numpy as np


class DialogueEnv:
"""Multi-turn dialogue environment that simulates conversations."""
from torchrl.envs import EnvBase
from torchrl.envs.libs.gym import GymWrapper
from torchrl.modules import ProbabilisticActor, ValueOperator
from torchrl.collectors import SyncDataCollector
from torchrl.data import TensorDictReplayBuffer, LazyTensorStorage
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value import GAE

# Define the Dialogue Environment in TorchRL Format
class DialogueEnvTorchRL(EnvBase):
"""TorchRL-compatible multi-turn dialogue environment that simulates conversations."""

def __init__(self):
super().__init__()
self.turns = 5 # Each dialogue lasts 5 turns
self.current_turn = 0
self.conversation = []
self.action_spec = torch.arange(10) # 10 discrete actions (dialogue responses)
self.observation_spec = torch.zeros(100) # Fixed-size state representation

def reset(self):
def _reset(self):
"""Resets the environment for a new dialogue."""
self.current_turn = 0
self.conversation = []
return "Hi, how can I help you today?" # Starting dialogue
return {"observation": self._encode_state("Hi, how can I help you today?")}

def step(self, action):
def _step(self, action):
"""Takes an action (a response) and advances the conversation."""
self.conversation.append(action)
action_text = f"Action {action.item()}"
self.conversation.append(action_text)
self.current_turn += 1

if self.current_turn < self.turns:
# Generate the next response from the environment (placeholder)
next_state = f"Response {self.current_turn}: How about this?"
done = False
reward = self._human_feedback(action)
else:
next_state = "Conversation ended."
done = True
reward = self._human_feedback(action)

return next_state, reward, done
reward = self._human_feedback(action_text)
return {"observation": self._encode_state(next_state), "reward": reward, "done": done}

def _human_feedback(self, action):
"""Simulates human feedback by returning a random reward."""
return np.random.choice([1, -1]) # 1 for positive feedback, -1 for negative

def _encode_state(self, state, size=100):
"""Encodes state into a tensor format (pads or truncates)."""
state_tensor = torch.tensor([ord(c) for c in state], dtype=torch.float32)
if state_tensor.size(0) < size:
padded_tensor = torch.cat([state_tensor, torch.zeros(size - state_tensor.size(0))])
else:
padded_tensor = state_tensor[:size]
return padded_tensor.unsqueeze(0) # Add batch dimension

# Define Policy Network
class PolicyNetwork(nn.Module):
"""Policy network that defines the agent's behavior."""

Expand All @@ -53,80 +71,87 @@ def forward(self, x):
x = torch.relu(self.fc1(x))
return self.fc2(x)

# Define Value Network for PPO
class ValueNetwork(nn.Module):
"""Value network for estimating the state value."""

def pad_or_truncate(state, size=100):
"""Pads or truncates the input state to match the required input size."""
state_tensor = torch.tensor([ord(c) for c in state], dtype=torch.float32)
if state_tensor.size(0) < size:
padded_tensor = torch.cat([state_tensor, torch.zeros(size - state_tensor.size(0))])
else:
padded_tensor = state_tensor[:size]
return padded_tensor.unsqueeze(0) # Add batch dimension
def __init__(self, input_size=100, hidden_size=128):
super(ValueNetwork, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, 1)

def forward(self, x):
"""Forward pass through the network."""
x = torch.relu(self.fc1(x))
return self.fc2(x)

def train_rlhf(env, model, optimizer, num_episodes=1000):
"""Trains the policy network using reinforcement learning with human feedback."""
gamma = 0.99 # Discount factor for future rewards
# Training Setup using PPO in TorchRL
def train_rlhf_torchrl(num_episodes=1000, batch_size=32):
"""Trains the policy network using PPO with reinforcement learning with human feedback."""

# Instantiate environment
env = DialogueEnvTorchRL()

# Create policy and value networks
policy_model = PolicyNetwork(input_size=100, hidden_size=128, output_size=10)
value_model = ValueNetwork(input_size=100, hidden_size=128)

# Create policy distribution
policy = ProbabilisticActor(
module=policy_model,
in_keys=["observation"],
out_keys=["action"],
distribution_class=torch.distributions.Categorical
)

# Value function
value_operator = ValueOperator(
module=value_model,
in_keys=["observation"]
)

# Optimizers
policy_optimizer = optim.Adam(policy.parameters(), lr=1e-3)
value_optimizer = optim.Adam(value_operator.parameters(), lr=1e-3)

# Setup collector
collector = SyncDataCollector(
env, policy, frames_per_batch=batch_size, total_frames=num_episodes * batch_size
)

# Replay buffer
buffer = TensorDictReplayBuffer(
storage=LazyTensorStorage(max_size=10000)
)

# Loss function (PPO)
advantage_module = GAE(value_operator=value_operator, gamma=0.99, lmbda=0.95)
loss_module = ClipPPOLoss(
actor=policy,
critic=value_operator,
advantage_module=advantage_module,
clip_epsilon=0.2
)

for episode in range(num_episodes):
state = env.reset()
total_reward = 0
log_probs = []
rewards = []

done = False
while not done:
# Pad or truncate the input state to the required size
state_tensor = pad_or_truncate(state, size=100)
logits = model(state_tensor)
action_probs = torch.softmax(logits, dim=-1)
action_dist = torch.distributions.Categorical(action_probs)

action = action_dist.sample()
log_prob = action_dist.log_prob(action)
log_probs.append(log_prob)

# Take the action in the environment
action_text = f"Action {action.item()}"
next_state, reward, done = env.step(action_text)
rewards.append(reward)
total_reward += reward

state = next_state

# Calculate the discounted rewards
discounted_rewards = []
cumulative_reward = 0
for r in reversed(rewards):
cumulative_reward = r + gamma * cumulative_reward
discounted_rewards.insert(0, cumulative_reward)

# Normalize the rewards
discounted_rewards = torch.tensor(discounted_rewards)
discounted_rewards = (discounted_rewards - discounted_rewards.mean()) / (discounted_rewards.std() + 1e-6)

# Policy Gradient: Update the policy
policy_loss = []
for log_prob, reward in zip(log_probs, discounted_rewards):
policy_loss.append(-log_prob * reward)

optimizer.zero_grad()
policy_loss = torch.cat(policy_loss).sum()
policy_loss.backward()
optimizer.step()

print(f"Episode {episode + 1}/{num_episodes}, Total Reward: {total_reward}")
for batch in collector:
buffer.extend(batch)

# Sample from buffer
sampled_batch = buffer.sample(batch_size)

if __name__ == "__main__":
# Instantiate environment and model
env = DialogueEnv()
input_size = 100 # Placeholder for state size (e.g., fixed-length input of size 100)
hidden_size = 128
output_size = 10 # Placeholder for the number of possible actions (dialogue responses)
# Compute loss and update policy
loss = loss_module(sampled_batch)
policy_optimizer.zero_grad()
loss["loss_objective"].backward()
policy_optimizer.step()

model = PolicyNetwork(input_size, hidden_size, output_size)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
# Update value function
value_optimizer.zero_grad()
loss["loss_critic"].backward()
value_optimizer.step()

# Train the policy using RL with Human Feedback (simulated)
train_rlhf(env, model, optimizer, num_episodes=1000)
print(f"Episode {episode + 1}/{num_episodes}, Loss: {loss['loss_objective'].item()}")

if __name__ == "__main__":
train_rlhf_torchrl(num_episodes=1000, batch_size=32)

0 comments on commit 717c01e

Please sign in to comment.