From d89149fd18ab79f43e4074c75ff64e88a6f98120 Mon Sep 17 00:00:00 2001 From: Bob Caddy Date: Mon, 13 Nov 2023 11:50:35 -0700 Subject: [PATCH] Move all concat common tools into their own file --- python_scripts/concat_2d_data.py | 19 +-- python_scripts/concat_3d_data.py | 184 +---------------------------- python_scripts/concat_internals.py | 178 ++++++++++++++++++++++++++++ 3 files changed, 194 insertions(+), 187 deletions(-) create mode 100644 python_scripts/concat_internals.py diff --git a/python_scripts/concat_2d_data.py b/python_scripts/concat_2d_data.py index 16f1668a0..5cf6fde55 100644 --- a/python_scripts/concat_2d_data.py +++ b/python_scripts/concat_2d_data.py @@ -15,11 +15,10 @@ """ import h5py -import argparse import pathlib import numpy as np -from concat_3d_data import copy_header, common_cli, destination_safe_open +import concat_internals # ============================================================================== def concat_2d_dataset(source_directory: pathlib.Path, @@ -34,7 +33,7 @@ def concat_2d_dataset(source_directory: pathlib.Path, destination_dtype: np.dtype = None, compression_type: str = None, compression_options: str = None, - chunking = None): + chunking = None) -> None: """Concatenate 2D HDF5 Cholla datasets. i.e. take the single files generated per process and concatenate them into a single, large file. This function concatenates a single output time and can be called multiple times, @@ -104,12 +103,12 @@ def concat_2d_dataset(source_directory: pathlib.Path, assert dataset_kind in ['slice', 'proj', 'rot_proj'], '`dataset_kind` can only be one of "slice", "proj", "rot_proj".' # Open destination file - destination_file = destination_safe_open(output_directory / f'{output_number}_{dataset_kind}.h5') + destination_file = concat_internals.destination_safe_open(output_directory / f'{output_number}_{dataset_kind}.h5') # Setup the destination file with h5py.File(source_directory / f'{output_number}_{dataset_kind}.h5.0', 'r') as source_file: # Copy over header - destination_file = copy_header(source_file, destination_file) + destination_file = concat_internals.copy_header(source_file, destination_file) # Get a list of all datasets in the source file datasets_to_copy = list(source_file.keys()) @@ -169,7 +168,7 @@ def concat_2d_dataset(source_directory: pathlib.Path, # ============================================================================== # ============================================================================== -def __get_2d_dataset_shape(source_file: h5py.File, dataset: str): +def __get_2d_dataset_shape(source_file: h5py.File, dataset: str) -> tuple: """Determine the shape of the full 2D dataset Args: @@ -200,7 +199,7 @@ def __get_2d_dataset_shape(source_file: h5py.File, dataset: str): # ============================================================================== # ============================================================================== -def __write_bounds_2d_dataset(source_file: h5py.File, dataset: str): +def __write_bounds_2d_dataset(source_file: h5py.File, dataset: str) -> tuple: """Determine the bounds of the concatenated file to write to Args: @@ -211,7 +210,9 @@ def __write_bounds_2d_dataset(source_file: h5py.File, dataset: str): ValueError: If the dataset name isn't a 2D dataset name Returns: - tuple: The write bounds for the concatenated file to be used like `output_file[dataset][return[0]:return[1], return[2]:return[3]] + tuple: The write bounds for the concatenated file to be used like + `output_file[dataset][return[0]:return[1], return[2]:return[3]]` followed by a bool to indicate if the file is + in the slice if concatenating a slice """ if 'xzr' in dataset: @@ -241,7 +242,7 @@ def __write_bounds_2d_dataset(source_file: h5py.File, dataset: str): from timeit import default_timer start = default_timer() - cli = common_cli() + cli = concat_internals.common_cli() cli.add_argument('-d', '--dataset-kind', type=str, required=True, help='What kind of 2D dataset to concatnate. Options are "slice", "proj", and "rot_proj"') cli.add_argument('--disable-xy', default=True, action='store_false', help='Disables concating the XY datasets.') cli.add_argument('--disable-yz', default=True, action='store_false', help='Disables concating the YZ datasets.') diff --git a/python_scripts/concat_3d_data.py b/python_scripts/concat_3d_data.py index 08cc1a50b..930c108e2 100644 --- a/python_scripts/concat_3d_data.py +++ b/python_scripts/concat_3d_data.py @@ -14,9 +14,10 @@ import h5py import numpy as np -import argparse import pathlib +import concat_internals + # ====================================================================================================================== def concat_3d_output(source_directory: pathlib.Path, output_directory: pathlib.Path, @@ -26,7 +27,7 @@ def concat_3d_output(source_directory: pathlib.Path, destination_dtype: np.dtype = None, compression_type: str = None, compression_options: str = None, - chunking = None): + chunking = None) -> None: """Concatenate a single 3D HDF5 Cholla dataset. i.e. take the single files generated per process and concatenate them into a single, large file. @@ -77,12 +78,12 @@ def concat_3d_output(source_directory: pathlib.Path, assert output_number >= 0, 'output_number must be greater than or equal to 0' # Open the output file for writing - destination_file = destination_safe_open(output_directory / f'{output_number}.h5') + destination_file = concat_internals.destination_safe_open(output_directory / f'{output_number}.h5') # Setup the output file with h5py.File(source_directory / f'{output_number}.h5.0', 'r') as source_file: # Copy header data - destination_file = copy_header(source_file, destination_file) + destination_file = concat_internals.copy_header(source_file, destination_file) # Create the datasets in the output file datasets_to_copy = list(source_file.keys()) @@ -132,184 +133,11 @@ def concat_3d_output(source_directory: pathlib.Path, destination_file.close() # ============================================================================== -# ============================================================================== -def destination_safe_open(filename: pathlib.Path) -> h5py.File: - """Opens a HDF5 file safely and provides useful error messages for some common failure modes - - Parameters - ---------- - filename : pathlib.Path - - The full path and name of the file to open : - - filename: pathlib.Path : - - - Returns - ------- - h5py.File - - The opened HDF5 file object - - - - """ - - try: - destination_file = h5py.File(filename, 'w-') - except FileExistsError: - # It might be better for this to simply print the error message and return - # rather than exiting. That way if a single call fails in a parallel - # environment it doesn't take down the entire job - raise FileExistsError(f'File "{filename}" already exists and will not be overwritten, skipping.') - - return destination_file -# ============================================================================== - -# ============================================================================== -def copy_header(source_file: h5py.File, destination_file: h5py.File): - """Copy the attributes of one HDF5 file to another, skipping all fields that are specific to an individual rank - - Parameters - ---------- - source_file : h5py.File - The source file - destination_file : h5py.File - The destination file - source_file: h5py.File : - - destination_file: h5py.File : - - - Returns - ------- - h5py.File - The destination file with the new header attributes - - """ - fields_to_skip = ['dims_local', 'offset'] - - for attr_key in source_file.attrs.keys(): - if attr_key not in fields_to_skip: - destination_file.attrs[attr_key] = source_file.attrs[attr_key] - - return destination_file -# ============================================================================== - -# ============================================================================== -def common_cli() -> argparse.ArgumentParser: - """This function provides the basis for the common CLI amongst the various concatenation scripts. It returns an - `argparse.ArgumentParser` object to which additional arguments can be passed before the final `.parse_args()` method - is used. - - Parameters - ---------- - - Returns - ------- - - """ - - # ============================================================================ - def concat_output(raw_argument: str) -> list: - """Function used to parse the `--concat-output` argument - """ - # Check if the string is empty - if len(raw_argument) < 1: - raise ValueError('The --concat-output argument must not be of length zero.') - - # Strip unneeded characters - cleaned_argument = raw_argument.replace(' ', '') - cleaned_argument = cleaned_argument.replace('[', '') - cleaned_argument = cleaned_argument.replace(']', '') - - # Check that it only has the allowed characters - allowed_charaters = set('0123456789,-') - if not set(cleaned_argument).issubset(allowed_charaters): - raise ValueError("Argument contains incorrect characters. Should only contain '0-9', ',', and '-'.") - - # Split on commas - cleaned_argument = cleaned_argument.split(',') - - # Generate the final list - iterable_argument = set() - for arg in cleaned_argument: - if '-' not in arg: - if int(arg) < 0: - raise ValueError() - iterable_argument.add(int(arg)) - else: - start, end = arg.split('-') - start, end = int(start), int(end) - if end < start: - raise ValueError('The end of a range must be larger than the start of the range.') - if start < 0: - raise ValueError() - iterable_argument = iterable_argument.union(set(range(start, end+1))) - - return iterable_argument - # ============================================================================ - - # ============================================================================ - def positive_int(raw_argument: str) -> int: - arg = int(raw_argument) - if arg < 0: - raise ValueError('Argument must be 0 or greater.') - - return arg - # ============================================================================ - - # ============================================================================ - def skip_fields(raw_argument: str) -> list: - # Strip unneeded characters - cleaned_argument = raw_argument.replace(' ', '') - cleaned_argument = cleaned_argument.replace('[', '') - cleaned_argument = cleaned_argument.replace(']', '') - cleaned_argument = cleaned_argument.split(',') - - return cleaned_argument - # ============================================================================ - - # ============================================================================ - def chunk_arg(raw_argument: str): - # Strip unneeded characters - cleaned_argument = raw_argument.replace(' ', '') - cleaned_argument = cleaned_argument.replace('(', '') - cleaned_argument = cleaned_argument.replace(')', '') - - # Check that it only has the allowed characters - allowed_charaters = set('0123456789,') - if not set(cleaned_argument).issubset(allowed_charaters): - raise ValueError("Argument contains incorrect characters. Should only contain '0-9', ',', and '-'.") - - # Convert to a tuple and return - return tuple([int(i) for i in cleaned_argument.split(',')]) - # ============================================================================ - - # Initialize the CLI - cli = argparse.ArgumentParser() - - # Required Arguments - cli.add_argument('-s', '--source-directory', type=pathlib.Path, required=True, help='The path to the directory for the source HDF5 files.') - cli.add_argument('-o', '--output-directory', type=pathlib.Path, required=True, help='The path to the directory to write out the concatenated HDF5 files.') - cli.add_argument('-n', '--num-processes', type=positive_int, required=True, help='The number of processes that were used') - cli.add_argument('-c', '--concat-outputs', type=concat_output, required=True, help='Which outputs to concatenate. Can be a single number (e.g. 8), a range (e.g. 2-9), or a list (e.g. [1,2,3]). Ranges are inclusive') - - # Optional Arguments - cli.add_argument('--skip-fields', type=skip_fields, default=[], help='List of fields to skip concatenating. Defaults to empty.') - cli.add_argument('--dtype', type=str, default=None, help='The data type of the output datasets. Accepts most numpy types. Defaults to the same as the input datasets.') - cli.add_argument('--compression-type', type=str, default=None, help='What kind of compression to use on the output data. Defaults to None.') - cli.add_argument('--compression-opts', type=str, default=None, help='What compression settings to use if compressing. Defaults to None.') - cli.add_argument('--chunking', type=chunk_arg, default=None, nargs='?', const=True, help='Enable chunking of the output file. Default is `False`. If set without an argument then the chunk size will be automatically chosen or a tuple can be passed to indicate the chunk size desired.') - - return cli -# ============================================================================== - if __name__ == '__main__': from timeit import default_timer start = default_timer() - cli = common_cli() + cli = concat_internals.common_cli() args = cli.parse_args() # Perform the concatenation diff --git a/python_scripts/concat_internals.py b/python_scripts/concat_internals.py new file mode 100644 index 000000000..29bf49829 --- /dev/null +++ b/python_scripts/concat_internals.py @@ -0,0 +1,178 @@ +#!/usr/bin/env python3 +""" +Contains all the common tools for the various concatnation functions/scipts +""" + +import h5py +import argparse +import pathlib + +# ============================================================================== +def destination_safe_open(filename: pathlib.Path) -> h5py.File: + """Opens a HDF5 file safely and provides useful error messages for some common failure modes + + Parameters + ---------- + filename : pathlib.Path + + The full path and name of the file to open : + + filename: pathlib.Path : + + + Returns + ------- + h5py.File + + The opened HDF5 file object + """ + + try: + destination_file = h5py.File(filename, 'w-') + except FileExistsError: + # It might be better for this to simply print the error message and return + # rather than exiting. That way if a single call fails in a parallel + # environment it doesn't take down the entire job + raise FileExistsError(f'File "{filename}" already exists and will not be overwritten, skipping.') + + return destination_file +# ============================================================================== + +# ============================================================================== +def copy_header(source_file: h5py.File, destination_file: h5py.File) -> h5py.File: + """Copy the attributes of one HDF5 file to another, skipping all fields that are specific to an individual rank + + Parameters + ---------- + source_file : h5py.File + The source file + destination_file : h5py.File + The destination file + source_file: h5py.File : + + destination_file: h5py.File : + + + Returns + ------- + h5py.File + The destination file with the new header attributes + """ + fields_to_skip = ['dims_local', 'offset'] + + for attr_key in source_file.attrs.keys(): + if attr_key not in fields_to_skip: + destination_file.attrs[attr_key] = source_file.attrs[attr_key] + + return destination_file +# ============================================================================== + +# ============================================================================== +def common_cli() -> argparse.ArgumentParser: + """This function provides the basis for the common CLI amongst the various concatenation scripts. It returns an + `argparse.ArgumentParser` object to which additional arguments can be passed before the final `.parse_args()` method + is used. + + Parameters + ---------- + + Returns + ------- + argparse.ArgumentParser + The common components of the CLI for the concatenation scripts + """ + + # ============================================================================ + def concat_output(raw_argument: str) -> list: + """Function used to parse the `--concat-output` argument + """ + # Check if the string is empty + if len(raw_argument) < 1: + raise ValueError('The --concat-output argument must not be of length zero.') + + # Strip unneeded characters + cleaned_argument = raw_argument.replace(' ', '') + cleaned_argument = cleaned_argument.replace('[', '') + cleaned_argument = cleaned_argument.replace(']', '') + + # Check that it only has the allowed characters + allowed_charaters = set('0123456789,-') + if not set(cleaned_argument).issubset(allowed_charaters): + raise ValueError("Argument contains incorrect characters. Should only contain '0-9', ',', and '-'.") + + # Split on commas + cleaned_argument = cleaned_argument.split(',') + + # Generate the final list + iterable_argument = set() + for arg in cleaned_argument: + if '-' not in arg: + if int(arg) < 0: + raise ValueError() + iterable_argument.add(int(arg)) + else: + start, end = arg.split('-') + start, end = int(start), int(end) + if end < start: + raise ValueError('The end of a range must be larger than the start of the range.') + if start < 0: + raise ValueError() + iterable_argument = iterable_argument.union(set(range(start, end+1))) + + return iterable_argument + # ============================================================================ + + # ============================================================================ + def positive_int(raw_argument: str) -> int: + arg = int(raw_argument) + if arg < 0: + raise ValueError('Argument must be 0 or greater.') + + return arg + # ============================================================================ + + # ============================================================================ + def skip_fields(raw_argument: str) -> list: + # Strip unneeded characters + cleaned_argument = raw_argument.replace(' ', '') + cleaned_argument = cleaned_argument.replace('[', '') + cleaned_argument = cleaned_argument.replace(']', '') + cleaned_argument = cleaned_argument.split(',') + + return cleaned_argument + # ============================================================================ + + # ============================================================================ + def chunk_arg(raw_argument: str) -> tuple: + # Strip unneeded characters + cleaned_argument = raw_argument.replace(' ', '') + cleaned_argument = cleaned_argument.replace('(', '') + cleaned_argument = cleaned_argument.replace(')', '') + + # Check that it only has the allowed characters + allowed_charaters = set('0123456789,') + if not set(cleaned_argument).issubset(allowed_charaters): + raise ValueError("Argument contains incorrect characters. Should only contain '0-9', ',', and '-'.") + + # Convert to a tuple and return + return tuple([int(i) for i in cleaned_argument.split(',')]) + # ============================================================================ + + # Initialize the CLI + cli = argparse.ArgumentParser() + + # Required Arguments + cli.add_argument('-s', '--source-directory', type=pathlib.Path, required=True, help='The path to the directory for the source HDF5 files.') + cli.add_argument('-o', '--output-directory', type=pathlib.Path, required=True, help='The path to the directory to write out the concatenated HDF5 files.') + cli.add_argument('-n', '--num-processes', type=positive_int, required=True, help='The number of processes that were used') + cli.add_argument('-c', '--concat-outputs', type=concat_output, required=True, help='Which outputs to concatenate. Can be a single number (e.g. 8), a range (e.g. 2-9), or a list (e.g. [1,2,3]). Ranges are inclusive') + + # Optional Arguments + cli.add_argument('--skip-fields', type=skip_fields, default=[], help='List of fields to skip concatenating. Defaults to empty.') + cli.add_argument('--dtype', type=str, default=None, help='The data type of the output datasets. Accepts most numpy types. Defaults to the same as the input datasets.') + cli.add_argument('--compression-type', type=str, default=None, help='What kind of compression to use on the output data. Defaults to None.') + cli.add_argument('--compression-opts', type=str, default=None, help='What compression settings to use if compressing. Defaults to None.') + cli.add_argument('--chunking', type=chunk_arg, default=None, nargs='?', const=True, help='Enable chunking of the output file. Default is `False`. If set without an argument then the chunk size will be automatically chosen or a tuple can be passed to indicate the chunk size desired.') + + return cli +# ==============================================================================