Skip to content

Commit

Permalink
Merge pull request #21 from GenerationSoftware/gen-322-c4-issue-465-h…
Browse files Browse the repository at this point in the history
…ooks

Implemented safety mechanisms around vault hooks
  • Loading branch information
asselstine authored Aug 15, 2023
2 parents 5195883 + 657230e commit 8e3f87e
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 18 deletions.
92 changes: 77 additions & 15 deletions src/Vault.sol
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,18 @@ error LPZeroAddress();
/// @param maxYieldFeePercentage The max yield fee percentage in integer format (this value is equal to 1 in decimal format)
error YieldFeePercentageGTPrecision(uint256 yieldFeePercentage, uint256 maxYieldFeePercentage);

/// @notice Emitted when the BeforeClaim prize hook fails
/// @param reason The revert reason that was thrown
error BeforeClaimPrizeFailed(bytes reason);

/// @notice Emitted when the AfterClaim prize hook fails
/// @param reason The revert reason that was thrown
error AfterClaimPrizeFailed(bytes reason);

// The gas to give to each of the before and after prize claim hooks.
// This should be enough gas to mint an NFT if needed.
uint256 constant HOOK_GAS = 150_000;

/**
* @title PoolTogether V5 Vault
* @author PoolTogether Inc Team, Generation Software Team
Expand Down Expand Up @@ -199,6 +211,20 @@ contract Vault is ERC4626, ERC20Permit, ILiquidationSource, Ownable {
*/
event RecordedExchangeRate(uint256 exchangeRate);

/**
* @notice Emitted when a prize claim fails
* @param winner The winner of the prize
* @param tier The prize tier
* @param prizeIndex The prize index
* @param reason The revert reason that was thrown
*/
event ClaimFailed(
address indexed winner,
uint8 indexed tier,
uint32 indexed prizeIndex,
bytes reason
);

/* ============ Variables ============ */

/// @notice Address of the TwabController used to keep track of balances.
Expand Down Expand Up @@ -616,21 +642,28 @@ contract Vault is ERC4626, ERC20Permit, ILiquidationSource, Ownable {
uint32[][] calldata _prizeIndices,
uint96 _feePerClaim,
address _feeRecipient
) external returns (uint256) {
if (msg.sender != _claimer) revert CallerNotClaimer(msg.sender, _claimer);

) external onlyClaimer returns (uint256) {
uint totalPrizes;

for (uint w = 0; w < _winners.length; w++) {
uint prizeIndicesLength = _prizeIndices[w].length;
for (uint p = 0; p < prizeIndicesLength; p++) {
totalPrizes += _claimPrize(
try this.claimPrize_INTERNAL_USE_ONLY(
_winners[w],
_tier,
_prizeIndices[w][p],
_feePerClaim,
_feeRecipient
);
) returns (uint256 prize) {
totalPrizes += prize;
} catch (bytes memory reason) {
emit ClaimFailed(
_winners[w],
_tier,
_prizeIndices[w][p],
reason
);
}
}
}

Expand Down Expand Up @@ -1038,6 +1071,24 @@ contract Vault is ERC4626, ERC20Permit, ILiquidationSource, Ownable {
}

/* ============ Claim Functions ============ */
/**
* @notice Claim prize for a winner
* @param _winner The winner of the prize
* @param _tier The prize tier
* @param _prizeIndex The prize index
* @param _fee The fee to charge
* @param _feeRecipient The recipient of the fee
* @return The total prize amount claimed. Zero if already claimed.
*/
function claimPrize(
address _winner,
uint8 _tier,
uint32 _prizeIndex,
uint96 _fee,
address _feeRecipient
) external onlyClaimer returns (uint256) {
return this.claimPrize_INTERNAL_USE_ONLY(_winner, _tier, _prizeIndex, _fee, _feeRecipient);
}

/**
* @notice Claim prize for `_winner`.
Expand All @@ -1048,18 +1099,24 @@ contract Vault is ERC4626, ERC20Permit, ILiquidationSource, Ownable {
* @param _feeRecipient Address that will receive the fee
* @return uint256 The total prize amount claimed
*/
function _claimPrize(
function claimPrize_INTERNAL_USE_ONLY(
address _winner,
uint8 _tier,
uint32 _prizeIndex,
uint96 _fee,
address _feeRecipient
) internal returns (uint256) {
) external returns (uint256) {
assert(msg.sender == address(this));

VaultHooks memory hooks = _hooks[_winner];
address recipient;

if (hooks.useBeforeClaimPrize) {
recipient = hooks.implementation.beforeClaimPrize(_winner, _tier, _prizeIndex);
try hooks.implementation.beforeClaimPrize{gas: HOOK_GAS}(_winner, _tier, _prizeIndex, _fee, _feeRecipient) returns (address result) {
recipient = result;
} catch (bytes memory reason) {
revert BeforeClaimPrizeFailed(reason);
}
} else {
recipient = _winner;
}
Expand All @@ -1074,13 +1131,10 @@ contract Vault is ERC4626, ERC20Permit, ILiquidationSource, Ownable {
);

if (hooks.useAfterClaimPrize) {
hooks.implementation.afterClaimPrize(
_winner,
_tier,
_prizeIndex,
prizeTotal - _fee,
recipient
);
try hooks.implementation.afterClaimPrize{gas: HOOK_GAS}(_winner, _tier, _prizeIndex, prizeTotal - _fee, recipient) {
} catch (bytes memory reason) {
revert AfterClaimPrizeFailed(reason);
}
}

return prizeTotal;
Expand Down Expand Up @@ -1239,4 +1293,12 @@ contract Vault is ERC4626, ERC20Permit, ILiquidationSource, Ownable {
function _setYieldFeeRecipient(address yieldFeeRecipient_) internal {
_yieldFeeRecipient = yieldFeeRecipient_;
}

/**
* @notice Requires the caller to be the claimer
*/
modifier onlyClaimer() {
if (msg.sender != _claimer) revert CallerNotClaimer(msg.sender, _claimer);
_;
}
}
4 changes: 3 additions & 1 deletion src/interfaces/IVaultHooks.sol
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ interface IVaultHooks {
function beforeClaimPrize(
address winner,
uint8 tier,
uint32 prizeIndex
uint32 prizeIndex,
uint96 fee,
address feeRecipient
) external returns (address);

/// @notice Triggered after the prize pool claim prize function is called.
Expand Down
23 changes: 21 additions & 2 deletions test/unit/Vault/Vault.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ contract VaultTest is UnitBaseSetup {
vm.stopPrank();
}

function testClaimPrizeClaimerNotSet() public {
function testClaimPrizesClaimerNotSet() public {
vault.setClaimer(address(0));

address _randomUser = address(0xFf107770b6a31261836307218997C66c34681B5A);
Expand All @@ -232,7 +232,7 @@ contract VaultTest is UnitBaseSetup {
vm.stopPrank();
}

function testClaimPrizeCallerNotClaimer() public {
function testClaimPrizesCallerNotClaimer() public {
vm.startPrank(alice);

vm.expectRevert(abi.encodeWithSelector(CallerNotClaimer.selector, alice, claimer));
Expand All @@ -241,6 +241,25 @@ contract VaultTest is UnitBaseSetup {
vm.stopPrank();
}

function testClaimPrizeCallerNotClaimer() public {
vm.startPrank(alice);

vm.expectRevert(abi.encodeWithSelector(CallerNotClaimer.selector, alice, claimer));
vault.claimPrize(alice, uint8(1), uint32(0), uint96(0), address(0));

vm.stopPrank();
}

function testClaimPrize_INTERNAL_USE_ONLY_notCallable() public {
vm.startPrank(address(claimer));

mockPrizePoolClaimPrize(uint8(1), alice, 0, 0, address(claimer));
vm.expectRevert();
vault.claimPrize_INTERNAL_USE_ONLY(alice, uint8(1), 0, 0, address(claimer));

vm.stopPrank();
}

/* ============ Getters ============ */
function testGetTwabController() external {
assertEq(vault.twabController(), address(twabController));
Expand Down

0 comments on commit 8e3f87e

Please sign in to comment.