You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I added a depth input to ppo_rgb.py and I was wondering if this is the correct/proper implementation for processing both RGB and depth (as well as state) information.
The current implementation uses NatureCNN with a 16-channel input to process the RGBD tensor:
classNatureCNN(nn.Module):
def__init__(self, sample_obs):
super().__init__()
print("Initializing NatureCNN with observation keys:", sample_obs.keys())
extractors= {}
self.out_features=0feature_size=256if"rgbd"insample_obs: #rgbd instead of rgb# CNN for RGBD inputcnn=nn.Sequential(
nn.Conv2d(
in_channels=16, # Full RGBD channelsout_channels=32,
kernel_size=8,
stride=4,
padding=0,
),
nn.ReLU(),
nn.Conv2d(
in_channels=32, out_channels=64, kernel_size=4, stride=2, padding=0
),
nn.ReLU(),
nn.Conv2d(
in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=0
),
nn.ReLU(),
nn.Flatten(),
)
# RGBD dimension calculationwithtorch.no_grad():
n_flatten=cnn(sample_obs["rgbd"].float().permute(0,3,1,2).cpu()).shape[1]
fc=nn.Sequential(nn.Linear(n_flatten, feature_size), nn.ReLU())
extractors["rgbd"] =nn.Sequential(cnn, fc)
self.out_features+=feature_sizeif"state"insample_obs:
# for state data we simply pass it through a single linear layerstate_size=sample_obs["state"].shape[-1]
extractors["state"] =nn.Linear(state_size, 256)
self.out_features+=256assertlen(extractors) >0, f"No valid observations found in {sample_obs.keys()}"self.extractors=nn.ModuleDict(extractors)
defforward(self, observations) ->torch.Tensor:
encoded_tensor_list= []
# self.extractors contain nn.Modules that do all the processing.forkey, extractorinself.extractors.items():
obs=observations[key]
ifkey=="rgbd": # rgbd instead of rgbobs=obs.float().permute(0,3,1,2)
obs=obs/255.0# normalize valuesencoded_tensor_list.append(extractor(obs))
returntorch.cat(encoded_tensor_list, dim=1)
Is this the correct approach for handling rgb + depth + state input?
Should RGB and depth be processed through separate CNN pathways instead of a single 16-channel input?
I added a depth input to ppo_rgb.py and I was wondering if this is the correct/proper implementation for processing both RGB and depth (as well as state) information.
The current implementation uses NatureCNN with a 16-channel input to process the RGBD tensor:
I also set
obs_mode
to "rgbd" inenv_kwargs
:rgb
,depth
, andstate
are all set toTrue
when flattening:Is this the correct approach for handling rgb + depth + state input?
Should RGB and depth be processed through separate CNN pathways instead of a single 16-channel input?
(The full file can be viewed here)
Thank you!
The text was updated successfully, but these errors were encountered: