Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented safety mechanisms around vault hooks #21

Merged
merged 2 commits into from
Aug 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 77 additions & 15 deletions src/Vault.sol
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,18 @@ error LPZeroAddress();
/// @param maxYieldFeePercentage The max yield fee percentage in integer format (this value is equal to 1 in decimal format)
error YieldFeePercentageGTPrecision(uint256 yieldFeePercentage, uint256 maxYieldFeePercentage);

/// @notice Emitted when the BeforeClaim prize hook fails
/// @param reason The revert reason that was thrown
error BeforeClaimPrizeFailed(bytes reason);

/// @notice Emitted when the AfterClaim prize hook fails
/// @param reason The revert reason that was thrown
error AfterClaimPrizeFailed(bytes reason);

// The gas to give to each of the before and after prize claim hooks.
// This should be enough gas to mint an NFT if needed.
uint256 constant HOOK_GAS = 150_000;

/**
* @title PoolTogether V5 Vault
* @author PoolTogether Inc Team, Generation Software Team
Expand Down Expand Up @@ -199,6 +211,20 @@ contract Vault is ERC4626, ERC20Permit, ILiquidationSource, Ownable {
*/
event RecordedExchangeRate(uint256 exchangeRate);

/**
* @notice Emitted when a prize claim fails
* @param winner The winner of the prize
* @param tier The prize tier
* @param prizeIndex The prize index
* @param reason The revert reason that was thrown
*/
event ClaimFailed(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add Natspec doc.

address indexed winner,
uint8 indexed tier,
uint32 indexed prizeIndex,
bytes reason
);

/* ============ Variables ============ */

/// @notice Address of the TwabController used to keep track of balances.
Expand Down Expand Up @@ -616,21 +642,28 @@ contract Vault is ERC4626, ERC20Permit, ILiquidationSource, Ownable {
uint32[][] calldata _prizeIndices,
uint96 _feePerClaim,
address _feeRecipient
) external returns (uint256) {
if (msg.sender != _claimer) revert CallerNotClaimer(msg.sender, _claimer);

) external onlyClaimer returns (uint256) {
uint totalPrizes;

for (uint w = 0; w < _winners.length; w++) {
uint prizeIndicesLength = _prizeIndices[w].length;
for (uint p = 0; p < prizeIndicesLength; p++) {
totalPrizes += _claimPrize(
try this.claimPrize_INTERNAL_USE_ONLY(
_winners[w],
_tier,
_prizeIndices[w][p],
_feePerClaim,
_feeRecipient
);
) returns (uint256 prize) {
totalPrizes += prize;
} catch (bytes memory reason) {
emit ClaimFailed(
_winners[w],
_tier,
_prizeIndices[w][p],
reason
);
}
}
}

Expand Down Expand Up @@ -1038,6 +1071,24 @@ contract Vault is ERC4626, ERC20Permit, ILiquidationSource, Ownable {
}

/* ============ Claim Functions ============ */
/**
* @notice Claim prize for a winner
* @param _winner The winner of the prize
* @param _tier The prize tier
* @param _prizeIndex The prize index
* @param _fee The fee to charge
* @param _feeRecipient The recipient of the fee
* @return The total prize amount claimed. Zero if already claimed.
*/
function claimPrize(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add Natspec doc.

address _winner,
uint8 _tier,
uint32 _prizeIndex,
uint96 _fee,
address _feeRecipient
) external onlyClaimer returns (uint256) {
return this.claimPrize_INTERNAL_USE_ONLY(_winner, _tier, _prizeIndex, _fee, _feeRecipient);
}

/**
* @notice Claim prize for `_winner`.
Expand All @@ -1048,18 +1099,24 @@ contract Vault is ERC4626, ERC20Permit, ILiquidationSource, Ownable {
* @param _feeRecipient Address that will receive the fee
* @return uint256 The total prize amount claimed
*/
function _claimPrize(
function claimPrize_INTERNAL_USE_ONLY(
address _winner,
uint8 _tier,
uint32 _prizeIndex,
uint96 _fee,
address _feeRecipient
) internal returns (uint256) {
) external returns (uint256) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be internal no? Or even private.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It has to be external because of the try catch. Kind of a weird trick

assert(msg.sender == address(this));

VaultHooks memory hooks = _hooks[_winner];
address recipient;

if (hooks.useBeforeClaimPrize) {
recipient = hooks.implementation.beforeClaimPrize(_winner, _tier, _prizeIndex);
try hooks.implementation.beforeClaimPrize{gas: HOOK_GAS}(_winner, _tier, _prizeIndex, _fee, _feeRecipient) returns (address result) {
recipient = result;
} catch (bytes memory reason) {
revert BeforeClaimPrizeFailed(reason);
}
} else {
recipient = _winner;
}
Expand All @@ -1074,13 +1131,10 @@ contract Vault is ERC4626, ERC20Permit, ILiquidationSource, Ownable {
);

if (hooks.useAfterClaimPrize) {
hooks.implementation.afterClaimPrize(
_winner,
_tier,
_prizeIndex,
prizeTotal - _fee,
recipient
);
try hooks.implementation.afterClaimPrize{gas: HOOK_GAS}(_winner, _tier, _prizeIndex, prizeTotal - _fee, recipient) {
} catch (bytes memory reason) {
revert AfterClaimPrizeFailed(reason);
}
}

return prizeTotal;
Expand Down Expand Up @@ -1239,4 +1293,12 @@ contract Vault is ERC4626, ERC20Permit, ILiquidationSource, Ownable {
function _setYieldFeeRecipient(address yieldFeeRecipient_) internal {
_yieldFeeRecipient = yieldFeeRecipient_;
}

/**
* @notice Requires the caller to be the claimer
*/
modifier onlyClaimer() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add Natspec doc.

if (msg.sender != _claimer) revert CallerNotClaimer(msg.sender, _claimer);
_;
}
}
4 changes: 3 additions & 1 deletion src/interfaces/IVaultHooks.sol
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ interface IVaultHooks {
function beforeClaimPrize(
address winner,
uint8 tier,
uint32 prizeIndex
uint32 prizeIndex,
uint96 fee,
address feeRecipient
) external returns (address);

/// @notice Triggered after the prize pool claim prize function is called.
Expand Down
23 changes: 21 additions & 2 deletions test/unit/Vault/Vault.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ contract VaultTest is UnitBaseSetup {
vm.stopPrank();
}

function testClaimPrizeClaimerNotSet() public {
function testClaimPrizesClaimerNotSet() public {
vault.setClaimer(address(0));

address _randomUser = address(0xFf107770b6a31261836307218997C66c34681B5A);
Expand All @@ -232,7 +232,7 @@ contract VaultTest is UnitBaseSetup {
vm.stopPrank();
}

function testClaimPrizeCallerNotClaimer() public {
function testClaimPrizesCallerNotClaimer() public {
vm.startPrank(alice);

vm.expectRevert(abi.encodeWithSelector(CallerNotClaimer.selector, alice, claimer));
Expand All @@ -241,6 +241,25 @@ contract VaultTest is UnitBaseSetup {
vm.stopPrank();
}

function testClaimPrizeCallerNotClaimer() public {
vm.startPrank(alice);

vm.expectRevert(abi.encodeWithSelector(CallerNotClaimer.selector, alice, claimer));
vault.claimPrize(alice, uint8(1), uint32(0), uint96(0), address(0));

vm.stopPrank();
}

function testClaimPrize_INTERNAL_USE_ONLY_notCallable() public {
vm.startPrank(address(claimer));

mockPrizePoolClaimPrize(uint8(1), alice, 0, 0, address(claimer));
vm.expectRevert();
vault.claimPrize_INTERNAL_USE_ONLY(alice, uint8(1), 0, 0, address(claimer));

vm.stopPrank();
}

/* ============ Getters ============ */
function testGetTwabController() external {
assertEq(vault.twabController(), address(twabController));
Expand Down
Loading