Skip to content

Commit

Permalink
Additional authentication (#33)
Browse files Browse the repository at this point in the history
* added tokens to export routes, you can now only export reports that you created or that have no owner

* fixed tests by adding v2 routes to server calls

* added user authentication to report list_all and list_by

* fixed report list routes
  • Loading branch information
Mr-Squared authored Apr 2, 2023
1 parent f925028 commit 744df63
Show file tree
Hide file tree
Showing 5 changed files with 161 additions and 32 deletions.
28 changes: 20 additions & 8 deletions src/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from src.error import *


def export_json_report_v1(report_id: int):
def export_json_report_v1(report_id: int, owner=None):
if report_id < 0:
raise InputError(detail="Report id cannot be less than 0")

Expand All @@ -18,9 +18,12 @@ def export_json_report_v1(report_id: int):
except DoesNotExist:
raise NotFoundError(detail=f"Report with id {report_id} not found")

if report.owner != None and report.owner != owner:
raise ForbiddenError("You do not have permission to view report")

return Report(**report.to_json())

def export_pdf_report_v1(report_id: int) -> bytes:
def export_pdf_report_v1(report_id: int, owner=None) -> bytes:
if report_id < 0:
raise InputError(detail="Report id cannot be less than 0")

Expand All @@ -29,6 +32,9 @@ def export_pdf_report_v1(report_id: int) -> bytes:
except DoesNotExist:
raise NotFoundError(detail=f"Report with id {report_id} not found")

if report.owner != None and report.owner != owner:
raise ForbiddenError("You do not have permission to view report")

html = export_html_report_v1(report_id)
pdf_bytes = HTML(string=html).write_pdf()

Expand Down Expand Up @@ -73,7 +79,7 @@ def add_violations(soup, violations, parent):
location_string = "Line " + str(violation["line"]) + ", Column " + str(violation["column"])
v.find("code", {"name": "location"}).string = location_string

def export_html_report_v1(report_id: int):
def export_html_report_v1(report_id: int, owner=None):
if report_id < 0:
raise InputError(detail="Report id cannot be less than 0")

Expand All @@ -82,6 +88,9 @@ def export_html_report_v1(report_id: int):
except DoesNotExist:
raise NotFoundError(detail=f"Report with id {report_id} not found")

if report.owner != None and report.owner != owner:
raise ForbiddenError("You do not have permission to view report")

report = report.to_json()

with open("src/report_template.html", "r") as file:
Expand Down Expand Up @@ -152,7 +161,7 @@ def write_violations(writer, violations):
]
writer.writerow(data)

def export_csv_report_v1(report_id: int):
def export_csv_report_v1(report_id: int, owner=None):
if report_id < 0:
raise InputError(detail="Report id cannot be less than 0")

Expand All @@ -161,6 +170,9 @@ def export_csv_report_v1(report_id: int):
except DoesNotExist:
raise NotFoundError(detail=f"Report with id {report_id} not found")

if report.owner != None and report.owner != owner:
raise ForbiddenError("You do not have permission to view report")

report = report.to_json()

csv_buffer = StringIO()
Expand All @@ -185,15 +197,15 @@ def export_csv_report_v1(report_id: int):

return csv_contents

def report_bulk_export_json_v1(report_ids) -> ReportList:
return ReportList(reports=[export_json_report_v1(report_id) for report_id in report_ids])
def report_bulk_export_json_v1(report_ids, owner=None) -> ReportList:
return ReportList(reports=[export_json_report_v1(report_id, owner) for report_id in report_ids])

def report_bulk_export_pdf_v1(report_ids) -> BytesIO:
def report_bulk_export_pdf_v1(report_ids, owner=None) -> BytesIO:
reports = BytesIO()

with ZipFile(reports, 'w', ZIP_DEFLATED) as f:
for report_id in report_ids:
f.writestr(f"invoice_validation_report_{report_id}.pdf", export_pdf_report_v1(report_id))
f.writestr(f"invoice_validation_report_{report_id}.pdf", export_pdf_report_v1(report_id, owner))

reports.seek(0)

Expand Down
55 changes: 53 additions & 2 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,10 +192,18 @@ async def invoice_upload_url_v2(invoice: RemoteInvoice, token = Depends(get_toke
async def export_json_report(report_id: int) -> Report:
return export_json_report_v1(report_id)

@app.get("/export/json_report/v2", tags=["export"])
async def export_json_report_v2(report_id: int, token = Depends(get_token)) -> Report:
return export_json_report_v1(report_id, owner=Sessions.get(token=token).user)

@app.post("/export/bulk_json_reports/v1", tags=["export"])
async def report_bulk_export_json(report_ids: List[int]) -> ReportList:
return report_bulk_export_json_v1(report_ids)

@app.post("/export/bulk_json_reports/v2", tags=["export"])
async def report_bulk_export_json_v2(report_ids: List[int], token = Depends(get_token)) -> ReportList:
return report_bulk_export_json_v1(report_ids, owner=Sessions.get(token=token).user)

@app.get("/export/pdf_report/v1", tags=["export"])
async def export_pdf_report(report_id: int) -> StreamingResponse:
pdf_file = BytesIO(export_pdf_report_v1(report_id))
Expand All @@ -207,6 +215,17 @@ async def export_pdf_report(report_id: int) -> StreamingResponse:
}
return StreamingResponse(pdf_file, headers=headers)

@app.get("/export/pdf_report/v2", tags=["export"])
async def export_pdf_report_v2(report_id: int, token = Depends(get_token)) -> StreamingResponse:
pdf_file = BytesIO(export_pdf_report_v1(report_id, owner=Sessions.get(token=token).user))

# Return the PDF as a streaming response
headers = {
"Content-Disposition": f"attachment; filename=invoice_validation_report_{report_id}.pdf",
"Content-Type": "application/pdf",
}
return StreamingResponse(pdf_file, headers=headers)

@app.post("/export/bulk_pdf_reports/v1", tags=["export"])
async def report_bulk_export_pdf(report_ids: List[int]) -> StreamingResponse:
reports_zip = report_bulk_export_pdf_v1(report_ids)
Expand All @@ -217,11 +236,26 @@ async def report_bulk_export_pdf(report_ids: List[int]) -> StreamingResponse:
headers = { "Content-Disposition": f"attachment; filename=reports.zip"}
)

@app.post("/export/bulk_pdf_reports/v2", tags=["export"])
async def report_bulk_export_pdf_v2(report_ids: List[int], token = Depends(get_token)) -> StreamingResponse:
reports_zip = report_bulk_export_pdf_v1(report_ids, owner=Sessions.get(token=token).user)

return StreamingResponse(
reports_zip,
media_type="application/x-zip-compressed",
headers = { "Content-Disposition": f"attachment; filename=reports.zip"}
)

@app.get("/export/html_report/v1", response_class=HTMLResponse, tags=["export"])
async def export_html_report(report_id: int) -> HTMLResponse:
html_content = export_html_report_v1(report_id)
return HTMLResponse(content=html_content, status_code=200)

@app.get("/export/html_report/v2", response_class=HTMLResponse, tags=["export"])
async def export_html_report_v2(report_id: int, token = Depends(get_token)) -> HTMLResponse:
html_content = export_html_report_v1(report_id, owner=Sessions.get(token=token).user)
return HTMLResponse(content=html_content, status_code=200)

@app.get("/export/csv_report/v1", tags=["export"])
async def export_csv_report(report_id: int) -> HTMLResponse:
csv_contents = export_csv_report_v1(report_id)
Expand All @@ -231,6 +265,15 @@ async def export_csv_report(report_id: int) -> HTMLResponse:

return response

@app.get("/export/csv_report/v2", tags=["export"])
async def export_csv_report_v2(report_id: int, token = Depends(get_token)) -> HTMLResponse:
csv_contents = export_csv_report_v1(report_id, owner=Sessions.get(token=token).user)

response = HTMLResponse(content=csv_contents, media_type='text/csv')
response.headers['Content-Disposition'] = f'attachment; filename="invoice_validation_report_{report_id}.csv"'

return response

@app.post("/report/wellformedness/v1", tags=["report"])
async def report_wellformedness(file: UploadFile = File(...)) -> Evaluation:
invoice_text = await file.read()
Expand All @@ -255,10 +298,18 @@ async def report_peppol(file: UploadFile = File(...)) -> Evaluation:
async def report_list_all() -> ReportIDs:
return report_list_all_v1()

@app.get("/report/list_all/v2", tags=["report"])
async def report_list_all_v2(token = Depends(get_token)) -> ReportIDs:
return report_list_all_v1(owner=Sessions.get(token=token).user)

@app.get("/report/list_by/v1", tags=["report"])
async def report_list_by(order_by: OrderBy) -> ReportIDs:
return report_list_by_v1(order_by)

@app.get("/report/list_by/v2", tags=["report"])
async def report_list_by_v2(order_by: OrderBy, token = Depends(get_token)) -> ReportIDs:
return report_list_by_v1(order_by, owner=Sessions.get(token=token).user)

@app.get("/report/check_validity/v1", tags=["report"])
async def invoice_check_validity(report_id: int) -> CheckValidReturn:
return invoice_check_validity_v1(report_id)
Expand All @@ -267,8 +318,6 @@ async def invoice_check_validity(report_id: int) -> CheckValidReturn:
async def report_lint(invoice: TextInvoice) -> LintReport:
return report_lint_v1(invoice_text=invoice.text)

### Below to be replaced with proper authentication system ###

@app.put("/report/change_name/v2", tags=["report"])
async def report_change_name(report_id: int, new_name: str, token: str = Depends(get_token)) -> Dict[None, None]:
return report_change_name_v2(token, report_id, new_name)
Expand All @@ -285,6 +334,8 @@ async def auth_login(form_data: OAuth2PasswordRequestForm = Depends()):
async def auth_register(email: str, password: str) -> AuthReturnV2:
return auth_register_v2(email, password)

# Not in schema

@app.post("/invoice/generate_hash/v2", include_in_schema=False)
async def invoice_generate_hash(invoice_text: TextInvoice) -> str:
return invoice_generate_hash_v1(invoice_text)
Expand Down
22 changes: 18 additions & 4 deletions src/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,30 @@ def report_peppol_v1(invoice_text: str) -> Evaluation:

return Evaluation(**evaluation.to_json())

def report_list_all_v1() -> ReportIDs:
return ReportIDs(report_ids=[report.id for report in Reports.select()])
def report_list_all_v1(owner=None) -> ReportIDs:
report_ids = []
for report in Reports.select():
if owner == None and report.owner == None:
report_ids.append(report.id)
elif report.owner == owner:
report_ids.append(report.id)

return ReportIDs(report_ids=report_ids)

def report_list_by_v1(order_by: OrderBy) -> ReportIDs:
def report_list_by_v1(order_by: OrderBy, owner=None) -> ReportIDs:
if order_by.is_ascending:
order = getattr(Reports, order_by.table).asc()
else:
order = getattr(Reports, order_by.table).desc()

report_ids = []
for report in Reports.select().order_by(order):
if owner == None and report.owner == None:
report_ids.append(report.id)
elif report.owner == owner:
report_ids.append(rereport.idport)

return ReportIDs(report_ids=[report.id for report in Reports.select().order_by(order)])
return ReportIDs(report_ids=report_ids)

def report_change_name_v2(token: str, report_id: int, new_name: str) -> Dict[None, None]:
if report_id < 0:
Expand Down
12 changes: 6 additions & 6 deletions tests/report/change_name_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from src.type_structure import *
from tests.server_calls import report_change_name_v2, export_json_report_v1, invoice_upload_text_v2, auth_register_v2, clear_v1
from tests.server_calls import report_change_name_v2, export_json_report_v2, invoice_upload_text_v2, auth_register_v2, clear_v1
from tests.constants import VALID_INVOICE_TEXT

"""
Expand All @@ -14,13 +14,13 @@ def test_change_name_valid():
invoice = TextInvoice(name="My Invoice", text=VALID_INVOICE_TEXT)
report_id = invoice_upload_text_v2(token, invoice.name, invoice.text)["report_id"]

report = Report(**export_json_report_v1(report_id))
report = Report(**export_json_report_v2(token, report_id))

# Checking for the old name of the invoice
assert report.invoice_name == "My Invoice"

report_change_name_v2(token, report_id, "New Name")
report = Report(**export_json_report_v1(report_id))
report = Report(**export_json_report_v2(token, report_id))

# Checking for the new name of the invoice
assert report.invoice_name == "New Name"
Expand All @@ -46,7 +46,7 @@ def test_change_name_long_invalid():
invoice = TextInvoice(name="My Invoice", text=VALID_INVOICE_TEXT)
report_id = invoice_upload_text_v2(token, invoice.name, invoice.text)["report_id"]

report = Report(**export_json_report_v1(report_id))
report = Report(**export_json_report_v2(token, report_id))

# Checking for the old name of the invoice
assert report.invoice_name == "My Invoice"
Expand All @@ -60,7 +60,7 @@ def test_change_name_invalid_report_id_negative():
invoice = TextInvoice(name="My Invoice", text=VALID_INVOICE_TEXT)
report_id = invoice_upload_text_v2(token, invoice.name, invoice.text)["report_id"]

report = Report(**export_json_report_v1(report_id))
report = Report(**export_json_report_v2(token, report_id))

# Checking for the old name of the invoice
assert report.invoice_name == "My Invoice"
Expand All @@ -73,7 +73,7 @@ def test_change_name_invalid_report_id_not_found():
invoice = TextInvoice(name="My Invoice", text=VALID_INVOICE_TEXT)
report_id = invoice_upload_text_v2(token, invoice.name, invoice.text)["report_id"]

report = Report(**export_json_report_v1(report_id))
report = Report(**export_json_report_v2(token, report_id))

# Checking for the old name of the invoice
assert report.invoice_name == "My Invoice"
Expand Down
Loading

0 comments on commit 744df63

Please sign in to comment.