Skip to content

Commit

Permalink
Cleaning up API to look attractive to other groups (#24)
Browse files Browse the repository at this point in the history
* Making tests comply to interface changes + changing Invoice to TextInvoice for testing

* Changing url back to localhost

* Combining helpers into one file

* Adding documentation to main.py + improvements to typing / file IO

* Fixing implentation to conform to modified main.py interface and types
  • Loading branch information
jeremytraini authored Mar 15, 2023
1 parent b7a61c7 commit add3172
Show file tree
Hide file tree
Showing 31 changed files with 421 additions and 353 deletions.
2 changes: 1 addition & 1 deletion src/authentication.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from src.database import Users, IntegrityError, DoesNotExist
from src.helper import string_in_range
from src.helpers import string_in_range
from src.error import InputError
from src.type_structure import *
import hashlib
Expand Down
2 changes: 1 addition & 1 deletion src/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
port = 8000

base_url = f"0.0.0.0"
base_url = f"localhost"

full_url = f"http://{base_url}:{port}/"
5 changes: 5 additions & 0 deletions src/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,8 @@
xsltproc = proc.new_xslt30_processor()
SYNTAX_EXECUTABLE = xsltproc.compile_stylesheet(stylesheet_file="src/validation_artefacts/AUNZ-UBL-validation.xslt")
PEPPOL_EXECUTABLE = xsltproc.compile_stylesheet(stylesheet_file="src/validation_artefacts/AUNZ-PEPPOL-validation.xslt")

# The token below is a temporary token for Sprint 1.
# It will be replaced with a token that is generated by Auth endpoints.
ADMIN_TOKEN = "UG#&*GFUBIFBIUEB#&*FUB"

8 changes: 6 additions & 2 deletions src/database.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from peewee import *
import os
from src.constants import ADMIN_TOKEN

db = None
if 'RDS_DB_NAME' in os.environ:
Expand Down Expand Up @@ -58,7 +59,7 @@ class Reports(BaseModel):
def to_json(self):
return {
"report_id": self.id, # type: ignore
"date_generated": self.date_generated,
"date_generated": str(self.date_generated),
"invoice_name": self.invoice_name,
"invoice_text": self.invoice_text,
"invoice_hash": self.invoice_hash,
Expand Down Expand Up @@ -109,7 +110,10 @@ def create_tables():
with db:
db.create_tables(tables)

def clear_v1():
def clear_v1(token: str):
if not token == ADMIN_TOKEN:
raise Exception("Only admins can clear the database")

with db:
db.drop_tables(tables)
db.create_tables(tables)
Expand Down
33 changes: 19 additions & 14 deletions src/export.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from src.type_structure import *
from src.database import Reports
from bs4 import BeautifulSoup
from copy import copy, deepcopy
import json
from html import escape
from peewee import DoesNotExist
from weasyprint import HTML
from io import BytesIO, StringIO
from io import StringIO, BytesIO
import csv
from zipfile import ZipFile, ZIP_DEFLATED


def export_json_report_v1(report_id: int):
Expand All @@ -16,9 +15,9 @@ def export_json_report_v1(report_id: int):
except DoesNotExist:
raise Exception(f"Report with id {report_id} not found")

return report.to_json()
return Report(**report.to_json())

def export_pdf_report_v1(report_id: int):
def export_pdf_report_v1(report_id: int) -> bytes:
html = export_html_report_v1(report_id)
pdf_bytes = HTML(string=html).write_pdf()

Expand Down Expand Up @@ -161,12 +160,18 @@ def export_csv_report_v1(report_id: int):

return csv_contents

def report_bulk_export_v1(report_ids, report_format) -> List:
report_format = report_format.lower()
print("Exporting reports")
if report_format == "json":
return [export_json_report_v1(report_id) for report_id in report_ids]
elif report_format == "html":
return [export_html_report_v1(report_id) for report_id in report_ids]
else:
raise Exception("Unknown report format")
def report_bulk_export_json_v1(report_ids) -> List:
return {
"reports": [export_json_report_v1(report_id) for report_id in report_ids]
}

def report_bulk_export_pdf_v1(report_ids) -> BytesIO:
reports = BytesIO()

with ZipFile(reports, 'w', ZIP_DEFLATED) as f:
for report_id in report_ids:
f.writestr(f"invoice_validation_report_{report_id}.pdf", export_pdf_report_v1(report_id))

reports.seek(0)

return reports
20 changes: 0 additions & 20 deletions src/helper.py

This file was deleted.

24 changes: 12 additions & 12 deletions src/helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from tempfile import NamedTemporaryFile
from src.type_structure import *
import requests

def create_temp_file(invoice_text: str) -> str:
tmp = NamedTemporaryFile(mode='w', delete=False)
Expand All @@ -9,16 +8,17 @@ def create_temp_file(invoice_text: str) -> str:

return tmp.name

def extract_text_from_invoice(invoice: Invoice) -> str:
if invoice.source == "url":
response = requests.get(invoice.data)
if response.status_code != 200:
raise Exception("Could not retrieve file from url")
def string_in_range(min_len:int, max_len:int, input_str:str) -> bool:
'''
This function checks if a string is within the ranges of min and max length.
data = response.text
elif invoice.source == "text":
data = invoice.data
else:
raise Exception("Invalid source, please enter url or text")
Arguments:
min_len (int) - Minimum length of string
max_len (int) - Maximum length of string
input_str (str) - Input string to check length
return data
Return Value:
Returns boolean to whether string is within range or not
'''

return len(input_str) >= min_len and len(input_str) <= max_len
19 changes: 7 additions & 12 deletions src/invoice.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
from typing import Dict
from src.type_structure import *
from src.database import Reports, DoesNotExist
import requests
from src.generation import generate_report


def invoice_upload_text_v1(invoice_name: str, invoice_text: str):
report_id = generate_report(invoice_name, invoice_text)

return {
"report_id": report_id
"report_id": generate_report(invoice_name, invoice_text)
}


Expand All @@ -27,11 +24,9 @@ def invoice_upload_url_v1(invoice_name: str, invoice_url: str):
}


def invoice_upload_file_v1(invoice_name: str, invoice_file):
report_id = generate_report(invoice_name, invoice_file.decode("utf-8"))

def invoice_upload_file_v1(invoice_name: str, invoice_text: str):
return {
"report_id": report_id
"report_id": generate_report(invoice_name, invoice_text)
}

def invoice_check_validity_v1(report_id: int) -> CheckValidReturn:
Expand All @@ -42,8 +37,8 @@ def invoice_check_validity_v1(report_id: int) -> CheckValidReturn:

return CheckValidReturn(is_valid=report.is_valid)

def invoice_generate_hash_v1(invoice: Invoice) -> str:
return "hash"
def invoice_generate_hash_v1(invoice: TextInvoice) -> str:
return {}

def invoice_file_upload_bulk_v1(invoices: List[Invoice]) -> List[int]:
return [generate_report(invoice.name, invoice.data) for invoice in invoices]
def invoice_upload_bulk_text_v1(invoices: List[TextInvoice]) -> ReportIDs:
return ReportIDs(report_ids=[generate_report(invoice.name, invoice.text) for invoice in invoices])
Loading

0 comments on commit add3172

Please sign in to comment.