Skip to content

Commit

Permalink
Merge pull request #570 from timcallow/flexible_snapshot_number
Browse files Browse the repository at this point in the history
Flexible snapshot number for data shuffling
  • Loading branch information
RandomDefaultUser authored Oct 7, 2024
2 parents f30f085 + 82881e2 commit 96e983d
Showing 1 changed file with 61 additions and 13 deletions.
74 changes: 61 additions & 13 deletions mala/datahandling/data_shuffler.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,29 @@ def __shuffle_numpy(
)
)

# if the number of new snapshots is not a divisor of the grid size
# then we have to trim the original snapshots to size
# the indicies to be removed are selected at random
if self.data_points_to_remove is not None:
if self.parameters.shuffling_seed is not None:
np.random.seed(idx * self.parameters.shuffling_seed)
ngrid = descriptor_data[idx].shape[0]
n_descriptor = descriptor_data[idx].shape[-1]
n_target = target_data[idx].shape[-1]

current_target = target_data[idx].reshape(-1, n_target)
current_descriptor = descriptor_data[idx].reshape(
-1, n_descriptor
)

indices = np.random.choice(
ngrid**3,
size=ngrid**3 - self.data_points_to_remove[idx],
)

descriptor_data[idx] = current_descriptor[indices]
target_data[idx] = current_target[indices]

# Do the actual shuffling.
target_name_openpmd = os.path.join(
target_save_path, save_name.replace("*", "%T")
Expand Down Expand Up @@ -165,16 +188,12 @@ def __shuffle_numpy(
)
new_descriptors[
last_start : current_chunk + last_start
] = descriptor_data[j].reshape(
current_grid_size, self.input_dimension
)[
] = descriptor_data[j].reshape(-1, self.input_dimension)[
i * current_chunk : (i + 1) * current_chunk, :
]
new_targets[
last_start : current_chunk + last_start
] = target_data[j].reshape(
current_grid_size, self.output_dimension
)[
] = target_data[j].reshape(-1, self.output_dimension)[
i * current_chunk : (i + 1) * current_chunk, :
]

Expand Down Expand Up @@ -240,7 +259,6 @@ def __shuffle_numpy(
# It will be executed one after another for both of them.
# Use this class to parameterize which of both should be shuffled.
class __DescriptorOrTarget:

def __init__(
self,
save_path,
Expand All @@ -258,7 +276,6 @@ def __init__(
self.dimension = dimension

class __MockedMPIComm:

def __init__(self):
self.rank = 0
self.size = 1
Expand Down Expand Up @@ -521,6 +538,8 @@ def shuffle_snapshots(
]
number_of_data_points = np.sum(snapshot_size_list)

self.data_points_to_remove = None

if number_of_shuffled_snapshots is None:
# If the user does not tell us how many snapshots to use,
# we have to check if the number of snapshots is straightforward.
Expand Down Expand Up @@ -584,10 +603,40 @@ def shuffle_snapshots(
del specified_number_of_new_snapshots

if number_of_data_points % number_of_new_snapshots != 0:
raise Exception(
"Cannot create this number of snapshots "
"from data provided."
)
if snapshot_type == "numpy":
self.data_points_to_remove = []
for i in range(0, self.nr_snapshots):
gridsize = self.parameters.snapshot_directories_list[
i
].grid_size
shuffled_gridsize = int(
gridsize / number_of_new_snapshots
)
self.data_points_to_remove.append(
gridsize
- shuffled_gridsize * number_of_new_snapshots
)
tot_points_missing = sum(self.data_points_to_remove)

printout(
"Warning: number of requested snapshots is not a divisor of",
"the original grid sizes.\n",
f"{tot_points_missing} / {number_of_data_points} data points",
"will be left out of the shuffled snapshots."
)

shuffle_dimensions = [
int(number_of_data_points / number_of_new_snapshots),
1,
1,
]

elif snapshot_type == "openpmd":
# TODO implement arbitrary grid sizes for openpmd
raise Exception(
"Cannot create this number of snapshots "
"from data provided."
)
else:
shuffle_dimensions = [
int(number_of_data_points / number_of_new_snapshots),
Expand All @@ -606,7 +655,6 @@ def shuffle_snapshots(
permutations = []
seeds = []
for i in range(0, number_of_new_snapshots):

# This makes the shuffling deterministic, if specified by the user.
if self.parameters.shuffling_seed is not None:
np.random.seed(i * self.parameters.shuffling_seed)
Expand Down

0 comments on commit 96e983d

Please sign in to comment.