Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Question/Clarification] Proper implementation for adding Depth state to ppo_rgb.py #794

Open
CreativeNick opened this issue Jan 7, 2025 · 0 comments

Comments

@CreativeNick
Copy link
Contributor

I added a depth input to ppo_rgb.py and I was wondering if this is the correct/proper implementation for processing both RGB and depth (as well as state) information.

The current implementation uses NatureCNN with a 16-channel input to process the RGBD tensor:

class NatureCNN(nn.Module):
    def __init__(self, sample_obs):
        super().__init__()
        print("Initializing NatureCNN with observation keys:", sample_obs.keys())

        extractors = {}

        self.out_features = 0
        feature_size = 256

        if "rgbd" in sample_obs: #rgbd instead of rgb
            # CNN for RGBD input
            cnn = nn.Sequential(
                nn.Conv2d(
                    in_channels=16,  # Full RGBD channels
                    out_channels=32,
                    kernel_size=8,
                    stride=4,
                    padding=0,
                ),
                nn.ReLU(),
                nn.Conv2d(
                    in_channels=32, out_channels=64, kernel_size=4, stride=2, padding=0
                ),
                nn.ReLU(),
                nn.Conv2d(
                    in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=0
                ),
                nn.ReLU(),
                nn.Flatten(),
            )

            # RGBD dimension calculation
            with torch.no_grad():
                n_flatten = cnn(sample_obs["rgbd"].float().permute(0,3,1,2).cpu()).shape[1]
                fc = nn.Sequential(nn.Linear(n_flatten, feature_size), nn.ReLU())
            extractors["rgbd"] = nn.Sequential(cnn, fc)
            self.out_features += feature_size

        if "state" in sample_obs:
            # for state data we simply pass it through a single linear layer
            state_size = sample_obs["state"].shape[-1]
            extractors["state"] = nn.Linear(state_size, 256)
            self.out_features += 256

        assert len(extractors) > 0, f"No valid observations found in {sample_obs.keys()}"
        self.extractors = nn.ModuleDict(extractors)

    def forward(self, observations) -> torch.Tensor:
        encoded_tensor_list = []
        # self.extractors contain nn.Modules that do all the processing.
        for key, extractor in self.extractors.items():
            obs = observations[key]
            if key == "rgbd": # rgbd instead of rgb
                obs = obs.float().permute(0,3,1,2)
                obs = obs / 255.0  # normalize values
            encoded_tensor_list.append(extractor(obs))
        return torch.cat(encoded_tensor_list, dim=1)

I also set obs_mode to "rgbd" in env_kwargs:

# env setup
env_kwargs = dict(
    obs_mode="rgbd",
    render_mode=args.render_mode,
    sim_backend="gpu"
)

rgb, depth, and state are all set to True when flattening:

envs = FlattenRGBDObservationWrapper(envs,
                                         rgb=True,
                                         depth=True,
                                         state=True)
    
eval_envs = FlattenRGBDObservationWrapper(eval_envs,
                                          rgb=True,
                                          depth=True,
                                          state=True)

Is this the correct approach for handling rgb + depth + state input?
Should RGB and depth be processed through separate CNN pathways instead of a single 16-channel input?

(The full file can be viewed here)

Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant