Skip to content

Commit

Permalink
fixed bugs in dataloaders and augmentations, adjusted models for new …
Browse files Browse the repository at this point in the history
…default in_channels
  • Loading branch information
Ciara-AI committed Apr 30, 2022
1 parent 74c5ab5 commit d68d38e
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 6 deletions.
2 changes: 1 addition & 1 deletion models/FCN.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def forward(self, x):
return x

class FullyConvNetwork(nn.Module):
def __init__(self, input_dims = 10, num_classes = 2):
def __init__(self, input_dims = 9, num_classes = 2):

"""
DESC
Expand Down
2 changes: 1 addition & 1 deletion models/resnet_tsc.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class ResNetTSC(nn.Module):
ResNetTSC Model Instance
"""

def __init__(self, num_classes=2, in_channels=10):
def __init__(self, num_classes=2, in_channels=9):
super(ResNetTSC, self).__init__()

# Configurations of Stages
Expand Down
7 changes: 4 additions & 3 deletions utils/DataLoaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ def __init__(self, root):
self.transforms = torchvision.transforms.Compose([
RotateColsRandomly(X_POS_COL, Y_POS_COL),
RotateColsRandomly(X_VEL_COL, Y_VEL_COL),
RotateColsRandomly(factor = NOICE_FACTOR),
GaussianNoise(factor = NOICE_FACTOR),

])

def __len__(self):
Expand All @@ -77,13 +78,13 @@ def __getitem__(self, idx):
X = self.X[idx]
Y = self.Y[idx]

X = torch.tensor(X)
X = torch.from_numpy(X)
Y = torch.tensor(Y).long()

X = self.transforms(X)

# cast to tensor and return
return X, torch.tensor(Y).long()
return X, Y

def collate_fn(batch):

Expand Down
2 changes: 1 addition & 1 deletion utils/augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(self, col_1, col_2):
def __call__(self, x):

# create an identity matrix of shape dim
dims = self.shape[2]
dims = x.shape[1]
A = torch.eye(dims)

# calculate random radians to rotate
Expand Down

0 comments on commit d68d38e

Please sign in to comment.