Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DPL-048 fix root sample ids #528

Closed
wants to merge 15 commits into from
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,10 @@ dmypy.json
.pyre/

tests/data/reports/*

# Data fixes
*/data-fixes/*/data/
*/data-fixes/*/test-data
*/data-fixes/*/constants.py

*.DS_Store
35 changes: 26 additions & 9 deletions lighthouse/data-fixes/DPL-048/data_fix_and_save.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,32 @@
# get the root_sample_ids, fix them, write these to a CSV to be used with the 'write_data' script (which inserts the fixed IDs into the DBs)
# get the data, fix it, write these to a CSV to be used with the 'write_data' script (which inserts the fixed data into the DBs)
import pandas as pd
import argparse

from data_getters import get_data
from data_helpers import remove_everything_after_first_underscore

def save_data():
data = get_data()
print("Editing the root_sample_ids...")
data = data.rename(columns={"root_sample_id": "original_root_sample_id"})
data["root_sample_id"] = data["original_root_sample_id"].apply(remove_everything_after_first_underscore)
print("Adding the root_sample_ids to a CSV file.")
data.to_csv('data-fixes/test-data/root_sample_ids.csv', index=False)
from constants import (
COLUMN_NAME,
ORIGINAL_COLUMN_NAME
)

def save_data(input_filename, output_filename):
if input_filename:
data = pd.read_csv(input_filename)
else:
data = get_data()

print("Editing the data...")
data = data.rename(columns={COLUMN_NAME: ORIGINAL_COLUMN_NAME})
data[COLUMN_NAME] = data[ORIGINAL_COLUMN_NAME].apply(remove_everything_after_first_underscore)
print("Adding the data to a CSV file.")
data.to_csv(output_filename, index=False)

if __name__ == "__main__":
save_data()
parser = argparse.ArgumentParser()
parser.add_argument("--input_file", required=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

required=False? I think you later rely on there being a value. Might be wrong.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you don't give an input_file it will return 'None' when you ask for it on line 30. Then that gets passed into the method, which has a check for whether it exists (ie. whether it's not None or it is None), and if it wasn't None it will read the file but if not it goes to the DB to get the data. This line was only for testing really, so that I could check that it would fix the data correctly and save to CSV when given some dummy data. In reality it should go to the DB to get the data because that's the data we're trying to fix

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahhh I see. That makes sense! Good stuff.

parser.add_argument("--output_file", required=True)
args = parser.parse_args()
input_filename = vars(args)["input_file"]
output_filename = vars(args)["output_file"]
save_data(input_filename=input_filename, output_filename=output_filename)
18 changes: 11 additions & 7 deletions lighthouse/data-fixes/DPL-048/data_getters.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,29 @@
# get the root_sample_ids from MLWH - any root_sample_ids containing an underscore
# get the data from MLWH - use the SQL_MLWH_GET_MALFORMED_DATA constant in constants.py to give the SQL call for the data
import sqlalchemy
import pandas as pd

from constants import MYSQL_DB_CONN_STRING, MLWH_DB, SQL_MLWH_GET_MALFORMED_ROOT_IDS
from constants import (
MYSQL_DB_CONN_STRING,
MLWH_DB,
SQL_MLWH_GET_MALFORMED_DATA
)

def get_data() -> pd.DataFrame:
print("Attempting to connect to DB.")
print("Attempting to connect to MLWH.")
try:
sql_engine = sqlalchemy.create_engine(
f"mysql+pymysql://{MYSQL_DB_CONN_STRING}/{MLWH_DB}", pool_recycle=3600
)
db_connection = sql_engine.connect()
print("Connected to the DB... getting data.")
data = pd.read_sql(SQL_MLWH_GET_MALFORMED_ROOT_IDS, db_connection)
print("Connected to MLWH... getting data.")
data = pd.read_sql(SQL_MLWH_GET_MALFORMED_DATA, db_connection)
print("Got the data.")
except Exception as e:
print("Error while connecting to MySQL")
print("Error while connecting to MLWH.")
print(e)
return None
finally:
if db_connection is not None:
print("Closing DB connection.")
print("Closing MLWH connection.")
db_connection.close()
return data
59 changes: 39 additions & 20 deletions lighthouse/data-fixes/DPL-048/data_writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,23 @@
import pandas as pd
import argparse

from constants import (LOCALHOST, MONGO_DB, MONGO_DB_CLIENT, MONGO_TABLE, MYSQL_USER, MYSQL_PWD, MLWH_DB, fixed_samples_file)
from constants import (
COLUMN_NAME,
ORIGINAL_COLUMN_NAME,
MONGO_COLUMN_NAME,
MYSQL_HOST,
MONGO_DB,
MONGO_DB_HOST,
MONGO_DB_AUTH_SOURCE,
MONGO_DB_USER,
MONGO_DB_PASSWORD,
MONGO_TABLE,
MYSQL_USER,
MYSQL_PWD,
MLWH_DB,
MLWH_TABLE,
fixed_data_file
)

def write_data_to_db(data: pd.DataFrame, database: str):
if database.lower() == "mongo":
Expand All @@ -18,48 +34,51 @@ def write_data_to_db(data: pd.DataFrame, database: str):
def write_to_mongo(data):
print("Attempting to connect to Mongo DB...")
try:
client = pymongo.MongoClient(MONGO_DB_CLIENT)
client = pymongo.MongoClient(MONGO_DB_HOST, username=MONGO_DB_USER, password=MONGO_DB_PASSWORD, authSource=MONGO_DB_AUTH_SOURCE)
db = client[MONGO_DB]
table = db[MONGO_TABLE]
print("Loading in the data...")
for index, row in data.iterrows():
root_sample_id = row["root_sample_id"]
original_root_sample_id = row["original_root_sample_id"]
new_value = row[COLUMN_NAME]
original_value = row[ORIGINAL_COLUMN_NAME]

update_query = { "Root Sample ID": original_root_sample_id }
new_value = { "$set": { "Root Sample ID": root_sample_id } }
update_query = { MONGO_COLUMN_NAME: original_value }
new_value_query = { "$set": { MONGO_COLUMN_NAME: new_value } }

table.update_many(update_query, new_value)
table.update_many(update_query, new_value_query)
print("Data loaded in successfully.")
except Exception as e:
print("Error while connecting to MongoDB")
print("Error while connecting to Mongo DB.")
print(e)
return None

def write_to_mysql(data):
print("Attempting to connect to MLWH...")
try:
db_connection = mysql.connector.connect(host = LOCALHOST,
db_connection = mysql.connector.connect(host = MYSQL_HOST,
database = MLWH_DB,
user = MYSQL_USER,
password = MYSQL_PWD)
password = MYSQL_PWD,
port = '3436')
print("Loading in the data...")
cursor = db_connection.cursor()
for index, row in data.iterrows():
root_sample_id = row["root_sample_id"]
original_root_sample_id = row["original_root_sample_id"]
new_value = row[COLUMN_NAME]
original_value = row[ORIGINAL_COLUMN_NAME]
update_query = (
f"UPDATE lighthouse_sample"
f" SET root_sample_id = '{root_sample_id}'"
f" WHERE root_sample_id = '{original_root_sample_id}'"
f"UPDATE {MLWH_TABLE}"
f" SET {COLUMN_NAME} = '{new_value}'"
f" WHERE {COLUMN_NAME} = '{original_value}'"
)
cursor.execute(update_query)
rows_updated = cursor.rowcount
db_connection.commit()
try:
cursor.execute(update_query)
db_connection.commit()
except Exception:
pass
cursor.close()
print("Data loaded in successfully.")
except Exception as e:
print("Error while connecting to MySQL")
print("Error while connecting to MLWH.")
print(e)
return None
finally:
Expand All @@ -73,5 +92,5 @@ def write_to_mysql(data):
parser.add_argument("--db")
args = parser.parse_args()
db = vars(args)["db"]
data = pd.read_csv(fixed_samples_file)
data = pd.read_csv(fixed_data_file)
write_data_to_db(data, db)
64 changes: 59 additions & 5 deletions lighthouse/data-fixes/DPL-048/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,22 @@
import pandas as pd
import sqlalchemy
import pymongo
import mysql.connector

from constants import MYSQL_DB_CONN_STRING, MLWH_DB, MONGO_DB, MONGO_DB_CLIENT, SQL_GET_ALL_DATA, malformed_csv, control_csv, skippable_csv
from constants import (
MYSQL_DB_CONN_STRING,
MYSQL_PWD,
MYSQL_USER,
MYSQL_HOST,
MLWH_DB,
MONGO_DB,
MONGO_DB_CLIENT,
MONGO_TABLE,
SQL_GET_ALL_DATA,
malformed_csv,
control_csv,
skippable_csv
)

def print_data():
try:
Expand All @@ -18,12 +32,12 @@ def print_data():
print('got data')
print(data)
except Exception as e:
print("Error while connecting to MySQL")
print("Error while connecting to MLWH.")
print(e)
return None
finally:
if db_connection is not None:
print("Closing mlwh connection")
print("Closing MLWH connection.")
db_connection.close()

def populate_local_db(database):
Expand All @@ -46,12 +60,12 @@ def populate_mongo(data):
try:
client = pymongo.MongoClient(MONGO_DB_CLIENT)
db = client[MONGO_DB]
table = db["samples"]
table = db[MONGO_TABLE]
data_dict = data.to_dict()
print(data_dict)
table.insert_many([data_dict])
except Exception as e:
print("Error while connecting to MongoDB")
print("Error while connecting to Mongo DB.")
print(e)
return None

Expand All @@ -73,3 +87,43 @@ def populate_mysql(data):
print("Closing mlwh connection")
db_connection.close()
return None

def find_duplicate_root_sample_ids(root_sample_ids):
print("Attempting to connect to DB.")
try:
db_connection = mysql.connector.connect(host = MYSQL_HOST,
database = MLWH_DB,
user = MYSQL_USER,
password = MYSQL_PWD,
port = '3436')
full_data = pd.DataFrame()
print("Loading the data...")
cursor = db_connection.cursor()
for index, row in root_sample_ids.iterrows():
if index % 1000 == 0:
print("Reached index "+ str(index))
root_sample_id = row["root_sample_id"]
plate_barcode = row["plate_barcode"]
coordinate = row["coordinate"]
select_query = (
f"SELECT root_sample_id, plate_barcode, coordinate"
f" FROM lighthouse_sample"
f" WHERE root_sample_id = '{root_sample_id}'"
f" AND plate_barcode = '{plate_barcode}'"
f" AND coordinate = '{coordinate}'"
)
cursor.execute(select_query)
db_data = cursor.fetchall()
data_row = pd.DataFrame(db_data)
full_data = pd.concat([full_data, data_row])
cursor.close()
print("Data loaded in successfully.")
except Exception as e:
print("Error while connecting to MySQL")
print(e)
return None
finally:
if db_connection is not None:
print("Closing DB connection.")
db_connection.close()
return full_data