Skip to content

Commit

Permalink
Merge pull request #57 from mitre/household_perf_again
Browse files Browse the repository at this point in the history
Additional tweaks to improve performance of household inference
  • Loading branch information
dehall authored Apr 18, 2023
2 parents 4619b41 + 03f6d6e commit f6d11fc
Show file tree
Hide file tree
Showing 3 changed files with 243 additions and 115 deletions.
187 changes: 104 additions & 83 deletions households.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,13 @@
import sys
from datetime import datetime
from pathlib import Path
from random import shuffle
from zipfile import ZipFile

import pandas as pd

from definitions import TIMESTAMP_FMT
from derive_subkey import derive_subkey
from households.matching import addr_parse, get_household_matches
from households.matching import get_household_matches

HEADERS = ["HOUSEHOLD_POSITION", "PII_POSITIONS"]
HOUSEHOLD_PII_HEADERS = [
Expand Down Expand Up @@ -76,6 +75,10 @@ def parse_arguments():
" Smaller numbers may result in out of memory errors. Larger numbers"
" may increase runtime. Default is 4",
)
parser.add_argument(
"--pairsfile",
help="Location of matching pairs file",
)
parser.add_argument(
"--debug",
action="store_true",
Expand Down Expand Up @@ -108,53 +111,29 @@ def parse_source_file(source_file, debug=False):
if debug:
print(f"[{datetime.now()}] Start loading PII file")

# force all columns to be strings, even if they look numeric
df = pd.read_csv(source_file, dtype=str)

# break out the address into number, street, suffix, etc,
# so we can prefilter matches based on those
addr_cols = df.apply(
explode_address,
axis="columns",
result_type="expand",
# dtype=str means force all columns to be strings even if they look numeric
# keep_default_na keeps empty cells as empty string, not a NaN
# usecols means only read the given colummn names,
# aka don't read the columns that are never used here: given_name, DOB, sex
df = pd.read_csv(
source_file,
dtype=str,
keep_default_na=False,
usecols=[
"record_id",
"family_name",
"phone_number",
"household_street_address",
"household_zip",
],
)
df = pd.concat([df, addr_cols], axis="columns")

if debug:
print(f"[{datetime.now()}] Done pre-processing PII file")
print(f"[{datetime.now()}] Done loading PII file")

return df


def explode_address(row):
# this addr_parse function is relatively slow so only run it once per row.
# by caching the exploded dict this way we ensure
# that we have it in the right form in all the right places its needed
parsed = addr_parse(row.household_street_address)
parsed["exploded_address"] = parsed.copy()
parsed["exploded_address"][
"household_street_address"
] = row.household_street_address
return parsed


def write_households_pii(output_rows, household_time):
shuffle(output_rows)
timestamp = household_time.strftime(TIMESTAMP_FMT)
hh_pii_path = Path("temp-data") / f"households_pii-{timestamp}.csv"
with open(
hh_pii_path,
"w",
newline="",
encoding="utf-8",
) as house_csv:
print(f"Writing households PII to {hh_pii_path}")
writer = csv.writer(house_csv)
writer.writerow(HOUSEHOLD_PII_HEADERS)
for output_row in output_rows:
writer.writerow(output_row)


# Simple breadth-first-search to turn a graph-like structure of pairs
# into a list representing the ids in the household
def bfs_traverse_matches(pos_to_pairs, position):
Expand Down Expand Up @@ -187,58 +166,99 @@ def get_default_pii_csv(dirname="temp-data"):
return source_file


def write_mapping_file(pos_pid_rows, hid_pat_id_rows, args):
def write_pii_and_mapping_file(pos_pid_rows, hid_pat_id_rows, household_time, args):
if args.sourcefile:
source_file = Path(args.sourcefile)
else:
source_file = get_default_pii_csv()
print(f"PII Source: {str(source_file)}")
pii_lines = parse_source_file(source_file, args.debug)
output_rows = []

# pos_to_pairs is a dict of:
# (patient position) --> [matching pairs that include that patient]
# so it can be traversed sort of like a graph from any given patient
# note the key is patient position within the pii_lines dataframe
pos_to_pairs = get_household_matches(
pii_lines, args.split_factor, args.debug, args.pairsfile
)

mapping_file = Path(args.mappingfile)

n_households = 0
with open(mapping_file, "w", newline="", encoding="utf-8") as csvfile:
writer = csv.writer(csvfile)
writer.writerow(HEADERS)
already_added = set()

# pos_to_pairs is a dict of:
# (patient position) --> [matching pairs that include that patient]
# so it can be traversed sort of like a graph from any given patient
# note the key is patient position within the pii_lines dataframe
pos_to_pairs = get_household_matches(pii_lines, args.split_factor, args.debug)
mapping_writer = csv.writer(csvfile)
mapping_writer.writerow(HEADERS)

if args.debug:
print(f"[{datetime.now()}] Assembling output file")

hclk_position = 0
# Match households
for position, line in pii_lines.iterrows():
if position in already_added:
continue
already_added.add(position)

if position in pos_to_pairs:
pat_clks = bfs_traverse_matches(pos_to_pairs, position)
pat_ids = list(map(lambda p: pii_lines.at[p, "record_id"], pat_clks))
already_added.update(pat_clks)
else:
pat_clks = [position]
pat_ids = [line[0]]

string_pat_clks = [str(int) for int in pat_clks]
pat_string = ",".join(string_pat_clks)
writer.writerow([hclk_position, pat_string])
n_households += 1
pos_pid_rows.append([hclk_position, line[0]])
for patid in pat_ids:
hid_pat_id_rows.append([hclk_position, patid])
# note pat_ids_str will be quoted by the csv writer if needed
pat_ids_str = ",".join(pat_ids)
output_row = [line[2], line[5], line[6], line[7], pat_ids_str]
hclk_position += 1
output_rows.append(output_row)
return output_rows, n_households
timestamp = household_time.strftime(TIMESTAMP_FMT)
hh_pii_path = Path("temp-data") / f"households_pii-{timestamp}.csv"
with open(
hh_pii_path,
"w",
newline="",
encoding="utf-8",
) as hh_pii_csv:
print(f"Writing households PII to {hh_pii_path}")
pii_writer = csv.writer(hh_pii_csv)
pii_writer.writerow(HOUSEHOLD_PII_HEADERS)

pii_lines["written_to_file"] = False
hclk_position = 0
lines_processed = 0
five_percent = int(len(pii_lines) / 20)
# Match households
for position, line in pii_lines.sample(frac=1).iterrows():
# sample(frac=1) shuffles the entire dataframe
# note that "position" is the index and still relative to the original

lines_processed += 1

if args.debug and (lines_processed % five_percent) == 0:
print(
f"[{datetime.now()}] Processing pii lines"
f" - {lines_processed}/{len(pii_lines)}"
)

if line["written_to_file"]:
continue
line["written_to_file"] = True

if position in pos_to_pairs:
pat_positions = bfs_traverse_matches(pos_to_pairs, position)
# map those row numbers to PATIDs
pat_ids = list(
map(lambda p: pii_lines.at[p, "record_id"], pat_positions)
)
# mark all these rows as written to file
pii_lines.loc[pat_positions, ["written_to_file"]] = True
else:
pat_positions = [position]
pat_ids = [line[0]]

string_pat_positions = [str(p) for p in pat_positions]
pat_string = ",".join(string_pat_positions)
mapping_writer.writerow([hclk_position, pat_string])
n_households += 1

if args.testrun:
pos_pid_rows.append([hclk_position, line[0]])
for patid in pat_ids:
hid_pat_id_rows.append([hclk_position, patid])

# note pat_ids_str will be quoted by the csv writer if needed
pat_ids_str = ",".join(pat_ids)
output_row = [
line["family_name"],
line["phone_number"],
line["household_street_address"],
line["household_zip"],
pat_ids_str,
]
hclk_position += 1
pii_writer.writerow(output_row)
return n_households


def write_scoring_file(hid_pat_id_rows):
Expand Down Expand Up @@ -297,8 +317,9 @@ def infer_households(args, household_time):
hid_pat_id_rows = []
os.makedirs(Path("output") / "households", exist_ok=True)
os.makedirs("temp-data", exist_ok=True)
output_rows, n_households = write_mapping_file(pos_pid_rows, hid_pat_id_rows, args)
write_households_pii(output_rows, household_time)
n_households = write_pii_and_mapping_file(
pos_pid_rows, hid_pat_id_rows, household_time, args
)
if args.testrun:
write_scoring_file(hid_pat_id_rows)
write_hid_hh_pos_map(pos_pid_rows)
Expand Down
Loading

0 comments on commit f6d11fc

Please sign in to comment.