Skip to content

Commit

Permalink
Add test cases for solution downsampling
Browse files Browse the repository at this point in the history
  • Loading branch information
aymkhalil committed Feb 23, 2016
1 parent 00ae4a3 commit 3aba659
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 2 deletions.
1 change: 1 addition & 0 deletions src/petclaw/solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def _init_ds_solution(self):
for i in range(self.domain.num_dim)], proc_sizes=self.state.q_da.getProcSizes())
ds_state = self._init_ds_state(self.state)
self._ds_solution = pyclaw.Solution(ds_state, ds_domain)
self._ds_solution.t = self.t

def _init_ds_state(self, state):
"""
Expand Down
1 change: 1 addition & 0 deletions src/pyclaw/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def __init__(self):
self.F_file_name = 'F'
r"""(string) - Name of text file containing functionals"""
self.downsampling_factors = None
r"""(tuple) - A tuple of factors in each grid dimension that will be used in downsampling the solution by local averaging"""

# ========== Access methods ===============================================
def __str__(self):
Expand Down
1 change: 1 addition & 0 deletions src/pyclaw/solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,7 @@ def _init_ds_solution(self):
domain = Domain(self.domain.grid.lower,self.domain.grid.upper,self.domain.grid.num_cells/np.array(self.downsampling_factors))
state = State(domain,self.state.num_eqn,self.state.num_aux)
self._ds_solution = Solution(state, domain)
self._ds_solution.t = self.t

def _init_ds_state(self, state):
"""
Expand Down
40 changes: 38 additions & 2 deletions src/pyclaw/tests/test_io.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import numpy as np
from clawpack.pyclaw import Solution
from clawpack.pyclaw.util import check_solutions_are_same
from clawpack.pyclaw.util import check_solutions_are_same, check_solution_ds_is_downsampled_from_a

class IOTest():
@property
Expand Down Expand Up @@ -33,6 +33,14 @@ def test_io_from_hdf5_with_aux(self):
regression_dir = os.path.join(self.test_data_dir,'./advection_2d_with_aux')
self.read_write_and_compare(self.file_formats,regression_dir,'hdf5',0,aux=True)

def test_ds_from_hdf5(self):
regression_dir = os.path.join(self.test_data_dir,'./Sedov_regression_hdf')
self.read_downsample_and_compare(self.file_formats,regression_dir,'hdf5',1,(2,2,2))

def test_ds_from_hdf5_with_aux(self):
regression_dir = os.path.join(self.test_data_dir,'./advection_2d_with_aux')
self.read_downsample_and_compare(self.file_formats,regression_dir,'hdf5',0,(2,2),aux=True)

def read_write_and_compare(self, file_formats,regression_dir,regression_format,frame_num,aux=False):
r"""Test IO file formats:
- Reading in an HDF file
Expand All @@ -59,4 +67,32 @@ def read_write_and_compare(self, file_formats,regression_dir,regression_format,f
# Compare solutions
# Probably better to do this by defining __eq__ for each class
for fmt, sol in s.iteritems():
check_solutions_are_same(sol,ref_sol)
check_solutions_are_same(sol,ref_sol)

def read_downsample_and_compare(self, file_formats,regression_dir,regression_format,frame_num,downsampling_factors,aux=False):
r"""Test downsampling:
- Read in a solution from file
- Downsample the solution
- Checking that q & aux arrays are correctly downsampled
"""
a_sol = self.solution
a_sol.read(frame_num,path=regression_dir,file_format=regression_format,read_aux=aux)
if aux:
assert (a_sol.state.aux is not None)

# Write solution file in each format
io_test_dir = os.path.join(self.this_dir,'./io_test')
for fmt in file_formats:
a_sol.downsampling_factors = downsampling_factors
a_sol.downsample(aux, False).write(frame_num,path=io_test_dir,file_format=fmt,write_aux=aux)

# Read solutions back in
s = {}
for fmt in file_formats:
s[fmt] = self.solution
s[fmt].read(frame_num,path=io_test_dir,file_format=fmt,write_aux=aux)

# Compare solutions
# Probably better to do this by defining __eq__ for each class
for fmt, ds_sol in s.iteritems():
check_solution_ds_is_downsampled_from_a(a_sol, ds_sol)
22 changes: 22 additions & 0 deletions src/pyclaw/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,28 @@ def check_solutions_are_same(sol_a,sol_b):
for attr in ['units','on_lower_boundary','on_upper_boundary']:
if hasattr(ref_dim,attr):
assert getattr(dim,attr) == getattr(ref_dim,attr)

def check_ds_sol_is_downsampled_from_a_sol(sol_ds, sol_a):
from skimage.transform import downscale_local_mean

assert len(sol_a.states) == len(sol_ds.states)
assert sol_a.t == sol_ds.t
for state in sol_a.states:
for ds_state in sol_ds.states:
if ds_state.patch.patch_index == state.patch.patch_index:
break

# Required state attributes
assert np.linalg.norm(downscale_local_mean(state.q, (1,) + sol_a.downsampling_factors) - ds_state.q) < 1.e-6 # Not sure why this can be so large
if state.aux is not None:
assert np.linalg.norm(downscale_local_mean(state.aux, (1,) + sol_a.downsampling_factors) - ds_state.aux) < 1.e-16
for attr in ['t', 'num_eqn', 'num_aux']:
assert getattr(state,attr) == getattr(ds_state,attr)
# Optional state attributes
for attr in ['patch_index', 'level']:
if hasattr(ds_state,attr):
assert getattr(state,attr) == getattr(ds_state,attr)

# ============================================================================
# F2PY Utility Functions
# ============================================================================
Expand Down

0 comments on commit 3aba659

Please sign in to comment.