Skip to content

Commit

Permalink
Added decoding function to covert back to normal values
Browse files Browse the repository at this point in the history
  • Loading branch information
dmccrevan committed Dec 13, 2018
1 parent c479c38 commit 3420998
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 8 deletions.
19 changes: 19 additions & 0 deletions src/modules/generationtools/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,22 @@ def categorical_convert(col):
# sample from the distributions and return that value
return col.apply(lambda x: distributions[x].rvs()), limits


def undo_cat(col, lim):
"""Convert the categorical column to normalized valus
Arguments:
col {Column} -- The dataframe's column
lim {Dict} -- The dictionary containing the limits per catergorical column
Returns:
Column -- The new column
"""


def cat_decode(x, limits):
for k, v in limits.items():
if x < k:
return v

return col.apply(lambda x: cat_decode(x, lim))
15 changes: 7 additions & 8 deletions src/modules/generationtools/synthesize.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import numpy as np
from scipy.stats import expon, truncnorm, beta, uniform, norm
from .cleandata import MissingValues, DatetimeToEPOCH
from .categorical import identify, categorical_convert
from .categorical import identify, categorical_convert, undo_cat

def sample(f, sigma):
'''
Expand Down Expand Up @@ -57,15 +57,13 @@ def synthesize_table(file_in, file_out, lines = 0):
df = MissingValues(df)
df = DatetimeToEPOCH(df)

cat_cols = {}
limits = {}
counter = 0
for col in df:
if(identify(df[col])):
new_col, limit = categorical_convert(df[col])
limits[counter] = limit
counter += 1
limits[col] = limit
df[col] = new_col

# calculate distributions and covariances using tools in model_generation.py
dists, pvalues, params = mg.findBestDistribution(df)
f = (dists, params)
Expand All @@ -78,8 +76,9 @@ def synthesize_table(file_in, file_out, lines = 0):
new_df = pd.DataFrame(columns = list(df))
for k in range(lines):
new_df.loc[k] = sample(f, sigma)
new_df.to_csv(file_out, index = False)

for col in new_df.columns:
if col in limits:
new_df[col] = undo_cat(new_df[col], limits[col])

if __name__ == "__main__":
synthesize_table("test.csv", "s_test.csv")
new_df.to_csv(file_out, index = False)

0 comments on commit 3420998

Please sign in to comment.