diff --git a/src/Vault.sol b/src/Vault.sol index b457403..2a7c415 100644 --- a/src/Vault.sol +++ b/src/Vault.sol @@ -98,6 +98,18 @@ error LPZeroAddress(); /// @param maxYieldFeePercentage The max yield fee percentage in integer format (this value is equal to 1 in decimal format) error YieldFeePercentageGTPrecision(uint256 yieldFeePercentage, uint256 maxYieldFeePercentage); +/// @notice Emitted when the BeforeClaim prize hook fails +/// @param reason The revert reason that was thrown +error BeforeClaimPrizeFailed(bytes reason); + +/// @notice Emitted when the AfterClaim prize hook fails +/// @param reason The revert reason that was thrown +error AfterClaimPrizeFailed(bytes reason); + +// The gas to give to each of the before and after prize claim hooks. +// This should be enough gas to mint an NFT if needed. +uint256 constant HOOK_GAS = 150_000; + /** * @title PoolTogether V5 Vault * @author PoolTogether Inc Team, Generation Software Team @@ -199,6 +211,20 @@ contract Vault is ERC4626, ERC20Permit, ILiquidationSource, Ownable { */ event RecordedExchangeRate(uint256 exchangeRate); + /** + * @notice Emitted when a prize claim fails + * @param winner The winner of the prize + * @param tier The prize tier + * @param prizeIndex The prize index + * @param reason The revert reason that was thrown + */ + event ClaimFailed( + address indexed winner, + uint8 indexed tier, + uint32 indexed prizeIndex, + bytes reason + ); + /* ============ Variables ============ */ /// @notice Address of the TwabController used to keep track of balances. @@ -616,21 +642,28 @@ contract Vault is ERC4626, ERC20Permit, ILiquidationSource, Ownable { uint32[][] calldata _prizeIndices, uint96 _feePerClaim, address _feeRecipient - ) external returns (uint256) { - if (msg.sender != _claimer) revert CallerNotClaimer(msg.sender, _claimer); - + ) external onlyClaimer returns (uint256) { uint totalPrizes; for (uint w = 0; w < _winners.length; w++) { uint prizeIndicesLength = _prizeIndices[w].length; for (uint p = 0; p < prizeIndicesLength; p++) { - totalPrizes += _claimPrize( + try this.claimPrize_INTERNAL_USE_ONLY( _winners[w], _tier, _prizeIndices[w][p], _feePerClaim, _feeRecipient - ); + ) returns (uint256 prize) { + totalPrizes += prize; + } catch (bytes memory reason) { + emit ClaimFailed( + _winners[w], + _tier, + _prizeIndices[w][p], + reason + ); + } } } @@ -1038,6 +1071,24 @@ contract Vault is ERC4626, ERC20Permit, ILiquidationSource, Ownable { } /* ============ Claim Functions ============ */ + /** + * @notice Claim prize for a winner + * @param _winner The winner of the prize + * @param _tier The prize tier + * @param _prizeIndex The prize index + * @param _fee The fee to charge + * @param _feeRecipient The recipient of the fee + * @return The total prize amount claimed. Zero if already claimed. + */ + function claimPrize( + address _winner, + uint8 _tier, + uint32 _prizeIndex, + uint96 _fee, + address _feeRecipient + ) external onlyClaimer returns (uint256) { + return this.claimPrize_INTERNAL_USE_ONLY(_winner, _tier, _prizeIndex, _fee, _feeRecipient); + } /** * @notice Claim prize for `_winner`. @@ -1048,18 +1099,24 @@ contract Vault is ERC4626, ERC20Permit, ILiquidationSource, Ownable { * @param _feeRecipient Address that will receive the fee * @return uint256 The total prize amount claimed */ - function _claimPrize( + function claimPrize_INTERNAL_USE_ONLY( address _winner, uint8 _tier, uint32 _prizeIndex, uint96 _fee, address _feeRecipient - ) internal returns (uint256) { + ) external returns (uint256) { + assert(msg.sender == address(this)); + VaultHooks memory hooks = _hooks[_winner]; address recipient; if (hooks.useBeforeClaimPrize) { - recipient = hooks.implementation.beforeClaimPrize(_winner, _tier, _prizeIndex); + try hooks.implementation.beforeClaimPrize{gas: HOOK_GAS}(_winner, _tier, _prizeIndex, _fee, _feeRecipient) returns (address result) { + recipient = result; + } catch (bytes memory reason) { + revert BeforeClaimPrizeFailed(reason); + } } else { recipient = _winner; } @@ -1074,13 +1131,10 @@ contract Vault is ERC4626, ERC20Permit, ILiquidationSource, Ownable { ); if (hooks.useAfterClaimPrize) { - hooks.implementation.afterClaimPrize( - _winner, - _tier, - _prizeIndex, - prizeTotal - _fee, - recipient - ); + try hooks.implementation.afterClaimPrize{gas: HOOK_GAS}(_winner, _tier, _prizeIndex, prizeTotal - _fee, recipient) { + } catch (bytes memory reason) { + revert AfterClaimPrizeFailed(reason); + } } return prizeTotal; @@ -1239,4 +1293,12 @@ contract Vault is ERC4626, ERC20Permit, ILiquidationSource, Ownable { function _setYieldFeeRecipient(address yieldFeeRecipient_) internal { _yieldFeeRecipient = yieldFeeRecipient_; } + + /** + * @notice Requires the caller to be the claimer + */ + modifier onlyClaimer() { + if (msg.sender != _claimer) revert CallerNotClaimer(msg.sender, _claimer); + _; + } } diff --git a/src/interfaces/IVaultHooks.sol b/src/interfaces/IVaultHooks.sol index b103edb..07fba1d 100644 --- a/src/interfaces/IVaultHooks.sol +++ b/src/interfaces/IVaultHooks.sol @@ -22,7 +22,9 @@ interface IVaultHooks { function beforeClaimPrize( address winner, uint8 tier, - uint32 prizeIndex + uint32 prizeIndex, + uint96 fee, + address feeRecipient ) external returns (address); /// @notice Triggered after the prize pool claim prize function is called. diff --git a/test/unit/Vault/Vault.t.sol b/test/unit/Vault/Vault.t.sol index 7637a50..aed69b0 100644 --- a/test/unit/Vault/Vault.t.sol +++ b/test/unit/Vault/Vault.t.sol @@ -218,7 +218,7 @@ contract VaultTest is UnitBaseSetup { vm.stopPrank(); } - function testClaimPrizeClaimerNotSet() public { + function testClaimPrizesClaimerNotSet() public { vault.setClaimer(address(0)); address _randomUser = address(0xFf107770b6a31261836307218997C66c34681B5A); @@ -232,7 +232,7 @@ contract VaultTest is UnitBaseSetup { vm.stopPrank(); } - function testClaimPrizeCallerNotClaimer() public { + function testClaimPrizesCallerNotClaimer() public { vm.startPrank(alice); vm.expectRevert(abi.encodeWithSelector(CallerNotClaimer.selector, alice, claimer)); @@ -241,6 +241,25 @@ contract VaultTest is UnitBaseSetup { vm.stopPrank(); } + function testClaimPrizeCallerNotClaimer() public { + vm.startPrank(alice); + + vm.expectRevert(abi.encodeWithSelector(CallerNotClaimer.selector, alice, claimer)); + vault.claimPrize(alice, uint8(1), uint32(0), uint96(0), address(0)); + + vm.stopPrank(); + } + + function testClaimPrize_INTERNAL_USE_ONLY_notCallable() public { + vm.startPrank(address(claimer)); + + mockPrizePoolClaimPrize(uint8(1), alice, 0, 0, address(claimer)); + vm.expectRevert(); + vault.claimPrize_INTERNAL_USE_ONLY(alice, uint8(1), 0, 0, address(claimer)); + + vm.stopPrank(); + } + /* ============ Getters ============ */ function testGetTwabController() external { assertEq(vault.twabController(), address(twabController));