diff --git a/contracts/ReferralRegistry.sol b/contracts/ReferralRegistry.sol index 2db7f75..32a8666 100644 --- a/contracts/ReferralRegistry.sol +++ b/contracts/ReferralRegistry.sol @@ -57,6 +57,9 @@ contract ReferralRegistry is UUPSHelper { /// @notice Mapping to store user to referrer relationships mapping(string => mapping(address => address)) public keyToUserToReferrer; + /// @notice Mapping to list referred + mapping(string => mapping(address => address[])) public keyToReferred; + /// @notice Adds a new referral key to the list /// @param key The referral key to add /// @param _cost The cost of the referral program @@ -161,10 +164,22 @@ contract ReferralRegistry is UUPSHelper { /// @param key The referral key for which the user is acknowledging the referrer /// @param referrer The address of the referrer function acknowledgeReferrer(string calldata key, address referrer) public { + if (keyToUserToReferrer[key][msg.sender] != address(0)) { + address previousReferrer = keyToUserToReferrer[key][msg.sender]; + address[] storage previousListOfReferred = keyToReferred[key][previousReferrer]; + for (uint256 i = 0; i < previousListOfReferred.length; i++) { + if (previousListOfReferred[i] == msg.sender) { + previousListOfReferred[i] = previousListOfReferred[previousListOfReferred.length - 1]; + previousListOfReferred.pop(); + break; + } + } + } if (referralPrograms[key].requiresRefererToBeSet) { require(refererStatus[key][referrer] == ReferralStatus.Set, "Referrer has not created a referral link"); } keyToUserToReferrer[key][msg.sender] = referrer; + keyToReferred[key][referrer].push(msg.sender); emit ReferrerAcknowledged(key, msg.sender, referrer); } @@ -281,6 +296,13 @@ contract ReferralRegistry is UUPSHelper { return keyToUserToReferrer[key][user]; } + /// @notice Gets the list of referred users + /// @param key The referral key to check + /// @param user The referrer + function getReferredUsers(string calldata key, address user) external view returns (address[] memory) { + return keyToReferred[key][user]; + } + /// @notice Gets the cost of a referral for a specific key /// @param key The referral key to check /// @return The cost of the referral for the given key diff --git a/test/unit/ReferralRegistry.t.sol b/test/unit/ReferralRegistry.t.sol index 0391201..dc30b74 100644 --- a/test/unit/ReferralRegistry.t.sol +++ b/test/unit/ReferralRegistry.t.sol @@ -143,6 +143,50 @@ contract ReferralRegistryTest is Test { assertEq(referrer, referrerOnChain); } + function testUserSwitchesReferrer() public { + vm.prank(owner); + uint256 fee = referralRegistry.costReferralProgram(); + + referralRegistry.addReferralKey{value: fee}(referralKey, cost, requiresRefererToBeSet, owner, requiresAuthorization, paymentToken); + + string memory referrerCode1 = "referrerCode1"; + string memory referrerCode2 = "referrerCode2"; + + // Referrer 1 becomes a referrer + vm.startPrank(referrer); + IERC20(paymentToken).approve(address(referralRegistry), cost); + referralRegistry.becomeReferrer(referralKey, referrerCode1); + vm.stopPrank(); + + // Referrer 2 becomes a referrer + address referrer2 = vm.addr(5); + vm.startPrank(referrer2); + IERC20(paymentToken).approve(address(referralRegistry), cost); + referralRegistry.becomeReferrer(referralKey, referrerCode2); + vm.stopPrank(); + + // User acknowledges referrer 1 + vm.prank(user); + referralRegistry.acknowledgeReferrerByKey(referralKey, referrerCode1); + address referrerOnChain = referralRegistry.getReferrer(referralKey, user); + assertEq(referrer, referrerOnChain); + + // User switches to referrer 2 + vm.prank(user); + referralRegistry.acknowledgeReferrerByKey(referralKey, referrerCode2); + referrerOnChain = referralRegistry.getReferrer(referralKey, user); + assertEq(referrer2, referrerOnChain); + + // Assert the list of referred users for both referrers + address[] memory referredUsers1 = referralRegistry.getReferredUsers(referralKey, referrer); + address[] memory referredUsers2 = referralRegistry.getReferredUsers(referralKey, referrer2); + + assertEq(referredUsers1.length, 0); + assertEq(referredUsers2.length, 1); + assertEq(referredUsers2[0], user); + } + + function testAcknowledgeReferrerByKeyWithoutCost() public { vm.prank(owner); uint256 fee = referralRegistry.costReferralProgram();