forked from L2-Regulasyon/Teknofest2023
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgenerate_data.py
49 lines (39 loc) · 1.29 KB
/
generate_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import pandas as pd
from utils.constants import TARGET_DICT, TARGET_DICT_FASHION
from utils.preprocess_utils import preprocess_text
from sklearn.model_selection import StratifiedKFold
from sklearn.utils import shuffle
import argparse
parser = argparse.ArgumentParser(
description="Generating preprocessed text classification data"
)
parser.add_argument(
"--data_path",
required=True,
type=str,
help="Path to pandas-readable not-validation splitted dataset file.",
)
parser.add_argument("--data_name", required=True, type=str, help="name of the data")
parser.add_argument(
"--fashion", type=bool, default=False, help="Is it category prediction"
)
opt = parser.parse_args() # Corrected method
# Load data
df = pd.read_csv(opt.data_path, sep=",")
if opt.fashion:
df["text"] = df["title"]
else:
df["text"] = df["description_text"]
preprocess_text(df)
# Length filtering
df["text_len"] = df.text.str.len()
df = df[(df.text_len >= 5)].reset_index(drop=True)
# Label Encoding
if opt.fashion:
df["target"] = df["related_product"].map(TARGET_DICT_FASHION)
else:
df["target"] = df["category"].map(TARGET_DICT)
# Shuffle the DataFrame
df_shuffled = shuffle(df, random_state=42)
# Export shuffled DataFrame
df_shuffled.to_csv(f"../data/processed/{opt.data_name}.csv", index=False)