Skip to content

Commit

Permalink
add batched feature to normalizer
Browse files Browse the repository at this point in the history
  • Loading branch information
Zevrap-81 committed Dec 17, 2022
1 parent d6ce378 commit 51698d6
Showing 1 changed file with 27 additions and 3 deletions.
30 changes: 27 additions & 3 deletions trainer/normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,27 @@ def __init__(self, name:str):
self.range_= None
self.samples_seen= 0

def __call__(self, x):
return self.transform(x)
def __call__(self, x, inverse= False, batched=True):
if batched:
shape= x.shape
x= x.reshape(-1, shape[-1])

if not inverse:
x= self.transform(x)
else:
x= self.inverse(x)

if batched:
x= x.reshape(shape)

return x

def fit(self, x, batched=True):
if batched:
batch_size=x.shape[0]
shape= x.shape
x= x.reshape(-1, shape[-1])

def fit(self, x):
if isinstance(x, np.ndarray):
return self.fit_numpy(x)
else:
Expand Down Expand Up @@ -45,6 +62,8 @@ def fit_torch(self, x:torch.tensor):
self.samples_seen+=1

def transform(self, x):
self.__handling_zeros()

min, max= torch.from_numpy(self.min_), torch.from_numpy(self.max_)
range= torch.from_numpy(self.range_)

Expand All @@ -56,6 +75,11 @@ def inverse(self, y):
y= y * torch.from_numpy(self.range_).to(device) + torch.from_numpy(self.min_).to(device)
return y

def __handling_zeros(self):
#handling zeroes in range
constant_mask = self.range < 10 * np.finfo(self.range.dtype).eps
self.range[constant_mask] = 1.0

def save_state(self, dir:str= ""):
dir= osp.join(dir,self.name+"_norm.pkl")
with open(dir, "wb") as pickle_file:
Expand Down

0 comments on commit 51698d6

Please sign in to comment.