Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented safety mechanisms around vault hooks #21

Merged
merged 2 commits into from
Aug 15, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 52 additions & 15 deletions src/Vault.sol
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,12 @@ error LPZeroAddress();
/// @param maxYieldFeePercentage The max yield fee percentage in integer format (this value is equal to 1 in decimal format)
error YieldFeePercentageGTPrecision(uint256 yieldFeePercentage, uint256 maxYieldFeePercentage);

error BeforeClaimPrizeFailed(bytes reason);

error AfterClaimPrizeFailed(bytes reason);

uint256 constant HOOK_GAS = 150_000;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add Natspec docs.


/**
* @title PoolTogether V5 Vault
* @author PoolTogether Inc Team, Generation Software Team
Expand Down Expand Up @@ -199,6 +205,13 @@ contract Vault is ERC4626, ERC20Permit, ILiquidationSource, Ownable {
*/
event RecordedExchangeRate(uint256 exchangeRate);

event ClaimFailed(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add Natspec doc.

address indexed winner,
uint8 indexed tier,
uint32 indexed prizeIndex,
bytes reason
);

/* ============ Variables ============ */

/// @notice Address of the TwabController used to keep track of balances.
Expand Down Expand Up @@ -616,21 +629,28 @@ contract Vault is ERC4626, ERC20Permit, ILiquidationSource, Ownable {
uint32[][] calldata _prizeIndices,
uint96 _feePerClaim,
address _feeRecipient
) external returns (uint256) {
if (msg.sender != _claimer) revert CallerNotClaimer(msg.sender, _claimer);

) external onlyClaimer returns (uint256) {
uint totalPrizes;

for (uint w = 0; w < _winners.length; w++) {
uint prizeIndicesLength = _prizeIndices[w].length;
for (uint p = 0; p < prizeIndicesLength; p++) {
totalPrizes += _claimPrize(
try this.claimPrize_INTERNAL_USE_ONLY(
_winners[w],
_tier,
_prizeIndices[w][p],
_feePerClaim,
_feeRecipient
);
) returns (uint256 prize) {
totalPrizes += prize;
} catch (bytes memory reason) {
emit ClaimFailed(
_winners[w],
_tier,
_prizeIndices[w][p],
reason
);
}
}
}

Expand Down Expand Up @@ -1038,6 +1058,15 @@ contract Vault is ERC4626, ERC20Permit, ILiquidationSource, Ownable {
}

/* ============ Claim Functions ============ */
function claimPrize(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add Natspec doc.

address _winner,
uint8 _tier,
uint32 _prizeIndex,
uint96 _fee,
address _feeRecipient
) external onlyClaimer returns (uint256) {
return this.claimPrize_INTERNAL_USE_ONLY(_winner, _tier, _prizeIndex, _fee, _feeRecipient);
}

/**
* @notice Claim prize for `_winner`.
Expand All @@ -1048,18 +1077,24 @@ contract Vault is ERC4626, ERC20Permit, ILiquidationSource, Ownable {
* @param _feeRecipient Address that will receive the fee
* @return uint256 The total prize amount claimed
*/
function _claimPrize(
function claimPrize_INTERNAL_USE_ONLY(
address _winner,
uint8 _tier,
uint32 _prizeIndex,
uint96 _fee,
address _feeRecipient
) internal returns (uint256) {
) external returns (uint256) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be internal no? Or even private.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It has to be external because of the try catch. Kind of a weird trick

assert(msg.sender == address(this));

VaultHooks memory hooks = _hooks[_winner];
address recipient;

if (hooks.useBeforeClaimPrize) {
recipient = hooks.implementation.beforeClaimPrize(_winner, _tier, _prizeIndex);
try hooks.implementation.beforeClaimPrize{gas: HOOK_GAS}(_winner, _tier, _prizeIndex, _fee, _feeRecipient) returns (address result) {
recipient = result;
} catch (bytes memory reason) {
revert BeforeClaimPrizeFailed(reason);
}
} else {
recipient = _winner;
}
Expand All @@ -1074,13 +1109,10 @@ contract Vault is ERC4626, ERC20Permit, ILiquidationSource, Ownable {
);

if (hooks.useAfterClaimPrize) {
hooks.implementation.afterClaimPrize(
_winner,
_tier,
_prizeIndex,
prizeTotal - _fee,
recipient
);
try hooks.implementation.afterClaimPrize{gas: HOOK_GAS}(_winner, _tier, _prizeIndex, prizeTotal - _fee, recipient) {
} catch (bytes memory reason) {
revert AfterClaimPrizeFailed(reason);
}
}

return prizeTotal;
Expand Down Expand Up @@ -1239,4 +1271,9 @@ contract Vault is ERC4626, ERC20Permit, ILiquidationSource, Ownable {
function _setYieldFeeRecipient(address yieldFeeRecipient_) internal {
_yieldFeeRecipient = yieldFeeRecipient_;
}

modifier onlyClaimer() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add Natspec doc.

if (msg.sender != _claimer) revert CallerNotClaimer(msg.sender, _claimer);
_;
}
}
4 changes: 3 additions & 1 deletion src/interfaces/IVaultHooks.sol
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ interface IVaultHooks {
function beforeClaimPrize(
address winner,
uint8 tier,
uint32 prizeIndex
uint32 prizeIndex,
uint96 fee,
address feeRecipient
) external returns (address);

/// @notice Triggered after the prize pool claim prize function is called.
Expand Down
23 changes: 21 additions & 2 deletions test/unit/Vault/Vault.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ contract VaultTest is UnitBaseSetup {
vm.stopPrank();
}

function testClaimPrizeClaimerNotSet() public {
function testClaimPrizesClaimerNotSet() public {
vault.setClaimer(address(0));

address _randomUser = address(0xFf107770b6a31261836307218997C66c34681B5A);
Expand All @@ -232,7 +232,7 @@ contract VaultTest is UnitBaseSetup {
vm.stopPrank();
}

function testClaimPrizeCallerNotClaimer() public {
function testClaimPrizesCallerNotClaimer() public {
vm.startPrank(alice);

vm.expectRevert(abi.encodeWithSelector(CallerNotClaimer.selector, alice, claimer));
Expand All @@ -241,6 +241,25 @@ contract VaultTest is UnitBaseSetup {
vm.stopPrank();
}

function testClaimPrizeCallerNotClaimer() public {
vm.startPrank(alice);

vm.expectRevert(abi.encodeWithSelector(CallerNotClaimer.selector, alice, claimer));
vault.claimPrize(alice, uint8(1), uint32(0), uint96(0), address(0));

vm.stopPrank();
}

function testClaimPrize_INTERNAL_USE_ONLY_notCallable() public {
vm.startPrank(address(claimer));

mockPrizePoolClaimPrize(uint8(1), alice, 0, 0, address(claimer));
vm.expectRevert();
vault.claimPrize_INTERNAL_USE_ONLY(alice, uint8(1), 0, 0, address(claimer));

vm.stopPrank();
}

/* ============ Getters ============ */
function testGetTwabController() external {
assertEq(vault.twabController(), address(twabController));
Expand Down
Loading