diff --git a/nesta/core/orms/arxiv_orm.py b/nesta/core/orms/arxiv_orm.py index 43c54b5e..7d36a954 100644 --- a/nesta/core/orms/arxiv_orm.py +++ b/nesta/core/orms/arxiv_orm.py @@ -1,9 +1,9 @@ -''' +""" Arxiv ===== -''' +""" from sqlalchemy import Table, Column, ForeignKey -from sqlalchemy.dialects.mysql import VARCHAR, TEXT +from sqlalchemy.dialects.mysql import VARCHAR, TEXT, MEDIUMTEXT from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import relationship from sqlalchemy.types import JSON, DATE, INTEGER, BIGINT, FLOAT, BOOLEAN @@ -20,34 +20,35 @@ """Association table for Arxiv articles and their categories.""" -article_categories = Table('arxiv_article_categories', Base.metadata, - Column('article_id', - VARCHAR(40), - ForeignKey('arxiv_articles.id'), - primary_key=True), - Column('category_id', - VARCHAR(40), - ForeignKey('arxiv_categories.id'), - primary_key=True)) +article_categories = Table( + "arxiv_article_categories", + Base.metadata, + Column( + "article_id", VARCHAR(40), ForeignKey("arxiv_articles.id"), primary_key=True + ), + Column( + "category_id", VARCHAR(40), ForeignKey("arxiv_categories.id"), primary_key=True + ), +) """Association table to Microsoft Academic Graph fields of study.""" -article_fields_of_study = Table('arxiv_article_fields_of_study', Base.metadata, - Column('article_id', - VARCHAR(40), - ForeignKey('arxiv_articles.id'), - primary_key=True), - Column('fos_id', - BIGINT, - ForeignKey(FieldOfStudy.id), - primary_key=True)) +article_fields_of_study = Table( + "arxiv_article_fields_of_study", + Base.metadata, + Column( + "article_id", VARCHAR(40), ForeignKey("arxiv_articles.id"), primary_key=True + ), + Column("fos_id", BIGINT, ForeignKey(FieldOfStudy.id), primary_key=True), +) class ArticleInstitute(Base): """Association table to GRID institutes.""" - __tablename__ = 'arxiv_article_institutes' - article_id = Column(VARCHAR(40), ForeignKey('arxiv_articles.id'), primary_key=True) + __tablename__ = "arxiv_article_institutes" + + article_id = Column(VARCHAR(40), ForeignKey("arxiv_articles.id"), primary_key=True) institute_id = Column(VARCHAR(20), ForeignKey(Institute.id), primary_key=True) is_multinational = Column(BOOLEAN) matching_score = Column(FLOAT) @@ -56,16 +57,17 @@ class ArticleInstitute(Base): class Article(Base): """Arxiv articles and metadata.""" - __tablename__ = 'arxiv_articles' + + __tablename__ = "arxiv_articles" id = Column(VARCHAR(40), primary_key=True, autoincrement=False) datestamp = Column(DATE) created = Column(DATE) updated = Column(DATE) - title = Column(TEXT) - journal_ref = Column(TEXT) + title = Column(TEXT(collation="utf8mb4_unicode_ci")) + journal_ref = Column(TEXT(collation="utf8mb4_unicode_ci")) doi = Column(VARCHAR(200)) - abstract = Column(TEXT) + abstract = Column(MEDIUMTEXT(collation="utf8mb4_unicode_ci")) authors = Column(JSON) mag_authors = Column(JSON) mag_id = Column(BIGINT) @@ -74,19 +76,17 @@ class Article(Base): citation_count_updated = Column(DATE) msc_class = Column(VARCHAR(200)) institute_match_attempted = Column(BOOLEAN, default=False) - categories = relationship('Category', - secondary=article_categories) - fields_of_study = relationship(FieldOfStudy, - secondary=article_fields_of_study) - institutes = relationship('ArticleInstitute') - corex_topics = relationship('CorExTopic', - secondary='arxiv_article_corex_topics') + categories = relationship("Category", secondary=article_categories) + fields_of_study = relationship(FieldOfStudy, secondary=article_fields_of_study) + institutes = relationship("ArticleInstitute") + corex_topics = relationship("CorExTopic", secondary="arxiv_article_corex_topics") article_source = Column(VARCHAR(7), index=True, default=None) class Category(Base): """Lookup table for Arxiv category descriptions.""" - __tablename__ = 'arxiv_categories' + + __tablename__ = "arxiv_categories" id = Column(VARCHAR(40), primary_key=True) description = Column(VARCHAR(100)) @@ -107,39 +107,37 @@ class Category(Base): # id = Column(VARCHAR(40), ForeignKey('arxiv_article_msc.msc_id'), primary_key=True) # description = Column(VARCHAR(100)) + class CorExTopic(Base): """CorEx topics derived from arXiv data""" - __tablename__ = 'arxiv_corex_topics' + + __tablename__ = "arxiv_corex_topics" id = Column(INTEGER, primary_key=True, autoincrement=False) terms = Column(JSON) class ArticleTopic(Base): """Association table to CorEx topics.""" - __tablename__ = 'arxiv_article_corex_topics' - article_id = Column(VARCHAR(40), - ForeignKey(Article.id), - primary_key=True) - topic_id = Column(INTEGER, - ForeignKey(CorExTopic.id), - primary_key=True, - autoincrement=False) + + __tablename__ = "arxiv_article_corex_topics" + article_id = Column(VARCHAR(40), ForeignKey(Article.id), primary_key=True) + topic_id = Column( + INTEGER, ForeignKey(CorExTopic.id), primary_key=True, autoincrement=False + ) topic_weight = Column(FLOAT) class ArticleVector(Base): """Document vectors for articles.""" - __tablename__ = 'arxiv_vector' - article_id = Column(VARCHAR(40), - ForeignKey(Article.id), - primary_key=True) + + __tablename__ = "arxiv_vector" + article_id = Column(VARCHAR(40), ForeignKey(Article.id), primary_key=True) vector = Column(JSON) class ArticleCluster(Base): """Document clusters for articles.""" - __tablename__ = 'arxiv_cluster' - article_id = Column(VARCHAR(40), - ForeignKey(Article.id), - primary_key=True) + + __tablename__ = "arxiv_cluster" + article_id = Column(VARCHAR(40), ForeignKey(Article.id), primary_key=True) clusters = Column(JSON) diff --git a/nesta/core/routines/arxiv/arxiv_grid_task.py b/nesta/core/routines/arxiv/arxiv_grid_task.py index be7cf7e8..b5d5e0d7 100644 --- a/nesta/core/routines/arxiv/arxiv_grid_task.py +++ b/nesta/core/routines/arxiv/arxiv_grid_task.py @@ -14,7 +14,11 @@ from datetime import datetime from nesta.core.routines.arxiv.arxiv_mag_sparql_task import MagSparqlTask -from nesta.packages.arxiv.collect_arxiv import add_article_institutes, create_article_institute_links, update_existing_articles +from nesta.packages.arxiv.collect_arxiv import ( + add_article_institutes, + create_article_institute_links, + update_existing_articles, +) from nesta.packages.grid.grid import ComboFuzzer, grid_name_lookup from nesta.packages.misc_utils.batches import BatchWriter from nesta.core.orms.arxiv_orm import Base, Article @@ -39,6 +43,7 @@ class GridTask(luigi.Task): retrieved. Must be in YYYY-MM-DD format (not used in this task but passed down to others) """ + date = luigi.DateParameter() _routine_id = luigi.Parameter() test = luigi.BoolParameter(default=True) @@ -47,41 +52,46 @@ class GridTask(luigi.Task): mag_config_path = luigi.Parameter() insert_batch_size = luigi.IntParameter(default=500) articles_from_date = luigi.Parameter() + article_source = luigi.Parameter(default=None) def output(self): - '''Points to the output database engine''' + """Points to the output database engine""" db_config = misctools.get_config(self.db_config_path, "mysqldb") - db_config["database"] = 'dev' if self.test else 'production' + db_config["database"] = "dev" if self.test else "production" db_config["table"] = "arXlive " # Note, not a real table update_id = "ArxivGrid_{}".format(self.date) return MySqlTarget(update_id=update_id, **db_config) def requires(self): - yield MagSparqlTask(date=self.date, - _routine_id=self._routine_id, - db_config_path=self.db_config_path, - db_config_env=self.db_config_env, - mag_config_path=self.mag_config_path, - test=self.test, - articles_from_date=self.articles_from_date, - insert_batch_size=self.insert_batch_size) + yield MagSparqlTask( + date=self.date, + _routine_id=self._routine_id, + db_config_path=self.db_config_path, + db_config_env=self.db_config_env, + mag_config_path=self.mag_config_path, + test=self.test, + articles_from_date=self.articles_from_date, + insert_batch_size=self.insert_batch_size, + article_source=self.article_source, + ) def run(self): # database setup - database = 'dev' if self.test else 'production' + database = "dev" if self.test else "production" logging.info(f"Using {database} database") - self.engine = get_mysql_engine(self.db_config_env, 'mysqldb', database) + self.engine = get_mysql_engine(self.db_config_env, "mysqldb", database) Base.metadata.create_all(self.engine) - article_institute_batcher = BatchWriter(self.insert_batch_size, - add_article_institutes, - self.engine) - match_attempted_batcher = BatchWriter(self.insert_batch_size, - update_existing_articles, - self.engine) + article_institute_batcher = BatchWriter( + self.insert_batch_size, add_article_institutes, self.engine + ) + match_attempted_batcher = BatchWriter( + self.insert_batch_size, update_existing_articles, self.engine + ) - fuzzer = ComboFuzzer([fuzz.token_sort_ratio, fuzz.partial_ratio], - store_history=True) + fuzzer = ComboFuzzer( + [fuzz.token_sort_ratio, fuzz.partial_ratio], store_history=True + ) # extract lookup of GRID institute names to ids - seems to be OK to hold in memory institute_name_id_lookup = grid_name_lookup(self.engine) @@ -91,13 +101,15 @@ def run(self): all_grid_ids = {i.id for i in session.query(Institute.id).all()} logging.info(f"{len(all_grid_ids)} institutes in GRID") - article_query = (session - .query(Article.id, Article.mag_authors) - .filter(Article.institute_match_attempted.is_(False) - & ~Article.institutes.any() - & Article.mag_authors.isnot(None))) + article_query = session.query(Article.id, Article.mag_authors).filter( + Article.institute_match_attempted.is_(False) + & ~Article.institutes.any() + & Article.mag_authors.isnot(None) + ) total = article_query.count() - logging.info(f"Total articles with authors and no institutes links: {total}") + logging.info( + f"Total articles with authors and no institutes links: {total}" + ) logging.debug("Starting the matching process") articles = article_query.all() @@ -106,28 +118,33 @@ def run(self): article_institute_links = [] for author in article.mag_authors: # prevent duplicates when a mixture of institute aliases are used in the same article - existing_article_institute_ids = {link['institute_id'] - for link in article_institute_links} + existing_article_institute_ids = { + link["institute_id"] for link in article_institute_links + } # extract and validate grid_id try: - extracted_grid_id = author['affiliation_grid_id'] + extracted_grid_id = author["affiliation_grid_id"] except KeyError: pass else: # check grid id is valid - if (extracted_grid_id in all_grid_ids - and extracted_grid_id not in existing_article_institute_ids): - links = create_article_institute_links(article_id=article.id, - institute_ids=[extracted_grid_id], - score=1) + if ( + extracted_grid_id in all_grid_ids + and extracted_grid_id not in existing_article_institute_ids + ): + links = create_article_institute_links( + article_id=article.id, + institute_ids=[extracted_grid_id], + score=1, + ) article_institute_links.extend(links) logging.debug(f"Used grid_id: {extracted_grid_id}") continue # extract author affiliation try: - affiliation = author['author_affiliation'] + affiliation = author["author_affiliation"] except KeyError: # no grid id or affiliation for this author logging.debug(f"No affiliation found in: {author}") @@ -140,37 +157,43 @@ def run(self): pass else: institute_ids = set(institute_ids) - existing_article_institute_ids - links = create_article_institute_links(article_id=article.id, - institute_ids=institute_ids, - score=1) + links = create_article_institute_links( + article_id=article.id, institute_ids=institute_ids, score=1 + ) article_institute_links.extend(links) logging.debug(f"Found an exact match for: {affiliation}") continue # fuzzy matching try: - match, score = fuzzer.fuzzy_match_one(affiliation, - institute_name_id_lookup.keys()) + match, score = fuzzer.fuzzy_match_one( + affiliation, institute_name_id_lookup.keys() + ) except KeyError: # failed fuzzy match logging.debug(f"Failed fuzzy match: {affiliation}") else: institute_ids = institute_name_id_lookup[match] institute_ids = set(institute_ids) - existing_article_institute_ids - links = create_article_institute_links(article_id=article.id, - institute_ids=institute_ids, - score=score) + links = create_article_institute_links( + article_id=article.id, institute_ids=institute_ids, score=score + ) article_institute_links.extend(links) - logging.debug(f"Found a fuzzy match: {affiliation} {score} {match}") + logging.debug( + f"Found a fuzzy match: {affiliation} {score} {match}" + ) # add links for this article to the batch queue article_institute_batcher.extend(article_institute_links) # mark that matching has been attempted for this article - match_attempted_batcher.append(dict(id=article.id, - institute_match_attempted=True)) + match_attempted_batcher.append( + dict(id=article.id, institute_match_attempted=True) + ) if not count % 100: - logging.info(f"{count} processed articles from {total} : {(count / total) * 100:.1f}%") + logging.info( + f"{count} processed articles from {total} : {(count / total) * 100:.1f}%" + ) if self.test and count == 5000: logging.warning("Exiting after 5000 articles in test mode") @@ -184,8 +207,12 @@ def run(self): match_attempted_batcher.write() logging.info("All articles processed") - logging.info(f"Total successful fuzzy matches for institute names: {len(fuzzer.successful_fuzzy_matches)}") - logging.info(f"Total failed fuzzy matches for institute names{len(fuzzer.failed_fuzzy_matches): }") + logging.info( + f"Total successful fuzzy matches for institute names: {len(fuzzer.successful_fuzzy_matches)}" + ) + logging.info( + f"Total failed fuzzy matches for institute names{len(fuzzer.failed_fuzzy_matches): }" + ) # mark as done logging.info("Task complete") @@ -200,20 +227,22 @@ class GridRootTask(luigi.WrapperTask): articles_from_date = luigi.Parameter(default=None) insert_batch_size = luigi.IntParameter(default=500) debug = luigi.BoolParameter(default=False) + article_source = luigi.BoolParameter(default=None) def requires(self): - '''Collects the database configurations - and executes the central task.''' + """Collects the database configurations + and executes the central task.""" logging.getLogger().setLevel(logging.INFO) _routine_id = "{}-{}".format(self.date, self.production) grid_task_kwargs = { - '_routine_id':_routine_id, - 'db_config_path':self.db_config_path, - 'db_config_env':'MYSQLDB', - 'mag_config_path':'mag.config', - 'test':not self.production, - 'insert_batch_size':self.insert_batch_size, - 'articles_from_date':self.articles_from_date, - 'date':self.date, + "_routine_id": _routine_id, + "db_config_path": self.db_config_path, + "db_config_env": "MYSQLDB", + "mag_config_path": "mag.config", + "test": not self.production, + "insert_batch_size": self.insert_batch_size, + "articles_from_date": self.articles_from_date, + "date": self.date, + "article_source": self.article_source, } yield GridTask(**grid_task_kwargs) diff --git a/nesta/core/routines/arxiv/arxiv_mag_sparql_task.py b/nesta/core/routines/arxiv/arxiv_mag_sparql_task.py index 3e2e03f8..db713f97 100644 --- a/nesta/core/routines/arxiv/arxiv_mag_sparql_task.py +++ b/nesta/core/routines/arxiv/arxiv_mag_sparql_task.py @@ -12,7 +12,12 @@ from nesta.core.routines.arxiv.arxiv_mag_task import QueryMagTask from nesta.packages.arxiv.collect_arxiv import update_existing_articles -from nesta.packages.mag.query_mag_sparql import update_field_of_study_ids_sparql, extract_entity_id, query_articles_by_doi, query_authors +from nesta.packages.mag.query_mag_sparql import ( + update_field_of_study_ids_sparql, + extract_entity_id, + query_articles_by_doi, + query_authors, +) from nesta.packages.misc_utils.batches import BatchWriter from nesta.core.orms.arxiv_orm import Base, Article from nesta.core.orms.mag_orm import FieldOfStudy @@ -37,6 +42,7 @@ class MagSparqlTask(luigi.Task): retrieved. Must be in YYYY-MM-DD format (not used in this task but passed down to others) """ + date = luigi.DateParameter() _routine_id = luigi.Parameter() test = luigi.BoolParameter(default=True) @@ -45,54 +51,68 @@ class MagSparqlTask(luigi.Task): mag_config_path = luigi.Parameter() insert_batch_size = luigi.IntParameter(default=500) articles_from_date = luigi.Parameter() + article_source = luigi.Parameter(default=None) def output(self): - '''Points to the output database engine''' + """Points to the output database engine""" db_config = misctools.get_config(self.db_config_path, "mysqldb") - db_config["database"] = 'dev' if self.test else 'production' + db_config["database"] = "dev" if self.test else "production" db_config["table"] = "arXlive " # Note, not a real table update_id = "ArxivMagSparql_{}".format(self.date) return MySqlTarget(update_id=update_id, **db_config) def requires(self): - yield QueryMagTask(date=self.date, - _routine_id=self._routine_id, - db_config_path=self.db_config_path, - db_config_env=self.db_config_env, - mag_config_path=self.mag_config_path, - test=self.test, - articles_from_date=self.articles_from_date, - insert_batch_size=self.insert_batch_size) + yield QueryMagTask( + date=self.date, + _routine_id=self._routine_id, + db_config_path=self.db_config_path, + db_config_env=self.db_config_env, + mag_config_path=self.mag_config_path, + test=self.test, + articles_from_date=self.articles_from_date, + insert_batch_size=self.insert_batch_size, + article_source=self.article_source, + ) def run(self): # database setup - database = 'dev' if self.test else 'production' + database = "dev" if self.test else "production" logging.warning(f"Using {database} database") - self.engine = get_mysql_engine(self.db_config_env, 'mysqldb', database) + self.engine = get_mysql_engine(self.db_config_env, "mysqldb", database) Base.metadata.create_all(self.engine) with db_session(self.engine) as session: - field_mapping = {'paper': 'mag_id', - 'paperTitle': 'title', - 'fieldsOfStudy': 'fields_of_study', - 'citationCount': 'citation_count'} - - logging.info("Querying database for articles without fields of study and with doi") - articles_to_process = [dict(id=a.id, doi=a.doi, title=a.title) for a in - (session - .query(Article) - .filter((Article.mag_authors.is_(None) | ~Article.fields_of_study.any()) - & Article.doi.isnot(None)) - .all())] + field_mapping = { + "paper": "mag_id", + "paperTitle": "title", + "fieldsOfStudy": "fields_of_study", + "citationCount": "citation_count", + } + + logging.info( + "Querying database for articles without fields of study and with doi" + ) + articles_to_process = [ + dict(id=a.id, doi=a.doi, title=a.title) + for a in ( + session.query(Article) + .filter( + (Article.mag_authors.is_(None) | ~Article.fields_of_study.any()) + & Article.doi.isnot(None) + ) + .all() + ) + ] total_arxiv_ids_to_process = len(articles_to_process) logging.info(f"{total_arxiv_ids_to_process} articles to process") - all_articles_to_update = BatchWriter(self.insert_batch_size, - update_existing_articles, - self.engine) + all_articles_to_update = BatchWriter( + self.insert_batch_size, update_existing_articles, self.engine + ) - for count, row in enumerate(query_articles_by_doi(articles_to_process), - start=1): + for count, row in enumerate( + query_articles_by_doi(articles_to_process), start=1 + ): # renaming and reformatting for code, description in field_mapping.items(): try: @@ -100,66 +120,77 @@ def run(self): except KeyError: pass - if row.get('citation_count', None) is not None: - row['citation_count_updated'] = date.today() + if row.get("citation_count", None) is not None: + row["citation_count_updated"] = date.today() # reformat fos_ids out of entity urls try: - fos = row.pop('fields_of_study') - row['fields_of_study'] = {extract_entity_id(f) for f in fos.split(',')} + fos = row.pop("fields_of_study") + row["fields_of_study"] = { + extract_entity_id(f) for f in fos.split(",") + } except KeyError: # missing fields of study - row['fields_of_study'] = [] + row["fields_of_study"] = [] except (AttributeError, TypeError): # either of these could occur when the same doi is present in 2 # articles in the same batch logging.debug("Already processed") - row['fields_of_study'] = fos + row["fields_of_study"] = fos # reformat mag_id out of entity url try: - row['mag_id'] = extract_entity_id(row['mag_id']) + row["mag_id"] = extract_entity_id(row["mag_id"]) except TypeError: # id has already been extracted pass # query for author and affiliation details try: - author_ids = {extract_entity_id(a) for a in row.pop('authors').split(',')} - row['mag_authors'] = list(query_authors(author_ids)) + author_ids = { + extract_entity_id(a) for a in row.pop("authors").split(",") + } + row["mag_authors"] = list(query_authors(author_ids)) except KeyError: pass # drop unnecessary fields - for f in ['score', 'title']: + for f in ["score", "title"]: try: del row[f] except KeyError: pass # check fields of study exist in the database - logging.debug('Checking fields of study exist in db') - found_fos_ids = {fos.id for fos in (session - .query(FieldOfStudy) - .filter(FieldOfStudy.id.in_(row['fields_of_study'])) - .all())} - - missing_fos_ids = row['fields_of_study'] - found_fos_ids + logging.debug("Checking fields of study exist in db") + found_fos_ids = { + fos.id + for fos in ( + session.query(FieldOfStudy) + .filter(FieldOfStudy.id.in_(row["fields_of_study"])) + .all() + ) + } + + missing_fos_ids = row["fields_of_study"] - found_fos_ids if missing_fos_ids: logging.info(f"Missing field of study ids: {missing_fos_ids}") - fos_not_found = update_field_of_study_ids_sparql(self.engine, - missing_fos_ids) + fos_not_found = update_field_of_study_ids_sparql( + self.engine, missing_fos_ids + ) # any fos not found in mag are removed to prevent foreign key # constraint errors when building the link table for fos in fos_not_found: - row['fields_of_study'].remove(fos) + row["fields_of_study"].remove(fos) # add this row to the queue logging.debug(row) all_articles_to_update.append(row) if not count % 1000: - logging.info(f"{count} done. {total_arxiv_ids_to_process - count} articles left to process") + logging.info( + f"{count} done. {total_arxiv_ids_to_process - count} articles left to process" + ) if self.test and count == 3000: logging.warning("Exiting after 3000 rows in test mode") break diff --git a/nesta/core/routines/arxiv/arxiv_mag_task.py b/nesta/core/routines/arxiv/arxiv_mag_task.py index d528c0fa..2b7fecce 100644 --- a/nesta/core/routines/arxiv/arxiv_mag_task.py +++ b/nesta/core/routines/arxiv/arxiv_mag_task.py @@ -12,11 +12,17 @@ import logging import pprint +from nesta.core.routines.arxiv.cord_collect_task import CollectCordTask from nesta.core.routines.arxiv.arxiv_iterative_date_task import DateTask from nesta.core.routines.arxiv.magrxiv_collect_iterative_task import CollectMagrxivTask from nesta.packages.arxiv.collect_arxiv import BatchedTitles, update_existing_articles from nesta.packages.misc_utils.batches import BatchWriter -from nesta.packages.mag.query_mag_api import build_expr, query_mag_api, dedupe_entities, update_field_of_study_ids +from nesta.packages.mag.query_mag_api import ( + build_expr, + query_mag_api, + dedupe_entities, + update_field_of_study_ids, +) from nesta.core.orms.arxiv_orm import Base, Article from nesta.core.orms.mag_orm import FieldOfStudy from nesta.core.orms.orm_utils import get_mysql_engine, db_session @@ -40,6 +46,7 @@ class QueryMagTask(luigi.Task): retrieved. Must be in YYYY-MM-DD format (not used in this task but passed down to others) """ + date = luigi.DateParameter() _routine_id = luigi.Parameter() test = luigi.BoolParameter(default=True) @@ -48,89 +55,117 @@ class QueryMagTask(luigi.Task): mag_config_path = luigi.Parameter() insert_batch_size = luigi.IntParameter(default=500) articles_from_date = luigi.Parameter() + article_source = luigi.Parameter(default=None) def output(self): - '''Points to the output database engine''' + """Points to the output database engine""" db_config = misctools.get_config(self.db_config_path, "mysqldb") - db_config["database"] = 'dev' if self.test else 'production' + db_config["database"] = "dev" if self.test else "production" db_config["table"] = "arXlive " # Note, not a real table update_id = "ArxivQueryMag_{}".format(self.date) return MySqlTarget(update_id=update_id, **db_config) def requires(self): - yield DateTask(date=self.date, - _routine_id=self._routine_id, - db_config_path=self.db_config_path, - db_config_env=self.db_config_env, - test=self.test, - articles_from_date=self.articles_from_date, - insert_batch_size=self.insert_batch_size) - # Start collection from Jan 2010 unless in test mode - articles_from_date = '1 January 2010' - if self.test: # 11 days ago for test - articles_from_date = dt.strftime(dt.now() - timedelta(days=11), '%d %B %Y') - yield CollectMagrxivTask(date=self.date, - routine_id=self._routine_id, - db_config_path=self.db_config_path, - db_config_env=self.db_config_env, - test=self.test, - articles_from_date=articles_from_date, - insert_batch_size=self.insert_batch_size) - + kwargs = dict( + date=self.date, + db_config_path=self.db_config_path, + db_config_env=self.db_config_env, + test=self.test, + ) + article_source = self.article_source + if (article_source is None) or (article_source == "arxiv"): + yield DateTask( + _routine_id=self._routine_id, + articles_from_date=self.articles_from_date, + insert_batch_size=self.insert_batch_size, + **kwargs, + ) + if (article_source is None) or (article_source == "cord"): + yield CollectCordTask(routine_id=self._routine_id, **kwargs) + if article_source is None: + # Start collection from Jan 2010 unless in test mode + articles_from_date = "1 January 2010" + if self.test: # 11 days ago for test + articles_from_date = dt.strftime( + dt.now() - timedelta(days=11), "%d %B %Y" + ) + yield CollectMagrxivTask( + routine_id=self._routine_id, + articles_from_date=articles_from_date, + insert_batch_size=self.insert_batch_size, + **kwargs, + ) def run(self): pp = pprint.PrettyPrinter(indent=4, width=100) - mag_config = misctools.get_config(self.mag_config_path, 'mag') - mag_subscription_key = mag_config['subscription_key'] + mag_config = misctools.get_config(self.mag_config_path, "mag") + mag_subscription_key = mag_config["subscription_key"] # database setup - database = 'dev' if self.test else 'production' + database = "dev" if self.test else "production" logging.warning(f"Using {database} database") - self.engine = get_mysql_engine(self.db_config_env, 'mysqldb', database) + self.engine = get_mysql_engine(self.db_config_env, "mysqldb", database) Base.metadata.create_all(self.engine) with db_session(self.engine) as session: - paper_fields = ["Id", "Ti", "F.FId", "CC", - "AA.AuN", "AA.AuId", "AA.AfN", "AA.AfId", "AA.S"] - - author_mapping = {'AuN': 'author_name', - 'AuId': 'author_id', - 'AfN': 'author_affiliation', - 'AfId': 'author_affiliation_id', - 'S': 'author_order'} - - field_mapping = {'Id': 'mag_id', - 'Ti': 'title', - 'F': 'fields_of_study', - 'AA': 'mag_authors', - 'CC': 'citation_count', - 'logprob': 'mag_match_prob'} + paper_fields = [ + "Id", + "Ti", + "F.FId", + "CC", + "AA.AuN", + "AA.AuId", + "AA.AfN", + "AA.AfId", + "AA.S", + ] + + author_mapping = { + "AuN": "author_name", + "AuId": "author_id", + "AfN": "author_affiliation", + "AfId": "author_affiliation_id", + "S": "author_order", + } + + field_mapping = { + "Id": "mag_id", + "Ti": "title", + "F": "fields_of_study", + "AA": "mag_authors", + "CC": "citation_count", + "logprob": "mag_match_prob", + } logging.info("Querying database for articles without fields of study") - arxiv_ids_to_process = {a.id for a in (session. - query(Article) - .filter(~Article.fields_of_study.any()) - .all())} + arxiv_ids_to_process = { + a.id + for a in ( + session.query(Article).filter(~Article.fields_of_study.any()).all() + ) + } total_arxiv_ids_to_process = len(arxiv_ids_to_process) logging.info(f"{total_arxiv_ids_to_process} articles to process") - all_articles_to_update = BatchWriter(self.insert_batch_size, - update_existing_articles, - self.engine) + all_articles_to_update = BatchWriter( + self.insert_batch_size, update_existing_articles, self.engine + ) batched_titles = BatchedTitles(arxiv_ids_to_process, 10000, session) batch_field_of_study_ids = set() - for count, expr in enumerate(build_expr(batched_titles, 'Ti'), 1): + for count, expr in enumerate(build_expr(batched_titles, "Ti"), 1): logging.debug(pp.pformat(expr)) - expr_length = len(expr.split(',')) + expr_length = len(expr.split(",")) logging.info(f"Querying MAG for {expr_length} titles") total_arxiv_ids_to_process -= expr_length batch_data = query_mag_api(expr, paper_fields, mag_subscription_key) logging.debug(pp.pformat(batch_data)) - returned_entities = batch_data['entities'] - logging.info(f"{len(returned_entities)} entities returned from MAG (potentially including duplicates)") + returned_entities = batch_data["entities"] + logging.info( + f"{len(returned_entities)} entities returned from MAG (potentially including duplicates)" + ) # dedupe response keeping the entity with the highest logprob deduped_mag_ids = dedupe_entities(returned_entities) @@ -144,7 +179,7 @@ def run(self): for row in returned_entities: # exclude duplicate titles - if row['Id'] not in deduped_mag_ids: + if row["Id"] not in deduped_mag_ids: continue # renaming and reformatting @@ -154,56 +189,71 @@ def run(self): except KeyError: pass - for author in row.get('mag_authors', []): + for author in row.get("mag_authors", []): for code, description in author_mapping.items(): try: author[description] = author.pop(code) except KeyError: pass - if row.get('citation_count', None) is not None: - row['citation_count_updated'] = date.today() + if row.get("citation_count", None) is not None: + row["citation_count_updated"] = date.today() # reformat fos_ids out of dictionaries try: - row['fields_of_study'] = {f['FId'] for f in row.pop('fields_of_study')} + row["fields_of_study"] = { + f["FId"] for f in row.pop("fields_of_study") + } except KeyError: - row['fields_of_study'] = [] - batch_field_of_study_ids.update(row['fields_of_study']) + row["fields_of_study"] = [] + batch_field_of_study_ids.update(row["fields_of_study"]) # get list of ids which share the same title try: - matching_articles = batched_titles[row['title']] + matching_articles = batched_titles[row["title"]] except KeyError: - logging.warning(f"Returned title not found in original data: {row['title']}") + logging.warning( + f"Returned title not found in original data: {row['title']}" + ) continue # drop unnecessary fields - for f in ['prob', 'title']: + for f in ["prob", "title"]: del row[f] # add each matching article for this title to the batch for article_id in matching_articles: - batch_article_data.append({**row, 'id': article_id}) + batch_article_data.append({**row, "id": article_id}) # check fields of study are in database - batch_field_of_study_ids = {fos_id for article in batch_article_data - for fos_id in article['fields_of_study']} - logging.debug('Checking fields of study exist in db') - found_fos_ids = {fos.id for fos in (session - .query(FieldOfStudy) - .filter(FieldOfStudy.id.in_(batch_field_of_study_ids)) - .all())} + batch_field_of_study_ids = { + fos_id + for article in batch_article_data + for fos_id in article["fields_of_study"] + } + logging.debug("Checking fields of study exist in db") + found_fos_ids = { + fos.id + for fos in ( + session.query(FieldOfStudy) + .filter(FieldOfStudy.id.in_(batch_field_of_study_ids)) + .all() + ) + } missing_fos_ids = batch_field_of_study_ids - found_fos_ids if missing_fos_ids: # query mag for details if not found - update_field_of_study_ids(mag_subscription_key, session, missing_fos_ids) + update_field_of_study_ids( + mag_subscription_key, session, missing_fos_ids + ) # add this batch to the queue all_articles_to_update.extend(batch_article_data) - logging.info(f"Batch {count} done. {total_arxiv_ids_to_process} articles left to process") + logging.info( + f"Batch {count} done. {total_arxiv_ids_to_process} articles left to process" + ) if self.test and count == 10: logging.warning("Exiting after 10 batches in test mode") break diff --git a/nesta/core/routines/arxiv/cord_collect_task.py b/nesta/core/routines/arxiv/cord_collect_task.py new file mode 100644 index 00000000..90e74640 --- /dev/null +++ b/nesta/core/routines/arxiv/cord_collect_task.py @@ -0,0 +1,91 @@ +""" +Cord Collect +============ + +Luigi routine to collect the latest data from CORD +""" +from datetime import datetime as dt +from datetime import timedelta +import luigi +import logging +import pandas as pd + +from nesta.packages.cord.cord import cord_data, to_arxiv_format +from nesta.core.orms.arxiv_orm import Article, Base +from nesta.core.orms.orm_utils import get_mysql_engine, db_session, insert_data +from nesta.core.luigihacks import misctools +from nesta.core.luigihacks.mysqldb import MySqlTarget + + +class CollectCordTask(luigi.Task): + """Collect CORD articles + + Args: + date (datetime): Datetime used to label the outputs + db_config_env (str): environmental variable pointing to the db config file + db_config_path (str): The output database configuration + insert_batch_size (int): Number of records to insert into the database at once + articles_from_date (str): Earliest possible date considered to collect articles. + """ + + date = luigi.DateParameter() + dont_recollect = luigi.BoolParameter(default=False) + routine_id = luigi.Parameter() + test = luigi.BoolParameter(default=True) + db_config_env = luigi.Parameter() + db_config_path = luigi.Parameter() + insert_batch_size = luigi.IntParameter(default=500) + + def output(self): + """Points to the output database engine""" + db_config = misctools.get_config(self.db_config_path, "mysqldb") + db_config["database"] = "dev" if self.test else "production" + db_config["table"] = "arXlive " # Note, not a real table + update_id = "CordCollect_{}".format(self.date) + return MySqlTarget(update_id=update_id, **db_config) + + def run(self): + date = '2020-05-31' if self.test else None + articles = list(map(to_arxiv_format, cord_data(date=date))) + db = "dev" if self.test else "production" + engine = get_mysql_engine(self.db_config_env, "mysqldb", db) + with db_session(engine) as session: + insert_data( + self.db_config_env, + "mysqldb", + db, + Base, + Article, + articles, + low_memory=True, + ) + # Mark as done + self.output().touch() + + +class CordRootTask(luigi.WrapperTask): + """A dummy root task, which collects the database configurations + and executes the central task. + + Args: + date (datetime): Date used to label the outputs + db_config_path (str): Path to the MySQL database configuration + production (bool): Flag indicating whether running in testing + mode (False, default), or production mode (True). + + """ + + date = luigi.DateParameter(default=dt.today()) + db_config_path = luigi.Parameter(default="mysqldb.config") + production = luigi.BoolParameter(default=False) + + def requires(self): + routine_id = "{}-{}".format(self.date, self.production) + logging.getLogger().setLevel(logging.INFO) + yield CollectCordTask( + date=self.date, + routine_id=routine_id, + test=not self.production, + db_config_env="MYSQLDB", + db_config_path=self.db_config_path, + ) diff --git a/nesta/packages/cord/cord.py b/nesta/packages/cord/cord.py new file mode 100644 index 00000000..3e6018aa --- /dev/null +++ b/nesta/packages/cord/cord.py @@ -0,0 +1,119 @@ +from tempfile import TemporaryFile +import os.path +import requests +import shutil +import tarfile +import csv +import re +from io import StringIO +import unicodedata + + +LATEST_RE = re.compile(r"(\d{4})-(\d{2})-(\d{2})") +CSV_RE = "{date}/(.*)metadata(.*).csv" # Name changes over time +BASE_URL = ( + "https://ai2-semanticscholar-cord-19." + "s3-us-west-2.amazonaws.com/historical_releases{}" +) +HTML_URL = BASE_URL.format(".html") +DATA_URL = BASE_URL.format("/cord-19_{date}.tar.gz") +AWS_TMP_DIR = "/dev/shm/" +CORD_TO_ARXIV_LOOKUP = { + "datestamp": "publish_time", + "created": "publish_time", + "updated": "publish_time", + "title": "title", + "journal_ref": "journal", + "doi": "doi", + "abstract": "abstract", + "authors": "authors" +} + + +def stream_to_file(url, fileobj): + """ + Stream a large file from a URL to a file object in a memory-efficient + rate-efficient manner. + """ + with requests.get(url, stream=True) as r: + r.raise_for_status() + shutil.copyfileobj(r.raw, fileobj) + fileobj.seek(0) # Reset the file pointer, ready for reading + + +def cord_csv(date=None): + """Load the CORD19 metadata CSV file for the given date string""" + # Prepare variables + if date is None: + date = most_recent_date() + + url = DATA_URL.format(date=date) + csv_re = re.compile(CSV_RE.format(date=date)) + tmp_dir = AWS_TMP_DIR if os.path.isdir(AWS_TMP_DIR) else None + # Stream the huge tarball into the local tempfile + with TemporaryFile(dir=tmp_dir, suffix=".tar.gz") as fileobj: + stream_to_file(url, fileobj) + # Filter out the CSV file based on regex + tf = tarfile.open(fileobj=fileobj) + (match,) = filter(None, map(csv_re.match, tf.getnames())) + # Load the CSV file data into memory + csv = tf.extractfile(match.group()) + return StringIO(csv.read().decode("utf-8")) # Return wipes the cache + + +def cord_data(date=None): + """ + Yields lines from the CORD19 metadata CSV file + for the given date string. + + Note: takes 20 mins over grounded internet for a ~10GB tarball + """ + with cord_csv(date) as f: + for line in csv.DictReader(f): + yield line + + +def most_recent_date(): + """Determine the most recent date of CORD data from their website""" + response = requests.get(HTML_URL) + response.raise_for_status() + return LATEST_RE.search(response.text).group() + + +def remove_private_chars(text): + """Remove private unicode characters""" + return "".join([char for char in text if unicodedata.category(char) != 'Co']) + +def convert_date(text): + """Standardise one of the three expected date formats""" + # Mapping of text length to processing lambda + date_converter = {0: lambda text: None, + 4: lambda text: f'{text}-01-01', + 10: lambda text: text} + # Convert the datestring according to the number of chars + try: + return date_converter[len(text)](text) + except KeyError: + raise ValueError(f'Unrecognise date format: {text}') + +def to_arxiv_format(cord_row): + """Convert a row of CORD data ready for ingestion in the arxiv MySQL table""" + # Remove private unicode + for field in ['abstract', 'title']: + cord_row[field] = remove_private_chars(cord_row[field]) + cord_row['publish_time'] = convert_date(cord_row['publish_time']) + # Empty to null + cord_row = {k: (v if (v != '' or k == 'authors') else None) + for k, v in cord_row.items()} + # Format authors + authors = cord_row["authors"].split(";") + cord_row["authors"] = list(map(str.strip, authors)) + # Transpose field nameas + arxiv_row = { + arxiv_key: cord_row[cord_key] + for arxiv_key, cord_key in CORD_TO_ARXIV_LOOKUP.items() + } + # Hard-coded fields + arxiv_row["id"] = f"cord-{cord_row['cord_uid']}" + arxiv_row["article_source"] = "cord" + return arxiv_row diff --git a/nesta/packages/cord/tests/test_cord.py b/nesta/packages/cord/tests/test_cord.py new file mode 100644 index 00000000..cb30079a --- /dev/null +++ b/nesta/packages/cord/tests/test_cord.py @@ -0,0 +1,88 @@ +from unittest import mock +from io import BytesIO +from nesta.packages.cord.cord import ( + TemporaryFile, + stream_to_file, + cord_data, + most_recent_date, + remove_private_chars, + convert_date, + to_arxiv_format, +) + +PATH = "nesta.packages.cord.cord.{}" + +COMPACT_FILE = ( + "https://nesta-open-data.s3.eu-west-2.amazonaws.com/" + "unit_tests/cord-19_2020-03-13.tar.gz" +) + + +@mock.patch(PATH.format("requests")) +def test_stream_to_file(mocked_requests): + # Setup the mock stream + response_text = b"this is some data" + mocked_requests.get().__enter__().raw = BytesIO(response_text) + # Steam the response to file + with TemporaryFile() as tf: + stream_to_file("dummy.com", tf) + assert tf.read() == response_text + + +@mock.patch(PATH.format("DATA_URL")) +def test_cord_data(mocked_data_url): + mocked_data_url.format.return_value = COMPACT_FILE + data = list(cord_data(date="2020-03-13")) + assert len(data) > 1000 # Lots of data + first = data[1000] # Not the first row, but one that contains good data + assert len(first) > 5 # Lots of columns + # All columns have the same length + assert all(len(row) == len(first) for row in data) + assert type(first["abstract"]) is str # Contains text data + assert len(first["abstract"]) > 100 # Contains text data + + +def test_most_recent_date(): + year, month, day = most_recent_date().split("-") + assert int(year) >= 2021 + assert int(month) > 0 + assert int(day) > 0 + + +def test_remove_private_chars(): + before = "\uf0b7test \uf0c7string 😃" + after = remove_private_chars(before) + assert after == "test string 😃" + + +def test_convert_date(): + assert convert_date("1991") == "1991-01-01" + assert convert_date("1991-03-31") == "1991-03-31" + assert convert_date("") is None + + +def test_to_arxiv_format(): + cord_row = { + "abstract": "An abstract", + "title": "a Title", + "authors": "name, surname; other name, other surname", + "cord_uid": 123, + "a bonus field": "drop me", # will be dropped + "publish_time": "2020-02-12", + "journal": "JOURNAL", + "doi": "", # converted to None + "another bonus field": None, # will be dropped + } + arxiv_row = { + "id": "cord-123", + "abstract": "An abstract", + "title": "a Title", + "authors": ["name, surname", "other name, other surname"], + "created": "2020-02-12", + "datestamp": "2020-02-12", + "updated": "2020-02-12", + "journal_ref": "JOURNAL", + "doi": None, + "article_source": "cord", # hard-coded + } + assert to_arxiv_format(cord_row) == arxiv_row