Skip to content

Commit

Permalink
feat: make normalizeRoute receive provider via Options
Browse files Browse the repository at this point in the history
  • Loading branch information
cristovaoth committed Jan 8, 2025
1 parent 1a93b86 commit e1e8865
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 19 deletions.
9 changes: 4 additions & 5 deletions src/execute/normalizeRoute.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@ import assert from 'assert'

import { expect, describe, it } from 'bun:test'

import { encodeMultiSend } from './multisend'
import { privateKeyToAccount } from 'viem/accounts'
import { randomHash, testClient } from '../../test/client'
import { deploySafe } from '../../test/avatar'
import { eoaSafe } from '../../test/routes'
import { AccountType, Safe } from '../types'
import { AccountType, ChainId } from '../types'
import { normalizeRoute } from './normalizeRoute'
import { Eip1193Provider } from '@safe-global/protocol-kit'

describe('normalizeRoute', () => {
it('queries and patches missing threshold in a SAFE account', async () => {
Expand Down Expand Up @@ -38,9 +36,10 @@ describe('normalizeRoute', () => {
;(route.waypoints[1].account as any).threshold = undefined
assert(route.waypoints[1].account.threshold == undefined)

route = await normalizeRoute(route, testClient as Eip1193Provider)
route = await normalizeRoute(route, {
providers: { [testClient.chain.id as ChainId]: testClient },
})
assert(route.waypoints[1].account.type == AccountType.SAFE)

expect(route.waypoints[1].account.threshold).toEqual(3)
})
})
31 changes: 17 additions & 14 deletions src/execute/normalizeRoute.ts
Original file line number Diff line number Diff line change
@@ -1,24 +1,25 @@
import { Address, encodeFunctionData, getAddress, parseAbi } from 'viem'
import { type Eip1193Provider } from '@safe-global/protocol-kit'
import { Address, encodeFunctionData, parseAbi } from 'viem'

import { validatePrefixedAddress } from '../addresses'
import { splitPrefixedAddress, validatePrefixedAddress } from '../addresses'

import {
Account,
AccountType,
ChainId,
Connection,
PrefixedAddress,
Route,
StartingPoint,
Waypoint,
} from '../types'
import { getEip1193Provider, Options } from './options'

export async function normalizeRoute(
route: Route,
provider: Eip1193Provider
options?: Options
): Promise<Route> {
const waypoints = await Promise.all(
route.waypoints.map((w) => normalizeWaypoint(w, provider))
route.waypoints.map((w) => normalizeWaypoint(w, options))
)

return {
Expand All @@ -31,11 +32,11 @@ export async function normalizeRoute(

export async function normalizeWaypoint(
waypoint: StartingPoint | Waypoint,
provider: Eip1193Provider
options?: Options
): Promise<StartingPoint | Waypoint> {
waypoint = {
...waypoint,
account: await normalizeAccount(waypoint.account, provider),
account: await normalizeAccount(waypoint.account, options),
}

if ('connection' in waypoint) {
Expand All @@ -50,7 +51,7 @@ export async function normalizeWaypoint(

async function normalizeAccount(
account: Account,
provider: Eip1193Provider
options?: Options
): Promise<Account> {
account = {
...account,
Expand All @@ -62,7 +63,7 @@ async function normalizeAccount(
account.type == AccountType.SAFE &&
typeof account.threshold != 'number'
) {
account.threshold = await fetchThreshold(account.address, provider)
account.threshold = await fetchThreshold(account, options)
}

return account
Expand All @@ -86,18 +87,20 @@ function normalizePrefixedAddress(address: PrefixedAddress): PrefixedAddress {
}

async function fetchThreshold(
safe: Address,
provider: Eip1193Provider
account: Account,
options?: Options
): Promise<number> {
const abi = parseAbi(['function getThreshold() view returns (uint256)'])
const [chainId, safe] = splitPrefixedAddress(account.prefixedAddress)
const provider = getEip1193Provider({ chainId: chainId as ChainId, options })

return Number(
await provider.request({
method: 'eth_call',
params: [
{
to: getAddress(safe),
to: safe,
data: encodeFunctionData({
abi,
abi: parseAbi(['function getThreshold() view returns (uint256)']),
functionName: 'getThreshold',
args: [],
}),
Expand Down

0 comments on commit e1e8865

Please sign in to comment.