Skip to content

Commit

Permalink
Update tasks.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Daniel Khashabi authored Sep 22, 2020
1 parent ea8e43c commit 7bf0653
Showing 1 changed file with 22 additions and 2 deletions.
24 changes: 22 additions & 2 deletions tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@

DATA_DIR = f"gs://unifiedqa/data/"

def generic_dataset_preprocessor(ds):
def dataset_preprocessor(ds):
def normalize_text(text):
"""Lowercase and remove quotes from a TensorFlow string."""
text = tf.strings.lower(text)
Expand All @@ -66,14 +66,34 @@ def to_inputs_and_targets(ex):
return ds.map(to_inputs_and_targets,
num_parallel_calls=tf.data.experimental.AUTOTUNE)

def get_downloaded_data_path(data_dir1, split):
def get_path(data_dir1, split):
tsv_path = {
"train": os.path.join(data_dir1, "train.tsv"),
"dev": os.path.join(data_dir1, "dev.tsv"),
"test": os.path.join(data_dir1, "test.tsv")
}
return tsv_path[split]


def dataset_fn(split, shuffle_files=False, dataset=""):
# We only have one file for each split.
del shuffle_files

# Load lines from the text file as examples.
ds = tf.data.TextLineDataset(get_path(DATA_DIR + dataset, split))
# Split each "<question>\t<answer>" example into (question, answer) tuple.
print(" >>>> about to read csv . . . ")
ds = ds.map(
functools.partial(tf.io.decode_csv, record_defaults=["", ""],
field_delim="\t", use_quote_delim=False),
num_parallel_calls=tf.data.experimental.AUTOTUNE)
# print(" >>>> after reading csv . . . ")
# Map each tuple to a {"question": ... "answer": ...} dict.
ds = ds.map(lambda *ex: dict(zip(["inputs", "targets"], ex)))
# print(" >>>> after mapping . . . ")
return ds


for dataset in DATASETS:
print(f" >>>> reading dataset: {dataset}")
t5.data.set_tfds_data_dir_override(DATA_DIR + dataset)
Expand Down

0 comments on commit 7bf0653

Please sign in to comment.