diff --git a/update_cars.py b/update_cars.py index cb979f8..53ce8aa 100644 --- a/update_cars.py +++ b/update_cars.py @@ -1,48 +1,40 @@ import asyncio -from typing import List +from typing import List, Dict, Any + +from pymongo import ASCENDING, IndexModel, collection +from pymongo.collection import Collection import updater from db import MongoDBConnection -async def main(): - collection_name: str = "cars" - zip_file_name: str = "Monthly New Registration of Cars by Make.zip" - zip_url: str = ( - f"https://datamall.lta.gov.sg/content/dam/datamall/datasets/Facts_Figures/Vehicle Registration/{zip_file_name}" - ) - key_fields: List[str] = ["month"] +async def create_indexes(collection: Collection) -> None: + indexes = [ + IndexModel([("month", ASCENDING), ("make", ASCENDING)]), + IndexModel([("month", ASCENDING)]), + IndexModel([("make", ASCENDING)]), + IndexModel([("fuel_type", ASCENDING)]), + IndexModel([("make", ASCENDING), ("fuel_type", ASCENDING)]), + IndexModel([("number", ASCENDING)]), + ] + collection.create_indexes(indexes) - db = MongoDBConnection().database - collection = db[collection_name] - # Create indexes - collection.create_index({"month": 1, "make": 1}) - collection.create_index({"month": 1}) - collection.create_index({"make": 1}) - collection.create_index({"fuel_type": 1}) - collection.create_index({"make": 1, "fuel_type": 1}) - collection.create_index({"number": 1}) - - message = await updater.main(collection_name, zip_file_name, zip_url, key_fields) - - if message["inserted_count"] > 0: - print("Running aggregation...") - - replace_empty_string_with_zero = [ +async def run_aggregations(collection: Collection) -> None: + aggregations = [ + [ {"$match": {"number": ""}}, {"$set": {"number": 0}}, { "$merge": { - "into": collection_name, + "into": collection.name, "on": "_id", "whenMatched": "replace", "whenNotMatched": "discard", } }, - ] - - format_values = [ + ], + [ { "$addFields": { "make": { @@ -50,7 +42,7 @@ async def main(): "input": "$make", "find": ".", "replacement": "", - }, + } }, "vehicle_type": { "$replaceAll": { @@ -63,30 +55,48 @@ async def main(): }, { "$merge": { - "into": collection_name, + "into": collection.name, "on": "_id", "whenMatched": "merge", "whenNotMatched": "discard", } }, - ] - - uppercase_make = [ + ], + [ {"$addFields": {"make": {"$toUpper": "$make"}}}, { "$merge": { - "into": collection_name, + "into": collection.name, "on": "_id", "whenMatched": "merge", "whenNotMatched": "discard", } }, - ] + ], + ] + + for aggregation in aggregations: + collection.aggregate(aggregation) + - collection.aggregate(replace_empty_string_with_zero) - collection.aggregate(format_values) - collection.aggregate(uppercase_make) +async def main() -> Dict[str, Any]: + collection_name: str = "cars" + zip_file_name: str = "Monthly New Registration of Cars by Make.zip" + zip_url: str = ( + f"https://datamall.lta.gov.sg/content/dam/datamall/datasets/Facts_Figures/Vehicle Registration/{zip_file_name}" + ) + key_fields: List[str] = ["month"] + db = MongoDBConnection().database + collection = db[collection_name] + + await create_indexes(collection) + + message = await updater.main(collection_name, zip_file_name, zip_url, key_fields) + + if message["inserted_count"] > 0: + print("Running aggregation...") + await run_aggregations(collection) print("Aggregation complete.") db.client.close() diff --git a/updater.py b/updater.py index d1b01c4..5e4434d 100644 --- a/updater.py +++ b/updater.py @@ -3,9 +3,10 @@ import os import tempfile import time -from typing import List, Dict, Any +from typing import List, Dict, Any, Tuple from dotenv import load_dotenv +from pymongo.results import InsertManyResult from db import MongoDBConnection from download_file import download_file @@ -16,17 +17,13 @@ def read_csv_data(file_path: str) -> List[Dict[str, Any]]: - csv_data = [] with open(file_path, "r", encoding="utf-8") as csv_file: - csv_reader = csv.DictReader(csv_file) - for row in csv_reader: - csv_data.append(row) - return csv_data + return list(csv.DictReader(csv_file)) async def updater( collection_name: str, zip_file_name: str, zip_url: str, key_fields: List[str] -) -> tuple[Any, str] | tuple[None, str]: +) -> Tuple[InsertManyResult | None, str]: db = MongoDBConnection().database collection = db[collection_name] @@ -55,7 +52,6 @@ async def updater( start = time.time() result = collection.insert_many(new_data_to_insert) end = time.time() - db.client.close() message = f"{len(result.inserted_ids)} document(s) inserted in {round((end - start) * 1000)}ms" return result, message else: @@ -63,8 +59,9 @@ async def updater( return None, message except Exception as error: - print(f"An error has occurred: {error}") - raise + raise Exception(f"An error has occurred: {error}") + finally: + db.client.close() async def main(