diff --git a/app/lib/ppom/ppom-util.test.ts b/app/lib/ppom/ppom-util.test.ts index 5179288d0a9..1046d63bf93 100644 --- a/app/lib/ppom/ppom-util.test.ts +++ b/app/lib/ppom/ppom-util.test.ts @@ -4,10 +4,8 @@ import * as TransactionActions from '../../actions/transaction'; // eslint-disab import * as NetworkControllerSelectors from '../../selectors/networkController'; // eslint-disable-line import/no-namespace import Engine from '../../core/Engine'; import PPOMUtil from './ppom-util'; -import { - isSecurityAlertsAPIEnabled, - validateWithSecurityAlertsAPI, -} from './security-alerts-api'; +// eslint-disable-next-line import/no-namespace +import * as securityAlertAPI from './security-alerts-api'; const CHAIN_ID_MOCK = '0x1'; @@ -90,10 +88,17 @@ const mockSignatureRequest = { describe('PPOM Utils', () => { const validateWithSecurityAlertsAPIMock = jest.mocked( - validateWithSecurityAlertsAPI, + securityAlertAPI.validateWithSecurityAlertsAPI, ); - const isSecurityAlertsEnabledMock = jest.mocked(isSecurityAlertsAPIEnabled); + const isSecurityAlertsEnabledMock = jest.mocked( + securityAlertAPI.isSecurityAlertsAPIEnabled, + ); + + const getSupportedChainIdsMock = jest.spyOn( + securityAlertAPI, + 'getSecurityAlertsAPISupportedChainIds', + ); const normalizeTransactionParamsMock = jest.mocked( normalizeTransactionParams, @@ -275,6 +280,7 @@ describe('PPOM Utils', () => { it('uses security alerts API if enabled', async () => { isSecurityAlertsEnabledMock.mockReturnValue(true); + getSupportedChainIdsMock.mockResolvedValue([CHAIN_ID_MOCK]); await PPOMUtil.validateRequest(mockRequest, CHAIN_ID_MOCK); @@ -289,6 +295,7 @@ describe('PPOM Utils', () => { it('uses controller if security alerts API throws', async () => { isSecurityAlertsEnabledMock.mockReturnValue(true); + getSupportedChainIdsMock.mockResolvedValue([CHAIN_ID_MOCK]); validateWithSecurityAlertsAPIMock.mockRejectedValue( new Error('Test Error'), @@ -306,5 +313,17 @@ describe('PPOM Utils', () => { mockRequest, ); }); + + it('validates correctly if security alerts API throws', async () => { + const spy = jest.spyOn( + TransactionActions, + 'setTransactionSecurityAlertResponse', + ); + jest + .spyOn(securityAlertAPI, 'getSecurityAlertsAPISupportedChainIds') + .mockRejectedValue(new Error('Test Error')); + await PPOMUtil.validateRequest(mockRequest, CHAIN_ID_MOCK); + expect(spy).toBeCalledTimes(2); + }); }); }); diff --git a/app/lib/ppom/ppom-util.ts b/app/lib/ppom/ppom-util.ts index 5b35790797b..8415622c3c2 100644 --- a/app/lib/ppom/ppom-util.ts +++ b/app/lib/ppom/ppom-util.ts @@ -18,10 +18,13 @@ import { import { WALLET_CONNECT_ORIGIN } from '../../util/walletconnect'; import AppConstants from '../../core/AppConstants'; import { + getSecurityAlertsAPISupportedChainIds, isSecurityAlertsAPIEnabled, validateWithSecurityAlertsAPI, } from './security-alerts-api'; import { PPOMController } from '@metamask/ppom-validator'; +import { Hex } from '@metamask/utils'; +import { BLOCKAID_SUPPORTED_CHAIN_IDS } from '../../util/networks'; export interface PPOMRequest { method: string; @@ -64,8 +67,14 @@ async function validateRequest(req: PPOMRequest, transactionId?: string) { const chainId = NetworkController.state.providerConfig.chainId; const isConfirmationMethod = CONFIRMATION_METHODS.includes(req.method); - - if (!ppomController || !isBlockaidFeatureEnabled() || !isConfirmationMethod) { + const isSupportedChain = await isChainSupported(chainId); + + if ( + !ppomController || + !isBlockaidFeatureEnabled() || + !isConfirmationMethod || + !isSupportedChain + ) { return; } @@ -124,6 +133,22 @@ async function validateRequest(req: PPOMRequest, transactionId?: string) { } } +async function isChainSupported(chainId: Hex): Promise { + let supportedChainIds = BLOCKAID_SUPPORTED_CHAIN_IDS; + + try { + if (isSecurityAlertsAPIEnabled()) { + supportedChainIds = await getSecurityAlertsAPISupportedChainIds(); + } + } catch (e) { + Logger.log( + `Error fetching supported chains from security alerts API: ${e}`, + ); + } + + return supportedChainIds.includes(chainId); +} + async function validateWithController( ppomController: PPOMController, request: PPOMRequest, diff --git a/app/lib/ppom/security-alerts-api.test.ts b/app/lib/ppom/security-alerts-api.test.ts index 2c0dc92b825..2ca15d22052 100644 --- a/app/lib/ppom/security-alerts-api.test.ts +++ b/app/lib/ppom/security-alerts-api.test.ts @@ -2,7 +2,10 @@ import { Reason, ResultType, } from '../../components/Views/confirmations/components/BlockaidBanner/BlockaidBanner.types'; -import { validateWithSecurityAlertsAPI } from './security-alerts-api'; +import { + getSecurityAlertsAPISupportedChainIds, + validateWithSecurityAlertsAPI, +} from './security-alerts-api'; const CHAIN_ID_MOCK = '0x1'; @@ -66,4 +69,31 @@ describe('Security Alerts API', () => { ); }); }); + + describe('getSecurityAlertsAPISupportedChainIds', () => { + it('sends GET request', async () => { + const SUPPORTED_CHAIN_IDS_MOCK = ['0x1', '0x2']; + fetchMock.mockResolvedValue({ + ok: true, + json: async () => SUPPORTED_CHAIN_IDS_MOCK, + }); + const response = await getSecurityAlertsAPISupportedChainIds(); + + expect(response).toEqual(SUPPORTED_CHAIN_IDS_MOCK); + + expect(fetchMock).toHaveBeenCalledTimes(1); + expect(fetchMock).toHaveBeenCalledWith( + `https://example.com/supportedChains`, + undefined, + ); + }); + + it('throws an error if response is not ok', async () => { + fetchMock.mockResolvedValue({ ok: false, status: 404 }); + + await expect(getSecurityAlertsAPISupportedChainIds()).rejects.toThrow( + 'Security alerts API request failed with status: 404', + ); + }); + }); }); diff --git a/app/lib/ppom/security-alerts-api.ts b/app/lib/ppom/security-alerts-api.ts index fb06d928cb0..94737e28b8d 100644 --- a/app/lib/ppom/security-alerts-api.ts +++ b/app/lib/ppom/security-alerts-api.ts @@ -1,6 +1,8 @@ +import { Hex } from '@metamask/utils'; import { SecurityAlertResponse } from '../../components/Views/confirmations/components/BlockaidBanner/BlockaidBanner.types'; const ENDPOINT_VALIDATE = 'validate'; +const ENDPOINT_SUPPORTED_CHAINS = 'supportedChains'; export interface SecurityAlertsAPIRequest { method: string; @@ -13,22 +15,26 @@ export function isSecurityAlertsAPIEnabled() { export async function validateWithSecurityAlertsAPI( chainId: string, - request: SecurityAlertsAPIRequest, + body: SecurityAlertsAPIRequest, ): Promise { const endpoint = `${ENDPOINT_VALIDATE}/${chainId}`; - return postRequest(endpoint, request); -} - -async function postRequest(endpoint: string, body: unknown) { - const url = getUrl(endpoint); - - const response = await fetch(url, { + return request(endpoint, { method: 'POST', body: JSON.stringify(body), headers: { 'Content-Type': 'application/json', }, }); +} + +export async function getSecurityAlertsAPISupportedChainIds(): Promise { + return request(ENDPOINT_SUPPORTED_CHAINS); +} + +async function request(endpoint: string, options?: RequestInit) { + const url = getUrl(endpoint); + + const response = await fetch(url, options); if (!response.ok) { throw new Error(