Skip to content

Commit

Permalink
wikipediaを間引く
Browse files Browse the repository at this point in the history
  • Loading branch information
yuiseki committed May 2, 2024
1 parent d017728 commit 85e95e2
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,9 @@ def prepare_train_data(dataset_id):
if "dataset_load_config" in train_config:
dataset_load_config = train_config["dataset_load_config"]
data = load_dataset(dataset_id, dataset_load_config, split="train", num_proc=32)
if dataset_load_config == "20231101.ja" or dataset_load_config == "20231101.vi" or dataset_load_config == "20231101.es":
if dataset_load_config == "20231101.ja" or dataset_load_config == "20231101.vi" or dataset_load_config == "20231101.es" or dataset_load_config == "20231101.it":
data = data.filter(lambda item, idx: idx % 3 == 0, with_indices=True)
if dataset_load_config == "20231101.de":
if dataset_load_config == "20231101.de" or dataset_load_config == "20231101.fr":
data = data.filter(lambda item, idx: idx % 5 == 0, with_indices=True)
else:
data = load_dataset(dataset_id, split="train", num_proc=32)
Expand Down Expand Up @@ -162,7 +162,8 @@ def prepare_train_data(dataset_id):
lambda x: simple_template_for_train(x[input_field_name], x[output_field_name]),
axis=1,
)

# keep only text field
data = data_df[["text"]]
data = Dataset.from_pandas(data_df)
data = data.train_test_split(seed=42, test_size=0.2)
print(len(data["train"]))
Expand Down Expand Up @@ -281,7 +282,7 @@ def load_model_and_tokenizer(model_id):
args=training_arguments,
tokenizer=tokenizer,
packing=False,
max_seq_length=512,
max_seq_length=1024,
)

#
Expand Down

0 comments on commit 85e95e2

Please sign in to comment.