Skip to content

Commit

Permalink
Refactor slice & dset_3d scripts with common structure
Browse files Browse the repository at this point in the history
The two scripts now have nearly identical CLI and structure
  • Loading branch information
bcaddy committed Oct 30, 2023
1 parent 266a749 commit 1fa2342
Show file tree
Hide file tree
Showing 2 changed files with 185 additions and 169 deletions.
216 changes: 122 additions & 94 deletions python_scripts/cat_dset_3D.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python3
"""
Python script for concatenating 3D hdf5 datasets. Includes a CLI for concatenating Cholla HDF5 datasets and can be
imported into other scripts where the `concat_3d` function can be used to concatenate the datasets.
imported into other scripts where the `concat_3d_field` function can be used to concatenate the datasets.
Generally the easiest way to import this script is to add the `python_scripts` directory to your python path in your
script like this:
Expand All @@ -18,85 +18,10 @@
import pathlib

# ======================================================================================================================
def main():
"""This function handles the CLI argument parsing and is only intended to be used when this script is invoked from the
command line. If you're importing this file then use the `concat_3d` or `concat_3d_single` functions directly.
"""
# Argument handling
cli = argparse.ArgumentParser()
# Required Arguments
cli.add_argument('-s', '--start_num', type=int, required=True, help='The first output step to concatenate')
cli.add_argument('-e', '--end_num', type=int, required=True, help='The last output step to concatenate')
cli.add_argument('-n', '--num_processes', type=int, required=True, help='The number of processes that were used')
# Optional Arguments
cli.add_argument('-i', '--input_dir', type=pathlib.Path, default=pathlib.Path.cwd(), help='The input directory.')
cli.add_argument('-o', '--output_dir', type=pathlib.Path, default=pathlib.Path.cwd(), help='The output directory.')
cli.add_argument('--skip-fields', type=list, default=[], help='List of fields to skip concatenating. Defaults to empty.')
cli.add_argument('--dtype', type=str, default=None, help='The data type of the output datasets. Accepts most numpy types. Defaults to the same as the input datasets.')
cli.add_argument('--compression-type', type=str, default=None, help='What kind of compression to use on the output data. Defaults to None.')
cli.add_argument('--compression-opts', type=str, default=None, help='What compression settings to use if compressing. Defaults to None.')
args = cli.parse_args()

# Perform the concatenation
concat_3d(start_num=args.start_num,
end_num=args.end_num,
num_processes=args.num_processes,
input_dir=args.input_dir,
output_dir=args.output_dir,
skip_fields=args.skip_fields,
destination_dtype=args.dtype,
compression_type=args.compression_type,
compression_options=args.compression_opts)
# ======================================================================================================================

# ======================================================================================================================
def concat_3d(start_num: int,
end_num: int,
num_processes: int,
input_dir: pathlib.Path = pathlib.Path.cwd(),
output_dir: pathlib.Path = pathlib.Path.cwd(),
skip_fields: list = [],
destination_dtype: np.dtype = None,
compression_type: str = None,
compression_options: str = None):
"""Concatenate 3D HDF5 Cholla datasets. i.e. take the single files generated per process and concatenate them into a
single, large file. All outputs from start_num to end_num will be concatenated.
Args:
start_num (int): The first output step to concatenate
end_num (int): The last output step to concatenate
num_processes (int): The number of processes that were used
input_dir (pathlib.Path, optional): The input directory. Defaults to pathlib.Path.cwd().
output_dir (pathlib.Path, optional): The output directory. Defaults to pathlib.Path.cwd().
skip_fields (list, optional): List of fields to skip concatenating. Defaults to [].
destination_dtype (np.dtype, optional): The data type of the output datasets. Accepts most numpy types. Defaults to the same as the input datasets.
compression_type (str, optional): What kind of compression to use on the output data. Defaults to None.
compression_options (str, optional): What compression settings to use if compressing. Defaults to None.
"""

# Error checking
assert start_num >= 0, 'start_num must be greater than or equal to 0'
assert end_num >= 0, 'end_num must be greater than or equal to 0'
assert start_num <= end_num, 'end_num should be greater than or equal to start_num'
assert num_processes > 1, 'num_processes must be greater than 1'

# loop over outputs
for n in range(start_num, end_num+1):
concat_3d_single(output_number=n,
num_processes=num_processes,
input_dir=input_dir,
output_dir=output_dir,
skip_fields=skip_fields,
destination_dtype=destination_dtype,
compression_type=compression_type,
compression_options=compression_options)
# ======================================================================================================================

# ======================================================================================================================
def concat_3d_single(output_number: int,
def concat_3d_output(source_directory: pathlib.Path,
output_directory: pathlib.Path,
num_processes: int,
input_dir: pathlib.Path = pathlib.Path.cwd(),
output_dir: pathlib.Path = pathlib.Path.cwd(),
output_number: int,
skip_fields: list = [],
destination_dtype: np.dtype = None,
compression_type: str = None,
Expand All @@ -120,13 +45,13 @@ def concat_3d_single(output_number: int,
assert num_processes > 1, 'num_processes must be greater than 1'
assert output_number >= 0, 'output_number must be greater than or equal to 0'

# open the output file for writing (don't overwrite if exists)
fileout = h5py.File(output_dir / f'{output_number}.h5', 'a')
# open the output file for writing (fail if it exists)
destination_file = h5py.File(output_directory / f'{output_number}.h5', 'w-')

# Setup the output file
with h5py.File(input_dir / f'{output_number}.h5.0', 'r') as source_file:
with h5py.File(source_directory / f'{output_number}.h5.0', 'r') as source_file:
# Copy header data
fileout = copy_header(source_file, fileout)
destination_file = copy_header(source_file, destination_file)

# Create the datasets in the output file
datasets_to_copy = list(source_file.keys())
Expand All @@ -137,7 +62,11 @@ def concat_3d_single(output_number: int,

data_shape = source_file.attrs['dims']

fileout.create_dataset(name=dataset,
if dataset == 'magnetic_x': data_shape[0] += 1
if dataset == 'magnetic_y': data_shape[1] += 1
if dataset == 'magnetic_z': data_shape[2] += 1

destination_file.create_dataset(name=dataset,
shape=data_shape,
dtype=dtype,
compression=compression_type,
Expand All @@ -146,20 +75,29 @@ def concat_3d_single(output_number: int,
# loop over files for a given output
for i in range(0, num_processes):
# open the input file for reading
filein = h5py.File(input_dir / f'{output_number}.h5.{i}', 'r')
# read in the header data from the input file
head = filein.attrs
source_file = h5py.File(source_directory / f'{output_number}.h5.{i}', 'r')

# write data from individual processor file to correct location in concatenated file
nx_local, ny_local, nz_local = filein.attrs['dims_local']
x_start, y_start, z_start = filein.attrs['offset']
# Compute the offset slicing
nx_local, ny_local, nz_local = source_file.attrs['dims_local']
x_start, y_start, z_start = source_file.attrs['offset']
x_end, y_end, z_end = x_start+nx_local, y_start+ny_local, z_start+nz_local

# write data from individual processor file to correct location in concatenated file
for dataset in datasets_to_copy:
fileout[dataset][x_start:x_start+nx_local, y_start:y_start+ny_local,z_start:z_start+nz_local] = filein[dataset]
magnetic_offset = [0,0,0]
if dataset == 'magnetic_x': magnetic_offset[0] = 1
if dataset == 'magnetic_y': magnetic_offset[1] = 1
if dataset == 'magnetic_z': magnetic_offset[2] = 1

destination_file[dataset][x_start:x_end+magnetic_offset[0],
y_start:y_end+magnetic_offset[1],
z_start:z_end+magnetic_offset[2]] = source_file[dataset]

filein.close()
# Now that the copy is done we close the source file
source_file.close()

fileout.close()
# Close destination file now that it is fully constructed
destination_file.close()
# ======================================================================================================================

# ==============================================================================
Expand All @@ -182,5 +120,95 @@ def copy_header(source_file: h5py.File, destination_file: h5py.File):
return destination_file
# ==============================================================================

# ==============================================================================
def common_cli() -> argparse.ArgumentParser:
"""This function provides the basis for the common CLI amongst the various concatenation scripts. It returns an
`argparse.ArgumentParser` object to which additional arguments can be passed before the final `.parse_args()` method
is used.
"""

# ============================================================================
# Function used to parse the `--concat-output` argument
def concat_output(raw_argument: str) -> list:
# Check if the string is empty
if len(raw_argument) < 1:
raise ValueError('The --concat-output argument must not be of length zero.')

# Strip unneeded characters
cleaned_argument = raw_argument.replace(' ', '')
cleaned_argument = cleaned_argument.replace('[', '')
cleaned_argument = cleaned_argument.replace(']', '')

# Check that it only has the allowed characters
allowed_charaters = set('0123456789,-')
if not set(cleaned_argument).issubset(allowed_charaters):
raise ValueError("Argument contains incorrect characters. Should only contain '0-9', ',', and '-'.")

# Split on commas
cleaned_argument = cleaned_argument.split(',')

# Generate the final list
iterable_argument = set()
for arg in cleaned_argument:
if '-' not in arg:
if int(arg) < 0:
raise ValueError()
iterable_argument.add(int(arg))
else:
start, end = arg.split('-')
start, end = int(start), int(end)
if end < start:
raise ValueError('The end of a range must be larger than the start of the range.')
if start < 0:
raise ValueError()
iterable_argument = iterable_argument.union(set(range(start, end+1)))

return iterable_argument
# ============================================================================

# ============================================================================
def positive_int(raw_argument: str) -> int:
arg = int(raw_argument)
if arg < 0:
raise ValueError('Argument must be 0 or greater.')

return arg
# ============================================================================

# Initialize the CLI
cli = argparse.ArgumentParser()

# Required Arguments
cli.add_argument('-s', '--source-directory', type=pathlib.Path, required=True, help='The path to the directory for the source HDF5 files.')
cli.add_argument('-o', '--output-directory', type=pathlib.Path, required=True, help='The path to the directory to write out the concatenated HDF5 files.')
cli.add_argument('-n', '--num-processes', type=positive_int, required=True, help='The number of processes that were used')
cli.add_argument('-c', '--concat-outputs', type=concat_output, required=True, help='Which outputs to concatenate. Can be a single number (e.g. 8), a range (e.g. 2-9), or a list (e.g. [1,2,3]). Ranges are inclusive')

# Optional Arguments
cli.add_argument('--skip-fields', type=list, default=[], help='List of fields to skip concatenating. Defaults to empty.')
cli.add_argument('--dtype', type=str, default=None, help='The data type of the output datasets. Accepts most numpy types. Defaults to the same as the input datasets.')
cli.add_argument('--compression-type', type=str, default=None, help='What kind of compression to use on the output data. Defaults to None.')
cli.add_argument('--compression-opts', type=str, default=None, help='What compression settings to use if compressing. Defaults to None.')

return cli
# ==============================================================================

if __name__ == '__main__':
main()
from timeit import default_timer
start = default_timer()

cli = common_cli()
args = cli.parse_args()

# Perform the concatenation
for output in args.concat_outputs:
concat_3d_output(source_directory=args.source_directory,
output_directory=args.output_directory,
num_processes=args.num_processes,
output_number=output,
skip_fields=args.skip_fields,
destination_dtype=args.dtype,
compression_type=args.compression_type,
compression_options=args.compression_opts)

print(f'\nTime to execute: {round(default_timer()-start,2)} seconds')
Loading

0 comments on commit 1fa2342

Please sign in to comment.