diff --git a/src/modules/generationtools/categorical.py b/src/modules/generationtools/categorical.py index 303ce6b..38d8201 100644 --- a/src/modules/generationtools/categorical.py +++ b/src/modules/generationtools/categorical.py @@ -76,3 +76,22 @@ def categorical_convert(col): # sample from the distributions and return that value return col.apply(lambda x: distributions[x].rvs()), limits + +def undo_cat(col, lim): + """Convert the categorical column to normalized valus + + Arguments: + col {Column} -- The dataframe's column + lim {Dict} -- The dictionary containing the limits per catergorical column + + Returns: + Column -- The new column + """ + + + def cat_decode(x, limits): + for k, v in limits.items(): + if x < k: + return v + + return col.apply(lambda x: cat_decode(x, lim)) \ No newline at end of file diff --git a/src/modules/generationtools/synthesize.py b/src/modules/generationtools/synthesize.py index fa1731b..3f86a5c 100644 --- a/src/modules/generationtools/synthesize.py +++ b/src/modules/generationtools/synthesize.py @@ -9,7 +9,7 @@ import numpy as np from scipy.stats import expon, truncnorm, beta, uniform, norm from .cleandata import MissingValues, DatetimeToEPOCH -from .categorical import identify, categorical_convert +from .categorical import identify, categorical_convert, undo_cat def sample(f, sigma): ''' @@ -57,15 +57,13 @@ def synthesize_table(file_in, file_out, lines = 0): df = MissingValues(df) df = DatetimeToEPOCH(df) + cat_cols = {} limits = {} - counter = 0 for col in df: if(identify(df[col])): new_col, limit = categorical_convert(df[col]) - limits[counter] = limit - counter += 1 + limits[col] = limit df[col] = new_col - # calculate distributions and covariances using tools in model_generation.py dists, pvalues, params = mg.findBestDistribution(df) f = (dists, params) @@ -78,8 +76,9 @@ def synthesize_table(file_in, file_out, lines = 0): new_df = pd.DataFrame(columns = list(df)) for k in range(lines): new_df.loc[k] = sample(f, sigma) - new_df.to_csv(file_out, index = False) + for col in new_df.columns: + if col in limits: + new_df[col] = undo_cat(new_df[col], limits[col]) -if __name__ == "__main__": - synthesize_table("test.csv", "s_test.csv") \ No newline at end of file + new_df.to_csv(file_out, index = False)