diff --git a/src/authentication.py b/src/authentication.py index 7e64a3b..173b8be 100644 --- a/src/authentication.py +++ b/src/authentication.py @@ -1,4 +1,5 @@ -from src.database import Users, IntegrityError, DoesNotExist +from datetime import datetime, timedelta +from src.database import Users, Sessions, IntegrityError, DoesNotExist from src.helpers import string_in_range from src.error import InputError from src.type_structure import * @@ -9,7 +10,7 @@ Auth login and register functions. ''' -def auth_login_v1(email, password): +def auth_login_v1(email, password) -> AuthReturnV1: ''' This function uses user-inputted login data to log a user into the system. @@ -29,18 +30,15 @@ def auth_login_v1(email, password): try: user = Users.get(email=email) except DoesNotExist: - raise InputError(detail="Invalid input: No user with email " + email + ".") + raise InputError(detail="Invalid input: Incorrect email or password.") if user.password_hash != hashlib.sha256(password.encode("utf-8")).hexdigest(): - raise InputError(detail="Invalid input: Incorrect password.") + raise InputError(detail="Invalid input: Incorrect email or password.") + return AuthReturnV1(auth_user_id=user.id) - # data_store.start_token_session(data_store.encode_user_jwt(user_id)) - return {'auth_user_id' : user.id} - - -def auth_register_v1(email, password): +def auth_register_v1(email, password) -> AuthReturnV1: ''' This function registers a user into the system by getting their details. @@ -71,7 +69,6 @@ def auth_register_v1(email, password): raise InputError(detail="Invalid input: Password is too short.") # Generate password hash - # salt = data_store.gen_salt() password_hash = hashlib.sha256(password.encode("utf-8")).hexdigest() try: @@ -81,6 +78,20 @@ def auth_register_v1(email, password): raise InputError(detail="Invalid input: Email " + email + " is already taken.") # Return id once register is successful - return { - 'auth_user_id': user.id, - } + return AuthReturnV1(auth_user_id=user.id) + + +def auth_login_v2(email, password) -> AuthReturnV2: + id = auth_login_v1(email, password).auth_user_id + now = datetime.now() + token = hashlib.sha256(id.to_bytes(8, 'big') + now.strftime("%s").encode("utf-8")).hexdigest() + Sessions.create(user=id, token=token, date_created=now, date_expires=now + timedelta(days=1)) + return AuthReturnV2(token=token) + + +def auth_register_v2(email, password) -> AuthReturnV2: + id = auth_register_v1(email, password).auth_user_id + now = datetime.now() + token = hashlib.sha256(id.to_bytes(8, 'big') + now.strftime("%s").encode("utf-8")).hexdigest() + Sessions.create(user=id, token=token, date_created=now, date_expires=now + timedelta(days=1)) + return AuthReturnV2(token=token) \ No newline at end of file diff --git a/src/database.py b/src/database.py index 35d863c..ccffdf3 100644 --- a/src/database.py +++ b/src/database.py @@ -55,6 +55,7 @@ class Reports(BaseModel): schema = ForeignKeyField(Evaluations, backref='schema', null=True, default=None) syntax = ForeignKeyField(Evaluations, backref='syntax', null=True, default=None) peppol = ForeignKeyField(Evaluations, backref='peppol', null=True, default=None) + owner = ForeignKeyField(Users, backref='users', null=True) def to_json(self): return { diff --git a/src/error.py b/src/error.py index 7f8762f..cd442ce 100644 --- a/src/error.py +++ b/src/error.py @@ -5,9 +5,14 @@ def __init__(self, detail: str): self.status_code = 400 self.detail = detail -class TokenError(Exception): +class UnauthorisedError(Exception): def __init__(self, detail: str): - self.status_code = 402 + self.status_code = 401 + self.detail = detail + +class ForbiddenError(Exception): + def __init__(self, detail: str): + self.status_code = 403 self.detail = detail class NotFoundError(Exception): diff --git a/src/export.py b/src/export.py index af1c073..110c0ae 100644 --- a/src/export.py +++ b/src/export.py @@ -9,7 +9,7 @@ from src.error import * -def export_json_report_v1(report_id: int): +def export_json_report_v1(report_id: int, owner=None): if report_id < 0: raise InputError(detail="Report id cannot be less than 0") @@ -18,9 +18,12 @@ def export_json_report_v1(report_id: int): except DoesNotExist: raise NotFoundError(detail=f"Report with id {report_id} not found") + if report.owner != None and report.owner != owner: + raise ForbiddenError("You do not have permission to view report") + return Report(**report.to_json()) -def export_pdf_report_v1(report_id: int) -> bytes: +def export_pdf_report_v1(report_id: int, owner=None) -> bytes: if report_id < 0: raise InputError(detail="Report id cannot be less than 0") @@ -29,6 +32,9 @@ def export_pdf_report_v1(report_id: int) -> bytes: except DoesNotExist: raise NotFoundError(detail=f"Report with id {report_id} not found") + if report.owner != None and report.owner != owner: + raise ForbiddenError("You do not have permission to view report") + html = export_html_report_v1(report_id) pdf_bytes = HTML(string=html).write_pdf() @@ -73,7 +79,7 @@ def add_violations(soup, violations, parent): location_string = "Line " + str(violation["line"]) + ", Column " + str(violation["column"]) v.find("code", {"name": "location"}).string = location_string -def export_html_report_v1(report_id: int): +def export_html_report_v1(report_id: int, owner=None): if report_id < 0: raise InputError(detail="Report id cannot be less than 0") @@ -82,6 +88,9 @@ def export_html_report_v1(report_id: int): except DoesNotExist: raise NotFoundError(detail=f"Report with id {report_id} not found") + if report.owner != None and report.owner != owner: + raise ForbiddenError("You do not have permission to view report") + report = report.to_json() with open("src/report_template.html", "r") as file: @@ -152,7 +161,7 @@ def write_violations(writer, violations): ] writer.writerow(data) -def export_csv_report_v1(report_id: int): +def export_csv_report_v1(report_id: int, owner=None): if report_id < 0: raise InputError(detail="Report id cannot be less than 0") @@ -161,6 +170,9 @@ def export_csv_report_v1(report_id: int): except DoesNotExist: raise NotFoundError(detail=f"Report with id {report_id} not found") + if report.owner != None and report.owner != owner: + raise ForbiddenError("You do not have permission to view report") + report = report.to_json() csv_buffer = StringIO() @@ -185,17 +197,15 @@ def export_csv_report_v1(report_id: int): return csv_contents -def report_bulk_export_json_v1(report_ids) -> List: - return { - "reports": [export_json_report_v1(report_id) for report_id in report_ids] - } +def report_bulk_export_json_v1(report_ids, owner=None) -> ReportList: + return ReportList(reports=[export_json_report_v1(report_id, owner) for report_id in report_ids]) -def report_bulk_export_pdf_v1(report_ids) -> BytesIO: +def report_bulk_export_pdf_v1(report_ids, owner=None) -> BytesIO: reports = BytesIO() with ZipFile(reports, 'w', ZIP_DEFLATED) as f: for report_id in report_ids: - f.writestr(f"invoice_validation_report_{report_id}.pdf", export_pdf_report_v1(report_id)) + f.writestr(f"invoice_validation_report_{report_id}.pdf", export_pdf_report_v1(report_id, owner)) reports.seek(0) diff --git a/src/generation.py b/src/generation.py index c61bea2..d4f6953 100644 --- a/src/generation.py +++ b/src/generation.py @@ -18,14 +18,14 @@ def generate_schema_evaluation(invoice_text: str) -> Evaluations: def generate_parser_evaluation(violations) -> Evaluations: evaluation = Evaluations.create( - is_valid=True if len(violations) == 0 else False, + is_valid=len(violations) == 0, num_warnings=0, num_errors=len(violations), num_rules_failed=len(violations) ) for violation in violations: - violation.evaluation = evaluation.id + violation.evaluation = evaluation.id #type: ignore violation.save() return evaluation @@ -52,7 +52,7 @@ def generate_xslt_evaluation(executable, invoice_text) -> Evaluations: return evaluation -def generate_report(invoice_name: str, invoice_text: str) -> int: +def generate_report(invoice_name: str, invoice_text: str, owner) -> int: wellformedness_evaluation = None schema_evaluation = None syntax_evaluation = None @@ -89,12 +89,13 @@ def generate_report(invoice_name: str, invoice_text: str) -> int: wellformedness=wellformedness_evaluation.id if wellformedness_evaluation else None, schema=schema_evaluation.id if schema_evaluation else None, syntax=syntax_evaluation.id if syntax_evaluation else None, - peppol=peppol_evaluation.id if peppol_evaluation else None + peppol=peppol_evaluation.id if peppol_evaluation else None, + owner=owner ) return report.id -def generate_diagnostic_list(invoice_text: str) -> int: +def generate_diagnostic_list(invoice_text: str) -> List[LintDiagnostic]: report = [] wellformedness_violations = get_wellformedness_violations(invoice_text) @@ -106,6 +107,7 @@ def generate_diagnostic_list(invoice_text: str) -> int: column=violation.column, xpath=violation.xpath, message=violation.message, + suggestion=None, severity="error" if violation.is_fatal else "warning" )) @@ -121,6 +123,7 @@ def generate_diagnostic_list(invoice_text: str) -> int: column=violation.column, xpath=violation.xpath, message=violation.message, + suggestion=None, severity="error" if violation.is_fatal else "warning" )) diff --git a/src/invoice.py b/src/invoice.py index 0e19999..0089676 100644 --- a/src/invoice.py +++ b/src/invoice.py @@ -5,16 +5,14 @@ from src.generation import generate_report -def invoice_upload_text_v1(invoice_name: str, invoice_text: str): +def invoice_upload_text_v1(invoice_name: str, invoice_text: str, owner = None) -> ReportID: if len(invoice_name) > 100: raise InputError(detail="Name cannot be longer than 100 characters") - return { - "report_id": generate_report(invoice_name, invoice_text) - } + return ReportID(report_id=generate_report(invoice_name, invoice_text, owner)) -def invoice_upload_url_v1(invoice_name: str, invoice_url: str): +def invoice_upload_url_v1(invoice_name: str, invoice_url: str, owner = None): if len(invoice_name) > 100: raise InputError(detail="Name cannot be longer than 100 characters") @@ -27,21 +25,15 @@ def invoice_upload_url_v1(invoice_name: str, invoice_url: str): raise InputError(detail="URL does not point to plain text or XML data") invoice_text = response.text - - report_id = generate_report(invoice_name, invoice_text) - return { - "report_id": report_id - } + return ReportID(report_id=generate_report(invoice_name, invoice_text, owner)) -def invoice_upload_file_v1(invoice_name: str, invoice_text: str): +def invoice_upload_file_v1(invoice_name: str, invoice_text: str, owner = None) -> ReportID: if not invoice_name.endswith('.xml'): raise InputError(detail="Invoice file type is not XML") - return { - "report_id": generate_report(invoice_name, invoice_text) - } + return ReportID(report_id=generate_report(invoice_name, invoice_text, owner)) def invoice_check_validity_v1(report_id: int) -> CheckValidReturn: if report_id < 0: @@ -55,7 +47,7 @@ def invoice_check_validity_v1(report_id: int) -> CheckValidReturn: return CheckValidReturn(is_valid=report.is_valid) def invoice_generate_hash_v1(invoice: TextInvoice) -> str: - return {} + return "" -def invoice_upload_bulk_text_v1(invoices: List[TextInvoice]) -> ReportIDs: - return ReportIDs(report_ids=[generate_report(invoice.name, invoice.text) for invoice in invoices]) +def invoice_upload_bulk_text_v1(invoices: List[TextInvoice], owner = None) -> ReportIDs: + return ReportIDs(report_ids=[generate_report(invoice.name, invoice.text, owner) for invoice in invoices]) diff --git a/src/main.py b/src/main.py index 562581c..7adb591 100644 --- a/src/main.py +++ b/src/main.py @@ -1,4 +1,3 @@ -import signal from src.config import base_url, port from src.health_check import health_check_v1 from src.report import * @@ -8,8 +7,9 @@ from src.authentication import * from src.type_structure import * from src.database import clear_v1 -from fastapi import FastAPI, Request, HTTPException, UploadFile, File -from fastapi.responses import Response, JSONResponse, HTMLResponse, StreamingResponse +from fastapi import Depends, FastAPI, Request,UploadFile, File +from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm +from fastapi.responses import JSONResponse, HTMLResponse, StreamingResponse from fastapi.middleware.cors import CORSMiddleware from io import BytesIO import uvicorn @@ -35,6 +35,7 @@ }, ] + app = FastAPI(title="CHURROS VALIDATION API", description=description, version="0.0.1", @@ -61,19 +62,30 @@ async def input_error_exception_handler(request: Request, exc: InputError): }, ) -@app.exception_handler(TokenError) -async def input_error_exception_handler(request: Request, exc: TokenError): +@app.exception_handler(UnauthorisedError) +async def authorization_error_exception_handler(request: Request, exc: UnauthorisedError): + return JSONResponse( + status_code=401, + content={ + "code": 401, + "name": "Unauthorised Error", + "detail": exc.detail + }, + ) + +@app.exception_handler(ForbiddenError) +async def forbidden_error_exception_handler(request: Request, exc: ForbiddenError): return JSONResponse( - status_code=402, + status_code=403, content={ - "code": 402, - "name": "Token Error", + "code": 403, + "name": "Forbidden Error", "detail": exc.detail }, ) @app.exception_handler(NotFoundError) -async def input_error_exception_handler(request: Request, exc: NotFoundError): +async def not_found_error_exception_handler(request: Request, exc: NotFoundError): return JSONResponse( status_code=404, content={ @@ -94,6 +106,24 @@ async def validation_exception_handler(request: Request, exc: InternalServerErro }, ) +# token validation below + +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth_login/v2") + +async def get_token(token: str = Depends(oauth2_scheme)) -> str: + if token == ADMIN_TOKEN: + return token + + try: + session = Sessions.get(token=token) + except DoesNotExist: + raise UnauthorisedError("Invalid token, please login/register") + + if session.date_expires < datetime.now(): + raise UnauthorisedError("Expired token, please login again") + + return session.token + # ENDPOINTS BELOW @app.get("/") @@ -107,7 +137,12 @@ async def health_check(): @app.post("/invoice/upload_file/v1", tags=["invoice"]) async def invoice_upload_file(file: UploadFile = File(...)) -> ReportID: invoice_text = await file.read() - return invoice_upload_file_v1(invoice_name=file.filename, invoice_text=invoice_text.decode("utf-8")) + return invoice_upload_file_v1(invoice_name=file.filename, invoice_text=invoice_text.decode("utf-8")) #type: ignore + +@app.post("/invoice/upload_file/v2", tags=["invoice"]) +async def invoice_upload_file_v2(file: UploadFile = File(...), token = Depends(get_token)) -> ReportID: + invoice_text = await file.read() + return invoice_upload_file_v1(invoice_name=file.filename, invoice_text=invoice_text.decode("utf-8"), owner=Sessions.get(token=token).user) #type: ignore @app.post("/invoice/bulk_upload_file/v1", tags=["invoice"]) async def invoice_bulk_upload_file(files: List[UploadFile] = File(...)) -> ReportIDs: @@ -115,31 +150,62 @@ async def invoice_bulk_upload_file(files: List[UploadFile] = File(...)) -> Repor for file in files: invoice_text = await file.read() - invoice = TextInvoice(name=file.filename, text=invoice_text.decode("utf-8")) + invoice = TextInvoice(name=file.filename, text=invoice_text.decode("utf-8")) #type: ignore invoices.append(invoice) return invoice_upload_bulk_text_v1(invoices) +@app.post("/invoice/bulk_upload_file/v2", tags=["invoice"]) +async def invoice_bulk_upload_file_v2(files: List[UploadFile] = File(...), token = Depends(get_token)) -> ReportIDs: + invoices = [] + + for file in files: + invoice_text = await file.read() + invoice = TextInvoice(name=file.filename, text=invoice_text.decode("utf-8")) #type: ignore + invoices.append(invoice) + + return invoice_upload_bulk_text_v1(invoices, owner=Sessions.get(token=token).user) + @app.post("/invoice/upload_text/v1", tags=["invoice"]) async def invoice_upload_text(invoice: TextInvoice) -> ReportID: return invoice_upload_text_v1(invoice_name=invoice.name, invoice_text=invoice.text) +@app.post("/invoice/upload_text/v2", tags=["invoice"]) +async def invoice_upload_text_v2(invoice: TextInvoice, token = Depends(get_token)) -> ReportID: + return invoice_upload_text_v1(invoice_name=invoice.name, invoice_text=invoice.text, owner=Sessions.get(token=token).user) + @app.post("/invoice/bulk_upload_text/v1", tags=["invoice"]) async def invoice_upload_bulk_text(invoices: List[TextInvoice]) -> ReportIDs: return invoice_upload_bulk_text_v1(invoices) +@app.post("/invoice/bulk_upload_text/v2", tags=["invoice"]) +async def invoice_upload_bulk_text_v2(invoices: List[TextInvoice], token = Depends(get_token)) -> ReportIDs: + return invoice_upload_bulk_text_v1(invoices, owner=Sessions.get(token=token).user) + @app.post("/invoice/upload_url/v1", tags=["invoice"]) async def invoice_upload_url(invoice: RemoteInvoice) -> ReportID: return invoice_upload_url_v1(invoice_name=invoice.name, invoice_url=invoice.url) +@app.post("/invoice/upload_url/v2", tags=["invoice"]) +async def invoice_upload_url_v2(invoice: RemoteInvoice, token = Depends(get_token)) -> ReportID: + return invoice_upload_url_v1(invoice_name=invoice.name, invoice_url=invoice.url, owner=Sessions.get(token=token).user) + @app.get("/export/json_report/v1", tags=["export"]) async def export_json_report(report_id: int) -> Report: return export_json_report_v1(report_id) +@app.get("/export/json_report/v2", tags=["export"]) +async def export_json_report_v2(report_id: int, token = Depends(get_token)) -> Report: + return export_json_report_v1(report_id, owner=Sessions.get(token=token).user) + @app.post("/export/bulk_json_reports/v1", tags=["export"]) async def report_bulk_export_json(report_ids: List[int]) -> ReportList: return report_bulk_export_json_v1(report_ids) +@app.post("/export/bulk_json_reports/v2", tags=["export"]) +async def report_bulk_export_json_v2(report_ids: List[int], token = Depends(get_token)) -> ReportList: + return report_bulk_export_json_v1(report_ids, owner=Sessions.get(token=token).user) + @app.get("/export/pdf_report/v1", tags=["export"]) async def export_pdf_report(report_id: int) -> StreamingResponse: pdf_file = BytesIO(export_pdf_report_v1(report_id)) @@ -151,6 +217,17 @@ async def export_pdf_report(report_id: int) -> StreamingResponse: } return StreamingResponse(pdf_file, headers=headers) +@app.get("/export/pdf_report/v2", tags=["export"]) +async def export_pdf_report_v2(report_id: int, token = Depends(get_token)) -> StreamingResponse: + pdf_file = BytesIO(export_pdf_report_v1(report_id, owner=Sessions.get(token=token).user)) + + # Return the PDF as a streaming response + headers = { + "Content-Disposition": f"attachment; filename=invoice_validation_report_{report_id}.pdf", + "Content-Type": "application/pdf", + } + return StreamingResponse(pdf_file, headers=headers) + @app.post("/export/bulk_pdf_reports/v1", tags=["export"]) async def report_bulk_export_pdf(report_ids: List[int]) -> StreamingResponse: reports_zip = report_bulk_export_pdf_v1(report_ids) @@ -161,11 +238,26 @@ async def report_bulk_export_pdf(report_ids: List[int]) -> StreamingResponse: headers = { "Content-Disposition": f"attachment; filename=reports.zip"} ) +@app.post("/export/bulk_pdf_reports/v2", tags=["export"]) +async def report_bulk_export_pdf_v2(report_ids: List[int], token = Depends(get_token)) -> StreamingResponse: + reports_zip = report_bulk_export_pdf_v1(report_ids, owner=Sessions.get(token=token).user) + + return StreamingResponse( + reports_zip, + media_type="application/x-zip-compressed", + headers = { "Content-Disposition": f"attachment; filename=reports.zip"} + ) + @app.get("/export/html_report/v1", response_class=HTMLResponse, tags=["export"]) async def export_html_report(report_id: int) -> HTMLResponse: html_content = export_html_report_v1(report_id) return HTMLResponse(content=html_content, status_code=200) +@app.get("/export/html_report/v2", response_class=HTMLResponse, tags=["export"]) +async def export_html_report_v2(report_id: int, token = Depends(get_token)) -> HTMLResponse: + html_content = export_html_report_v1(report_id, owner=Sessions.get(token=token).user) + return HTMLResponse(content=html_content, status_code=200) + @app.get("/export/csv_report/v1", tags=["export"]) async def export_csv_report(report_id: int) -> HTMLResponse: csv_contents = export_csv_report_v1(report_id) @@ -175,6 +267,7 @@ async def export_csv_report(report_id: int) -> HTMLResponse: return response + @app.post("/report/send_email/v2", tags=["report"]) async def send_email_report(email, report_id): pdf_file = BytesIO(export_pdf_report_v1(report_id)).read() @@ -188,6 +281,15 @@ async def send_email_report(email, report_id): }, ) +@app.get("/export/csv_report/v2", tags=["export"]) +async def export_csv_report_v2(report_id: int, token = Depends(get_token)) -> HTMLResponse: + csv_contents = export_csv_report_v1(report_id, owner=Sessions.get(token=token).user) + + response = HTMLResponse(content=csv_contents, media_type='text/csv') + response.headers['Content-Disposition'] = f'attachment; filename="invoice_validation_report_{report_id}.csv"' + + return response + @app.post("/report/wellformedness/v1", tags=["report"]) async def report_wellformedness(file: UploadFile = File(...)) -> Evaluation: invoice_text = await file.read() @@ -212,10 +314,18 @@ async def report_peppol(file: UploadFile = File(...)) -> Evaluation: async def report_list_all() -> ReportIDs: return report_list_all_v1() +@app.get("/report/list_all/v2", tags=["report"]) +async def report_list_all_v2(token = Depends(get_token)) -> ReportIDs: + return report_list_all_v1(owner=Sessions.get(token=token).user) + @app.get("/report/list_by/v1", tags=["report"]) async def report_list_by(order_by: OrderBy) -> ReportIDs: return report_list_by_v1(order_by) +@app.get("/report/list_by/v2", tags=["report"]) +async def report_list_by_v2(order_by: OrderBy, token = Depends(get_token)) -> ReportIDs: + return report_list_by_v1(order_by, owner=Sessions.get(token=token).user) + @app.get("/report/check_validity/v1", tags=["report"]) async def invoice_check_validity(report_id: int) -> CheckValidReturn: return invoice_check_validity_v1(report_id) @@ -224,26 +334,26 @@ async def invoice_check_validity(report_id: int) -> CheckValidReturn: async def report_lint(invoice: TextInvoice) -> LintReport: return report_lint_v1(invoice_text=invoice.text) -### Below to be replaced with proper authentication system ### +@app.put("/report/change_name/v2", tags=["report"]) +async def report_change_name(report_id: int, new_name: str, token: str = Depends(get_token)) -> Dict[None, None]: + return report_change_name_v2(token, report_id, new_name) -@app.put("/report/change_name/v2", include_in_schema=False) -async def report_change_name(token: str, report_id: int, new_name: str) -> Dict[None, None]: - return report_change_name_v1(token, report_id, new_name) +@app.delete("/report/delete/v2", tags=["report"]) +async def report_delete(report_id: int, token: str = Depends(get_token)) -> Dict[None, None]: + return report_delete_v2(token, report_id) -@app.delete("/report/delete/v2", include_in_schema=False) -async def report_delete(token: str, report_id: int) -> Dict[None, None]: - return report_delete_v1(token, report_id) +@app.post("/auth_login/v2", tags=["auth"]) +async def auth_login(form_data: OAuth2PasswordRequestForm = Depends()): + return Token(access_token=auth_login_v2(form_data.username, form_data.password).token, token_type="bearer") -@app.get("/auth_login/v2", include_in_schema=False) -async def auth_login(email: str, password: str): - return auth_login_v1(email, password) +@app.post("/auth_register/v2", tags=["auth"]) +async def auth_register(email: str, password: str) -> AuthReturnV2: + return auth_register_v2(email, password) -@app.get("/auth_register/v2", include_in_schema=False) -async def auth_register(email: str, password: str): - return auth_register_v1(email, password) +# Not in schema @app.post("/invoice/generate_hash/v2", include_in_schema=False) -async def invoice_generate_hash(invoice_text: str) -> str: +async def invoice_generate_hash(invoice_text: TextInvoice) -> str: return invoice_generate_hash_v1(invoice_text) @app.delete("/clear/v1", include_in_schema=False) diff --git a/src/report.py b/src/report.py index 5314cb2..d1f2f04 100644 --- a/src/report.py +++ b/src/report.py @@ -1,6 +1,6 @@ -from src.type_structure import * from typing import Dict -from src.database import Reports +from src.type_structure import * +from src.database import Reports, Sessions from src.generation import generate_xslt_evaluation, generate_schema_evaluation, generate_wellformedness_evaluation, generate_diagnostic_list from peewee import DoesNotExist from src.constants import ADMIN_TOKEN, PEPPOL_EXECUTABLE, SYNTAX_EXECUTABLE @@ -36,24 +36,32 @@ def report_peppol_v1(invoice_text: str) -> Evaluation: return Evaluation(**evaluation.to_json()) -def report_list_all_v1() -> List[int]: - return ReportIDs(report_ids=[report.id for report in Reports.select()]) +def report_list_all_v1(owner=None) -> ReportIDs: + report_ids = [] + for report in Reports.select(): + if owner == None and report.owner == None: + report_ids.append(report.id) + elif report.owner == owner: + report_ids.append(report.id) + + return ReportIDs(report_ids=report_ids) -def report_list_by_v1(order_by: OrderBy) -> List[int]: +def report_list_by_v1(order_by: OrderBy, owner=None) -> ReportIDs: if order_by.is_ascending: order = getattr(Reports, order_by.table).asc() else: order = getattr(Reports, order_by.table).desc() - - return ReportIDs(report_ids=[report.id for report in Reports.select().order_by(order)]) + + report_ids = [] + for report in Reports.select().order_by(order): + if owner == None and report.owner == None: + report_ids.append(report.id) + elif report.owner == owner: + report_ids.append(rereport.idport) + + return ReportIDs(report_ids=report_ids) -def report_change_name_v1(token: str, report_id: int, new_name: str) -> Dict[None, None]: - if len(new_name) > 100: - raise InputError(detail="New name is longer than 100 characters") - - if not token == ADMIN_TOKEN: - raise InputError(detail="Only admins can change the names of reports at the moment") - +def report_change_name_v2(token: str, report_id: int, new_name: str) -> Dict[None, None]: if report_id < 0: raise InputError(detail="Report id cannot be less than 0") @@ -62,22 +70,39 @@ def report_change_name_v1(token: str, report_id: int, new_name: str) -> Dict[Non except DoesNotExist: raise NotFoundError(detail=f"Report with id {report_id} not found") + if not token == ADMIN_TOKEN: + try: + session = Sessions.get(token=token) + except DoesNotExist: + raise UnauthorisedError("Invalid token") + if not report.owner == session.user: + raise ForbiddenError("You do not have permission to rename this report") + + if len(new_name) > 100: + raise InputError(detail="New name is longer than 100 characters") + report.invoice_name = new_name report.save() return {} -def report_delete_v1(token: str, report_id: int) -> Dict[None, None]: +def report_delete_v2(token: str, report_id: int) -> Dict[None, None]: if report_id < 0: raise InputError(detail="Report id cannot be less than 0") - if not token == ADMIN_TOKEN: - raise InputError(detail="Only admins can change the names of reports at the moment") - try: report = Reports.get_by_id(report_id) except DoesNotExist: raise NotFoundError(detail=f"Report with id {report_id} not found") + + if not token == ADMIN_TOKEN: + try: + session = Sessions.get(token=token) + except DoesNotExist: + raise UnauthorisedError("Invalid token") + + if not report.owner == session.user: + raise ForbiddenError("You do not have permission to delete this report") report.delete_instance() diff --git a/src/type_structure.py b/src/type_structure.py index be0bf46..43651fc 100644 --- a/src/type_structure.py +++ b/src/type_structure.py @@ -51,6 +51,10 @@ class Report(BaseModel): syntax_evaluation: Union[Evaluation, None] peppol_evaluation: Union[Evaluation, None] +class Token(BaseModel): + access_token: str + token_type: str + class ReportList(BaseModel): reports: List[Report] @@ -78,3 +82,9 @@ class LintDiagnostic(BaseModel): class LintReport(BaseModel): report: List[LintDiagnostic] + +class AuthReturnV1(BaseModel): + auth_user_id: int + +class AuthReturnV2(BaseModel): + token: str diff --git a/tests/authentication/login_test.py b/tests/authentication/login_test.py index a764556..5f87d24 100644 --- a/tests/authentication/login_test.py +++ b/tests/authentication/login_test.py @@ -1,5 +1,5 @@ from tests.server_calls import auth_login_v2, auth_register_v2, clear_v1 - +from time import sleep """ ============================================================== AUTH_LOGIN_V1 TESTS @@ -9,19 +9,21 @@ # Succesful login def test_login_success(): clear_v1() - # Register and login functions should return same id for same user + # Register and login functions should return different tokens reg_return_value = auth_register_v2("test@test.com", "password") + # 1 second sleep to allow for a time difference between generating tokens + sleep(1) login_return_value = auth_login_v2("test@test.com", "password") print(login_return_value) - assert reg_return_value["auth_user_id"] == login_return_value["auth_user_id"] + assert reg_return_value["token"] != login_return_value["access_token"] def test_login_multiple_success(): clear_v1() - reg_return_value_1 = auth_register_v2("test@test.com", "password")["auth_user_id"] + reg_return_value_1 = auth_register_v2("test@test.com", "password")["token"] # First user registered and logged in assert reg_return_value_1 - reg_return_value_2 = auth_register_v2("test1@test.com", "password")["auth_user_id"] + reg_return_value_2 = auth_register_v2("test1@test.com", "password")["token"] # Second user registered and logged in assert reg_return_value_2 @@ -31,19 +33,19 @@ def test_login_multiple_success(): def test_login_incorrect_email(): clear_v1() auth_register_v2("test@test.com", "password") - assert auth_login_v2("test2@test.com", "password")['detail'] == "Invalid input: No user with email test2@test.com." + assert auth_login_v2("test2@test.com", "password")['detail'] == "Invalid input: Incorrect email or password." def test_login_incorrect_email_and_password(): clear_v1() auth_register_v2("test@test.com", "password") - assert auth_login_v2("test@test.com", "efef")['detail'] == "Invalid input: Incorrect password." + assert auth_login_v2("test@test.com", "efef")['detail'] == "Invalid input: Incorrect email or password." def test_login_incorrect_password(): clear_v1() # Password is incorrect auth_register_v2("test@test.com", "password") - assert auth_login_v2("test@test.com", "eeffef")['detail'] == "Invalid input: Incorrect password." + assert auth_login_v2("test@test.com", "eeffef")['detail'] == "Invalid input: Incorrect email or password." # Password is incorrect (and empty) auth_register_v2("test1@test.com", "password") - assert auth_login_v2("test1@test.com", "")['detail'] == "Invalid input: Incorrect password." \ No newline at end of file + assert auth_login_v2("test1@test.com", "")['detail'][0]['msg'] == "field required" \ No newline at end of file diff --git a/tests/authentication/register_test.py b/tests/authentication/register_test.py index 99b7d4a..b5b431c 100644 --- a/tests/authentication/register_test.py +++ b/tests/authentication/register_test.py @@ -9,21 +9,21 @@ # Test single registers with valid emails def test_register_unique_id_valid(): clear_v1() - auth_user1 = auth_register_v2("test@test.com", "luciddreams14") - auth_user2 = auth_register_v2("test1@test.com", "luciddreams14") - print(auth_user1) - # Testing if user ID is unique - assert auth_user1["auth_user_id"] != auth_user2["auth_user_id"] - assert len(auth_user1) == 1 + token1 = auth_register_v2("test@test.com", "luciddreams14") + token2 = auth_register_v2("test1@test.com", "luciddreams14") + print(token1) + # Testing if tokens are unique + assert token1["token"] != token2["token"] + assert len(token1) == 1 # Test multiple registers def test_register_multiple_success(): clear_v1() - auth_user1 =auth_register_v2("test@test.com", "www.www")["auth_user_id"] - auth_user2 =auth_register_v2("test1@test.com", "lisbon2424")["auth_user_id"] - auth_user3 =auth_register_v2("test2@test.com", "janedoe")["auth_user_id"] - auth_user4 =auth_register_v2("test3@test.com", "knittingislife")["auth_user_id"] - assert auth_user1 != auth_user2 != auth_user3 != auth_user4 + token1 =auth_register_v2("test@test.com", "www.www")["token"] + token2 =auth_register_v2("test1@test.com", "lisbon2424")["token"] + token3 =auth_register_v2("test2@test.com", "janedoe")["token"] + token4 =auth_register_v2("test3@test.com", "knittingislife")["token"] + assert token1 != token2 != token3 != token4 # Test Input errors for invalid email - failing regex match def test_register_invalid_email(): diff --git a/tests/bulk/bulk_export_test.py b/tests/bulk/bulk_export_test.py index ce30432..c0274ab 100644 --- a/tests/bulk/bulk_export_test.py +++ b/tests/bulk/bulk_export_test.py @@ -11,21 +11,21 @@ def test_bulk_export_valid(): data = VALID_INVOICE_TEXT - invoice_valid = TextInvoice(name="My Invoice", source="text", text=data) + invoice_valid = TextInvoice(name="My Invoice", text=data) data = invalidate_invoice(VALID_INVOICE_TEXT, "tag", "cac:BillingReference", "", "cac:BillingReferencee", 1) data = invalidate_invoice(data, "tag", "cac:BillingReference", "", "cac:BillingReferencee", 1) - invoice_schema = TextInvoice(name="My Invoice", source="text", text=data) + invoice_schema = TextInvoice(name="My Invoice", text=data) data = invalidate_invoice(VALID_INVOICE_TEXT, 'content', 'cbc:EndpointID', '', 'Not an ABN', 1) - invoice_peppol = TextInvoice(name="My Invoice", source="text", text=data) + invoice_peppol = TextInvoice(name="My Invoice", text=data) data = invalidate_invoice(VALID_INVOICE_TEXT, 'attrib', 'cbc:Amount', 'currencyID', 'TEST', 1) - invoice_syntax = TextInvoice(name="My Invoice", source="text", text=data) + invoice_syntax = TextInvoice(name="My Invoice", text=data) data = replace_part_of_string(VALID_INVOICE_TEXT, 2025, 2027, "id") - invoice_wellformedness = TextInvoice(name="My Invoice", source="text", text=data) + invoice_wellformedness = TextInvoice(name="My Invoice", text=data) report_ids = [] for invoice in [invoice_valid, invoice_schema, invoice_peppol, invoice_syntax, invoice_wellformedness]: diff --git a/tests/bulk/bulk_upload_test.py b/tests/bulk/bulk_upload_test.py index 84bb0e5..de0cded 100644 --- a/tests/bulk/bulk_upload_test.py +++ b/tests/bulk/bulk_upload_test.py @@ -11,21 +11,21 @@ # def test_bulk_upload_valid(): # data = VALID_INVOICE_TEXT -# invoice_valid = TextInvoice(name="My Invoice", source="text", text=data) +# invoice_valid = TextInvoice(name="My Invoice", text=data) # data = invalidate_invoice(VALID_INVOICE_TEXT, "tag", "cac:BillingReference", "", "cac:BillingReferencee", 1) # data = invalidate_invoice(data, "tag", "cac:BillingReference", "", "cac:BillingReferencee", 1) -# invoice_schema = TextInvoice(name="My Invoice", source="text", text=data) +# invoice_schema = TextInvoice(name="My Invoice", text=data) # data = invalidate_invoice(VALID_INVOICE_TEXT, 'content', 'cbc:EndpointID', '', 'Not an ABN', 1) -# invoice_peppol = TextInvoice(name="My Invoice", source="text", text=data) +# invoice_peppol = TextInvoice(name="My Invoice", text=data) # data = invalidate_invoice(VALID_INVOICE_TEXT, 'attrib', 'cbc:Amount', 'currencyID', 'TEST', 1) -# invoice_syntax = TextInvoice(name="My Invoice", source="text", text=data) +# invoice_syntax = TextInvoice(name="My Invoice", text=data) # data = replace_part_of_string(VALID_INVOICE_TEXT, 2025, 2027, "id") -# invoice_wellformedness = TextInvoice(name="My Invoice", source="text", text=data) +# invoice_wellformedness = TextInvoice(name="My Invoice", text=data) # invoices = [invoice_valid, invoice_schema, invoice_peppol, invoice_syntax, invoice_wellformedness] diff --git a/tests/export/html_report_test.py b/tests/export/html_report_test.py index e9397ac..ec1ab47 100644 --- a/tests/export/html_report_test.py +++ b/tests/export/html_report_test.py @@ -12,7 +12,7 @@ # Testing that the report was generated properly and matches input data def test_html_valid_invoice(): - invoice = TextInvoice(name="My Invoice", source="text", text=VALID_INVOICE_TEXT) + invoice = TextInvoice(name="My Invoice", text=VALID_INVOICE_TEXT) report_id = invoice_upload_text_v1(invoice.name, invoice.text)["report_id"] report_bytes = export_html_report_v1(report_id) @@ -22,7 +22,7 @@ def test_html_valid_invoice(): def test_html_text_invalid_peppol_invoice(): data = invalidate_invoice(VALID_INVOICE_TEXT, 'content', 'cbc:EndpointID', '', 'Not an ABN', 1) - invoice = TextInvoice(name="My Invoice", source="text", text=data) + invoice = TextInvoice(name="My Invoice", text=data) report_id = invoice_upload_text_v1(invoice.name, invoice.text)["report_id"] report_bytes = export_html_report_v1(report_id) @@ -32,7 +32,7 @@ def test_html_text_invalid_peppol_invoice(): def test_html_text_invalid_wellformedness_invoice(): data = replace_part_of_string(VALID_INVOICE_TEXT, 2025, 2027, "id") - invoice = TextInvoice(name="My Invoice", source="text", text=data) + invoice = TextInvoice(name="My Invoice", text=data) report_id = invoice_upload_text_v1(invoice.name, invoice.text)["report_id"] report_bytes = export_html_report_v1(report_id) diff --git a/tests/export/json_report_test.py b/tests/export/json_report_test.py index 9ddb6f8..128ffa8 100644 --- a/tests/export/json_report_test.py +++ b/tests/export/json_report_test.py @@ -13,7 +13,7 @@ def test_json_valid_invoice(): data = VALID_INVOICE_TEXT - invoice = TextInvoice(name="My Invoice", source="text", text=data) + invoice = TextInvoice(name="My Invoice", text=data) report_id = invoice_upload_text_v1(invoice.name, invoice.text)["report_id"] report = export_json_report_v1(report_id) @@ -70,8 +70,8 @@ def test_json_unique_id(): data = VALID_INVOICE_TEXT # Creating 2 invoices - invoice1 = TextInvoice(name="Invoice01", source="text", text=data) - invoice2 = TextInvoice(name="Invoice02", source="text", text=data) + invoice1 = TextInvoice(name="Invoice01", text=data) + invoice2 = TextInvoice(name="Invoice02", text=data) # Creating 2 reports report_id1 = invoice_upload_text_v1(invoice1.name, invoice1.text)["report_id"] @@ -98,7 +98,7 @@ def test_json_single_violation(): # Invalidating the currency code data = invalidate_invoice(data, "attrib", "cbc:Amount", "currencyID", "TEST", 1) - invoice = TextInvoice(name="Invoice Test", source="text", text=data) + invoice = TextInvoice(name="Invoice Test", text=data) report_id = invoice_upload_text_v1(invoice.name, invoice.text)["report_id"] report = export_json_report_v1(report_id) @@ -150,7 +150,7 @@ def test_json_multiple_violations_same_rule(): data = invalidate_invoice(data, "content", "cbc:EndpointID", "", "Not an ABN 1", 1) data = invalidate_invoice(data, "content", "cbc:EndpointID", "", "Not an ABN 2", 2) - invoice = TextInvoice(name="Invoice Test", source="text", text=data) + invoice = TextInvoice(name="Invoice Test", text=data) report_id = invoice_upload_text_v1(invoice.name, invoice.text)["report_id"] report = export_json_report_v1(report_id) @@ -202,7 +202,7 @@ def test_json_multiple_violations_different_rules(): data = invalidate_invoice(data, 'content', 'cbc:IdentificationCode', '', 'TEST', 1) data = invalidate_invoice(data, 'content', 'cbc:IdentificationCode', '', 'TEST', 2) - invoice = TextInvoice(name="Invoice Test", source="text", text=data) + invoice = TextInvoice(name="Invoice Test", text=data) report_id = invoice_upload_text_v1(invoice.name, invoice.text)["report_id"] report = export_json_report_v1(report_id) @@ -250,7 +250,7 @@ def test_json_invalid_wellformedness(): # Removing a closing tag data = remove_part_of_string(data, 11530, 11540) - invoice = TextInvoice(name="Invoice Test", source="text", text=data) + invoice = TextInvoice(name="Invoice Test", text=data) report_id = invoice_upload_text_v1(invoice.name, invoice.text)["report_id"] report = export_json_report_v1(report_id) diff --git a/tests/export/pdf_report_test.py b/tests/export/pdf_report_test.py index 095a082..919173e 100644 --- a/tests/export/pdf_report_test.py +++ b/tests/export/pdf_report_test.py @@ -10,7 +10,7 @@ """ # Testing that the report was generated properly and matches input data def test_pdf_valid_invoice(): - invoice = TextInvoice(name="My Invoice", source="text", text=VALID_INVOICE_TEXT) + invoice = TextInvoice(name="My Invoice", text=VALID_INVOICE_TEXT) report_id = invoice_upload_text_v1(invoice.name, invoice.text)["report_id"] report_bytes = export_pdf_report_v1(report_id) diff --git a/tests/invoice/upload_text_test.py b/tests/invoice/upload_text_test.py index da899dc..52429a2 100644 --- a/tests/invoice/upload_text_test.py +++ b/tests/invoice/upload_text_test.py @@ -10,7 +10,7 @@ """ def test_upload_text_valid_invoice(): - invoice = TextInvoice(name="My Invoice", source="text", text=VALID_INVOICE_TEXT) + invoice = TextInvoice(name="My Invoice", text=VALID_INVOICE_TEXT) response = invoice_upload_text_v1(invoice.name, invoice.text) assert response['report_id'] >= 0 @@ -19,7 +19,7 @@ def test_upload_text_invalid_invoice(): # Invalidating the ABN, changing the content of the ABN data = invalidate_invoice(data, 'content', 'cbc:EndpointID', '', 'Not an ABN', 1) - invoice = TextInvoice(name="My Invoice", source="text", text=data) + invoice = TextInvoice(name="My Invoice", text=data) response = invoice_upload_text_v1(invoice.name, invoice.text) assert response['report_id'] >= 0 diff --git a/tests/report/change_name_test.py b/tests/report/change_name_test.py index da0c100..889b744 100644 --- a/tests/report/change_name_test.py +++ b/tests/report/change_name_test.py @@ -1,71 +1,81 @@ from src.type_structure import * -from tests.server_calls import report_change_name_invalid_token_v2, report_change_name_v2, export_json_report_v1, invoice_upload_text_v1 +from tests.server_calls import report_change_name_v2, export_json_report_v2, invoice_upload_text_v2, auth_register_v2, clear_v1 from tests.constants import VALID_INVOICE_TEXT -from tests.helpers import invalidate_invoice, remove_part_of_string """ ===================================== /report/change_name/v1 TESTS ===================================== """ - # Testing that the report was generated properly and matches input data -def test_change_name(): - invoice = TextInvoice(name="My Invoice", source="text", text=VALID_INVOICE_TEXT) - report_id = invoice_upload_text_v1(invoice.name, invoice.text)["report_id"] +def test_change_name_valid(): + clear_v1() + token = AuthReturnV2(**auth_register_v2("test@gmail.com", "abc123")).token + invoice = TextInvoice(name="My Invoice", text=VALID_INVOICE_TEXT) + report_id = invoice_upload_text_v2(token, invoice.name, invoice.text)["report_id"] - report = Report(**export_json_report_v1(report_id)) + report = Report(**export_json_report_v2(token, report_id)) # Checking for the old name of the invoice assert report.invoice_name == "My Invoice" - report_change_name_v2(report_id, "New Name") - report = Report(**export_json_report_v1(report_id)) + report_change_name_v2(token, report_id, "New Name") + report = Report(**export_json_report_v2(token, report_id)) # Checking for the new name of the invoice assert report.invoice_name == "New Name" -def test_change_name_long_invalid(): - invoice = TextInvoice(name="My Invoice", source="text", text=VALID_INVOICE_TEXT) - report_id = invoice_upload_text_v1(invoice.name, invoice.text)["report_id"] - - report = Report(**export_json_report_v1(report_id)) +def test_change_name_valid_upload_invalid_token(): + clear_v1() + token = AuthReturnV2(**auth_register_v2("test@gmail.com", "abc123")).token + report_id = invoice_upload_text_v2(token, "invoice", VALID_INVOICE_TEXT)["report_id"] - # Checking for the old name of the invoice - assert report.invoice_name == "My Invoice" + assert report_change_name_v2("invalid", report_id, "New Name")['detail'] == "Invalid token, please login/register" + +def test_change_name_not_owner(): + clear_v1() + token = AuthReturnV2(**auth_register_v2("test@gmail.com", "abc123")).token + token2 = AuthReturnV2(**auth_register_v2("test1@gmail.com", "abc123")).token + report_id = invoice_upload_text_v2(token, "invoice", VALID_INVOICE_TEXT)["report_id"] - new_name = "hellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohello" - assert report_change_name_v2(report_id, new_name)['detail'] == "New name is longer than 100 characters" + assert report_change_name_v2(token2, report_id, "New Name")['detail'] == "You do not have permission to rename this report" -def test_change_name_invalid_token(): - invoice = TextInvoice(name="My Invoice", source="text", text=VALID_INVOICE_TEXT) - report_id = invoice_upload_text_v1(invoice.name, invoice.text)["report_id"] +def test_change_name_long_invalid(): + clear_v1() + token = AuthReturnV2(**auth_register_v2("test@gmail.com", "abc123")).token + invoice = TextInvoice(name="My Invoice", text=VALID_INVOICE_TEXT) + report_id = invoice_upload_text_v2(token, invoice.name, invoice.text)["report_id"] - report = Report(**export_json_report_v1(report_id)) + report = Report(**export_json_report_v2(token, report_id)) # Checking for the old name of the invoice assert report.invoice_name == "My Invoice" - - assert report_change_name_invalid_token_v2(report_id, "New Name")['detail'] == "Only admins can change the names of reports at the moment" + + new_name = "hellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohello" + assert report_change_name_v2(token, report_id, new_name)['detail'] == "New name is longer than 100 characters" def test_change_name_invalid_report_id_negative(): - invoice = TextInvoice(name="My Invoice", source="text", text=VALID_INVOICE_TEXT) - report_id = invoice_upload_text_v1(invoice.name, invoice.text)["report_id"] + clear_v1() + token = AuthReturnV2(**auth_register_v2("test@gmail.com", "abc123")).token + invoice = TextInvoice(name="My Invoice", text=VALID_INVOICE_TEXT) + report_id = invoice_upload_text_v2(token, invoice.name, invoice.text)["report_id"] - report = Report(**export_json_report_v1(report_id)) + report = Report(**export_json_report_v2(token, report_id)) # Checking for the old name of the invoice assert report.invoice_name == "My Invoice" - assert report_change_name_v2(-1, "New Name")['detail'] == "Report id cannot be less than 0" + assert report_change_name_v2(token, -1, "New Name")['detail'] == "Report id cannot be less than 0" def test_change_name_invalid_report_id_not_found(): - invoice = TextInvoice(name="My Invoice", source="text", text=VALID_INVOICE_TEXT) - report_id = invoice_upload_text_v1(invoice.name, invoice.text)["report_id"] + clear_v1() + token = AuthReturnV2(**auth_register_v2("test@gmail.com", "abc123")).token + invoice = TextInvoice(name="My Invoice", text=VALID_INVOICE_TEXT) + report_id = invoice_upload_text_v2(token, invoice.name, invoice.text)["report_id"] - report = Report(**export_json_report_v1(report_id)) + report = Report(**export_json_report_v2(token, report_id)) # Checking for the old name of the invoice assert report.invoice_name == "My Invoice" - assert report_change_name_v2(2937293, "New Name")['detail'] == "Report with id 2937293 not found" + assert report_change_name_v2(token, 2937293, "New Name")['detail'] == "Report with id 2937293 not found" diff --git a/tests/report/delete_test.py b/tests/report/delete_test.py index ef2866e..92632b7 100644 --- a/tests/report/delete_test.py +++ b/tests/report/delete_test.py @@ -1,7 +1,6 @@ from src.type_structure import * -from tests.server_calls import report_delete_invalid_token_v2, report_delete_v2, export_json_report_v1, invoice_upload_text_v1 +from tests.server_calls import report_delete_v2, export_json_report_v1, invoice_upload_text_v2, auth_register_v2, clear_v1 from tests.constants import VALID_INVOICE_TEXT -from tests.helpers import invalidate_invoice, remove_part_of_string """ ===================================== @@ -9,21 +8,40 @@ ===================================== """ -def test_delete(): - report_id = invoice_upload_text_v1("invoice", VALID_INVOICE_TEXT)["report_id"] +def test_delete_valid(): + clear_v1() + token = AuthReturnV2(**auth_register_v2("test@gmail.com", "abc123")).token + report_id = invoice_upload_text_v2(token, "invoice", VALID_INVOICE_TEXT)["report_id"] - report_delete_v2(report_id) + report_delete_v2(token, report_id) assert export_json_report_v1(report_id)["detail"] == f"Report with id {report_id} not found" def test_delete_invalid_report_id_negative(): - assert report_delete_v2(-1)['detail'] == "Report id cannot be less than 0" + clear_v1() + token = AuthReturnV2(**auth_register_v2("test@gmail.com", "abc123")).token + assert report_delete_v2(token, -1)['detail'] == "Report id cannot be less than 0" def test_delete_invalid_report_id_not_found(): - assert report_delete_v2(2937293)['detail'] == "Report with id 2937293 not found" + clear_v1() + token = AuthReturnV2(**auth_register_v2("test@gmail.com", "abc123")).token + assert report_delete_v2(token, 2937293)['detail'] == "Report with id 2937293 not found" -def test_delete_invalid_token(): - invoice = TextInvoice(name="My Invoice", source="text", text=VALID_INVOICE_TEXT) - report_id = invoice_upload_text_v1(invoice.name, invoice.text)["report_id"] +def test_delete_invalid_upload(): + clear_v1() + assert invoice_upload_text_v2("INVALID", "invoice", VALID_INVOICE_TEXT)['detail'] == "Invalid token, please login/register" - assert report_delete_invalid_token_v2(report_id)['detail'] == "Only admins can change the names of reports at the moment" +def test_delete_valid_upload_invalid_token(): + clear_v1() + token = AuthReturnV2(**auth_register_v2("test@gmail.com", "abc123")).token + report_id = invoice_upload_text_v2(token, "invoice", VALID_INVOICE_TEXT)["report_id"] + + assert report_delete_v2("invalid", report_id)['detail'] == "Invalid token, please login/register" + +def test_delete_invalid_not_owner(): + clear_v1() + token = AuthReturnV2(**auth_register_v2("test@gmail.com", "abc123")).token + token2 = AuthReturnV2(**auth_register_v2("test1@gmail.com", "abc123")).token + report_id = invoice_upload_text_v2(token, "invoice", VALID_INVOICE_TEXT)["report_id"] + + assert report_delete_v2(token2, report_id)['detail'] == "You do not have permission to delete this report" diff --git a/tests/report/list_all_test.py b/tests/report/list_all_test.py index 7b50c3c..12a4db5 100644 --- a/tests/report/list_all_test.py +++ b/tests/report/list_all_test.py @@ -10,7 +10,7 @@ """ def test_list_all_one_report(): - invoice = TextInvoice(name="My Invoice", source="text", text=VALID_INVOICE_TEXT) + invoice = TextInvoice(name="My Invoice", text=VALID_INVOICE_TEXT) invoice_upload_text_v1(invoice.name, invoice.text) report_ids = report_list_all_v1()["report_ids"] @@ -23,7 +23,7 @@ def test_list_all_one_report(): def test_list_all_many_reports(): - invoice = TextInvoice(name="My Invoice", source="text", text=VALID_INVOICE_TEXT) + invoice = TextInvoice(name="My Invoice", text=VALID_INVOICE_TEXT) invoice_upload_text_v1(invoice.name, invoice.text) invoice_upload_text_v1(invoice.name, invoice.text) invoice_upload_text_v1(invoice.name, invoice.text) diff --git a/tests/report/peppol_test.py b/tests/report/peppol_test.py index a80ef5d..aefc89c 100644 --- a/tests/report/peppol_test.py +++ b/tests/report/peppol_test.py @@ -13,7 +13,7 @@ def test_peppol_valid_invoice(): data = VALID_INVOICE_TEXT - invoice = TextInvoice(name="My Invoice", source="text", text=data) + invoice = TextInvoice(name="My Invoice", text=data) peppol_evaluation = report_peppol_v1(invoice) peppol_evaluation = Evaluation(**peppol_evaluation) @@ -35,7 +35,7 @@ def test_peppol_single_violation(): # Invalidating the ABN, changing the content of the ABN data = invalidate_invoice(data, 'content', 'cbc:EndpointID', '', 'Not an ABN', 1) - invoice = TextInvoice(name="My Invoice", source="text", text=data) + invoice = TextInvoice(name="My Invoice", text=data) peppol_evaluation = report_peppol_v1(invoice) peppol_evaluation = Evaluation(**peppol_evaluation) @@ -76,7 +76,7 @@ def test_peppol_multiple_violations_same_rule(): data = invalidate_invoice(data, 'content', 'cbc:EndpointID', '', 'Not an ABN 1', 1) data = invalidate_invoice(data, 'content', 'cbc:EndpointID', '', 'Not an ABN 2', 2) - invoice = TextInvoice(name="My Invoice", source="text", text=data) + invoice = TextInvoice(name="My Invoice", text=data) peppol_evaluation = report_peppol_v1(invoice) peppol_evaluation = Evaluation(**peppol_evaluation) @@ -111,7 +111,7 @@ def test_peppol_multiple_violations_different_rules(): data = invalidate_invoice(data, 'content', 'cbc:IssueDate', '', 'bad-date', 1) data = invalidate_invoice(data, 'content', 'cbc:IssueDate', '', 'bad-date', 2) - invoice = TextInvoice(name="My Invoice", source="text", text=data) + invoice = TextInvoice(name="My Invoice", text=data) peppol_evaluation = report_peppol_v1(invoice) peppol_evaluation = Evaluation(**peppol_evaluation) @@ -149,7 +149,7 @@ def test_peppol_warning_doesnt_invalidate_report(): # Invalidating the ABN data = invalidate_invoice(data, 'content', 'cbc:EndpointID', '', 'Not an ABN 1', 1) - invoice = TextInvoice(name="My Invoice", source="text", text=data) + invoice = TextInvoice(name="My Invoice", text=data) peppol_evaluation = report_peppol_v1(invoice) peppol_evaluation = Evaluation(**peppol_evaluation) @@ -168,7 +168,7 @@ def test_peppol_fatal_error_invalidates_report(): # Changing the start date year to 2029 data = invalidate_invoice(data, 'content', 'cbc:IssueDate', '', 'bad-date', 1) - invoice = TextInvoice(name="My Invoice", source="text", text=data) + invoice = TextInvoice(name="My Invoice", text=data) peppol_evaluation = report_peppol_v1(invoice) peppol_evaluation = Evaluation(**peppol_evaluation) diff --git a/tests/report/schema_test.py b/tests/report/schema_test.py index 0785ea7..cfc6012 100644 --- a/tests/report/schema_test.py +++ b/tests/report/schema_test.py @@ -12,7 +12,7 @@ def test_schema_valid(): # Replacing the tags but making sure they are valid data = VALID_INVOICE_TEXT - invoice = TextInvoice(name="My Invoice", source="text", text=data) + invoice = TextInvoice(name="My Invoice", text=data) schema_evaluation = report_schema_v1(invoice) schema_evaluation = Evaluation(**schema_evaluation) @@ -31,7 +31,7 @@ def test_schema_tag_name_invalid(): data = invalidate_invoice(VALID_INVOICE_TEXT, "tag", "cac:BillingReference", "", "cac:BillingReferencee", 1) data = invalidate_invoice(data, "tag", "cac:BillingReference", "", "cac:BillingReferencee", 1) - invoice = TextInvoice(name="My Invoice", source="text", text=data) + invoice = TextInvoice(name="My Invoice", text=data) schema_evaluation = report_schema_v1(invoice) schema_evaluation = Evaluation(**schema_evaluation) @@ -61,7 +61,7 @@ def test_schema_tag_order_invalid(): # Invalidating the date data = invalidate_invoice(VALID_INVOICE_TEXT, "tag", "cbc:IssueDate", "", "cbc:DueDate", 1) data = invalidate_invoice(data, "tag", "cbc:DueDate", "", "cbc:IssueDate", 2) - invoice = TextInvoice(name="My Invoice", source="text", text=data) + invoice = TextInvoice(name="My Invoice", text=data) schema_evaluation = report_schema_v1(invoice) schema_evaluation = Evaluation(**schema_evaluation) @@ -90,7 +90,7 @@ def test_schema_tag_order_invalid(): def test_schema_date_type_invalid(): # Invalidating the date data = invalidate_invoice(VALID_INVOICE_TEXT, "content", "cbc:IssueDate", "", "totallyADate", 1) - invoice = TextInvoice(name="My Invoice", source="text", text=data) + invoice = TextInvoice(name="My Invoice", text=data) schema_evaluation = report_schema_v1(invoice) schema_evaluation = Evaluation(**schema_evaluation) @@ -122,7 +122,7 @@ def test_schema_tags_revalid(): data = invalidate_invoice(data, "content", "cbc:CopyIndicator", "", "true", 1) data = invalidate_invoice(data, "tag", "cbc:DueDate", "", "cbc:IssueDate", 1) - invoice = TextInvoice(name="My Invoice", source="text", text=data) + invoice = TextInvoice(name="My Invoice", text=data) schema_evaluation = report_schema_v1(invoice) schema_evaluation = Evaluation(**schema_evaluation) @@ -141,7 +141,7 @@ def test_schema_tags_multiple_errors_invalid(): # Also expects the following tag to be different data = invalidate_invoice(VALID_INVOICE_TEXT, "tag", "cbc:IssueDate", "", "cbc:CopyIndicator", 1) - invoice = TextInvoice(name="My Invoice", source="text", text=data) + invoice = TextInvoice(name="My Invoice", text=data) schema_evaluation = report_schema_v1(invoice) schema_evaluation = Evaluation(**schema_evaluation) diff --git a/tests/report/syntax_test.py b/tests/report/syntax_test.py index de5d36d..f53d087 100644 --- a/tests/report/syntax_test.py +++ b/tests/report/syntax_test.py @@ -13,7 +13,7 @@ def test_syntax_valid_invoice(): data = VALID_INVOICE_TEXT - invoice = TextInvoice(name="My Invoice", source="text", text=data) + invoice = TextInvoice(name="My Invoice", text=data) syntax_evaluation = report_syntax_v1(invoice) syntax_evaluation = Evaluation(**syntax_evaluation) @@ -34,7 +34,7 @@ def test_syntax_single_violation(): # Invalidating the currency code data = invalidate_invoice(data, 'attrib', 'cbc:Amount', 'currencyID', 'TEST', 1) - invoice = TextInvoice(name="My Invoice", source="text", text=data) + invoice = TextInvoice(name="My Invoice", text=data) syntax_evaluation = report_syntax_v1(invoice) syntax_evaluation = Evaluation(**syntax_evaluation) @@ -73,7 +73,7 @@ def test_syntax_multiple_violations_same_rule(): data = invalidate_invoice(data, 'attrib', 'cbc:Amount', 'currencyID', 'TEST', 1) data = invalidate_invoice(data, 'attrib', 'cbc:Amount', 'currencyID', 'TEST', 2) - invoice = TextInvoice(name="My Invoice", source="text", text=data) + invoice = TextInvoice(name="My Invoice", text=data) syntax_evaluation = report_syntax_v1(invoice) syntax_evaluation = Evaluation(**syntax_evaluation) @@ -108,7 +108,7 @@ def test_syntax_multiple_violations_different_rules(): data = invalidate_invoice(data, 'content', 'cbc:IdentificationCode', '', 'TEST', 1) data = invalidate_invoice(data, 'content', 'cbc:IdentificationCode', '', 'TEST', 2) - invoice = TextInvoice(name="My Invoice", source="text", text=data) + invoice = TextInvoice(name="My Invoice", text=data) syntax_evaluation = report_syntax_v1(invoice) syntax_evaluation = Evaluation(**syntax_evaluation) @@ -146,7 +146,7 @@ def test_syntax_warning_doesnt_invalidate_report(): # Violates [UBL-CR-003]-A UBL invoice should not include the ProfileExecutionID data = invalidate_invoice(data, 'tag', 'cbc:Note', '', 'cbc:ProfileExecutionID', 1) - invoice = TextInvoice(name="My Invoice", source="text", text=data) + invoice = TextInvoice(name="My Invoice", text=data) syntax_evaluation = report_syntax_v1(invoice) syntax_evaluation = Evaluation(**syntax_evaluation) @@ -165,7 +165,7 @@ def test_syntax_fatal_error_invalidates_report(): # Adding "D" to the currency code to make it invalid data = invalidate_invoice(data, 'attrib', 'cbc:Amount', 'currencyID', 'TEST', 1) - invoice = TextInvoice(name="My Invoice", source="text", text=data) + invoice = TextInvoice(name="My Invoice", text=data) syntax_evaluation = report_syntax_v1(invoice) syntax_evaluation = Evaluation(**syntax_evaluation) diff --git a/tests/report/wellformedness_test.py b/tests/report/wellformedness_test.py index 2408f83..b5e87d6 100644 --- a/tests/report/wellformedness_test.py +++ b/tests/report/wellformedness_test.py @@ -11,7 +11,7 @@ def test_wellformedness_valid_invoice(): data = VALID_INVOICE_TEXT - invoice = TextInvoice(name="My Invoice", source="text", text=data) + invoice = TextInvoice(name="My Invoice", text=data) wellformed_evaluation = report_wellformedness_v1(invoice) wellformed_evaluation = Evaluation(**wellformed_evaluation) @@ -30,7 +30,7 @@ def test_wellformedness_case_sensitive_tags_invalid(): # Invalidating the tags so that only one of the tags is capitalised data = replace_part_of_string(VALID_INVOICE_TEXT, 2025, 2027, "id") - invoice = TextInvoice(name="My Invoice", source="text", text=data) + invoice = TextInvoice(name="My Invoice", text=data) wellformed_evaluation = report_wellformedness_v1(invoice) wellformed_evaluation = Evaluation(**wellformed_evaluation) @@ -60,7 +60,7 @@ def test_wellformedness_case_sensitive_tags_valid(): data = replace_part_of_string(VALID_INVOICE_TEXT, 2025, 2027, "id") data = replace_part_of_string(data, 2045, 2047, "id") - invoice = TextInvoice(name="My Invoice", source="text", text=data) + invoice = TextInvoice(name="My Invoice", text=data) wellformed_evaluation = report_wellformedness_v1(invoice) wellformed_evaluation = Evaluation(**wellformed_evaluation) @@ -76,7 +76,7 @@ def test_wellformedness_case_sensitive_tags_valid(): def test_wellformedness_two_root_elements_invalid(): data = VALID_INVOICE_TEXT data = append_to_string(data, """Second root at the end""") - invoice = TextInvoice(name="My Invoice", source="text", text=data) + invoice = TextInvoice(name="My Invoice", text=data) wellformed_evaluation = report_wellformedness_v1(invoice) wellformed_evaluation = Evaluation(**wellformed_evaluation) @@ -101,7 +101,7 @@ def test_wellformedness_two_root_elements_invalid(): def test_wellformedness_no_closing_tag_invalid(): data = VALID_INVOICE_TEXT data = remove_part_of_string(data, 11530, 11540) - invoice = TextInvoice(name="My Invoice", source="text", text=data) + invoice = TextInvoice(name="My Invoice", text=data) wellformed_evaluation = report_wellformedness_v1(invoice) wellformed_evaluation = Evaluation(**wellformed_evaluation) @@ -128,7 +128,7 @@ def test_wellformedness_wrong_nesting_invalid(): data = VALID_INVOICE_TEXT data = remove_part_of_string(data, 11512, 11530) data = append_to_string(data, """""") - invoice = TextInvoice(name="My Invoice", source="text", text=data) + invoice = TextInvoice(name="My Invoice", text=data) wellformed_evaluation = report_wellformedness_v1(invoice) wellformed_evaluation = Evaluation(**wellformed_evaluation) @@ -154,7 +154,7 @@ def test_wellformedness_wrong_nesting_invalid(): def test_wellformedness_valid_version_number_error(): data = VALID_INVOICE_TEXT - invoice = TextInvoice(name="My Invoice", source="text", text=data) + invoice = TextInvoice(name="My Invoice", text=data) wellformedness_evaluation = report_wellformedness_v1(invoice) wellformedness_evaluation = Evaluation(**wellformedness_evaluation) @@ -176,7 +176,7 @@ def test_wellformedness_invalid_version_number_error(): data = VALID_INVOICE_TEXT data = replace_part_of_string(data, 15, 16, '5') - invoice = TextInvoice(name="My Invoice", source="text", text=data) + invoice = TextInvoice(name="My Invoice", text=data) wellformedness_evaluation = report_wellformedness_v1(invoice) wellformedness_evaluation = Evaluation(**wellformedness_evaluation) @@ -204,7 +204,7 @@ def test_wellformedness_invalid_version_number_error(): def test_wellformedness_no_escape_for_special_char_invalid(): data = VALID_INVOICE_TEXT data = replace_part_of_string(data, 694, 695, "<") - invoice = TextInvoice(name="My Invoice", source="text", text=data) + invoice = TextInvoice(name="My Invoice", text=data) wellformed_evaluation = report_wellformedness_v1(invoice) wellformed_evaluation = Evaluation(**wellformed_evaluation) @@ -234,7 +234,7 @@ def test_wellformedness_no_escape_for_special_char_invalid(): def test_wellformedness_escape_for_special_char_valid(): data = VALID_INVOICE_TEXT data = replace_part_of_string(data, 694, 695, "<") - invoice = TextInvoice(name="My Invoice", source="text", text=data) + invoice = TextInvoice(name="My Invoice", text=data) wellformed_evaluation = report_wellformedness_v1(invoice) wellformed_evaluation = Evaluation(**wellformed_evaluation) diff --git a/tests/server_calls.py b/tests/server_calls.py index fbcea67..0365225 100644 --- a/tests/server_calls.py +++ b/tests/server_calls.py @@ -28,6 +28,15 @@ def invoice_upload_text_v1(invoice_name: str, invoice_text: str) -> Server_call_ return json.loads(response.text) +def invoice_upload_text_v2(token: str, invoice_name: str, invoice_text: str) -> Server_call_return: + payload = TextInvoice(name=invoice_name, text=invoice_text).dict() + headers = { + "Authorization": "bearer " + token + } + response = requests.post(full_url + 'invoice/upload_text/v2', json=payload, headers=headers) + + return json.loads(response.text) + def invoice_bulk_upload_text_v1(invoices: List[TextInvoice]) -> Server_call_return: payload = { "invoices": [invoice.dict() for invoice in invoices] @@ -52,12 +61,32 @@ def export_json_report_v1(report_id: int) -> Server_call_return: return json.loads(response.text) +def export_json_report_v2(token: str, report_id: int) -> Server_call_return: + payload = { + "report_id": report_id + } + headers = { + "Authorization": "bearer " + token + } + response = requests.get(full_url + 'export/json_report/v2', params=payload, headers=headers) + + return json.loads(response.text) + def export_bulk_json_reports_v1(report_ids) -> Server_call_return: payload = report_ids response = requests.post(full_url + 'export/bulk_json_reports/v1', json=payload) return json.loads(response.text) +def export_bulk_json_reports_v2(token: str, report_ids) -> Server_call_return: + payload = report_ids + headers = { + "Authorization": "bearer " + token + } + response = requests.post(full_url + 'export/bulk_json_reports/v2', json=payload, headers=headers) + + return json.loads(response.text) + def export_pdf_report_v1(report_id: int): payload = { "report_id": report_id @@ -66,6 +95,17 @@ def export_pdf_report_v1(report_id: int): return response.content +def export_pdf_report_v2(token: str, report_id: int): + payload = { + "report_id": report_id + } + headers = { + "Authorization": "bearer " + token + } + response = requests.get(full_url + 'export/pdf_report/v2', params=payload, headers=headers) + + return response.content + def export_bulk_pdf_reports_v1(report_ids) -> Server_call_return: payload = { "report_ids": report_ids @@ -74,6 +114,17 @@ def export_bulk_pdf_reports_v1(report_ids) -> Server_call_return: return json.loads(response.text) +def export_bulk_pdf_reports_v2(token: str, report_ids) -> Server_call_return: + payload = { + "report_ids": report_ids + } + headers = { + "Authorization": "bearer " + token + } + response = requests.get(full_url + 'export/bulk_pdf_reports/v2', params=payload, headers=headers) + + return json.loads(response.text) + def export_html_report_v1(report_id: int): payload = { "report_id": report_id @@ -82,6 +133,17 @@ def export_html_report_v1(report_id: int): return response.content +def export_html_report_v2(token: str, report_id: int): + payload = { + "report_id": report_id + } + headers = { + "Authorization": "bearer " + token + } + response = requests.get(full_url + 'export/html_report/v2', params=payload, headers=headers) + + return response.content + def export_csv_report_v1(report_id: int): payload = { "report_id": report_id @@ -90,6 +152,17 @@ def export_csv_report_v1(report_id: int): return response.content +def export_csv_report_v2(token: str, report_id: int): + payload = { + "report_id": report_id + } + headers = { + "Authorization": "bearer " + token + } + response = requests.get(full_url + 'export/csv_report/v2', params=payload, headers=headers) + + return response.content + # Report Endpoints def report_wellformedness_v1(invoice: TextInvoice) -> Server_call_return: @@ -144,41 +217,28 @@ def report_check_validity_v1(report_id: int) -> Server_call_return: ### Other Endpoints -def report_delete_v2(report_id: int) -> Server_call_return: +def report_delete_v2(token: str, report_id: int) -> Server_call_return: payload = { "token": ADMIN_TOKEN, "report_id": report_id } - response = requests.delete(full_url + 'report/delete/v2', params=payload) - - return json.loads(response.text) - -def report_delete_invalid_token_v2(report_id: int) -> Server_call_return: - payload = { - "token": "invalidtoken", - "report_id": report_id + headers = { + "Authorization": "bearer " + token } - response = requests.delete(full_url + 'report/delete/v2', params=payload) + response = requests.delete(full_url + 'report/delete/v2', params=payload, headers=headers) return json.loads(response.text) -def report_change_name_v2(report_id: int, new_name: str) -> Server_call_return: +def report_change_name_v2(token: str, report_id: int, new_name: str) -> Server_call_return: payload = { "token": ADMIN_TOKEN, "report_id": report_id, "new_name": new_name } - response = requests.put(full_url + 'report/change_name/v2', params=payload) - - return json.loads(response.text) - -def report_change_name_invalid_token_v2(report_id: int, new_name: str) -> Server_call_return: - payload = { - "token": "invalidtoken", - "report_id": report_id, - "new_name": new_name + headers = { + "Authorization": "bearer " + token } - response = requests.put(full_url + 'report/change_name/v2', params=payload) + response = requests.put(full_url + 'report/change_name/v2', params=payload, headers=headers) return json.loads(response.text) @@ -197,17 +257,17 @@ def auth_register_v2(email: str, password: str) -> Server_call_return: "email": email, "password": password } - response = requests.get(full_url + 'auth_register/v2', params=payload) + response = requests.post(full_url + 'auth_register/v2', params=payload) return json.loads(response.text) def auth_login_v2(email: str, password: str) -> Server_call_return: payload = { - "email": email, + "username": email, "password": password } - response = requests.get(full_url + 'auth_login/v2', params=payload) + response = requests.post(full_url + 'auth_login/v2', data=payload) return json.loads(response.text)