Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SACPolicy Unsupported action space: Box(-1.0, 1.0, (1,), float32)") #1232

Open
Vladimir19052002 opened this issue Dec 25, 2024 · 0 comments
Open

Comments

@Vladimir19052002
Copy link

Description

Encountering a ValueError when initializing SACPolicy with a single-dimensional Box action space.

Reproduction

Minimal example script:

import torch
import torch.nn as nn
from gym.spaces import Box
import numpy as np
from tianshou.policy import SACPolicy
import torch.optim as optim
import logging

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger('TestLogger')

# Define Actor and Critic Networks
class SimpleActor(nn.Module):
    def __init__(self, input_dim):
        super(SimpleActor, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(input_dim, 16),
            nn.ReLU(),
            nn.Linear(16, 1),
            nn.Tanh()
        )
    
    def forward(self, x):
        return self.network(x)

class SimpleCritic(nn.Module):
    def __init__(self, input_dim, action_dim):
        super(SimpleCritic, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(input_dim + action_dim, 16),
            nn.ReLU(),
            nn.Linear(16, 1)
        )
    
    def forward(self, x, a):
        return self.network(torch.cat([x, a], dim=-1))

# Define action_space with shape=(1,)
action_space = Box(low=np.array([-1.0]), high=np.array([1.0]), dtype=np.float32)
logger.info(f"Test action_space: {action_space}")

# Initialize networks
actor = SimpleActor(input_dim=10)
critic = SimpleCritic(input_dim=10, action_dim=1)

# Initialize optimizers
actor_optimizer = optim.Adam(actor.parameters(), lr=1e-3)
critic_optimizer = optim.Adam(critic.parameters(), lr=1e-3)

# Initialize SACPolicy
try:
    policy = SACPolicy(
        actor=actor,
        actor_optim=actor_optimizer,
        critic=critic,
        critic_optim=critic_optimizer,
        action_space=action_space,
        tau=0.005,
        gamma=0.99,
        exploration_noise=0.1,
        action_scaling=False  # Ensuring no redundant scaling
    )
    logger.info("SACPolicy initialized successfully in minimal example.")
except ValueError as ve:
    logger.error(f"SACPolicy initialization failed in minimal example: {ve}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant