Skip to content

Commit

Permalink
Move all concat common tools into their own file
Browse files Browse the repository at this point in the history
  • Loading branch information
bcaddy committed Nov 13, 2023
1 parent 8b85da6 commit d89149f
Show file tree
Hide file tree
Showing 3 changed files with 194 additions and 187 deletions.
19 changes: 10 additions & 9 deletions python_scripts/concat_2d_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@
"""

import h5py
import argparse
import pathlib
import numpy as np

from concat_3d_data import copy_header, common_cli, destination_safe_open
import concat_internals

# ==============================================================================
def concat_2d_dataset(source_directory: pathlib.Path,
Expand All @@ -34,7 +33,7 @@ def concat_2d_dataset(source_directory: pathlib.Path,
destination_dtype: np.dtype = None,
compression_type: str = None,
compression_options: str = None,
chunking = None):
chunking = None) -> None:
"""Concatenate 2D HDF5 Cholla datasets. i.e. take the single files
generated per process and concatenate them into a single, large file. This
function concatenates a single output time and can be called multiple times,
Expand Down Expand Up @@ -104,12 +103,12 @@ def concat_2d_dataset(source_directory: pathlib.Path,
assert dataset_kind in ['slice', 'proj', 'rot_proj'], '`dataset_kind` can only be one of "slice", "proj", "rot_proj".'

# Open destination file
destination_file = destination_safe_open(output_directory / f'{output_number}_{dataset_kind}.h5')
destination_file = concat_internals.destination_safe_open(output_directory / f'{output_number}_{dataset_kind}.h5')

# Setup the destination file
with h5py.File(source_directory / f'{output_number}_{dataset_kind}.h5.0', 'r') as source_file:
# Copy over header
destination_file = copy_header(source_file, destination_file)
destination_file = concat_internals.copy_header(source_file, destination_file)

# Get a list of all datasets in the source file
datasets_to_copy = list(source_file.keys())
Expand Down Expand Up @@ -169,7 +168,7 @@ def concat_2d_dataset(source_directory: pathlib.Path,
# ==============================================================================

# ==============================================================================
def __get_2d_dataset_shape(source_file: h5py.File, dataset: str):
def __get_2d_dataset_shape(source_file: h5py.File, dataset: str) -> tuple:
"""Determine the shape of the full 2D dataset
Args:
Expand Down Expand Up @@ -200,7 +199,7 @@ def __get_2d_dataset_shape(source_file: h5py.File, dataset: str):
# ==============================================================================

# ==============================================================================
def __write_bounds_2d_dataset(source_file: h5py.File, dataset: str):
def __write_bounds_2d_dataset(source_file: h5py.File, dataset: str) -> tuple:
"""Determine the bounds of the concatenated file to write to
Args:
Expand All @@ -211,7 +210,9 @@ def __write_bounds_2d_dataset(source_file: h5py.File, dataset: str):
ValueError: If the dataset name isn't a 2D dataset name
Returns:
tuple: The write bounds for the concatenated file to be used like `output_file[dataset][return[0]:return[1], return[2]:return[3]]
tuple: The write bounds for the concatenated file to be used like
`output_file[dataset][return[0]:return[1], return[2]:return[3]]` followed by a bool to indicate if the file is
in the slice if concatenating a slice
"""

if 'xzr' in dataset:
Expand Down Expand Up @@ -241,7 +242,7 @@ def __write_bounds_2d_dataset(source_file: h5py.File, dataset: str):
from timeit import default_timer
start = default_timer()

cli = common_cli()
cli = concat_internals.common_cli()
cli.add_argument('-d', '--dataset-kind', type=str, required=True, help='What kind of 2D dataset to concatnate. Options are "slice", "proj", and "rot_proj"')
cli.add_argument('--disable-xy', default=True, action='store_false', help='Disables concating the XY datasets.')
cli.add_argument('--disable-yz', default=True, action='store_false', help='Disables concating the YZ datasets.')
Expand Down
184 changes: 6 additions & 178 deletions python_scripts/concat_3d_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@

import h5py
import numpy as np
import argparse
import pathlib

import concat_internals

# ======================================================================================================================
def concat_3d_output(source_directory: pathlib.Path,
output_directory: pathlib.Path,
Expand All @@ -26,7 +27,7 @@ def concat_3d_output(source_directory: pathlib.Path,
destination_dtype: np.dtype = None,
compression_type: str = None,
compression_options: str = None,
chunking = None):
chunking = None) -> None:
"""Concatenate a single 3D HDF5 Cholla dataset. i.e. take the single files generated per process and concatenate them into a
single, large file.
Expand Down Expand Up @@ -77,12 +78,12 @@ def concat_3d_output(source_directory: pathlib.Path,
assert output_number >= 0, 'output_number must be greater than or equal to 0'

# Open the output file for writing
destination_file = destination_safe_open(output_directory / f'{output_number}.h5')
destination_file = concat_internals.destination_safe_open(output_directory / f'{output_number}.h5')

# Setup the output file
with h5py.File(source_directory / f'{output_number}.h5.0', 'r') as source_file:
# Copy header data
destination_file = copy_header(source_file, destination_file)
destination_file = concat_internals.copy_header(source_file, destination_file)

# Create the datasets in the output file
datasets_to_copy = list(source_file.keys())
Expand Down Expand Up @@ -132,184 +133,11 @@ def concat_3d_output(source_directory: pathlib.Path,
destination_file.close()
# ==============================================================================

# ==============================================================================
def destination_safe_open(filename: pathlib.Path) -> h5py.File:
"""Opens a HDF5 file safely and provides useful error messages for some common failure modes
Parameters
----------
filename : pathlib.Path
The full path and name of the file to open :
filename: pathlib.Path :
Returns
-------
h5py.File
The opened HDF5 file object
"""

try:
destination_file = h5py.File(filename, 'w-')
except FileExistsError:
# It might be better for this to simply print the error message and return
# rather than exiting. That way if a single call fails in a parallel
# environment it doesn't take down the entire job
raise FileExistsError(f'File "{filename}" already exists and will not be overwritten, skipping.')

return destination_file
# ==============================================================================

# ==============================================================================
def copy_header(source_file: h5py.File, destination_file: h5py.File):
"""Copy the attributes of one HDF5 file to another, skipping all fields that are specific to an individual rank
Parameters
----------
source_file : h5py.File
The source file
destination_file : h5py.File
The destination file
source_file: h5py.File :
destination_file: h5py.File :
Returns
-------
h5py.File
The destination file with the new header attributes
"""
fields_to_skip = ['dims_local', 'offset']

for attr_key in source_file.attrs.keys():
if attr_key not in fields_to_skip:
destination_file.attrs[attr_key] = source_file.attrs[attr_key]

return destination_file
# ==============================================================================

# ==============================================================================
def common_cli() -> argparse.ArgumentParser:
"""This function provides the basis for the common CLI amongst the various concatenation scripts. It returns an
`argparse.ArgumentParser` object to which additional arguments can be passed before the final `.parse_args()` method
is used.
Parameters
----------
Returns
-------
"""

# ============================================================================
def concat_output(raw_argument: str) -> list:
"""Function used to parse the `--concat-output` argument
"""
# Check if the string is empty
if len(raw_argument) < 1:
raise ValueError('The --concat-output argument must not be of length zero.')

# Strip unneeded characters
cleaned_argument = raw_argument.replace(' ', '')
cleaned_argument = cleaned_argument.replace('[', '')
cleaned_argument = cleaned_argument.replace(']', '')

# Check that it only has the allowed characters
allowed_charaters = set('0123456789,-')
if not set(cleaned_argument).issubset(allowed_charaters):
raise ValueError("Argument contains incorrect characters. Should only contain '0-9', ',', and '-'.")

# Split on commas
cleaned_argument = cleaned_argument.split(',')

# Generate the final list
iterable_argument = set()
for arg in cleaned_argument:
if '-' not in arg:
if int(arg) < 0:
raise ValueError()
iterable_argument.add(int(arg))
else:
start, end = arg.split('-')
start, end = int(start), int(end)
if end < start:
raise ValueError('The end of a range must be larger than the start of the range.')
if start < 0:
raise ValueError()
iterable_argument = iterable_argument.union(set(range(start, end+1)))

return iterable_argument
# ============================================================================

# ============================================================================
def positive_int(raw_argument: str) -> int:
arg = int(raw_argument)
if arg < 0:
raise ValueError('Argument must be 0 or greater.')

return arg
# ============================================================================

# ============================================================================
def skip_fields(raw_argument: str) -> list:
# Strip unneeded characters
cleaned_argument = raw_argument.replace(' ', '')
cleaned_argument = cleaned_argument.replace('[', '')
cleaned_argument = cleaned_argument.replace(']', '')
cleaned_argument = cleaned_argument.split(',')

return cleaned_argument
# ============================================================================

# ============================================================================
def chunk_arg(raw_argument: str):
# Strip unneeded characters
cleaned_argument = raw_argument.replace(' ', '')
cleaned_argument = cleaned_argument.replace('(', '')
cleaned_argument = cleaned_argument.replace(')', '')

# Check that it only has the allowed characters
allowed_charaters = set('0123456789,')
if not set(cleaned_argument).issubset(allowed_charaters):
raise ValueError("Argument contains incorrect characters. Should only contain '0-9', ',', and '-'.")

# Convert to a tuple and return
return tuple([int(i) for i in cleaned_argument.split(',')])
# ============================================================================

# Initialize the CLI
cli = argparse.ArgumentParser()

# Required Arguments
cli.add_argument('-s', '--source-directory', type=pathlib.Path, required=True, help='The path to the directory for the source HDF5 files.')
cli.add_argument('-o', '--output-directory', type=pathlib.Path, required=True, help='The path to the directory to write out the concatenated HDF5 files.')
cli.add_argument('-n', '--num-processes', type=positive_int, required=True, help='The number of processes that were used')
cli.add_argument('-c', '--concat-outputs', type=concat_output, required=True, help='Which outputs to concatenate. Can be a single number (e.g. 8), a range (e.g. 2-9), or a list (e.g. [1,2,3]). Ranges are inclusive')

# Optional Arguments
cli.add_argument('--skip-fields', type=skip_fields, default=[], help='List of fields to skip concatenating. Defaults to empty.')
cli.add_argument('--dtype', type=str, default=None, help='The data type of the output datasets. Accepts most numpy types. Defaults to the same as the input datasets.')
cli.add_argument('--compression-type', type=str, default=None, help='What kind of compression to use on the output data. Defaults to None.')
cli.add_argument('--compression-opts', type=str, default=None, help='What compression settings to use if compressing. Defaults to None.')
cli.add_argument('--chunking', type=chunk_arg, default=None, nargs='?', const=True, help='Enable chunking of the output file. Default is `False`. If set without an argument then the chunk size will be automatically chosen or a tuple can be passed to indicate the chunk size desired.')

return cli
# ==============================================================================

if __name__ == '__main__':
from timeit import default_timer
start = default_timer()

cli = common_cli()
cli = concat_internals.common_cli()
args = cli.parse_args()

# Perform the concatenation
Expand Down
Loading

0 comments on commit d89149f

Please sign in to comment.