diff --git a/nesta/core/batchables/general/nih/curate/run.py b/nesta/core/batchables/general/nih/curate/run.py new file mode 100644 index 00000000..e437ad94 --- /dev/null +++ b/nesta/core/batchables/general/nih/curate/run.py @@ -0,0 +1,443 @@ +""" +run.py (general.nih.curate) +=========================== + +Curate NiH data, ready for ingestion to the general ES endpoint. +""" + +from nesta.core.luigihacks.elasticsearchplus import _null_empty_str +from nesta.core.luigihacks.elasticsearchplus import _clean_up_lists +from nesta.core.luigihacks.elasticsearchplus import _remove_padding +from nesta.core.luigihacks.elasticsearchplus import _country_detection +from nesta.packages.geo_utils.lookup import get_us_states_lookup +from nesta.packages.geo_utils.lookup import get_continent_lookup +from nesta.packages.geo_utils.lookup import get_country_continent_lookup +from nesta.packages.geo_utils.lookup import get_eu_countries +from nesta.packages.geo_utils.country_iso_code import country_iso_code +from nesta.packages.geo_utils.geocode import _geocode +from nesta.core.orms.orm_utils import db_session, get_mysql_engine +from nesta.core.orms.orm_utils import insert_data +from nesta.core.orms.orm_utils import object_to_dict + +# ORMs +from nesta.core.orms.nih_orm import Projects, Abstracts +from nesta.core.orms.nih_orm import TextDuplicate +from nesta.core.orms.general_orm import NihProject, Base + +from ast import literal_eval +import boto3 +from collections import defaultdict +from datetime import datetime +import dateutil.parser +from itertools import groupby, chain +import json +import logging +from operator import attrgetter +import os +from sqlalchemy.orm import load_only + + +# Alias these fields, as they're verbose and used a lot +PK_ID = Projects.application_id +CORE_ID = Projects.base_core_project_num +DATETIME_COLS = [c for c in Projects.__table__.columns + if c.type.python_type is datetime] + +# Group different types of field together, as they will be treated +# in a common way on aggregation +DATETIME_FIELDS = [c.name for c in DATETIME_COLS] +FLAT_FIELDS = ["application_id", "base_core_project_num", "fy", + "org_city", "org_country", "org_name", "org_state", + "org_zipcode", "project_title", "ic_name", "phr", + "abstract_text"] +LIST_FIELDS = ["clinicaltrial_ids", "clinicaltrial_titles", "patent_ids", + "patent_titles", "pmids", "project_terms"] + + +# Geo names edge cases +CTRY_LOOKUP = {'Korea Rep Of': 'Korea, Republic of', + 'Russia': 'Russian Federation', + 'Congo Dem Rep': 'Congo, The Democratic Republic of the', + "Cote D'ivoire": "Côte d'Ivoire", + 'Dominican Rep': 'Dominican Republic', + 'Eswatini': 'Swaziland', + 'Fed Micronesia': 'Micronesia, Federated States of', + 'Papua N Guinea': 'Papua New Guinea', + 'St Kitts/nevis': 'Saint Kitts and Nevis', + 'St Lucia': 'Saint Lucia', + 'Tanzania U Rep': 'Tanzania', + 'Trinidad/toba': 'Trinidad and Tobago'} + + +def get_projects_by_appl_id(engine, appl_ids, nrows=None, + pull_relationships=None): + """Get NiH projects by application ID (i.e. the primary key) + and assign each project to it's own group (i.e. a group size of 1). + This method is meant for projects with a NULL core ID, which therefore + can't be grouped into a family of projects. Note that the argument + `pull_relationships` is a dummy argument so that this function + can be used as a template alongside `group_projects_by_core_id`. + """ + # Get all projects in the given set of IDs + filter_stmt = PK_ID.in_(appl_ids) + with db_session(engine) as sess: + q = sess.query(Projects).filter(filter_stmt).order_by(PK_ID) + results = q.limit(nrows).all() + # "Fake" single project groups + groups = [[object_to_dict(obj, shallow=False, + properties=True)] + for obj in results] + return groups + + +def group_projects_by_core_id(engine, core_ids, nrows=None, + pull_relationships=False): + """Get NiH projects by the base core project number ("core id"), + and then group projects by this core id. If `pull_relationships` + is True, then also unbundle any SqlAlchemy "relationship" objects; + although this isn't required (and therefore substantially speeds + things up) when, for example, only IDs are required.""" + # Get all projects in the given set of IDs + filter_stmt = CORE_ID.in_(core_ids) + with db_session(engine) as sess: + q = sess.query(Projects).filter(filter_stmt).order_by(CORE_ID) + results = q.limit(nrows).all() + # Group the results by the core project number + groups = [[object_to_dict(obj, shallow=not pull_relationships, + properties=True) + for obj in group] + for _, group in + groupby(results, attrgetter('base_core_project_num'))] + return groups + + +def get_sim_weights(dupes, appl_ids): + """Retrieve the similarity weights for this project""" + sim_weights = defaultdict(list) + for d in dupes: + appl_id_1 = d['application_id_1'] + appl_id_2 = d['application_id_2'] + # Referring to Note a) in `retrieve_similar_projects`, + # determine which ID is the PK for the similar project + id_ = appl_id_1 if appl_id_1 not in appl_ids else appl_id_2 + sim_weights[id_].append(d['weight']) + # Match against the largest weight, if the similar project + # has been retrieved multiple times + sim_weights = {id_: max(weights) + for id_, weights in sim_weights.items()} + return sim_weights + + +def retrieve_similar_projects(engine, appl_ids): + """Retrieve all projects which are similar to those in this + project group. Some of the similar projects will be retrieved + multiple times if matched to multiple projects in the group. + `appl_ids` is the set of IDs in this group. + """ + # Note a) the TextDuplicate table doesn't double-count + # application IDs, so the application IDs of this group + # could be in either application_id_1 or application_id_2 + either = (TextDuplicate.application_id_1.in_(appl_ids) | + TextDuplicate.application_id_2.in_(appl_ids)) + both = (TextDuplicate.application_id_1.in_(appl_ids) & + TextDuplicate.application_id_2.in_(appl_ids)) + # We want either application_id_1 or application_id_2, but + # not both, since in such a case both projects would already be + # in the same group. + filter_stmt = (either & ~both) + with db_session(engine) as session: + dupes = session.query(TextDuplicate).filter(filter_stmt).all() + dupes = [object_to_dict(obj, shallow=True) for obj in dupes] + + # Get the similarity weights for this project + sim_weights = get_sim_weights(dupes, appl_ids) + sim_ids = set(sim_weights.keys()) + + # Retrieve only the required fields by project id + filter_stmt = PK_ID.in_(sim_ids) + query_fields = [PK_ID, CORE_ID, Projects.fy, *DATETIME_COLS] + with db_session(engine) as session: + q = session.query(*query_fields).filter(filter_stmt) + sim_projs = [{field.name: value + for field, value in zip(query_fields, values)} + for values in q.all()] + return sim_projs, sim_weights + + +def earliest_date(project): + """Determine the earliest date, among all the date fields + in this project. Returns `datetime.min` if no date is found.""" + year = project['fy'] + # Try to find a date + dates = [] + for f in DATETIME_FIELDS: + date = project[f] + if type(date) is str: + date = dateutil.parser.parse(date) + elif date is None: + continue + dates.append(date) + min_date = datetime.min # default value if no date fields present + if len(dates) > 0: + min_date = min(dates) + # Otherwise, fall back on the year field + elif year is not None: + min_date = datetime(year=year, month=1, day=1) + return min_date + + +def retrieve_similar_proj_ids(engine, appl_ids): + """Retrieve similar projects, expand each similar + project into its group using the core ID. + Then extract only the most recent PK ID from each group, + and group these PK IDs by their similarity score, in order + to have lists of "near duplicates", "very similar" and + "fairly similar" IDs. `appl_ids` is the set of IDs in this group, + from which similar projects are to be found. + """ + # Retrieve similar projects + projs, weights = retrieve_similar_projects(engine, appl_ids) + # Retrieve core IDs in order to perform groupby on the + # similar projects + groups = [] + core_ids = set() + for proj in projs: + core_id = proj["base_core_project_num"] + # Around 3% of projs have no core id, and hence no group + if core_id is None: + groups.append([proj]) + else: + core_ids.add(core_id) + groups += group_projects_by_core_id(engine, core_ids) + + # Return just the PK of the most recent project in each group + pk_weights = {} + for group in groups: + # Get the most recent project + sorted_group = sorted(group, key=earliest_date, reverse=True) + pk0 = sorted_group[0]['application_id'] + # Get the maximum similarity of any project in the group + pks = set(proj['application_id'] for proj in group) + max_weight = max(weights[pk] for pk in pks if pk in weights) + pk_weights[pk0] = max_weight + + # Group projects by their similarity + similar_projs = group_projs_by_similarity(pk_weights) + return similar_projs + + +def group_projs_by_similarity(pk_weights, + ranges = {'near_duplicate_ids': (0.8, 1), + 'very_similar_ids': (0.65, 0.8), + 'fairly_similar_ids': (0.4, 0.65)}): + """Group projects by range of similarity. Ranges have been + hand-selected, and clearly are subject to optimisation.""" + grouped_projs = {label: [pk for pk, weight in pk_weights.items() + if weight > lower and weight <= upper] + for label, (lower, upper) in ranges.items()} + return grouped_projs + + +def combine(func, list_of_dict, key): + """Apply the given function over the values retrieved + by the given key for each item in a of dictionaries""" + values = [_dict[key] for _dict in list_of_dict + if _dict[key] is not None] + if len(values) == 0: + return None + return func(values) + + +def first_non_null(values): + """Return the first non-null value in the list""" + for v in values: + if v is None: + continue + return v + return None + + +def join_and_dedupe(values): + """Flatten the list and deduplicate""" + return list(set(chain(*values))) + + +def format_us_zipcode(zipcode): + """NiH US postcodes have wildly inconsistent formatting, + leading to geocoding errors. If the postcode if greater + than 5 chars, it should be in the format XXXXX-XXXX, + or XXXXX, even if the first 5 chars require zero-padding.""" + ndigits = len(zipcode) + # Only apply the procedure to numeric postcodes like + if not zipcode.isnumeric(): + return zipcode + # e.g 123456789 --> 12345-6789 + # or 3456789 --> 00345-6789 + if ndigits > 5: + start, end = zipcode[:-4].zfill(5), zipcode[-4:] + return f'{start}-{end}' + # e.g 12345 --> 12345 + # or 345 --> 00345 + else: + return zipcode.zfill(5) + + +def geocode(city, state, country, postalcode): + """Apply the OSM geocoding for as many fields as possible.""" + kwargs = {'city': city, + 'state': state, + 'country': country, + 'postalcode': postalcode} + # Ditch null kwargs + kwargs = {k: v for k, v in kwargs.items() + if v is not None} + if len(kwargs) == 0: + return None + # Try with the postal code (doesn't always work, but when + # it does it gives more accurate results) + coords = _geocode(**kwargs) + # Otherwise, try removing the postcode + if coords is None and 'postalcode' in kwargs: + del kwargs['postalcode'] + coords = _geocode(**kwargs) + # If still no results, try a plain query (tends to give + # very coarse resolution) + if coords is None: + coords = _geocode(q=', '.join(kwargs.values())) + return coords + + +def aggregate_group(group): + """Aggregate fields from all projects in this group into a + single curated pseudo-project.""" + # Sort by most recent first + group = list(sorted(group, key=earliest_date, reverse=True)) + project = {"grouped_ids": [p['application_id'] for p in group], + "grouped_titles": [p['project_title'] for p in group]} + + # Extract the first non-null fields directly from PROJECT_FIELDS + for field in FLAT_FIELDS: + project[field] = combine(first_non_null, group, field) + # Concat list fields + for field in LIST_FIELDS: + project[field] = combine(join_and_dedupe, group, field) + # Specific aggregrations + project["project_start"] = combine(min, group, "project_start") + project["project_end"] = combine(max, group, "project_end") + project["total_cost"] = combine(sum, group, "total_cost") + + # Extra specific aggregrations for yearly funds + yearly_groups = defaultdict(list) + for proj in group: + date = earliest_date(proj) + if date == datetime.min: # i.e. no date found + continue + yearly_groups[date.year].append(proj) + # Combine by year group + yearly_funds = [{"year": year, + "project_start": combine(min, yr_group, "project_start"), + "project_end": combine(max, yr_group, "project_end"), + "total_cost": combine(sum, yr_group, "total_cost")} + for year, yr_group in yearly_groups.items()] + project["yearly_funds"] = sorted(yearly_funds, key=lambda x: x['year']) + return project + + +def extract_geographies(row): + """Infer standard geographic info for this row""" + + # Lookup helpers (note, all are lru_cached) + states_lookup = get_us_states_lookup() + ctry_continent_lookup = get_country_continent_lookup() + continent_lookup = get_continent_lookup() + eu_countries = get_eu_countries() + + # If country name is badly formatted, reassign + ctry = row['org_country'] + if ctry in CTRY_LOOKUP: + ctry = CTRY_LOOKUP[ctry] + + # Perform lookups + iso2 = None + if ctry is not None: + iso_info = country_iso_code(ctry) + iso2 = iso_info.alpha_2 + row['org_country'] = iso_info.name # Standardise country naming + row['iso2'] = iso2 + row['is_eu'] = iso2 in eu_countries + row['state_name'] = states_lookup[row['org_state']] + continent_iso2 = ctry_continent_lookup[iso2] + row['continent_iso2'] = continent_iso2 + row['continent_name'] = continent_lookup[continent_iso2] + + # Clean zip code if US + if iso2 == 'US' and row['org_zipcode'] is not None: + row['org_zipcode'] = format_us_zipcode(row['org_zipcode']) + + # Retrieve lat / lon for this org + row['coordinates'] = geocode(city=row['org_city'], + state=row['state_name'], + country=row['org_country'], + postalcode=row['org_zipcode']) + return row + + +def apply_cleaning(row): + """Curate raw data for ingestion to MySQL.""" + row = _country_detection(row, 'country_mentions') + row = _remove_padding(row) + row = _null_empty_str(row) + row = _clean_up_lists(row) + return row + + +def run(): + test = literal_eval(os.environ["BATCHPAR_test"]) + using_core_ids = literal_eval(os.environ["BATCHPAR_using_core_ids"]) + bucket = os.environ['BATCHPAR_bucket'] + batch_file = os.environ['BATCHPAR_batch_file'] + db_name = os.environ["BATCHPAR_db_name"] + os.environ["MYSQLDB"] = os.environ["BATCHPAR_config"] + + # Database setup + engine = get_mysql_engine("MYSQLDB", "mysqldb", db_name) + + # Retrieve list of core ids from s3 + nrows = 1000 if test else None + s3 = boto3.resource('s3') + obj = s3.Object(bucket, batch_file) + core_ids = json.loads(obj.get()['Body']._raw_stream.read()) + logging.info(f"{len(core_ids)} ids retrieved from s3") + + # Get the groups for this batch. + # Around 3% of core ids are null, and so these are retrieved + # in batches of application id instead, and otherwise pull + # in the projects with the non-null core id as these can + # be aggregated together. + data_getter = (group_projects_by_core_id if using_core_ids + else get_projects_by_appl_id) + groups = data_getter(engine, core_ids, pull_relationships=True) + + # Curate each group + data = [] + for group in groups: + appl_ids = [proj['application_id'] for proj in group] + similar_projs = retrieve_similar_proj_ids(engine, appl_ids) + project = aggregate_group(group) + geographies = extract_geographies(project) + row = {**project, **geographies, **similar_projs} + row = apply_cleaning(row) + data.append(row) + + # Insert data into the database + insert_data("MYSQLDB", "mysqldb", db_name, Base, + NihProject, data, low_memory=True) + logging.info("Batch job complete.") + + +if __name__ == "__main__": + log_stream_handler = logging.StreamHandler() + logging.basicConfig(handlers=[log_stream_handler, ], + level=logging.INFO, + format="%(asctime)s:%(levelname)s:%(message)s") + run() diff --git a/nesta/core/batchables/general/nih/curate/tests/test_nih_curate.py b/nesta/core/batchables/general/nih/curate/tests/test_nih_curate.py new file mode 100644 index 00000000..f7e6a004 --- /dev/null +++ b/nesta/core/batchables/general/nih/curate/tests/test_nih_curate.py @@ -0,0 +1,358 @@ +from unittest import mock + +from nesta.core.batchables.general.nih.curate import run + +PATH = "nesta.core.batchables.general.nih.curate.run.{}" +dt = run.datetime + +@mock.patch(PATH.format("db_session")) +@mock.patch(PATH.format("object_to_dict")) +def test_get_projects_by_appl_id(mocked_obj2dict, + mocked_db_session): + appl_ids = ['a', 'b', 1, 2, 3] + + # Mock the session and query + mocked_session = mock.Mock() + q = mocked_session.query().filter().order_by().limit() + q.all.return_value = appl_ids # <-- will just return the input + # Assign the session to the context manager + mocked_db_session().__enter__.return_value = mocked_session + + # Just return the value itself + mocked_obj2dict.side_effect = lambda obj, shallow, properties: obj + + # Test that single-member groups are created + groups = run.get_projects_by_appl_id(None, appl_ids) + assert groups == [[id_] for id_ in appl_ids] + + +def _result_factory(value): + m = mock.Mock() + m.base_core_project_num = value + return m + + +@mock.patch(PATH.format("db_session")) +@mock.patch(PATH.format("object_to_dict")) +def test_group_projects_by_core_id(mocked_obj2dict, + mocked_db_session): + core_ids = ['a', 1, 'b', 'b', 1, 2, 1] + results = [_result_factory(v) for v in core_ids] + groups = [[{'base_core_project_num': 'a'}], # Group 1 + [{'base_core_project_num': 1}, # Group 2 + {'base_core_project_num': 1}, + {'base_core_project_num': 1}], + [{'base_core_project_num': 'b'}, # Group 3 + {'base_core_project_num': 'b'}], + [{'base_core_project_num': 2}]] # Group 4 + + # Mock the session and query + mocked_session = mock.Mock() + q = mocked_session.query().filter().order_by().limit() + q.all.return_value = results # <-- will just return the input + # Assign the session to the context manager + mocked_db_session().__enter__.return_value = mocked_session + + # Just return the value itself + mocked_obj2dict.side_effect = lambda obj, shallow, properties: obj + + # Test that single-member groups are created + groups = run.group_projects_by_core_id(None, core_ids) + assert groups == groups + + +def test_get_sim_weights(): + appl_ids = [1, 2, 3, 4] + dupes = [{'application_id_1': 1, + 'application_id_2': 5, + 'weight': 0.4}, + {'application_id_1': 1, + 'application_id_2': 6, + 'weight': 0.9}, + {'application_id_1': 2, + 'application_id_2': 6, + 'weight': 0.8}, + {'application_id_1': 3, + 'application_id_2': 5, + 'weight': 0.3}] + # The max weight of ids not in `appl_ids` + sim_weights = {5: 0.4, 6: 0.9} + + assert run.get_sim_weights(dupes, appl_ids) == sim_weights + +@mock.patch(PATH.format("db_session")) +@mock.patch(PATH.format("object_to_dict")) +@mock.patch(PATH.format("get_sim_weights")) +def test_retrieve_similar_projects(mocked_get_sim_weights, + mocked_obj2dict, + mocked_db_session): + sim_weights = {5: 0.4, 6: 0.9} + mocked_get_sim_weights.return_value = sim_weights + sim_ids = [(id,) for id in set(sim_weights.keys())] + sim_projs = [{"application_id": id} for id, in sim_ids] + + # Mock the session and query + mocked_session = mock.MagicMock() + q = mocked_session.query().filter() + q.all.return_value = sim_ids # <-- will just return the input + # Assign the session to the context manager + mocked_db_session().__enter__.return_value = mocked_session + + # Just return the value itself + mocked_obj2dict.side_effect = lambda obj, shallow: obj + assert run.retrieve_similar_projects(None, []) == (sim_projs, sim_weights) + + +def test_earliest_date_good_dates(): + project = {'fy': 2020, + 'project_start': '2022-1-20', + 'project_end': None, + 'award_notice_date': '2021-1-20', + 'budget_end': '2021-1-20', + 'budget_start': None} + assert run.earliest_date(project) == dt(year=2021, month=1, day=20) + + +def test_earliest_date_only_year(): + project = {'fy': 2020, + 'project_start': None, + 'project_end': None, + 'award_notice_date': None, + 'budget_end': None, + 'budget_start': None} + assert run.earliest_date(project) == dt(year=2020, month=1, day=1) + + +def test_earliest_date_no_dates(): + project = {'fy': None, + 'project_start': None, + 'project_end': None, + 'award_notice_date': None, + 'budget_end': None, + 'budget_start': None} + assert run.earliest_date(project) == dt.min + + +@mock.patch(PATH.format('retrieve_similar_projects')) +@mock.patch(PATH.format('group_projects_by_core_id')) +@mock.patch(PATH.format('earliest_date')) +def test_retrieve_similar_proj_ids(mocked_earliest_date, + mocked_group_projects, + mocked_rsp): + projs = [{'application_id': 1, + 'base_core_project_num': None}, + {'application_id': 2, + 'base_core_project_num': 'two'}, + {'application_id': 3, + 'base_core_project_num': 'three'}, + {'application_id': 4, + 'base_core_project_num': None}] + weights = {1: 0.5, 2: 0.9, 3: 0.05, 4: 0.5, + 22: 0.1, 33: 0.7} + groups = [[{'application_id': 22, + 'base_core_project_num': 'two'}, + {'application_id': 2, + 'base_core_project_num': 'two'}], + [{'application_id': 33, + 'base_core_project_num': 'three'}, + {'application_id': 3, + 'base_core_project_num': 'three'}]] + + mocked_rsp.return_value = projs, weights + mocked_group_projects.return_value = groups + # The following will pick 22 and 33 from their groups, because + # they are the largest value (instead of fully implementing + # `earliest_date` in this test) + mocked_earliest_date.side_effect = lambda x: x['application_id'] + # Note that 22 picks up the weight of 2, and 33 keeps its own + # weight since, it is the largest weight in the group that wins + expected = {'near_duplicate_ids': [22], + 'very_similar_ids': [33], + 'fairly_similar_ids': [1, 4]} + assert run.retrieve_similar_proj_ids(None, None) == expected + + +def test_combine(): + list_of_dict = [{'a': 1}, {'a': -1}, {'a': None}] + assert run.combine(max, list_of_dict, 'a') == 1 + assert run.combine(min, list_of_dict, 'a') == -1 + + +def test_first_non_null(): + values = [None, None, 'foo', None, 'bar', None] + assert run.first_non_null(values) == 'foo' + + +def test_join_and_dedupe(): + values = [[None, None, 'foo', None, 'bar', None], + [None, None, 'foo', None, 'baz', None]] + expected = ['foo', 'bar', None, 'baz'] + found = run.join_and_dedupe(values) + assert len(expected) == len(found) + assert set(expected) == set(found) + + +def test_format_us_zipcode(): + assert run.format_us_zipcode('123456789') == '12345-6789' + assert run.format_us_zipcode('23456789') == '02345-6789' + assert run.format_us_zipcode('3456789') == '00345-6789' + assert run.format_us_zipcode('456789') == '00045-6789' + assert run.format_us_zipcode('56789') == '56789' + assert run.format_us_zipcode('6789') == '06789' + assert run.format_us_zipcode('789') == '00789' + assert run.format_us_zipcode('89') == '00089' + assert run.format_us_zipcode('9') == '00009' + + assert run.format_us_zipcode('anything else') == 'anything else' + assert run.format_us_zipcode('?') == '?' + + +@mock.patch(PATH.format('_geocode')) +def test_geocode(mocked__geocode): + assert run.geocode(None, None, None, None) == None + + mocked__geocode.side_effect = [None, 'bar'] + assert run.geocode(None, None, None, postalcode='something') == 'bar' + + mocked__geocode.side_effect = [None, None, 'foo'] + assert run.geocode(None, None, None, postalcode='something') == 'foo' + + mocked__geocode.side_effect = [None, 'baz'] + assert run.geocode(None, None, country='something', + postalcode=None) == 'baz' + + +def test_aggregate_group(): + proj1 = {'application_id': 1, + 'base_core_project_num': 'first', + 'fy': 2001, + 'org_city': 'Kansas City', + 'org_country': 'United States', + 'org_name': 'Big Corp', + 'org_state': None, + 'org_zipcode': '123456789', + 'project_title': 'first title', + 'ic_name': None, + 'phr': None, + 'abstract_text': 'first abstract', + 'total_cost': 100, + # List fields + 'clinicaltrial_ids': [1,2,3], + 'clinicaltrial_titles': ['title 1', 'title 3'], + 'patent_ids': [2,3,4,5], + 'patent_titles': ['patent 1', 'patent 2'], + 'pmids': ['a', 'c', 'd'], + 'project_terms': ['AAA', 'CCC'], + # Date fields + 'project_start': '2022-1-20', + 'project_end': None, + 'award_notice_date': '2021-1-20', + 'budget_end': '2021-1-20', + 'budget_start': None} + + + proj2 = {'application_id': 2, + 'base_core_project_num': 'first', + 'fy': 2002, + 'org_city': 'Kansas City', + 'org_country': 'United States', + 'org_name': 'Big Corp', + 'org_state': None, + 'org_zipcode': '123456789', + 'project_title': 'second title', + 'ic_name': None, + 'phr': 'second phr', + 'abstract_text': 'second abstract', + 'total_cost': 200, + # List fields + 'clinicaltrial_ids': [1,2,4], + 'clinicaltrial_titles': ['title 1', 'title 2'], + 'patent_ids': [1,3,4,5], + 'patent_titles': ['patent 1', 'patent 3'], + 'pmids': ['a', 'c', 'b'], + 'project_terms': ['AAA', 'BBB'], + # Date fields + 'project_start': '1990-1-20', + 'project_end': None, + 'award_notice_date': '2021-1-20', + 'budget_end': '2021-1-20', + 'budget_start': None} + + proj3 = {'application_id': 2, + 'base_core_project_num': 'first', + 'fy': 2002, + 'org_city': 'Kansas City', + 'org_country': 'United States', + 'org_name': 'Big Corp', + 'org_state': 'third state', + 'org_zipcode': '123456789', + 'project_title': 'third title', + 'ic_name': 'ms third', + 'phr': None, + 'abstract_text': None, + 'total_cost': 300, + # List fields + 'clinicaltrial_ids': [1,2,4], + 'clinicaltrial_titles': ['title 0', 'title 2'], + 'patent_ids': [1,3,4,5], + 'patent_titles': ['patent 0', 'patent 3'], + 'pmids': ['a', 'c', 'e'], + 'project_terms': ['AAA', 'DDD'], + # Date fields + 'project_start': '1999-1-20', + 'project_end': '2025-1-20', + 'award_notice_date': '2021-1-20', + 'budget_end': '2021-1-20', + 'budget_start': None} + + + group = [proj1, proj2, proj3] + aggregated_group = {'grouped_ids': [1, 2, 2], + 'grouped_titles': ['first title', 'third title', + 'second title'], + 'application_id': 1, + 'base_core_project_num': 'first', + 'fy': 2001, + 'org_city': 'Kansas City', + 'org_country': 'United States', + 'org_name': 'Big Corp', + 'org_state': 'third state', + 'org_zipcode': '123456789', + 'project_title': 'first title', + 'ic_name': 'ms third', + 'phr': 'second phr', + 'abstract_text': 'first abstract', + 'clinicaltrial_ids': [1, 2, 3, 4], + 'clinicaltrial_titles': ['title 0', 'title 1', + 'title 2', 'title 3'], + 'patent_ids': [1, 2, 3, 4, 5], + 'patent_titles': ['patent 0', 'patent 2', + 'patent 3', 'patent 1'], + 'pmids': ['a', 'b', 'c', 'd', 'e'], + 'project_terms': ['AAA', 'BBB', 'CCC', 'DDD'], + 'project_start': '1990-1-20', + 'project_end': '2025-1-20', + 'total_cost': 600, + 'yearly_funds': [{'year': 1990, + 'project_start': '1990-1-20', + 'project_end': None, + 'total_cost': 200}, + {'year': 1999, + 'project_start': '1999-1-20', + 'project_end': '2025-1-20', + 'total_cost': 300}, + {'year': 2021, + 'project_start': + '2022-1-20', + 'project_end': None, + 'total_cost': 100}]} + + # Check that all elements are the same + result = run.aggregate_group(group) + assert result.keys() == aggregated_group.keys() + for k, v in result.items(): + _v = aggregated_group[k] + if type(v) is list and type(v[0]) is not dict: + assert sorted(v) == sorted(_v) + else: + assert v == _v diff --git a/nesta/core/orms/general_orm.py b/nesta/core/orms/general_orm.py index 85692f0a..bdb05e60 100644 --- a/nesta/core/orms/general_orm.py +++ b/nesta/core/orms/general_orm.py @@ -6,11 +6,12 @@ ''' from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.dialects.mysql import VARCHAR, BIGINT, TEXT, JSON -from sqlalchemy.types import INT, DATE, DATETIME, BOOLEAN +from sqlalchemy.dialects.mysql import BIGINT, JSON +from sqlalchemy.types import INTEGER, DATE, DATETIME, BOOLEAN from sqlalchemy import Column from nesta.core.orms.crunchbase_orm import fixture as cb_fixture -from nesta.core.orms.crunchbase_orm import _TEXT +from nesta.core.orms.types import TEXT, VARCHAR + Base = declarative_base() @@ -30,15 +31,15 @@ class CrunchbaseOrg(Base): homepage_url = cb_fixture('url') last_funding_on = cb_fixture('happened_on') linkedin_url = cb_fixture('url') - long_description = Column(_TEXT) + long_description = Column(TEXT) name = cb_fixture('name') - num_exits = Column(INT) - num_funding_rounds = Column(INT) + num_exits = Column(INTEGER) + num_funding_rounds = Column(INTEGER) parent_id = cb_fixture('id_idx') primary_role = Column(VARCHAR(50)) region = cb_fixture('region') roles = cb_fixture('roles') - short_description = Column(VARCHAR(200, collation='utf8mb4_unicode_ci')) + short_description = Column(VARCHAR(200)) state_code = cb_fixture('state_code') status = Column(VARCHAR(9)) total_funding_usd = cb_fixture('monetary_amount') @@ -59,3 +60,50 @@ class CrunchbaseOrg(Base): is_eu = Column(BOOLEAN, nullable=False) state_name = cb_fixture('name') updated_at = cb_fixture('happened_on') + + +class NihProject(Base): + __tablename__ = 'curated_nih_projects' + + # Fields ported from nih_orm.Projects + application_id = Column(INTEGER, primary_key=True, autoincrement=False) + base_core_project_num = Column(VARCHAR(50), index=True) + fy = Column(INTEGER, index=True) + org_city = Column(VARCHAR(50), index=True) + org_country = Column(VARCHAR(50), index=True) + org_name = Column(VARCHAR(100), index=True) + org_state = Column(VARCHAR(2), index=True) + org_zipcode = Column(VARCHAR(10)) + project_title = Column(TEXT) + phr = Column(TEXT) + ic_name = Column(VARCHAR(100), index=True) + + # Fields from other ORMs in nih_orm + abstract_text = Column(TEXT) + + # New, joined or updated fields + clinicaltrial_ids = Column(JSON) + clinicaltrial_titles = Column(JSON) + currency = Column(VARCHAR(3), default="USD", index=True) + fairly_similar_ids = Column(JSON) + near_duplicate_ids = Column(JSON) + patent_ids = Column(JSON) + patent_titles = Column(JSON) + pmids = Column(JSON) + project_end = Column(DATETIME, index=True) + project_start = Column(DATETIME, index=True) + project_terms = Column(JSON) + grouped_ids = Column(JSON) + grouped_titles = Column(JSON) + total_cost = Column(BIGINT) + very_similar_ids = Column(JSON) + yearly_funds = Column(JSON) + + # Geographic fields + continent_iso2 = Column(VARCHAR(2), index=True) + continent_name = Column(TEXT) + coordinates = Column(JSON) + country_mentions = Column(JSON) + is_eu = Column(BOOLEAN, index=True, nullable=False) + iso2 = Column(VARCHAR(2), index=True) + state_name = Column(TEXT) diff --git a/nesta/core/orms/nih_orm.py b/nesta/core/orms/nih_orm.py index 9f410b9a..93e65868 100644 --- a/nesta/core/orms/nih_orm.py +++ b/nesta/core/orms/nih_orm.py @@ -6,17 +6,27 @@ ''' from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.dialects.mysql import VARCHAR as _VARCHAR -from sqlalchemy.dialects.mysql import TEXT as _TEXT from sqlalchemy.types import INTEGER, JSON, DATETIME, FLOAT from sqlalchemy import Column, Table, ForeignKey -from functools import partial +from sqlalchemy.orm import relationship +from sqlalchemy.ext.associationproxy import association_proxy +from nesta.core.orms.types import VARCHAR, TEXT + + +def getattr_(entity, attribute): + """Either unpack the attribute from every item in the entity + if the entity is a list, otherwise just return the attribute + from the entity. Returns None if the entity is either None + or empty.""" + if entity in (None, []): + return None + if isinstance(entity, list): + return [getattr(item, attribute) for item in entity] + return getattr(entity, attribute) -Base = declarative_base() -TEXT = _TEXT(collation='utf8mb4_unicode_ci') -VARCHAR = partial(_VARCHAR, collation='utf8mb4_unicode_ci') +Base = declarative_base() class Projects(Base): __tablename__ = 'nih_projects' @@ -65,15 +75,61 @@ class Projects(Base): direct_cost_amt = Column(INTEGER) indirect_cost_amt = Column(INTEGER) total_cost = Column(INTEGER) - subproject_id = Column(INTEGER) + subproject_id = Column(INTEGER, index=True) total_cost_sub_project = Column(INTEGER) nih_spending_cats = Column(JSON) + # Pseudo-FKs + abstract = relationship("Abstracts", uselist=False, + foreign_keys=[application_id], + primaryjoin=("Projects.application_id==" + "Abstracts.application_id")) + publications = relationship("LinkTables", uselist=True, + foreign_keys=[core_project_num], + primaryjoin=("Projects.core_project_num==" + "LinkTables.project_number")) + patents = relationship("Patents", uselist=True, + foreign_keys=[core_project_num], + primaryjoin=("Projects.core_project_num==" + "Patents.project_id")) + clinicalstudies = relationship("ClinicalStudies", uselist=True, + foreign_keys=[core_project_num], + primaryjoin=("Projects.core_project_num==" + "ClinicalStudies.core_project_number")) + + # Pseudo-fields (populated from relationships) + @property + def abstract_text(self): + return getattr_(self.abstract, "abstract_text") + + @property + def patent_ids(self): + return getattr_(self.patents, "patent_id") + + @property + def patent_titles(self): + return getattr_(self.patents, "patent_title") + + @property + def pmids(self): + return getattr_(self.publications, "pmid") + + @property + def clinicaltrial_ids(self): + return getattr_(self.clinicalstudies, "clinicaltrials_gov_id") + + @property + def clinicaltrial_titles(self): + return getattr_(self.clinicalstudies, "study") + + + class Abstracts(Base): __tablename__ = 'nih_abstracts' application_id = Column(INTEGER, primary_key=True, autoincrement=False) abstract_text = Column(TEXT) + class Publications(Base): diff --git a/nesta/core/orms/orm_utils.py b/nesta/core/orms/orm_utils.py index 21a2a386..d23d16df 100644 --- a/nesta/core/orms/orm_utils.py +++ b/nesta/core/orms/orm_utils.py @@ -50,13 +50,15 @@ def orm_column_names(_class): return columns -def object_to_dict(obj, shallow=False, found=None): +def object_to_dict(obj, shallow=False, properties=True, found=None): """Converts a nested SqlAlchemy object to a fully unpacked json object. Args: obj: A SqlAlchemy object (i.e. single 'row' of data) shallow (bool): Fully unpack nested objs via relationships. + properties (bool): Also retrieve all @property values as if they + were columns in the row object. found: For internal recursion, do not change the default. Returns: _obj (dict): An unpacked json-like dict object. @@ -64,9 +66,17 @@ def object_to_dict(obj, shallow=False, found=None): if found is None: # First time found = set() # Set up the mapper and retrieve shallow values - mapper = class_mapper(obj.__class__) + _class = obj.__class__ + mapper = class_mapper(_class) columns = [column.key for column in mapper.columns] out = dict(map(lambda c: _get_key_value(obj, c), columns)) + if properties: + # Work out if there are any properties + property_names = [name for name in dir(_class) + if type(getattr(_class, name)) is property] + # Then pull them out if they exist + for name in property_names: + out[name] = getattr(obj, name) # Shallow means ignore relationships relationships = {} if shallow else mapper.relationships for name, relation in relationships.items(): @@ -354,14 +364,14 @@ def cast_as_sql_python_type(field, data): def get_session(db_env, section, database, Base): """Return a database Session instance for the given credentials, and also setup the table structure for the intended Base ORM. - + Args: - db_env: See :obj:`get_mysql_engine` - section: See :obj:`get_mysql_engine` - database: See :obj:`get_mysql_engine` + db_env: See :obj:`get_mysql_engine` + section: See :obj:`get_mysql_engine` + database: See :obj:`get_mysql_engine` Base (:obj:`sqlalchemy.Base`): The Base ORM for this data. Returns: - session ((:obj:`sqlalchemy.Session`): A database Session instance + session ((:obj:`sqlalchemy.Session`): A database Session instance for the given credentials. """ engine = get_mysql_engine(db_env, section, database) @@ -578,7 +588,7 @@ def merge_duplicates(db_env, section, database, # Now merge the fields by taking the first non-null value objs = [] for pk, rows in pk_row_lookup.items(): - field_names = list(rows[0].keys()) + field_names = list(rows[0].keys()) merged_row = {} for col in field_names: value = None @@ -635,7 +645,7 @@ def insert_data(db_env, section, database, Base, # Drop existing objs if merging with db_session(engine) as session: - try_until_allowed(Base.metadata.create_all, + try_until_allowed(Base.metadata.create_all, session.get_bind()) if merge_non_null: session.execute(existing_objs) diff --git a/nesta/core/orms/tests/test_nih.py b/nesta/core/orms/tests/test_nih.py index 0da014ab..c1dcc1af 100644 --- a/nesta/core/orms/tests/test_nih.py +++ b/nesta/core/orms/tests/test_nih.py @@ -6,9 +6,27 @@ from nesta.core.orms.nih_orm import Publications from nesta.core.orms.nih_orm import Patents #from nesta.core.orms.nih_orm import LinkTables +from nesta.core.orms.nih_orm import getattr_ from nesta.core.orms.orm_utils import get_mysql_engine from sqlalchemy.exc import IntegrityError +def test__getattr(): + class AttributeDummy: + def __init__(self, a): + self.a = a + + class Dummy: + list_of_attrs = [AttributeDummy(1), AttributeDummy(7), + AttributeDummy(5)] + empty_list = [] + null_attr = None + single_attr = AttributeDummy('abc') + + assert getattr_(Dummy.list_of_attrs, 'a') == [1, 7, 5] + assert getattr_(Dummy.empty_list, 'a') == None + assert getattr_(Dummy.null_attr, 'a') == None + assert getattr_(Dummy.single_attr, 'a') == 'abc' + class TestMeetup(unittest.TestCase): '''Currently just a placeholder test to check that the schema compiles''' diff --git a/nesta/core/orms/types.py b/nesta/core/orms/types.py new file mode 100644 index 00000000..97da3757 --- /dev/null +++ b/nesta/core/orms/types.py @@ -0,0 +1,6 @@ +from sqlalchemy.dialects.mysql import VARCHAR as _VARCHAR +from sqlalchemy.dialects.mysql import TEXT as _TEXT +from functools import partial + +TEXT = _TEXT(collation='utf8mb4_unicode_ci') +VARCHAR = partial(_VARCHAR, collation='utf8mb4_unicode_ci') diff --git a/nesta/core/routines/datasets/nih/README b/nesta/core/routines/datasets/nih/README new file mode 100644 index 00000000..a33aa1b5 --- /dev/null +++ b/nesta/core/routines/datasets/nih/README @@ -0,0 +1,12 @@ +For future development, the order of tasks here is: + +1. nih_collect_task +2a. nih_drop_abstracts + 3a. nih_vectors_task + 4a. nih_dedupe_task +2b. nih_impute_base_id + +followed by (in the "general" project routine) + +5. general_curate_task + diff --git a/nesta/core/routines/datasets/nih/nih_impute_base_id.py b/nesta/core/routines/datasets/nih/nih_impute_base_id.py index 58ca00f8..d9326fd4 100644 --- a/nesta/core/routines/datasets/nih/nih_impute_base_id.py +++ b/nesta/core/routines/datasets/nih/nih_impute_base_id.py @@ -14,14 +14,12 @@ Any `core_project_num` failing this regex are ignored. ''' - import logging import luigi from datetime import datetime as dt from multiprocessing.dummy import Pool as ThreadPool from itertools import chain - from nesta.core.luigihacks.mysqldb import make_mysql_target from nesta.packages.nih.impute_base_id import retrieve_id_ranges from nesta.packages.nih.impute_base_id import impute_base_id_thread @@ -44,10 +42,11 @@ def output(self): def run(self): database = 'dev' if self.test else 'production' id_ranges = retrieve_id_ranges(database) + _id_ranges = map(lambda x: chain(x, [database]), id_ranges) + # Threading required because it takes 2-3 days to impute # all values on a single thread, or a few hours on 15 threads pool = ThreadPool(15) - _id_ranges = map(lambda x: chain(x, [database]), id_ranges) pool.starmap(impute_base_id_thread, _id_ranges) pool.close() pool.join() diff --git a/nesta/core/routines/datasets/nih/nih_root_task.py b/nesta/core/routines/datasets/nih/nih_root_task.py deleted file mode 100644 index f0a45014..00000000 --- a/nesta/core/routines/datasets/nih/nih_root_task.py +++ /dev/null @@ -1,65 +0,0 @@ -# TODO: update anything here with latest method (e.g. mysqltarget) -# TODO: set default batchable and runtime params where possible -# TODO: update orm, where required, incl lots of indexes -# TODO: update batchable as required -# TODO: write decent tests to check good dq -''' -Root Task (HealthMosaic) -======================== - -Luigi routine to collect NIH World RePORTER data -via the World ExPORTER data dump. The routine -transfers the data into the MySQL database before -processing and indexing the data to ElasticSearch. -''' - -import luigi -import datetime -import logging -from nesta.core.luigihacks.misctools import find_filepath_from_pathstub as f3p -import os - -from nesta.core.routines.nih.nih_data.nih_dedupe_task import DedupeTask - - -class RootTask(luigi.WrapperTask): - '''A dummy root task, which collects the database configurations - and executes the central task. - - Args: - date (datetime): Date used to label the outputs - db_config_path (str): Path to the MySQL database configuration - production (bool): Flag indicating whether running in testing - mode (False, default), or production mode (True). - ''' - date = luigi.DateParameter(default=datetime.date.today()) - db_config_path = luigi.Parameter(default="mysqldb.config") - production = luigi.BoolParameter(default=False) - drop_and_recreate = luigi.BoolParameter(default=False) - - def requires(self): - '''Collects the database configurations - and executes the central task.''' - _routine_id = "{}-{}".format(self.date, self.production) - - logging.getLogger().setLevel(logging.INFO) - yield DedupeTask(date=self.date, - drop_and_recreate=self.drop_and_recreate, - routine_id=_routine_id, - db_config_path=self.db_config_path, - process_batch_size=5000, - intermediate_bucket='nesta-production-intermediate', - test=(not self.production), - batchable=f3p("batchables/nih/" - "nih_dedupe"), - env_files=[f3p("nesta/"), - f3p("config/mysqldb.config"), - f3p("config/elasticsearch.yaml"), - f3p("nih.json")], - job_def="py37_amzn2", - job_name="NiHDedupeTask-%s" % _routine_id, - job_queue="HighPriority", - region_name="eu-west-2", - poll_time=10, - memory=1024, - max_live_jobs=20) diff --git a/nesta/core/routines/projects/general/curate_config.yaml b/nesta/core/routines/projects/general/curate_config.yaml new file mode 100644 index 00000000..90005ed7 --- /dev/null +++ b/nesta/core/routines/projects/general/curate_config.yaml @@ -0,0 +1,23 @@ +# Crunchbase +- dataset: companies + orm: crunchbase_orm + table_name: crunchbase_organizations + id_field: id + +# NiH (iterate over base_core_project_num) +- dataset: nih + orm: nih_orm + table_name: nih_projects + id_field: base_core_project_num + filter: "nih_projects.base_core_project_num IS NOT NULL" + batchable_kwargs: + using_core_ids: true + +# NiH (iterate over PK as base_core_project_num is null) +- dataset: nih + orm: nih_orm + table_name: nih_projects + id_field: application_id + filter: "nih_projects.base_core_project_num IS NULL" + batchable_kwargs: + using_core_ids: false diff --git a/nesta/core/routines/projects/general/general_curate.py b/nesta/core/routines/projects/general/general_curate.py index 32ad4cd5..5354d754 100644 --- a/nesta/core/routines/projects/general/general_curate.py +++ b/nesta/core/routines/projects/general/general_curate.py @@ -9,16 +9,64 @@ from nesta.core.luigihacks.sql2batchtask import Sql2BatchTask from nesta.core.luigihacks.misctools import find_filepath_from_pathstub as f3p from nesta.core.orms.crunchbase_orm import Organization as CrunchbaseOrg +from nesta.core.orms.nih_orm import Projects as NihProject +from nesta.core.orms.orm_utils import get_base_from_orm_name +from nesta.core.orms.orm_utils import get_class_by_tablename from nesta.core.luigihacks.misctools import get_config -from nesta.core.luigihacks.mysqldb import MySqlTarget +from nesta.core.luigihacks.mysqldb import make_mysql_target + +from sqlalchemy.sql import text as sql_text import luigi from datetime import datetime as dt import os +import pathlib +import yaml +from functools import lru_cache + S3_BUCKET='nesta-production-intermediate' ENV_FILES = ['mysqldb.config', 'nesta'] + +@lru_cache() +def read_config(): + """Read raw data from the config file""" + this_path = pathlib.Path(__file__).parent.absolute() + with open(this_path / 'curate_config.yaml') as f: + return yaml.safe_load(f) + + +def get_datasets(): + """Return unique values of the 'dataset' from the config""" + return list(set(item['dataset'] for item in read_config())) + + +def parse_config(): + """Yield this task's parameter fields from the config""" + config = read_config() + for item in config: + # Required fields + dataset = item['dataset'] + orm = item['orm'] + table_name = item['table_name'] + _id_field = item['id_field'] + + # Optional fields + filter = item.get('filter', None) + if filter is not None: + filter = sql_text(filter) + extra_kwargs = item.get('batchable_kwargs', {}) + + # Extract the actual ORM ID field + Base = get_base_from_orm_name(orm) + _class = get_class_by_tablename(Base, table_name) + id_field = getattr(_class, _id_field) + + # Yield this curation task's parameters + yield dataset, id_field, filter, extra_kwargs + + def kwarg_maker(dataset, routine_id): """kwarg factory for Sql2BatchTask tasks""" return dict(routine_id=f'{routine_id}_{dataset}', @@ -27,18 +75,14 @@ def kwarg_maker(dataset, routine_id): class CurateTask(luigi.Task): - process_batch_size = luigi.IntParameter(default=5000) + process_batch_size = luigi.IntParameter(default=1000) production = luigi.BoolParameter(default=False) date = luigi.DateParameter(default=dt.now()) + dataset = luigi.ChoiceParameter(default='all', + choices=['all'] + get_datasets()) def output(self): - test = not self.production - routine_id = f'General-Curate-Root-{self.date}-{test}' - db_config_path = os.environ['MYSQLDB'] - db_config = get_config(db_config_path, "mysqldb") - db_config["database"] = 'dev' if test else 'production' - db_config["table"] = f"{routine_id} " # Not a real table - return MySqlTarget(update_id=routine_id, **db_config) + return make_mysql_target(self) def requires(self): set_log_level(True) @@ -56,9 +100,13 @@ def requires(self): memory=2048, intermediate_bucket=S3_BUCKET) - params = (('companies', CrunchbaseOrg.id),) - for dataset, id_field in params: + # Iterate over each task specified in the config + for dataset, id_field, filter, kwargs in parse_config(): + if self.dataset != 'all' and dataset != self.dataset: + continue yield Sql2BatchTask(id_field=id_field, + filter=filter, + kwargs=kwargs, **kwarg_maker(dataset, routine_id), **default_kwargs) diff --git a/nesta/packages/geo_utils/country_iso_code.py b/nesta/packages/geo_utils/country_iso_code.py index d88b7e0a..e7dc4952 100644 --- a/nesta/packages/geo_utils/country_iso_code.py +++ b/nesta/packages/geo_utils/country_iso_code.py @@ -6,10 +6,24 @@ ''' import pycountry +from functools import lru_cache from nesta.packages.geo_utils.alpha2_to_continent import alpha2_to_continent_mapping +def _country_iso_code(country): + for name_type in ['name', 'common_name', 'official_name']: + query = {name_type: country} + try: + result = pycountry.countries.get(**query) + if result is not None: + return result + except KeyError: + pass + raise KeyError(f"{country} not found") + + +@lru_cache() def country_iso_code(country): ''' Look up the ISO 3166 codes for countries. @@ -22,17 +36,14 @@ def country_iso_code(country): Returns: Country object from the pycountry module ''' - country = str(country).title() - for name_type in ['name', 'common_name', 'official_name']: - query = {name_type: country} - try: - result = pycountry.countries.get(**query) - if result is not None: - return result - except KeyError: - pass - - raise KeyError(f"{country} not found") + country = str(country) + try: + # Note this will raise KeyError if fails + result = _country_iso_code(country) + except KeyError: + # Note this will raise KeyError if fails + result = _country_iso_code(country.title()) + return result def country_iso_code_dataframe(df, country='country'): @@ -78,7 +89,7 @@ def country_iso_code_to_name(code, iso2=False): Returns: str: name of the country or None if not valid """ - try: + try: if iso2: return pycountry.countries.get(alpha_2=code).name else: diff --git a/nesta/packages/geo_utils/geocode.py b/nesta/packages/geo_utils/geocode.py index ef9d3d63..7a143171 100644 --- a/nesta/packages/geo_utils/geocode.py +++ b/nesta/packages/geo_utils/geocode.py @@ -9,10 +9,12 @@ import pandas as pd import requests from retrying import retry +from functools import lru_cache from nesta.packages.decorators.ratelimit import ratelimit +@lru_cache() def geocode(**request_kwargs): ''' Geocoder using the Open Street Map Nominatim API. diff --git a/nesta/packages/geo_utils/lookup.py b/nesta/packages/geo_utils/lookup.py index 2f10fb71..616e4659 100644 --- a/nesta/packages/geo_utils/lookup.py +++ b/nesta/packages/geo_utils/lookup.py @@ -39,12 +39,37 @@ def get_continent_lookup(): return continent_lookup +@lru_cache() +def get_country_continent_lookup(): + """ + Retrieves continent lookups for all world countries, + by ISO2 code, from a static open URL. + + Returns: + data (dict): Values are country_name-continent pairs. + """ + r = requests.get(COUNTRY_CODES_URL) + r.raise_for_status() + with StringIO(r.text) as csv: + df = pd.read_csv(csv, usecols=['ISO3166-1-Alpha-2', + 'Continent'], + keep_default_na=False) + data = {row['ISO3166-1-Alpha-2']: row['Continent'] + for _, row in df.iterrows() + if not pd.isnull(row['ISO3166-1-Alpha-2'])} + # Kosovo, null + data['XK'] = 'EU' + data[None] = None + return data + + + @lru_cache() def get_country_region_lookup(): """ Retrieves subregions (around 18 in total) lookups for all world countries, by ISO2 code, - form a static open URL. + from a static open URL. Returns: data (dict): Values are country_name-region_name pairs. diff --git a/nesta/packages/geo_utils/tests/test_geotools.py b/nesta/packages/geo_utils/tests/test_geotools.py index ab262f82..79031b38 100644 --- a/nesta/packages/geo_utils/tests/test_geotools.py +++ b/nesta/packages/geo_utils/tests/test_geotools.py @@ -13,6 +13,7 @@ from nesta.packages.geo_utils.country_iso_code import country_iso_code_to_name from nesta.packages.geo_utils.lookup import get_continent_lookup from nesta.packages.geo_utils.lookup import get_country_region_lookup +from nesta.packages.geo_utils.lookup import get_country_continent_lookup REQUESTS = 'nesta.packages.geo_utils.geocode.requests.get' PYCOUNTRY = 'nesta.packages.geo_utils.country_iso_code.pycountry.countries.get' @@ -135,7 +136,7 @@ def test_underlying_geocoding_function_called_with_city_country(self, mocked_geo assert mocked_geocode.mock_calls == expected_calls @mock.patch(_GEOCODE) - def test_underlying_geocoding_function_called_with_query_fallback(self, mocked_geocode, + def test_underlying_geocoding_function_called_with_query_fallback(self, mocked_geocode, test_dataframe): mocked_geocode.side_effect = [None, None, None, 'dog', 'cat', 'squirrel'] geocoded_dataframe = geocode_dataframe(test_dataframe) @@ -319,6 +320,7 @@ def test_lookup_via_name(self, mocked_pycountry): assert country_iso_code('United Kingdom') == 'country_object' assert mocked_pycountry.mock_calls == expected_calls assert mocked_pycountry.call_count == 1 + country_iso_code.cache_clear() @mock.patch(PYCOUNTRY) def test_lookup_via_common_name(self, mocked_pycountry): @@ -330,6 +332,7 @@ def test_lookup_via_common_name(self, mocked_pycountry): assert country_iso_code('United Kingdom') == 'country_object' assert mocked_pycountry.mock_calls == expected_calls assert mocked_pycountry.call_count == 2 + country_iso_code.cache_clear() @mock.patch(PYCOUNTRY) def test_lookup_via_official_name(self, mocked_pycountry): @@ -342,25 +345,35 @@ def test_lookup_via_official_name(self, mocked_pycountry): assert country_iso_code('United Kingdom') == 'country_object' assert mocked_pycountry.mock_calls == expected_calls assert mocked_pycountry.call_count == 3 + country_iso_code.cache_clear() @mock.patch(PYCOUNTRY) def test_invalid_lookup_raises_keyerror(self, mocked_pycountry): - mocked_pycountry.side_effect = [KeyError(), KeyError(), KeyError()] + mocked_pycountry.side_effect = [KeyError(), KeyError(), KeyError()]*2 with pytest.raises(KeyError) as e: country_iso_code('Fake Country') assert 'Fake Country not found' in str(e.value) + country_iso_code.cache_clear() @mock.patch(PYCOUNTRY) def test_title_case_is_applied(self, mocked_pycountry): - expected_calls = [mock.call(name='United Kingdom'), - mock.call(name='United Kingdom'), - mock.call(name='United Kingdom')] - - country_iso_code('united kingdom') - country_iso_code('UNITED KINGDOM') - country_iso_code('United kingdom') + expected_calls = [] + names = ['united kingdom', 'UNITED KINGDOM', + 'United kingdom'] + mocked_pycountry.side_effect = [KeyError(), KeyError(), KeyError(), 'blah'] * len(names) + for name in names: + country_iso_code(name) # Find the iso codes + raw_call = mock.call(name=name) + common_call = mock.call(common_name=name) + official_call = mock.call(official_name=name) + title_call = mock.call(name='United Kingdom') + expected_calls.append(raw_call) # The initial call + expected_calls.append(common_call) # Tries common name call + expected_calls.append(official_call) # Tries official name + expected_calls.append(title_call) # The title case call assert mocked_pycountry.mock_calls == expected_calls + country_iso_code.cache_clear() class TestCountryIsoCodeDataframe(): @@ -460,3 +473,19 @@ def test_get_country_region_lookup(): assert all(len(v) == 2 for v in countries.values()) all_regions = {v[1] for v in countries.values()} assert len(all_regions) == 18 + + +def test_country_continent_lookup(): + lookup = get_country_continent_lookup() + non_nulls = {k: v for k, v in lookup.items() + if k is not None and k != ''} + # All iso2, so length == 2 + assert all(len(k) == 2 for k in non_nulls.items()) + assert all(len(v) == 2 for v in non_nulls.values()) + # Either strings or Nones + country_types = set(type(v) for v in lookup.values()) + assert country_types == {str, type(None)} + # Right ball-park of country and continent numbers + assert len(non_nulls) > 100 # num countries + assert len(non_nulls) < 1000 # num countries + assert len(set(non_nulls.values())) == 7 # num continents diff --git a/nesta/packages/nih/impute_base_id.py b/nesta/packages/nih/impute_base_id.py index dd46de8d..05ab4bce 100644 --- a/nesta/packages/nih/impute_base_id.py +++ b/nesta/packages/nih/impute_base_id.py @@ -88,7 +88,6 @@ def retrieve_id_ranges(database, chunksize=1000): def impute_base_id_thread(from_id, to_id, database): """Apply "impute_base_id" over this chunk of IDs""" - #from_id, to_id, database = args[0] # Unpack thread args engine = get_mysql_engine("MYSQLDB", "mysqldb", database) with db_session(engine) as session: impute_base_id(session, from_id, to_id) diff --git a/nesta/packages/nih/tests/test_impute_base_id.py b/nesta/packages/nih/tests/test_impute_base_id.py index b0513a91..145ce06d 100644 --- a/nesta/packages/nih/tests/test_impute_base_id.py +++ b/nesta/packages/nih/tests/test_impute_base_id.py @@ -32,12 +32,12 @@ def test_impute_base_id(): q.all.return_value = projects # Check that the base_project_num has not been imputed yet - assert all(p.base_core_project_num is None + assert all(p.base_core_project_num is None for p in projects) # Impute the ids impute_base_id(session, from_id=None, to_id=None) - + # Check that the base_project_num has been imputed imputed_values = [p.base_core_project_num for p in projects] expect_values = ["helloworld", "foobar", # <-- Regex passes @@ -51,13 +51,13 @@ def test_impute_base_id(): def test_retrieve_id_ranges(mocked_session_context, mocked_engine): session = mocked_session_context().__enter__() q = session.query().order_by() - q.all.return_value = [(0,), (1,), ("1",), (2,), (3,), + q.all.return_value = [(0,), (1,), ("1",), (2,), (3,), (5,), (8,), (13,), (21,)] id_ranges = retrieve_id_ranges(database="db_name", chunksize=3) assert id_ranges == [(0, 2), # 0 <= x <= 2 (2, 8), # 2 <= x <= 8 (8, 21)] # 8 <= x <= 21 - + @mock.patch(PATH.format("get_mysql_engine")) @mock.patch(PATH.format("db_session"))