diff --git a/sf_examples/nethack/models/__init__.py b/sf_examples/nethack/models/__init__.py index c8907f9b4..3c1302e75 100644 --- a/sf_examples/nethack/models/__init__.py +++ b/sf_examples/nethack/models/__init__.py @@ -1,11 +1,4 @@ -from sf_examples.nethack.models.chaotic_dwarf import ChaoticDwarvenGPT5 -from sf_examples.nethack.models.scaled import ScaledNet -from sf_examples.nethack.models.simba import SimbaActorEncoder, SimbaCriticEncoder +from sf_examples.nethack.models.vit import ViTActorEncoder, ViTCriticEncoder -MODELS = [ - ChaoticDwarvenGPT5, - ScaledNet, - SimbaActorEncoder, - SimbaCriticEncoder, -] +MODELS = [ViTActorEncoder, ViTCriticEncoder] MODELS_LOOKUP = {c.__name__: c for c in MODELS} diff --git a/sf_examples/nethack/models/chaotic_dwarf.py b/sf_examples/nethack/models/chaotic_dwarf.py deleted file mode 100644 index 8a3f2d19c..000000000 --- a/sf_examples/nethack/models/chaotic_dwarf.py +++ /dev/null @@ -1,346 +0,0 @@ -"""Adapted from Chaos Dwarf in Nethack Challenge Starter Kit: -https://github.com/Miffyli/nle-sample-factory-baseline - -MIT License - -Copyright (c) 2021 Anssi - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. -""" - -import torch -from nle import nethack -from torch import nn -from torch.nn import functional as F - -from sample_factory.algo.utils.torch_utils import calc_num_elements -from sample_factory.model.encoder import Encoder -from sample_factory.utils.typing import Config, ObsSpace - - -class MessageEncoder(nn.Module): - def __init__(self): - super(MessageEncoder, self).__init__() - self.hidden_dim = 128 - self.msg_fwd = nn.Sequential( - nn.Linear(nethack.MESSAGE_SHAPE[0], 128), - nn.ELU(inplace=True), - nn.Linear(128, self.hidden_dim), - nn.ELU(inplace=True), - ) - - def forward(self, message): - return self.msg_fwd(message / 255.0) - - -class BLStatsEncoder(nn.Module): - def __init__(self): - super(BLStatsEncoder, self).__init__() - self.hidden_dim = 128 + nethack.BLSTATS_SHAPE[0] - self.blstats_fwd = nn.Sequential( - nn.Linear(nethack.BLSTATS_SHAPE[0], 128), - nn.ELU(inplace=True), - nn.Linear(128, 128), - nn.ELU(inplace=True), - ) - - normalization_stats = torch.tensor( - [ - 1.0 / 79.0, # hero col - 1.0 / 21, # hero row - 0.0, # strength pct - 1.0 / 10, # strength - 1.0 / 10, # dexterity - 1.0 / 10, # constitution - 1.0 / 10, # intelligence - 1.0 / 10, # wisdom - 1.0 / 10, # charisma - 0.0, # score - 1.0 / 10, # hitpoints - 1.0 / 10, # max hitpoints - 0.0, # depth - 1.0 / 1000, # gold - 1.0 / 10, # energy - 1.0 / 10, # max energy - 1.0 / 10, # armor class - 0.0, # monster level - 1.0 / 10, # experience level - 1.0 / 100, # experience points - 1.0 / 1000, # time - 1.0, # hunger_state - 1.0 / 10, # carrying capacity - 0.0, # carrying capacity - 0.0, # level number - 0.0, # condition bits - 0.0, # alignment bits - ], - requires_grad=False, - ) - self.register_buffer("normalization_stats", normalization_stats) - - self.blstat_range = (-5, 5) - - def forward(self, blstats): - norm_bls = torch.clip( - blstats * self.normalization_stats, - self.blstat_range[0], - self.blstat_range[1], - ) - - return torch.cat([self.blstats_fwd(norm_bls), norm_bls], dim=-1) - - -class TopLineEncoder(nn.Module): - def __init__(self): - super(TopLineEncoder, self).__init__() - self.hidden_dim = 128 - self.i_dim = nethack.NLE_TERM_CO * 256 - - self.msg_fwd = nn.Sequential( - nn.Linear(self.i_dim, 128), - nn.ELU(inplace=True), - nn.Linear(128, self.hidden_dim), - nn.ELU(inplace=True), - ) - - def forward(self, message): - # Characters start at 33 in ASCII and go to 128. 96 = 128 - 32 - message_normed = F.one_hot((message).long(), 256).reshape(-1, self.i_dim).float() - return self.msg_fwd(message_normed) - - -class BottomLinesEncoder(nn.Module): - def __init__(self): - super(BottomLinesEncoder, self).__init__() - self.conv_layers = [] - w = nethack.NLE_TERM_CO * 2 - for in_ch, out_ch, filter, stride in [[2, 32, 8, 4], [32, 64, 4, 1]]: - self.conv_layers.append(nn.Conv1d(in_ch, out_ch, filter, stride=stride)) - self.conv_layers.append(nn.ELU(inplace=True)) - w = conv_outdim(w, filter, padding=0, stride=stride) - - self.conv_net = nn.Sequential(*self.conv_layers) - self.fwd_net = nn.Sequential( - nn.Linear(w * out_ch, 128), - nn.ELU(), - nn.Linear(128, 128), - nn.ELU(), - ) - self.hidden_dim = 128 - - def forward(self, bottom_lines): - B, D = bottom_lines.shape - # ASCII 32: ' ', ASCII [33-128]: visible characters - chars_normalised = (bottom_lines - 32) / 96 - - # ASCII [45-57]: -./01234556789 - numbers_mask = (bottom_lines > 44) * (bottom_lines < 58) - digits_normalised = numbers_mask * (bottom_lines - 47) / 10 - - # Put in different channels & conv (B, 2, D) - x = torch.stack([chars_normalised, digits_normalised], dim=1) - return self.fwd_net(self.conv_net(x).view(B, -1)) - - -def conv_outdim(i_dim, k, padding=0, stride=1, dilation=1): - """Return the dimension after applying a convolution along one axis""" - return int(1 + (i_dim + 2 * padding - dilation * (k - 1) - 1) / stride) - - -class InverseModel(nn.Module): - def __init__(self, h_dim, action_space): - super(InverseModel, self).__init__() - self.h_dim = h_dim * 2 - self.action_space = action_space - - self.fwd_model = nn.Sequential( - nn.Linear(self.h_dim, 128), - nn.ELU(inplace=True), - nn.Linear(128, 128), - nn.ELU(inplace=True), - nn.Linear(128, action_space), - ) - - def forward(self, obs): - T, B, *_ = obs.shape - x = torch.cat([obs[:-1], obs[1:]], dim=-1) - pred_a = self.fwd_model(x) - off_by_one = torch.ones((1, B, self.action_space), device=x.device) * -1 - return torch.cat([pred_a, off_by_one], dim=0) - - -class CharColorEncoder(nn.Module): - def __init__( - self, - screen_shape, - char_edim: int = 16, - color_edim: int = 16, - ): - super().__init__() - conv_layers = [] - - self.h, self.w = screen_shape - self.char_edim = char_edim - self.color_edim = color_edim - self.hidden_dim = 512 - - self.conv_filters = [ - [char_edim + color_edim, 32, (3, 5), (1, 2), (1, 2)], - [32, 64, (3, 5), (1, 2), (1, 2)], - [64, 128, 3, 1, 1], - [128, 128, 3, 1, 1], - ] - - for ( - in_channels, - out_channels, - filter_size, - stride, - dilation, - ) in self.conv_filters: - conv_layers.append( - nn.Conv2d( - in_channels, - out_channels, - filter_size, - stride=stride, - dilation=dilation, - ) - ) - conv_layers.append(nn.ELU(inplace=True)) - - self.conv_head = nn.Sequential(*conv_layers) - self.out_size = calc_num_elements(self.conv_head, (char_edim + color_edim,) + screen_shape) - - self.fc_head = nn.Sequential(nn.Linear(self.out_size, self.hidden_dim), nn.ELU(inplace=True)) - - self.char_embeddings = nn.Embedding(256, self.char_edim) - self.color_embeddings = nn.Embedding(128, self.color_edim) - - def forward(self, chars, colors): - chars, colors = self._embed(chars, colors) # 21 x 80 - x = self._stack(chars, colors) - x = self.conv_head(x) - x = x.view(-1, self.out_size) - x = self.fc_head(x) - return x - - def _embed(self, chars, colors): - chars = selectt(self.char_embeddings, chars.long(), True) - colors = selectt(self.color_embeddings, colors.long(), True) - return chars, colors - - def _stack(self, chars, colors): - obs = torch.cat([chars, colors], dim=-1) - return obs.permute(0, 1, 4, 2, 3).flatten(1, 2).contiguous() - - -class ChaoticDwarvenGPT5(Encoder): - def __init__(self, cfg: Config, obs_space: ObsSpace): - super().__init__(cfg) - self.obs_keys = list(sorted(obs_space.keys())) # always the same order - self.encoders = nn.ModuleDict() - - self.use_tty_only = cfg.use_tty_only - self.use_prev_action = cfg.use_prev_action - - screen_shape = obs_space["tty_chars"].shape - self.screen_encoder = CharColorEncoder( - (screen_shape[0] - 3, screen_shape[1]), - char_edim=cfg.char_edim, - color_edim=cfg.color_edim, - ) - - # top and bottom encoders - self.topline_encoder = TopLineEncoder() - self.bottomline_encoder = torch.jit.script(BottomLinesEncoder()) - - if self.use_prev_action: - self.num_actions = obs_space["prev_actions"].n - self.prev_actions_dim = self.num_actions - else: - self.num_actions = None - self.prev_actions_dim = 0 - - self.encoder_out_size = sum( - [ - self.topline_encoder.hidden_dim, - self.bottomline_encoder.hidden_dim, - self.screen_encoder.hidden_dim, - self.prev_actions_dim, - ] - ) - - def get_out_size(self) -> int: - return self.encoder_out_size - - def forward(self, obs_dict): - B, H, W = obs_dict["tty_chars"].shape - # to process images with CNNs we need channels dim - C = 1 - - # Take last channel for now - topline = obs_dict["tty_chars"][:, 0].contiguous() - bottom_line = obs_dict["tty_chars"][:, -2:].contiguous() - - # Blstats - blstats_rep = self.bottomline_encoder(bottom_line.float(memory_format=torch.contiguous_format).view(B, -1)) - - encodings = [ - self.topline_encoder(topline.float(memory_format=torch.contiguous_format).view(B, -1)), - blstats_rep, - ] - - # Main obs encoding - tty_chars = ( - obs_dict["tty_chars"][:, 1:-2] - .contiguous() - .float(memory_format=torch.contiguous_format) - .view(B, C, H - 3, W) - ) - tty_colors = obs_dict["tty_colors"][:, 1:-2].contiguous().view(B, C, H - 3, W) - encodings.append(self.screen_encoder(tty_chars, tty_colors)) - - if self.use_prev_action: - prev_actions = obs_dict["prev_actions"].long().view(B) - encodings.append(torch.nn.functional.one_hot(prev_actions, self.num_actions)) - - return torch.cat(encodings, dim=1) - - -def selectt(embedding_layer, x, use_index_select): - """Use index select instead of default forward to possible speed up embedding.""" - if use_index_select: - # Access weight through the embedding layer - return nn.functional.embedding(x, embedding_layer.weight) - else: - # Use standard embedding forward - return embedding_layer(x) - - -if __name__ == "__main__": - # Test the screen encoder - encoder = CharColorEncoder( - (21 - 3, 80), - char_edim=16, - color_edim=16, - ) - tty_chars = torch.zeros(160, 1, 21, 80) - tty_colors = torch.zeros(160, 1, 21, 80) - print(encoder(tty_chars, tty_colors).shape) diff --git a/sf_examples/nethack/models/crop.py b/sf_examples/nethack/models/crop.py deleted file mode 100644 index 2a4a8d80f..000000000 --- a/sf_examples/nethack/models/crop.py +++ /dev/null @@ -1,66 +0,0 @@ -import logging - -try: - import torch - from torch import nn - from torch.nn import functional as F -except ImportError: - logging.exception("PyTorch not found. Please install the agent dependencies with " '`pip install "nle[agent]"`') - - -def _step_to_range(delta, num_steps): - """Range of `num_steps` integers with distance `delta` centered around zero.""" - return delta * torch.arange(-num_steps // 2, num_steps // 2) - - -class Crop(nn.Module): - """Helper class for NetHackNet below.""" - - def __init__(self, height, width, height_target, width_target): - super(Crop, self).__init__() - self.width = width - self.height = height - self.width_target = width_target - self.height_target = height_target - width_grid = _step_to_range(2 / (self.width - 1), self.width_target)[None, :].expand(self.height_target, -1) - height_grid = _step_to_range(2 / (self.height - 1), height_target)[:, None].expand(-1, self.width_target) - - # "clone" necessary, https://github.com/pytorch/pytorch/issues/34880 - self.register_buffer("width_grid", width_grid.clone()) - self.register_buffer("height_grid", height_grid.clone()) - - def forward(self, inputs, coordinates): - """Calculates centered crop around given x,y coordinates. - Args: - inputs [B x H x W] - coordinates [B x 2] x,y coordinates - Returns: - [B x H' x W'] inputs cropped and centered around x,y coordinates. - """ - assert inputs.shape[1] == self.height - assert inputs.shape[2] == self.width - - inputs = inputs[:, None, :, :].float() - - x = coordinates[:, 0] - y = coordinates[:, 1] - - # NOTE: Need to do -self.width/2 + 1/2 here to cancel things out correctly - # with the width_grid below for both even and odd input dimensions. - x_shift = 2 / (self.width - 1) * (x.float() - self.width / 2 + 1 / 2) - y_shift = 2 / (self.height - 1) * (y.float() - self.height / 2 + 1 / 2) - - grid = torch.stack( - [ - self.width_grid[None, :, :] + x_shift[:, None, None], - self.height_grid[None, :, :] + y_shift[:, None, None], - ], - dim=3, - ) - - # NOTE: Location x, y in grid tells you the shift from the cursor - # coordinates. The reason we do all this 2/(self.width - 1) stuff is because - # the inverse of this happens in the below F.grid_sample function. The F.grid_sample - # implementation is here: https://github.com/pytorch/pytorch/blob/f064c5aa33483061a48994608d890b968ae53fb5/aten/src/THNN/generic/SpatialGridSamplerBilinear.c#L41 - - return torch.round(F.grid_sample(inputs, grid, align_corners=True)).squeeze(1).long() diff --git a/sf_examples/nethack/models/scaled.py b/sf_examples/nethack/models/scaled.py deleted file mode 100644 index 4ae2cce63..000000000 --- a/sf_examples/nethack/models/scaled.py +++ /dev/null @@ -1,382 +0,0 @@ -"""Adapted from Scaling Laws for Imitation Learning in NetHack: -https://arxiv.org/abs/2307.09423 - -Credit to Jens Tuyls -""" - -import math -from typing import List - -import torch -from nle import nethack # noqa: E402 -from nle.nethack.nethack import TERMINAL_SHAPE -from torch import nn -from torch.nn import functional as F - -from sample_factory.model.encoder import Encoder -from sample_factory.utils.typing import Config, ObsSpace -from sf_examples.nethack.models.crop import Crop -from sf_examples.nethack.models.utils import interleave - -PAD_CHAR = 0 -NUM_CHARS = 256 - - -class ScaledNet(Encoder): - def __init__(self, cfg: Config, obs_space: ObsSpace): - super().__init__(cfg) - - self.obs_keys = list(sorted(obs_space.keys())) # always the same order - self.encoders = nn.ModuleDict() - - self.use_prev_action = cfg.use_prev_action - self.msg_hdim = cfg.msg_hdim - self.h_dim = cfg.h_dim - self.il_mode = False - self.scale_cnn_channels = 1 - self.num_lstm_layers = 1 - self.num_fc_layers = 2 - self.num_screen_fc_layers = 1 - self.color_edim = cfg.color_edim - self.char_edim = cfg.char_edim - self.crop_dim = 9 - self.crop_out_filters = 8 - self.crop_num_layers = 5 - self.crop_inter_filters = 16 - self.crop_padding = 1 - self.crop_kernel_size = 3 - self.crop_stride = 1 - self.use_crop = cfg.use_crop - self.use_resnet = cfg.use_resnet - self.use_crop_norm = cfg.use_crop_norm - self.action_embedding_dim = 32 - self.obs_frame_stack = 1 - self.num_res_blocks = 2 - self.num_res_layers = 2 - self.screen_shape = TERMINAL_SHAPE - self.screen_kernel_size = cfg.screen_kernel_size - self.no_max_pool = cfg.no_max_pool - self.screen_conv_blocks = cfg.screen_conv_blocks - self.blstats_hdim = cfg.blstats_hdim if cfg.blstats_hdim else cfg.h_dim - self.fc_after_cnn_hdim = cfg.fc_after_cnn_hdim if cfg.fc_after_cnn_hdim else cfg.h_dim - - # NOTE: -3 because we cut the topline and bottom two lines - if self.use_crop: - self.crop = Crop(self.screen_shape[0] - 3, self.screen_shape[1], self.crop_dim, self.crop_dim) - crop_in_channels = [self.char_edim + self.color_edim] + [self.crop_inter_filters] * ( - self.crop_num_layers - 1 - ) - crop_out_channels = [self.crop_inter_filters] * (self.crop_num_layers - 1) + [self.crop_out_filters] - conv_extract_crop = [] - norm_extract_crop = [] - for i in range(self.crop_num_layers): - conv_extract_crop.append( - nn.Conv2d( - in_channels=crop_in_channels[i], - out_channels=crop_out_channels[i], - kernel_size=(self.crop_kernel_size, self.crop_kernel_size), - stride=self.crop_stride, - padding=self.crop_padding, - ) - ) - norm_extract_crop.append(nn.BatchNorm2d(crop_out_channels[i])) - - if self.use_crop_norm: - self.extract_crop_representation = nn.Sequential( - *interleave(conv_extract_crop, norm_extract_crop, [nn.ELU()] * len(conv_extract_crop)) - ) - else: - self.extract_crop_representation = nn.Sequential( - *interleave(conv_extract_crop, [nn.ELU()] * len(conv_extract_crop)) - ) - self.crop_out_dim = self.crop_dim**2 * self.crop_out_filters - else: - self.crop_out_dim = 0 - - self.topline_encoder = TopLineEncoder(msg_hdim=self.msg_hdim) - self.bottomline_encoder = BottomLinesEncoder(h_dim=self.blstats_hdim // 4) - - self.screen_encoder = CharColorEncoderResnet( - (self.screen_shape[0] - 3, self.screen_shape[1]), - h_dim=self.fc_after_cnn_hdim, - num_fc_layers=self.num_screen_fc_layers, - scale_cnn_channels=self.scale_cnn_channels, - color_edim=self.color_edim, - char_edim=self.char_edim, - obs_frame_stack=self.obs_frame_stack, - num_res_blocks=self.num_res_blocks, - num_res_layers=self.num_res_layers, - kernel_size=self.screen_kernel_size, - no_max_pool=self.no_max_pool, - screen_conv_blocks=self.screen_conv_blocks, - ) - - if self.use_prev_action: - self.num_actions = obs_space["prev_actions"].n - self.prev_actions_dim = self.num_actions - else: - self.num_actions = None - self.prev_actions_dim = 0 - - self.out_dim = sum( - [ - self.topline_encoder.msg_hdim, - self.bottomline_encoder.h_dim, - self.screen_encoder.h_dim, - self.prev_actions_dim, - self.crop_out_dim, - ] - ) - - fc_layers = [nn.Linear(self.out_dim, self.h_dim), nn.ReLU()] - for _ in range(self.num_fc_layers - 1): - fc_layers.append(nn.Linear(self.h_dim, self.h_dim)) - fc_layers.append(nn.ReLU()) - self.fc = nn.Sequential(*fc_layers) - - self.encoder_out_size = self.h_dim - - def get_out_size(self) -> int: - return self.encoder_out_size - - def _select(self, embed, x): - # Work around slow backward pass of nn.Embedding, see - # https://github.com/pytorch/pytorch/issues/24912 - out = embed.weight.index_select(0, x.reshape(-1)) - return out.reshape(x.shape + (-1,)) - - def forward(self, obs_dict): - B, H, W = obs_dict["tty_chars"].shape - # to process images with CNNs we need channels dim - C = 1 - - # Take last channel for now - topline = obs_dict["tty_chars"][:, 0].contiguous() - bottom_line = obs_dict["tty_chars"][:, -2:].contiguous() - - # Blstats - blstats_rep = self.bottomline_encoder(bottom_line.float(memory_format=torch.contiguous_format).view(B, -1)) - - encodings = [ - self.topline_encoder(topline.float(memory_format=torch.contiguous_format).view(B, -1)), - blstats_rep, - ] - - # Main obs encoding - tty_chars = ( - obs_dict["tty_chars"][:, 1:-2] - .contiguous() - .float(memory_format=torch.contiguous_format) - .view(B, C, H - 3, W) - ) - tty_colors = obs_dict["tty_colors"][:, 1:-2].contiguous().view(B, C, H - 3, W) - tty_cursor = obs_dict["tty_cursor"].contiguous().view(B, -1) - encodings.append(self.screen_encoder(tty_chars, tty_colors)) - - # Previous action encoding - if self.use_prev_action: - encodings.append(torch.nn.functional.one_hot(obs_dict["prev_actions"].long(), self.num_actions).view(B, -1)) - - # Crop encoding - if self.use_crop: - # very important! otherwise we'll mess with tty_cursor below - # uint8 is needed for -1 operation to work properly 0 -> 255 - tty_cursor = tty_cursor.clone().to(torch.uint8) - tty_cursor[:, 0] -= 1 # adjust y position for cropping below - tty_cursor = tty_cursor.flip(-1) # flip (y, x) to be (x, y) - crop_tty_chars = self.crop(tty_chars[..., -1, :, :], tty_cursor) - crop_tty_colors = self.crop(tty_colors[..., -1, :, :], tty_cursor) - crop_chars = selectt(self.screen_encoder.char_embeddings, crop_tty_chars.long(), True) - crop_colors = selectt(self.screen_encoder.color_embeddings, crop_tty_colors.long(), True) - crop_obs = torch.cat([crop_chars, crop_colors], dim=-1) - encodings.append(self.extract_crop_representation(crop_obs.permute(0, 3, 1, 2).contiguous()).view(B, -1)) - - encodings = self.fc(torch.cat(encodings, dim=1)) - - return encodings - - -class CharColorEncoderResnet(nn.Module): - """ - Inspired by network from IMPALA https://arxiv.org/pdf/1802.01561.pdf - """ - - def __init__( - self, - screen_shape, - h_dim: int = 512, - scale_cnn_channels: int = 1, - num_fc_layers: int = 1, - char_edim: int = 16, - color_edim: int = 16, - obs_frame_stack: int = 1, - num_res_blocks: int = 2, - num_res_layers: int = 2, - kernel_size: int = 3, - no_max_pool: bool = False, - screen_conv_blocks: int = 3, - ): - super(CharColorEncoderResnet, self).__init__() - - self.h, self.w = screen_shape - self.h_dim = h_dim - self.num_fc_layers = num_fc_layers - self.char_edim = char_edim - self.color_edim = color_edim - self.no_max_pool = no_max_pool - self.screen_conv_blocks = screen_conv_blocks - - self.blocks = [] - - self.conv_params = [ - [ - char_edim * obs_frame_stack + color_edim * obs_frame_stack, - int(16 * scale_cnn_channels), - kernel_size, - num_res_blocks, - ], - [int(16 * scale_cnn_channels), int(32 * scale_cnn_channels), kernel_size, num_res_blocks], - [int(32 * scale_cnn_channels), int(32 * scale_cnn_channels), kernel_size, num_res_blocks], - ] - - self.conv_params = self.conv_params[: self.screen_conv_blocks] - - for in_channels, out_channels, filter_size, num_res_blocks in self.conv_params: - block = [] - # Downsample - block.append(nn.Conv2d(in_channels, out_channels, filter_size, stride=1, padding=(filter_size // 2))) - if not self.no_max_pool: - block.append(nn.MaxPool2d(kernel_size=3, stride=2)) - self.h = math.floor((self.h - 1 * (3 - 1) - 1) / 2 + 1) # from PyTorch Docs - self.w = math.floor((self.w - 1 * (3 - 1) - 1) / 2 + 1) # from PyTorch Docs - - # Residual block(s) - for _ in range(num_res_blocks): - block.append(ResBlock(out_channels, out_channels, filter_size, num_res_layers)) - self.blocks.append(nn.Sequential(*block)) - - self.conv_net = nn.Sequential(*self.blocks) - self.out_size = self.h * self.w * out_channels - - fc_layers = [nn.Linear(self.out_size, self.h_dim), nn.ELU(inplace=True)] - for _ in range(self.num_fc_layers - 1): - fc_layers.append(nn.Linear(self.h_dim, self.h_dim)) - fc_layers.append(nn.ELU(inplace=True)) - self.fc_head = nn.Sequential(*fc_layers) - - self.char_embeddings = nn.Embedding(256, self.char_edim) - self.color_embeddings = nn.Embedding(128, self.color_edim) - - def forward(self, chars, colors): - chars, colors = self._embed(chars, colors) # 21 x 80 - x = self._stack(chars, colors) - x = self.conv_net(x) - x = x.view(-1, self.out_size) - x = self.fc_head(x) - return x - - def _embed(self, chars, colors): - chars = selectt(self.char_embeddings, chars.long(), True) - colors = selectt(self.color_embeddings, colors.long(), True) - return chars, colors - - def _stack(self, chars, colors): - obs = torch.cat([chars, colors], dim=-1) - return obs.permute(0, 1, 4, 2, 3).flatten(1, 2).contiguous() - - -class ResBlock(nn.Module): - def __init__(self, in_channels: int, out_channels: int, filter_size: int, num_layers: int): - super(ResBlock, self).__init__() - layers = [] - for _ in range(num_layers): - layers.append(nn.Conv2d(in_channels, out_channels, filter_size, stride=1, padding=(filter_size // 2))) - layers.append(nn.BatchNorm2d(out_channels)) - layers.append(nn.ELU(inplace=True)) - - self.net = nn.Sequential(*layers) - - def forward(self, x): - return self.net(x) + x - - -class BottomLinesEncoder(nn.Module): - """ - Adapted from https://github.com/dungeonsdatasubmission/dungeonsdata-neurips2022/blob/67139262966aa11555cf7aca15723375b36fbe42/experiment_code/hackrl/models/offline_chaotic_dwarf.py - """ - - def __init__(self, h_dim: int = 128, scale_cnn_channels: int = 1): - super(BottomLinesEncoder, self).__init__() - self.conv_layers = [] - w = nethack.NLE_TERM_CO * 2 - for in_ch, out_ch, filter, stride in [ - [2, int(32 * scale_cnn_channels), 8, 4], - [int(32 * scale_cnn_channels), int(64 * scale_cnn_channels), 4, 1], - ]: - self.conv_layers.append(nn.Conv1d(in_ch, out_ch, filter, stride=stride)) - self.conv_layers.append(nn.ELU(inplace=True)) - w = conv_outdim(w, filter, padding=0, stride=stride) - - self.h_dim = h_dim - - self.out_dim = w * out_ch - self.conv_net = nn.Sequential(*self.conv_layers) - self.fwd_net = nn.Sequential( - nn.Linear(self.out_dim, self.h_dim), - nn.ELU(), - nn.Linear(self.h_dim, self.h_dim), - nn.ELU(), - ) - - def forward(self, bottom_lines): - B, D = bottom_lines.shape - # ASCII 32: ' ', ASCII [33-128]: visible characters - chars_normalised = (bottom_lines - 32) / 96 - - # ASCII [45-57]: -./01234556789 - numbers_mask = (bottom_lines > 44) * (bottom_lines < 58) - digits_normalised = numbers_mask * (bottom_lines - 47) / 10 # why subtract 47 here and not 48? - - # Put in different channels & conv (B, 2, D) - x = torch.stack([chars_normalised, digits_normalised], dim=1) - return self.fwd_net(self.conv_net(x).view(B, -1)) - - -class TopLineEncoder(nn.Module): - """ - This class uses a one-hot encoding of the ASCII characters - as features that get fed into an MLP. - Adapted from https://github.com/dungeonsdatasubmission/dungeonsdata-neurips2022/blob/67139262966aa11555cf7aca15723375b36fbe42/experiment_code/hackrl/models/offline_chaotic_dwarf.py - """ - - def __init__(self, msg_hdim: int): - super(TopLineEncoder, self).__init__() - self.msg_hdim = msg_hdim - self.i_dim = nethack.NLE_TERM_CO * 256 - - self.msg_fwd = nn.Sequential( - nn.Linear(self.i_dim, self.msg_hdim), - nn.ELU(inplace=True), - nn.Linear(self.msg_hdim, self.msg_hdim), - nn.ELU(inplace=True), - ) - - def forward(self, message): - # Characters start at 33 in ASCII and go to 128. 96 = 128 - 32 - message_normed = F.one_hot((message).long(), 256).reshape(-1, self.i_dim).float() - return self.msg_fwd(message_normed) - - -def conv_outdim(i_dim, k, padding=0, stride=1, dilation=1): - """Return the dimension after applying a convolution along one axis""" - return int(1 + (i_dim + 2 * padding - dilation * (k - 1) - 1) / stride) - - -def selectt(embedding_layer, x, use_index_select): - """Use index select instead of default forward to possible speed up embedding.""" - if use_index_select: - out = embedding_layer.weight.index_select(0, x.view(-1)) - # handle reshaping x to 1-d and output back to N-d - return out.view(x.shape + (-1,)) - else: - return embedding_layer(x) diff --git a/sf_examples/nethack/models/simba.py b/sf_examples/nethack/models/simba.py deleted file mode 100644 index 5d25c8747..000000000 --- a/sf_examples/nethack/models/simba.py +++ /dev/null @@ -1,167 +0,0 @@ -import torch -import torch.nn as nn - -from sample_factory.algo.utils.torch_utils import calc_num_elements -from sample_factory.model.encoder import Encoder -from sample_factory.utils.typing import Config, ObsSpace - - -class ResBlock(nn.Module): - def __init__(self, hidden_dim, kernel_size=3, padding=1, stride=1): - super().__init__() - self.norm1 = nn.BatchNorm2d(hidden_dim) - self.linear1 = nn.Conv2d(hidden_dim, hidden_dim, kernel_size=kernel_size, padding=padding, stride=stride) - self.relu = nn.ReLU() - self.linear2 = nn.Conv2d(hidden_dim, hidden_dim, kernel_size=kernel_size, padding=padding, stride=stride) - - def forward(self, x): - residual = x - x = self.norm1(x) - x = self.linear1(x) - x = self.relu(x) - x = self.linear2(x) - return x + residual - - -class ResNet(nn.Module): - def __init__( - self, - input_dim, - hidden_dim, - num_blocks, - kernel_size: int = 3, - padding: int = 1, - stride: int = 1, - ): - super().__init__() - self.input_layer = nn.Conv2d(input_dim, hidden_dim, kernel_size=kernel_size, padding=padding, stride=stride) - self.res_blocks = nn.ModuleList( - [ResBlock(hidden_dim, kernel_size=kernel_size, padding=padding, stride=stride) for _ in range(num_blocks)] - ) - self.final_norm = nn.BatchNorm2d(hidden_dim) - - def forward(self, x): - x = self.input_layer(x) - - for res_block in self.res_blocks: - x = res_block(x) - - x = self.final_norm(x) - return x - - -class SimbaEncoder(nn.Module): - def __init__( - self, - obs_space, - *, - char_edim, - color_edim, - hidden_dim, - num_blocks, - kernel_size=3, - padding=1, - ): - super().__init__() - - self.char_embeddings = nn.Embedding(256, char_edim) - self.color_embeddings = nn.Embedding(128, color_edim) - self.resnet = ResNet( - char_edim + color_edim, - hidden_dim=hidden_dim, - num_blocks=num_blocks, - kernel_size=kernel_size, - padding=padding, - ) - - screen_shape = obs_space["tty_chars"].shape - self.out_size = calc_num_elements(self.resnet, (char_edim + color_edim,) + screen_shape) - - self.fc_head = nn.Sequential( - nn.Linear(self.out_size, hidden_dim), nn.ReLU(inplace=True), nn.LayerNorm(hidden_dim) - ) - - def forward(self, obs): - chars = obs["tty_chars"] - colors = obs["tty_colors"] - chars, colors = self._embed(chars, colors) - x = self._stack(chars, colors) - x = self.resnet(x) - x = x.view(-1, self.out_size) - x = self.fc_head(x) - return x - - def _embed(self, chars, colors): - chars = selectt(self.char_embeddings, chars.long(), True) - colors = selectt(self.color_embeddings, colors.long(), True) - return chars, colors - - def _stack(self, chars, colors): - obs = torch.cat([chars, colors], dim=-1) - return obs.permute(0, 3, 1, 2).contiguous() - - -def selectt(embedding_layer, x, use_index_select): - """Use index select instead of default forward to possible speed up embedding.""" - if use_index_select: - # Access weight through the embedding layer - return nn.functional.embedding(x, embedding_layer.weight) - else: - # Use standard embedding forward - return embedding_layer(x) - - -class SimbaActorEncoder(Encoder): - def __init__(self, cfg: Config, obs_space: ObsSpace): - super().__init__(cfg) - - self.model = SimbaEncoder( - obs_space=obs_space, - char_edim=self.cfg.actor_char_edim, - color_edim=self.cfg.actor_color_edim, - hidden_dim=self.cfg.actor_hidden_dim, - num_blocks=self.cfg.actor_num_blocks, - ) - - def forward(self, x): - return self.model(x) - - def get_out_size(self): - return self.cfg.actor_hidden_dim - - -class SimbaCriticEncoder(Encoder): - def __init__(self, cfg: Config, obs_space: ObsSpace): - super().__init__(cfg) - - self.model = SimbaEncoder( - obs_space=obs_space, - char_edim=self.cfg.critic_char_edim, - color_edim=self.cfg.critic_color_edim, - hidden_dim=self.cfg.critic_hidden_dim, - num_blocks=self.cfg.critic_num_blocks, - ) - - def forward(self, x): - return self.model(x) - - def get_out_size(self): - return self.cfg.critic_hidden_dim - - -if __name__ == "__main__": - from sample_factory.algo.utils.env_info import extract_env_info - from sample_factory.algo.utils.make_env import make_env_func_batched - from sample_factory.utils.attr_dict import AttrDict - from sf_examples.nethack.train_nethack import parse_nethack_args, register_nethack_components - - register_nethack_components() - cfg = parse_nethack_args(argv=["--env=nethack_score"]) - - env = make_env_func_batched(cfg, env_config=AttrDict(worker_index=0, vector_index=0, env_id=0)) - env_info = extract_env_info(env, cfg) - - obs, info = env.reset() - encoder = SimbaCriticEncoder(cfg, env_info.obs_space) - x = encoder(obs) - print(x.shape) diff --git a/sf_examples/nethack/models/utils.py b/sf_examples/nethack/models/utils.py deleted file mode 100644 index d30a8bba8..000000000 --- a/sf_examples/nethack/models/utils.py +++ /dev/null @@ -1,2 +0,0 @@ -def interleave(*args): - return [val for pair in zip(*args) for val in pair]