Skip to content

Commit

Permalink
feat:add committee calculator
Browse files Browse the repository at this point in the history
Signed-off-by: Chen Kai <[email protected]>
  • Loading branch information
GrapeBaBa committed Oct 11, 2024
1 parent e1b7d10 commit bf7050a
Show file tree
Hide file tree
Showing 5 changed files with 382 additions and 178 deletions.
168 changes: 168 additions & 0 deletions src/consensus/helpers/committee.zig
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ const phase0 = @import("../../consensus/phase0/types.zig");
const altair = @import("../../consensus/altair/types.zig");
const epoch_helper = @import("../../consensus/helpers/epoch.zig");
const validator_helper = @import("../../consensus/helpers/validator.zig");
const shuffle_helper = @import("../../consensus/helpers/shuffle.zig");
const seed_helper = @import("../../consensus/helpers/seed.zig");

/// Calculates the committee count per slot for a given epoch
/// Returns: The number of committees per slot
Expand All @@ -34,6 +36,72 @@ pub fn getCommitteeCountPerSlot(state: *const consensus.BeaconState, epoch: prim
return @max(@as(u64, 1), @min(preset.ActivePreset.get().MAX_COMMITTEES_PER_SLOT, committees_per_slot));
}

/// computeCommittee returns the committee for the current epoch.
/// @param indices - The validator indices.
/// @param seed - The seed.
/// @param index - The index of the committee.
/// @param count - The number of committees.
/// @param allocator - The allocator.
/// @returns The committee for the current epoch.
/// Spec pseudocode definition:
/// def compute_committee(indices: Sequence[ValidatorIndex],
/// seed: Bytes32,
/// index: uint64,
/// count: uint64) -> Sequence[ValidatorIndex]:
/// """
/// Return the committee corresponding to ``indices``, ``seed``, ``index``, and committee ``count``.
/// """
/// start = (len(indices) * index) // count
/// end = (len(indices) * uint64(index + 1)) // count
/// return [indices[compute_shuffled_index(uint64(i), uint64(len(indices)), seed)] for i in range(start, end)]
/// Note: Caller is responsible for freeing the returned slice.
pub fn computeCommittee(indices: []const primitives.ValidatorIndex, seed: primitives.Bytes32, index: u64, count: u64, allocator: std.mem.Allocator) ![]primitives.ValidatorIndex {
const len = indices.len;
const start = @divFloor(len * index, count);
const end = @divFloor(len * (index + 1), count);
var result = std.ArrayList(primitives.ValidatorIndex).init(allocator);
defer result.deinit();

var i: u64 = start;
while (i < end) : (i += 1) {
const shuffled_index = try shuffle_helper.computeShuffledIndex(@as(u64, i), @as(u64, len), seed);
try result.append(indices[shuffled_index]);
}

return result.toOwnedSlice();
}

/// getBeaconCommittee returns the beacon committee for the current epoch.
/// @param state - The beacon state.
/// @param slot - The slot.
/// @param index - The index of the committee.
/// @param allocator - The allocator.
/// @returns The beacon committee for the current epoch.
/// Spec pseudocode definition:
/// def get_beacon_committee(state: BeaconState, slot: Slot, index: CommitteeIndex) -> Sequence[ValidatorIndex]:
/// """
/// Return the beacon committee at ``slot`` for ``index``.
/// """
/// epoch = compute_epoch_at_slot(slot)
/// committees_per_slot = get_committee_count_per_slot(state, epoch)
/// return compute_committee(
/// indices=get_active_validator_indices(state, epoch),
/// seed=get_seed(state, epoch, DOMAIN_BEACON_ATTESTER),
/// index=(slot % SLOTS_PER_EPOCH) * committees_per_slot + index,
/// count=committees_per_slot * SLOTS_PER_EPOCH,
/// )
/// Note: Caller is responsible for freeing the returned slice.
pub fn getBeaconCommittee(state: *const consensus.BeaconState, slot: primitives.Slot, index: primitives.CommitteeIndex, allocator: std.mem.Allocator) ![]primitives.ValidatorIndex {
const epoch = epoch_helper.computeEpochAtSlot(slot);
const committeesPerSlot = try getCommitteeCountPerSlot(state, epoch, allocator);
const indices = try validator_helper.getActiveValidatorIndices(state, epoch, allocator);
defer allocator.free(indices);
const seed = seed_helper.getSeed(state, epoch, constants.DOMAIN_BEACON_ATTESTER);
const i = @mod(slot, preset.ActivePreset.get().SLOTS_PER_EPOCH) * committeesPerSlot + index;
const count = committeesPerSlot * preset.ActivePreset.get().SLOTS_PER_EPOCH;
return computeCommittee(indices, seed, i, count, allocator);
}

test "test getCommitteeCountPerSlot" {
preset.ActivePreset.set(preset.Presets.minimal);
defer preset.ActivePreset.reset();
Expand Down Expand Up @@ -107,3 +175,103 @@ test "test getCommitteeCountPerSlot" {
count,
);
}

test "test computeCommittee" {
preset.ActivePreset.set(preset.Presets.minimal);
defer preset.ActivePreset.reset();
const indices = [_]primitives.ValidatorIndex{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 };
const seed = .{1} ** 32;
const index = 1;
const count = 3;
const committee = try computeCommittee(&indices, seed, index, count, std.testing.allocator);
defer std.testing.allocator.free(committee);
try std.testing.expectEqual(3, committee.len);
try std.testing.expectEqual(9, committee[0]);
try std.testing.expectEqual(0, committee[1]);
try std.testing.expectEqual(8, committee[2]);
}

test "test getBeaconCommittee" {
preset.ActivePreset.set(preset.Presets.minimal);
defer preset.ActivePreset.reset();
var finalized_checkpoint = consensus.Checkpoint{
.epoch = 5,
.root = .{0} ** 32,
};
var validators = std.ArrayList(consensus.Validator).init(std.testing.allocator);
defer validators.deinit();
const validator1 = consensus.Validator{
.pubkey = undefined,
.withdrawal_credentials = undefined,
.effective_balance = 0,
.slashed = false,
.activation_eligibility_epoch = 0,
.activation_epoch = 0,
.exit_epoch = 10,
.withdrawable_epoch = 10,
};
const validator2 = consensus.Validator{
.pubkey = undefined,
.withdrawal_credentials = undefined,
.effective_balance = 0,
.slashed = false,
.activation_eligibility_epoch = 0,
.activation_epoch = 0,
.exit_epoch = 20,
.withdrawable_epoch = 20,
};
for (0..500000) |_| {
try validators.append(validator1);
try validators.append(validator2);
}

var block_roots = std.ArrayList(primitives.Root).init(std.testing.allocator);
defer block_roots.deinit();
const block_root1 = .{0} ** 32;
const block_root2 = .{1} ** 32;
const block_root3 = .{2} ** 32;
try block_roots.append(block_root1);
try block_roots.append(block_root2);
try block_roots.append(block_root3);

var randao_mixes = try std.ArrayList(primitives.Bytes32).initCapacity(std.testing.allocator, preset.ActivePreset.get().EPOCHS_PER_HISTORICAL_VECTOR);
defer randao_mixes.deinit();
for (0..preset.ActivePreset.get().EPOCHS_PER_HISTORICAL_VECTOR) |slot_index| {
try randao_mixes.append(.{@as(u8, @intCast(slot_index))} ** 32);
}

const state = consensus.BeaconState{
.altair = altair.BeaconState{
.genesis_time = 0,
.genesis_validators_root = .{0} ** 32,
.slot = 100,
.fork = undefined,
.block_roots = block_roots.items,
.state_roots = undefined,
.historical_roots = undefined,
.eth1_data = undefined,
.eth1_data_votes = undefined,
.eth1_deposit_index = 0,
.validators = validators.items,
.balances = undefined,
.randao_mixes = randao_mixes.items,
.slashings = undefined,
.previous_epoch_attestations = undefined,
.current_epoch_attestations = undefined,
.justification_bits = undefined,
.previous_justified_checkpoint = undefined,
.current_justified_checkpoint = undefined,
.finalized_checkpoint = &finalized_checkpoint,
.latest_block_header = undefined,
.inactivity_scores = undefined,
.current_sync_committee = undefined,
.next_sync_committee = undefined,
},
};

const committee = try getBeaconCommittee(&state, 100, 1, std.testing.allocator);
defer std.testing.allocator.free(committee);
try std.testing.expectEqual(15625, committee.len);
try std.testing.expectEqual(341591, committee[0]);
try std.testing.expectEqual(554849, committee[15624]);
}
179 changes: 1 addition & 178 deletions src/consensus/helpers/seed.zig
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ const preset = @import("../../presets/preset.zig");
const phase0 = @import("../../consensus/phase0/types.zig");
const altair = @import("../../consensus/altair/types.zig");
const epoch_helper = @import("../../consensus/helpers/epoch.zig");
const shuffle_helper = @import("../../consensus/helpers/shuffle.zig");
const sha256 = std.crypto.hash.sha2.Sha256;

/// getRandaoMix returns the randao mix at the given epoch.
Expand Down Expand Up @@ -46,96 +47,6 @@ pub fn getSeed(state: *const consensus.BeaconState, epoch: primitives.Epoch, dom
return h;
}

/// computeShuffledIndex returns the shuffled index.
/// @param index - The index.
/// @param index_count - The index count.
/// @param seed - The seed.
/// @returns The shuffled index.
/// Spec pseudocode definition:
/// def compute_shuffled_index(index: uint64, index_count: uint64, seed: Bytes32) -> uint64:
/// """
/// Return the shuffled index corresponding to ``seed`` (and ``index_count``).
/// """
/// assert index < index_count
///
/// # Swap or not (https://link.springer.com/content/pdf/10.1007%2F978-3-642-32009-5_1.pdf)
/// # See the 'generalized domain' algorithm on page 3
/// for current_round in range(SHUFFLE_ROUND_COUNT):
/// pivot = bytes_to_uint64(hash(seed + uint_to_bytes(uint8(current_round)))[0:8]) % index_count
/// flip = (pivot + index_count - index) % index_count
/// position = max(index, flip)
/// source = hash(
/// seed
/// + uint_to_bytes(uint8(current_round))
/// + uint_to_bytes(uint32(position // 256))
/// )
/// byte = uint8(source[(position % 256) // 8])
/// bit = (byte >> (position % 8)) % 2
/// index = flip if bit else index
///
/// return index
pub fn computeShuffledIndex(index: u64, index_count: u64, seed: primitives.Bytes32) !u64 {
if (index >= index_count) return error.IndexOutOfBounds;

var current_index = index;

// Perform the shuffling algorithm
for (@as(u64, 0)..preset.ActivePreset.get().SHUFFLE_ROUND_COUNT) |current_round| {
// Generate round seed
var round_seed: primitives.Bytes32 = undefined;
sha256.hash(seed ++ &[_]u8{@as(u8, @intCast(current_round))}, &round_seed, .{});

// Calculate pivot and flip
const pivot = @mod(std.mem.readInt(u64, round_seed[0..8], .little), index_count);
const flip = @mod((pivot + index_count - current_index), index_count);
const position = @max(current_index, flip);

// Generate source seed
var source_seed: primitives.Bytes32 = undefined;
const position_div_256 = @as(u32, @intCast(@divFloor(position, 256)));
sha256.hash(seed ++ &[_]u8{@as(u8, @intCast(current_round))} ++ std.mem.toBytes(position_div_256), &source_seed, .{});

// Determine bit value and update current_index
const byte_index = @divFloor(@mod(position, 256), 8);
const bit_index = @as(u3, @intCast(@mod(position, 8)));
const selected_byte = source_seed[byte_index];
const selected_bit = @mod(selected_byte >> bit_index, 2);

current_index = if (selected_bit == 1) flip else current_index;
}

return current_index;
}

pub fn computeProposerIndex(state: *const consensus.BeaconState, indices: []const primitives.ValidatorIndex, seed: primitives.Bytes32) !primitives.ValidatorIndex {
if (indices.len == 0) return error.EmptyValidatorIndices;
const MAX_RANDOM_BYTE: u8 = std.math.maxInt(u8);
var i: u64 = 0;
const total: u64 = indices.len;

while (true) {
const shuffled_index = try computeShuffledIndex(@mod(i, total), total, seed);
const candidate_index = indices[@intCast(shuffled_index)];
var hash_result: [32]u8 = undefined;
var seed_plus: [40]u8 = undefined;
@memcpy(seed_plus[0..32], &seed);
std.mem.writeInt(u64, seed_plus[32..40], @divFloor(i, 32), .little);
std.debug.print("seed_plus: {any}, i: {}\n", .{ seed_plus, i });
std.crypto.hash.sha2.Sha256.hash(&seed_plus, &hash_result, .{});
const randomByte = hash_result[@mod(i, 32)];
const effectiveBalance = state.validators()[candidate_index].effective_balance;

const max_effective_balance = switch (state.*) {
.electra => preset.ActivePreset.get().MAX_EFFECTIVE_BALANCE_ELECTRA,
else => preset.ActivePreset.get().MAX_EFFECTIVE_BALANCE,
};
if (effectiveBalance * MAX_RANDOM_BYTE >= max_effective_balance * randomByte) {
return candidate_index;
}
i += 1;
}
}

test "test get_randao_mix" {
preset.ActivePreset.set(preset.Presets.minimal);
defer preset.ActivePreset.reset();
Expand Down Expand Up @@ -275,91 +186,3 @@ test "test get_seed" {
std.crypto.hash.sha2.Sha256.hash(&expected_value, &expectedSeed, .{});
try std.testing.expectEqual(expectedSeed, seed);
}

test "test computeShuffledIndex" {
preset.ActivePreset.set(preset.Presets.minimal);
defer preset.ActivePreset.reset();
const index_count = 10;
const seed = .{3} ** 32;
const index = 5;
const shuffledIndex = try computeShuffledIndex(index, index_count, seed);
try std.testing.expectEqual(7, shuffledIndex);

const index_count1 = 10000000;
const seed1 = .{4} ** 32;
const index1 = 5776655;
const shuffledIndex1 = try computeShuffledIndex(index1, index_count1, seed1);
try std.testing.expectEqual(3446028, shuffledIndex1);
}

test "test computeProposerIndex" {
preset.ActivePreset.set(preset.Presets.minimal);
defer preset.ActivePreset.reset();
var finalized_checkpoint = consensus.Checkpoint{
.epoch = 5,
.root = .{0} ** 32,
};
var validators = std.ArrayList(consensus.Validator).init(std.testing.allocator);
defer validators.deinit();
const validator1 = consensus.Validator{
.pubkey = undefined,
.withdrawal_credentials = undefined,
.effective_balance = 12312312312,
.slashed = false,
.activation_eligibility_epoch = 0,
.activation_epoch = 0,
.exit_epoch = 0,
.withdrawable_epoch = 0,
};
try validators.append(validator1);
const validator2 = consensus.Validator{
.pubkey = undefined,
.withdrawal_credentials = undefined,
.effective_balance = 232323232332,
.slashed = false,
.activation_eligibility_epoch = 0,
.activation_epoch = 0,
.exit_epoch = 0,
.withdrawable_epoch = 0,
};
try validators.append(validator2);

var randao_mixes = try std.ArrayList(primitives.Bytes32).initCapacity(std.testing.allocator, preset.ActivePreset.get().EPOCHS_PER_HISTORICAL_VECTOR);
defer randao_mixes.deinit();
for (0..preset.ActivePreset.get().EPOCHS_PER_HISTORICAL_VECTOR) |slot_index| {
try randao_mixes.append(.{@as(u8, @intCast(slot_index))} ** 32);
}

const state = consensus.BeaconState{
.altair = altair.BeaconState{
.genesis_time = 0,
.genesis_validators_root = .{0} ** 32,
.slot = 100,
.fork = undefined,
.block_roots = undefined,
.state_roots = undefined,
.historical_roots = undefined,
.eth1_data = undefined,
.eth1_data_votes = undefined,
.eth1_deposit_index = 0,
.validators = validators.items,
.balances = undefined,
.randao_mixes = randao_mixes.items,
.slashings = undefined,
.previous_epoch_attestations = undefined,
.current_epoch_attestations = undefined,
.justification_bits = undefined,
.previous_justified_checkpoint = undefined,
.current_justified_checkpoint = undefined,
.finalized_checkpoint = &finalized_checkpoint,
.latest_block_header = undefined,
.inactivity_scores = undefined,
.current_sync_committee = undefined,
.next_sync_committee = undefined,
},
};

const validator_index = [_]primitives.ValidatorIndex{ 0, 1 };
const proposer_index = try computeProposerIndex(&state, &validator_index, .{1} ** 32);
try std.testing.expectEqual(0, proposer_index);
}
Loading

0 comments on commit bf7050a

Please sign in to comment.