Skip to content

Commit

Permalink
Protect endpoints from unauthorized use.
Browse files Browse the repository at this point in the history
Update pytests
  • Loading branch information
areyeslo committed Jan 16, 2025
1 parent 0ee0a75 commit 202caac
Show file tree
Hide file tree
Showing 19 changed files with 442 additions and 102 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -235,19 +235,20 @@ async def test_get_compliance_report_by_id_success(
compliance_report_base_schema,
set_mock_user,
):
with patch(
with patch(
"lcfs.web.api.compliance_report.views.ComplianceReportServices.get_compliance_report_by_id"
) as mock_get_compliance_report_by_id, patch(
"lcfs.web.api.compliance_report.views.ComplianceReportValidation.validate_organization_access"
) as mock_validate_organization_access:
) as mock_validate_organization_access, patch(
"lcfs.web.api.compliance_report.views.ComplianceReportValidation.validate_compliance_report_access"
) as mock_validate_compliance_report_access:
set_mock_user(fastapi_app, [RoleEnum.GOVERNMENT])

mock_compliance_report = ChainedComplianceReportSchema(
report=compliance_report_base_schema(), chain=[]
)

mock_compliance_report = compliance_report_base_schema()

mock_get_compliance_report_by_id.return_value = mock_compliance_report
mock_validate_organization_access.return_value = None
mock_validate_organization_access.return_value = True
mock_validate_compliance_report_access.return_value = True

url = fastapi_app.url_path_for("get_compliance_report_by_id", report_id=1)

Expand Down
45 changes: 45 additions & 0 deletions backend/lcfs/tests/fuel_export/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from datetime import datetime
import pytest
from unittest.mock import AsyncMock, MagicMock
from lcfs.web.api.compliance_report.schema import CompliancePeriodSchema, ComplianceReportHistorySchema, ComplianceReportOrganizationSchema, ComplianceReportStatusSchema, ComplianceReportUserSchema, SummarySchema
from lcfs.web.api.fuel_export.repo import FuelExportRepository
from lcfs.web.api.fuel_code.repo import FuelCodeRepository
from lcfs.web.api.fuel_export.services import FuelExportServices
Expand Down Expand Up @@ -45,6 +47,49 @@ def mock_compliance_report_repo():
repo = AsyncMock(spec=ComplianceReportRepository)
return repo

@pytest.fixture
def compliance_period_schema():
return CompliancePeriodSchema(
compliance_period_id=1,
description="2024",
effective_date=datetime(2024, 1, 1),
expiration_date=datetime(2024, 3, 31),
display_order=1,
)

@pytest.fixture
def compliance_report_organization_schema():
return ComplianceReportOrganizationSchema(
organization_id=1, name="Acme Corporation"
)

@pytest.fixture
def summary_schema():
return SummarySchema(summary_id=1, is_locked=False)

@pytest.fixture
def compliance_report_status_schema():
return ComplianceReportStatusSchema(compliance_report_status_id=1, status="Draft")

@pytest.fixture
def compliance_report_user_schema(compliance_report_organization_schema):
return ComplianceReportUserSchema(
first_name="John",
last_name="Doe",
organization=compliance_report_organization_schema,
)

@pytest.fixture
def compliance_report_history_schema(
compliance_report_status_schema, compliance_report_user_schema
):
return ComplianceReportHistorySchema(
compliance_report_history_id=1,
compliance_report_id=1,
status=compliance_report_status_schema,
user_profile=compliance_report_user_schema,
create_date=datetime(2024, 4, 1, 12, 0, 0),
)

@pytest.fixture
def mock_fuel_code_repo():
Expand Down
23 changes: 19 additions & 4 deletions backend/lcfs/tests/fuel_export/test_fuel_exports_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from fastapi import FastAPI
from fastapi.encoders import jsonable_encoder

from lcfs.tests.compliance_report.conftest import compliance_report_base_schema
from lcfs.web.api.compliance_report.schema import ChainedComplianceReportSchema
from lcfs.web.api.fuel_export.schema import (
FuelExportSchema,
FuelExportCreateUpdateSchema,
Expand Down Expand Up @@ -68,18 +70,24 @@ async def test_get_fuel_exports_invalid_payload(

@pytest.mark.anyio
async def test_get_fuel_exports_paginated_success(
client: AsyncClient, fastapi_app: FastAPI, set_mock_user
client: AsyncClient, fastapi_app: FastAPI, set_mock_user, compliance_report_base_schema
):
with patch(
"lcfs.web.api.fuel_export.views.FuelExportServices.get_fuel_exports_paginated"
) as mock_get_fuel_exports_paginated, patch(
"lcfs.web.api.fuel_export.views.ComplianceReportValidation.validate_organization_access"
) as mock_validate_organization_access:
) as mock_validate_organization_access, patch(
"lcfs.web.api.fuel_export.views.FuelExportServices.get_compliance_report_by_id"
) as mock_get_compliance_report_by_id:

mock_get_fuel_exports_paginated.return_value = FuelExportsSchema(
fuel_exports=[]
)
mock_validate_organization_access.return_value = True

mock_compliance_report = compliance_report_base_schema()

mock_get_compliance_report_by_id.return_value = mock_compliance_report
set_mock_user(fastapi_app, [RoleEnum.ANALYST])

url = fastapi_app.url_path_for("get_fuel_exports")
Expand All @@ -98,16 +106,23 @@ async def test_get_fuel_exports_paginated_success(

@pytest.mark.anyio
async def test_get_fuel_exports_list_success(
client: AsyncClient, fastapi_app: FastAPI, set_mock_user
client: AsyncClient, fastapi_app: FastAPI, set_mock_user, compliance_report_base_schema
):
with patch(
"lcfs.web.api.fuel_export.views.FuelExportServices.get_fuel_export_list"
) as mock_get_fuel_export_list, patch(
"lcfs.web.api.fuel_export.views.ComplianceReportValidation.validate_organization_access"
) as mock_validate_organization_access:
) as mock_validate_organization_access, patch(
"lcfs.web.api.fuel_export.views.FuelExportServices.get_compliance_report_by_id"
) as mock_get_compliance_report_by_id:

mock_get_fuel_export_list.return_value = FuelExportsSchema(fuel_exports=[])
mock_validate_organization_access.return_value = True

mock_compliance_report = compliance_report_base_schema()

mock_get_compliance_report_by_id.return_value = mock_compliance_report

set_mock_user(fastapi_app, [RoleEnum.ANALYST])

url = fastapi_app.url_path_for("get_fuel_exports")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from lcfs.db.base import UserTypeEnum, ActionTypeEnum
from lcfs.db.models.user.Role import RoleEnum
from lcfs.tests.compliance_report.conftest import compliance_report_base_schema
from lcfs.web.api.base import ComplianceReportRequestSchema
from lcfs.web.api.notional_transfer.schema import (
PaginatedNotionalTransferRequestSchema,
Expand Down Expand Up @@ -70,12 +71,20 @@ async def test_get_notional_transfers(
):
with patch(
"lcfs.web.api.notional_transfer.views.ComplianceReportValidation.validate_organization_access"
) as mock_validate_organization_access:
) as mock_validate_organization_access,patch(
"lcfs.web.api.notional_transfer.views.ComplianceReportValidation.validate_compliance_report_access"
) as mock_validate_compliance_report_access, patch(
"lcfs.web.api.notional_transfer.views.NotionalTransferServices.get_compliance_report_by_id"
) as mock_get_compliance_report_by_id:
set_mock_user(fastapi_app, [RoleEnum.SUPPLIER])
url = fastapi_app.url_path_for("get_notional_transfers")
payload = ComplianceReportRequestSchema(compliance_report_id=1).model_dump()

mock_validate_organization_access.return_value = True

mock_get_compliance_report_by_id.return_value = compliance_report_base_schema
mock_validate_compliance_report_access.return_value = True

mock_notional_transfer_service.get_notional_transfers.return_value = {
"notionalTransfers": []
}
Expand Down
12 changes: 11 additions & 1 deletion backend/lcfs/tests/other_uses/test_other_uses_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from lcfs.db.base import UserTypeEnum, ActionTypeEnum
from lcfs.db.models.user.Role import RoleEnum
from lcfs.tests.compliance_report.conftest import compliance_report_base_schema
from lcfs.web.api.base import ComplianceReportRequestSchema
from lcfs.web.api.other_uses.schema import (
PaginatedOtherUsesRequestSchema,
Expand Down Expand Up @@ -69,12 +70,21 @@ async def test_get_other_uses(
):
with patch(
"lcfs.web.api.other_uses.views.ComplianceReportValidation.validate_organization_access"
) as mock_validate_organization_access:
) as mock_validate_organization_access, patch(
"lcfs.web.api.notional_transfer.views.ComplianceReportValidation.validate_compliance_report_access"
) as mock_validate_compliance_report_access, patch(
"lcfs.web.api.notional_transfer.views.NotionalTransferServices.get_compliance_report_by_id"
) as mock_get_compliance_report_by_id:

set_mock_user(fastapi_app, [RoleEnum.SUPPLIER])
url = fastapi_app.url_path_for("get_other_uses")
payload = ComplianceReportRequestSchema(compliance_report_id=1).model_dump()

mock_validate_organization_access.return_value = True

mock_get_compliance_report_by_id.return_value = compliance_report_base_schema
mock_validate_compliance_report_access.return_value = True

mock_other_uses_service.get_other_uses.return_value = {"otherUses": []}

fastapi_app.dependency_overrides[OtherUsesServices] = (
Expand Down
20 changes: 19 additions & 1 deletion backend/lcfs/web/api/allocation_agreement/services.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import math
import structlog
from typing import List
from fastapi import Depends
from fastapi import Depends, HTTPException, status
from datetime import datetime

from lcfs.web.api.allocation_agreement.repo import AllocationAgreementRepository
from lcfs.web.api.compliance_report.repo import ComplianceReportRepository
from lcfs.web.core.decorators import service_handler
from lcfs.db.models.compliance.AllocationAgreement import AllocationAgreement
from lcfs.web.api.base import PaginationRequestSchema, PaginationResponseSchema
Expand Down Expand Up @@ -34,9 +35,11 @@ def __init__(
self,
repo: AllocationAgreementRepository = Depends(AllocationAgreementRepository),
fuel_repo: FuelCodeRepository = Depends(),
compliance_report_repo: ComplianceReportRepository = Depends(),
) -> None:
self.repo = repo
self.fuel_repo = fuel_repo
self.compliance_report_repo = compliance_report_repo

async def convert_to_model(
self, allocation_agreement: AllocationAgreementCreateSchema
Expand Down Expand Up @@ -350,3 +353,18 @@ async def create_allocation_agreement(
async def delete_allocation_agreement(self, allocation_agreement_id: int) -> str:
"""Delete an Allocation agreement"""
return await self.repo.delete_allocation_agreement(allocation_agreement_id)

@service_handler
async def get_compliance_report_by_id(self, compliance_report_id: int):
"""Get compliance report by period with status"""
compliance_report = await self.compliance_report_repo.get_compliance_report_by_id(
compliance_report_id,
)

if not compliance_report:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Compliance report not found for this period"
)

return compliance_report
30 changes: 26 additions & 4 deletions backend/lcfs/web/api/allocation_agreement/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from fastapi import (
APIRouter,
Body,
HTTPException,
status,
Request,
Response,
Expand Down Expand Up @@ -69,10 +70,31 @@ async def get_allocation_agreements(
report_validate: ComplianceReportValidation = Depends(),
):
"""Endpoint to get list of allocation agreements for a compliance report"""
await report_validate.validate_organization_access(
request_data.compliance_report_id
)
return await service.get_allocation_agreements(request_data.compliance_report_id)
try:
compliance_report_id = request_data.compliance_report_id

compliance_report = await service.get_compliance_report_by_id(compliance_report_id)
if not compliance_report:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Compliance report not found"
)

await report_validate.validate_compliance_report_access(compliance_report)
await report_validate.validate_organization_access(
request_data.compliance_report_id
)
return await service.get_allocation_agreements(request_data.compliance_report_id)
except HTTPException as http_ex:
# Re-raise HTTP exceptions to preserve status code and message
raise http_ex
except Exception as e:
# Log and handle unexpected errors
logger.exception("Error occurred", error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="An unexpected error occurred while processing your request"
)


@router.post(
Expand Down
15 changes: 15 additions & 0 deletions backend/lcfs/web/api/compliance_report/validation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from fastapi import Depends, HTTPException, Request
from lcfs.db.models.user.Role import RoleEnum
from lcfs.db.models.compliance.ComplianceReportStatus import ComplianceReportStatusEnum
from lcfs.web.api.compliance_report.repo import ComplianceReportRepository
from fastapi import status
from lcfs.web.api.role.schema import user_has_roles
Expand Down Expand Up @@ -41,3 +42,17 @@ async def validate_organization_access(self, compliance_report_id: int):
)

return compliance_report

async def validate_compliance_report_access(self, compliance_report):
"""Validates government user access to draft reports"""
is_government = user_has_roles(self.request.user, [RoleEnum.GOVERNMENT])

if compliance_report:
status_enum = ComplianceReportStatusEnum(compliance_report.current_status.status)
is_draft = status_enum == ComplianceReportStatusEnum.Draft

if is_government and is_draft:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Government users cannot access draft compliance reports"
)
3 changes: 2 additions & 1 deletion backend/lcfs/web/api/compliance_report/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ async def get_compliance_report_by_id(
service: ComplianceReportServices = Depends(),
validate: ComplianceReportValidation = Depends(),
) -> ChainedComplianceReportSchema:
await validate.validate_organization_access(report_id)
compliance_report = await validate.validate_organization_access(report_id)
await validate.validate_compliance_report_access(compliance_report)

mask_statuses = not user_has_roles(request.user, [RoleEnum.GOVERNMENT])

Expand Down
Loading

0 comments on commit 202caac

Please sign in to comment.