Skip to content

Commit

Permalink
Refactor and clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
ruchernchong committed Sep 26, 2024
1 parent 9971292 commit ec071ad
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 48 deletions.
86 changes: 48 additions & 38 deletions update_cars.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,48 @@
import asyncio
from typing import List
from typing import List, Dict, Any

from pymongo import ASCENDING, IndexModel, collection
from pymongo.collection import Collection

import updater
from db import MongoDBConnection


async def main():
collection_name: str = "cars"
zip_file_name: str = "Monthly New Registration of Cars by Make.zip"
zip_url: str = (
f"https://datamall.lta.gov.sg/content/dam/datamall/datasets/Facts_Figures/Vehicle Registration/{zip_file_name}"
)
key_fields: List[str] = ["month"]
async def create_indexes(collection: Collection) -> None:
indexes = [
IndexModel([("month", ASCENDING), ("make", ASCENDING)]),
IndexModel([("month", ASCENDING)]),
IndexModel([("make", ASCENDING)]),
IndexModel([("fuel_type", ASCENDING)]),
IndexModel([("make", ASCENDING), ("fuel_type", ASCENDING)]),
IndexModel([("number", ASCENDING)]),
]
collection.create_indexes(indexes)

db = MongoDBConnection().database
collection = db[collection_name]

# Create indexes
collection.create_index({"month": 1, "make": 1})
collection.create_index({"month": 1})
collection.create_index({"make": 1})
collection.create_index({"fuel_type": 1})
collection.create_index({"make": 1, "fuel_type": 1})
collection.create_index({"number": 1})

message = await updater.main(collection_name, zip_file_name, zip_url, key_fields)

if message["inserted_count"] > 0:
print("Running aggregation...")

replace_empty_string_with_zero = [
async def run_aggregations(collection: Collection) -> None:
aggregations = [
[
{"$match": {"number": ""}},
{"$set": {"number": 0}},
{
"$merge": {
"into": collection_name,
"into": collection.name,
"on": "_id",
"whenMatched": "replace",
"whenNotMatched": "discard",
}
},
]

format_values = [
],
[
{
"$addFields": {
"make": {
"$replaceAll": {
"input": "$make",
"find": ".",
"replacement": "",
},
}
},
"vehicle_type": {
"$replaceAll": {
Expand All @@ -63,30 +55,48 @@ async def main():
},
{
"$merge": {
"into": collection_name,
"into": collection.name,
"on": "_id",
"whenMatched": "merge",
"whenNotMatched": "discard",
}
},
]

uppercase_make = [
],
[
{"$addFields": {"make": {"$toUpper": "$make"}}},
{
"$merge": {
"into": collection_name,
"into": collection.name,
"on": "_id",
"whenMatched": "merge",
"whenNotMatched": "discard",
}
},
]
],
]

for aggregation in aggregations:
collection.aggregate(aggregation)


collection.aggregate(replace_empty_string_with_zero)
collection.aggregate(format_values)
collection.aggregate(uppercase_make)
async def main() -> Dict[str, Any]:
collection_name: str = "cars"
zip_file_name: str = "Monthly New Registration of Cars by Make.zip"
zip_url: str = (
f"https://datamall.lta.gov.sg/content/dam/datamall/datasets/Facts_Figures/Vehicle Registration/{zip_file_name}"
)
key_fields: List[str] = ["month"]

db = MongoDBConnection().database
collection = db[collection_name]

await create_indexes(collection)

message = await updater.main(collection_name, zip_file_name, zip_url, key_fields)

if message["inserted_count"] > 0:
print("Running aggregation...")
await run_aggregations(collection)
print("Aggregation complete.")

db.client.close()
Expand Down
17 changes: 7 additions & 10 deletions updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import os
import tempfile
import time
from typing import List, Dict, Any
from typing import List, Dict, Any, Tuple

from dotenv import load_dotenv
from pymongo.results import InsertManyResult

from db import MongoDBConnection
from download_file import download_file
Expand All @@ -16,17 +17,13 @@


def read_csv_data(file_path: str) -> List[Dict[str, Any]]:
csv_data = []
with open(file_path, "r", encoding="utf-8") as csv_file:
csv_reader = csv.DictReader(csv_file)
for row in csv_reader:
csv_data.append(row)
return csv_data
return list(csv.DictReader(csv_file))


async def updater(
collection_name: str, zip_file_name: str, zip_url: str, key_fields: List[str]
) -> tuple[Any, str] | tuple[None, str]:
) -> Tuple[InsertManyResult | None, str]:
db = MongoDBConnection().database
collection = db[collection_name]

Expand Down Expand Up @@ -55,16 +52,16 @@ async def updater(
start = time.time()
result = collection.insert_many(new_data_to_insert)
end = time.time()
db.client.close()
message = f"{len(result.inserted_ids)} document(s) inserted in {round((end - start) * 1000)}ms"
return result, message
else:
message = "No new data to insert. The provided data matches the existing records."
return None, message

except Exception as error:
print(f"An error has occurred: {error}")
raise
raise Exception(f"An error has occurred: {error}")
finally:
db.client.close()


async def main(
Expand Down

0 comments on commit ec071ad

Please sign in to comment.