diff --git a/contracts/ReferralRegistry.sol b/contracts/ReferralRegistry.sol index ab53c86..32a8666 100644 --- a/contracts/ReferralRegistry.sol +++ b/contracts/ReferralRegistry.sol @@ -299,7 +299,7 @@ contract ReferralRegistry is UUPSHelper { /// @notice Gets the list of referred users /// @param key The referral key to check /// @param user The referrer - function getReferred(string calldata key, address user) external view returns (address[] memory) { + function getReferredUsers(string calldata key, address user) external view returns (address[] memory) { return keyToReferred[key][user]; } diff --git a/test/unit/ReferralRegistry.t.sol b/test/unit/ReferralRegistry.t.sol index 0391201..dc30b74 100644 --- a/test/unit/ReferralRegistry.t.sol +++ b/test/unit/ReferralRegistry.t.sol @@ -143,6 +143,50 @@ contract ReferralRegistryTest is Test { assertEq(referrer, referrerOnChain); } + function testUserSwitchesReferrer() public { + vm.prank(owner); + uint256 fee = referralRegistry.costReferralProgram(); + + referralRegistry.addReferralKey{value: fee}(referralKey, cost, requiresRefererToBeSet, owner, requiresAuthorization, paymentToken); + + string memory referrerCode1 = "referrerCode1"; + string memory referrerCode2 = "referrerCode2"; + + // Referrer 1 becomes a referrer + vm.startPrank(referrer); + IERC20(paymentToken).approve(address(referralRegistry), cost); + referralRegistry.becomeReferrer(referralKey, referrerCode1); + vm.stopPrank(); + + // Referrer 2 becomes a referrer + address referrer2 = vm.addr(5); + vm.startPrank(referrer2); + IERC20(paymentToken).approve(address(referralRegistry), cost); + referralRegistry.becomeReferrer(referralKey, referrerCode2); + vm.stopPrank(); + + // User acknowledges referrer 1 + vm.prank(user); + referralRegistry.acknowledgeReferrerByKey(referralKey, referrerCode1); + address referrerOnChain = referralRegistry.getReferrer(referralKey, user); + assertEq(referrer, referrerOnChain); + + // User switches to referrer 2 + vm.prank(user); + referralRegistry.acknowledgeReferrerByKey(referralKey, referrerCode2); + referrerOnChain = referralRegistry.getReferrer(referralKey, user); + assertEq(referrer2, referrerOnChain); + + // Assert the list of referred users for both referrers + address[] memory referredUsers1 = referralRegistry.getReferredUsers(referralKey, referrer); + address[] memory referredUsers2 = referralRegistry.getReferredUsers(referralKey, referrer2); + + assertEq(referredUsers1.length, 0); + assertEq(referredUsers2.length, 1); + assertEq(referredUsers2[0], user); + } + + function testAcknowledgeReferrerByKeyWithoutCost() public { vm.prank(owner); uint256 fee = referralRegistry.costReferralProgram();